| | import future |
| | import builtins |
| | import past |
| | import six |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.optim as optim |
| | import torch.autograd |
| | import components as comp |
| | from torch.distributions import multinomial, categorical |
| |
|
| | import math |
| | import numpy as np |
| |
|
| | try: |
| | from . import helpers as h |
| | from . import ai |
| | from . import scheduling as S |
| | except: |
| | import helpers as h |
| | import ai |
| | import scheduling as S |
| |
|
| |
|
| |
|
| | class WrapDom(object): |
| | def __init__(self, a): |
| | self.a = eval(a) if type(a) is str else a |
| |
|
| | def box(self, *args, **kargs): |
| | return self.Domain(self.a.box(*args, **kargs)) |
| |
|
| | def boxBetween(self, *args, **kargs): |
| | return self.Domain(self.a.boxBetween(*args, **kargs)) |
| |
|
| | def line(self, *args, **kargs): |
| | return self.Domain(self.a.line(*args, **kargs)) |
| |
|
| | class DList(object): |
| | Domain = ai.ListDomain |
| | class MLoss(): |
| | def __init__(self, aw): |
| | self.aw = aw |
| | def loss(self, dom, *args, lr = 1, **kargs): |
| | if self.aw <= 0.0: |
| | return 0 |
| | return self.aw * dom.loss(*args, lr = lr * self.aw, **kargs) |
| |
|
| | def __init__(self, *al): |
| | if len(al) == 0: |
| | al = [("Point()", 1.0), ("Box()", 0.1)] |
| |
|
| | self.al = [(eval(a) if type(a) is str else a, S.Const.initConst(aw)) for a,aw in al] |
| |
|
| | def getDiv(self, **kargs): |
| | return 1.0 / sum(aw.getVal(**kargs) for _,aw in self.al) |
| |
|
| | def box(self, *args, **kargs): |
| | m = self.getDiv(**kargs) |
| | return self.Domain(ai.TaggedDomain(a.box(*args, **kargs), DList.MLoss(aw.getVal(**kargs) * m)) for a,aw in self.al) |
| |
|
| | def boxBetween(self, *args, **kargs): |
| | |
| | m = self.getDiv(**kargs) |
| | return self.Domain(ai.TaggedDomain(a.boxBetween(*args, **kargs), DList.MLoss(aw.getVal(**kargs) * m)) for a,aw in self.al) |
| |
|
| | def line(self, *args, **kargs): |
| | m = self.getDiv(**kargs) |
| | return self.Domain(ai.TaggedDomain(a.line(*args, **kargs), DList.MLoss(aw.getVal(**kargs) * m)) for a,aw in self.al) |
| | |
| | def __str__(self): |
| | return "DList(%s)" % h.sumStr("("+str(a)+","+str(w)+")" for a,w in self.al) |
| |
|
| | class Mix(DList): |
| | def __init__(self, a="Point()", b="Box()", aw = 1.0, bw = 0.1): |
| | super(Mix, self).__init__((a,aw), (b,bw)) |
| |
|
| | class LinMix(DList): |
| | def __init__(self, a="Point()", b="Box()", bw = 0.1): |
| | super(LinMix, self).__init__((a,S.Complement(bw)), (b,bw)) |
| |
|
| | class DProb(object): |
| | def __init__(self, *doms): |
| | if len(doms) == 0: |
| | doms = [("Point()", 0.8), ("Box()", 0.2)] |
| | div = 1.0 / sum(float(aw) for _,aw in doms) |
| | self.domains = [eval(a) if type(a) is str else a for a,_ in doms] |
| | self.probs = [ div * float(aw) for _,aw in doms] |
| |
|
| | def chooseDom(self): |
| | return self.domains[np.random.choice(len(self.domains), p = self.probs)] if len(self.domains) > 1 else self.domains[0] |
| |
|
| | def box(self, *args, **kargs): |
| | domain = self.chooseDom() |
| | return domain.box(*args, **kargs) |
| |
|
| | def line(self, *args, **kargs): |
| | domain = self.chooseDom() |
| | return domain.line(*args, **kargs) |
| |
|
| | def __str__(self): |
| | return "DProb(%s)" % h.sumStr("("+str(a)+","+str(w)+")" for a,w in zip(self.domains, self.probs)) |
| |
|
| | class Coin(DProb): |
| | def __init__(self, a="Point()", b="Box()", ap = 0.8, bp = 0.2): |
| | super(Coin, self).__init__((a,ap), (b,bp)) |
| |
|
| | class Point(object): |
| | Domain = h.dten |
| | def __init__(self, **kargs): |
| | pass |
| |
|
| | def box(self, original, *args, **kargs): |
| | return original |
| |
|
| | def line(self, original, other, *args, **kargs): |
| | return (original + other) / 2 |
| |
|
| | def boxBetween(self, o1, o2, *args, **kargs): |
| | return (o1 + o2) / 2 |
| |
|
| | def __str__(self): |
| | return "Point()" |
| |
|
| | class PointA(Point): |
| | def boxBetween(self, o1, o2, *args, **kargs): |
| | return o1 |
| |
|
| | def __str__(self): |
| | return "PointA()" |
| |
|
| | class PointB(Point): |
| | def boxBetween(self, o1, o2, *args, **kargs): |
| | return o2 |
| |
|
| | def __str__(self): |
| | return "PointB()" |
| |
|
| |
|
| | class NormalPoint(Point): |
| | def __init__(self, w = None, **kargs): |
| | self.epsilon = w |
| | |
| | def box(self, original, w, *args, **kargs): |
| | """ original = mu = mean, epsilon = variance""" |
| | if not self.epsilon is None: |
| | w = self.epsilon |
| |
|
| | inter = torch.randn_like(original, device = h.device) * w |
| | return original + inter |
| |
|
| | def __str__(self): |
| | return "NormalPoint(%s)" % ("" if self.epsilon is None else str(self.epsilon)) |
| |
|
| |
|
| |
|
| | class MI_FGSM(Point): |
| |
|
| | def __init__(self, w = None, r = 20.0, k = 100, mu = 0.8, should_end = True, restart = None, searchable=False,**kargs): |
| | self.epsilon = S.Const.initConst(w) |
| | self.k = k |
| | self.mu = mu |
| | self.r = float(r) |
| | self.should_end = should_end |
| | self.restart = restart |
| | self.searchable = searchable |
| |
|
| | def box(self, original, model, target = None, untargeted = False, **kargs): |
| | if target is None: |
| | untargeted = True |
| | with torch.no_grad(): |
| | target = model(original).max(1)[1] |
| | return self.attack(model, original, untargeted, target, **kargs) |
| |
|
| | def boxBetween(self, o1, o2, model, target = None, *args, **kargs): |
| | return self.attack(model, (o1 - o2).abs() / 2, (o1 + o2) / 2, target, **kargs) |
| |
|
| |
|
| | def attack(self, model, xo, untargeted, target, w, loss_function=ai.stdLoss, **kargs): |
| | w = self.epsilon.getVal(c = w, **kargs) |
| |
|
| | x = nn.Parameter(xo.clone(), requires_grad=True) |
| | gradorg = h.zeros(x.shape) |
| | is_eq = 1 |
| |
|
| | w = h.ones(x.shape) * w |
| | for i in range(self.k): |
| | if self.restart is not None and i % int(self.k / self.restart) == 0: |
| | x = is_eq * (torch.rand_like(xo) * w + xo) + (1 - is_eq) * x |
| | x = nn.Parameter(x, requires_grad = True) |
| |
|
| | model.optimizer.zero_grad() |
| |
|
| | out = model(x).vanillaTensorPart() |
| | loss = loss_function(out, target) |
| |
|
| | loss.sum().backward(retain_graph=True) |
| | with torch.no_grad(): |
| | oth = x.grad / torch.norm(x.grad, p=1) |
| | gradorg *= self.mu |
| | gradorg += oth |
| | grad = (self.r * w / self.k) * ai.mysign(gradorg) |
| | if self.should_end: |
| | is_eq = ai.mulIfEq(grad, out, target) |
| | x = (x + grad * is_eq) if untargeted else (x - grad * is_eq) |
| |
|
| | x = xo + torch.min(torch.max(x - xo, -w),w) |
| | x.requires_grad_() |
| |
|
| | model.optimizer.zero_grad() |
| |
|
| | return x |
| |
|
| | def boxBetween(self, o1, o2, model, target, *args, **kargs): |
| | raise "Not boxBetween is not yet supported by MI_FGSM" |
| |
|
| | def __str__(self): |
| | return "MI_FGSM(%s)" % (("" if self.epsilon is None else "w="+str(self.epsilon)+",") |
| | + ("" if self.k == 5 else "k="+str(self.k)+",") |
| | + ("" if self.r == 5.0 else "r="+str(self.r)+",") |
| | + ("" if self.mu == 0.8 else "r="+str(self.mu)+",") |
| | + ("" if self.should_end else "should_end=False")) |
| |
|
| |
|
| | class PGD(MI_FGSM): |
| | def __init__(self, r = 5.0, k = 5, **kargs): |
| | super(PGD,self).__init__(r=r, k = k, mu = 0, **kargs) |
| |
|
| | def __str__(self): |
| | return "PGD(%s)" % (("" if self.epsilon is None else "w="+str(self.epsilon)+",") |
| | + ("" if self.k == 5 else "k="+str(self.k)+",") |
| | + ("" if self.r == 5.0 else "r="+str(self.r)+",") |
| | + ("" if self.should_end else "should_end=False")) |
| |
|
| | class IFGSM(PGD): |
| |
|
| | def __init__(self, k = 5, **kargs): |
| | super(IFGSM, self).__init__(r = 1, k=k, **kargs) |
| |
|
| | def __str__(self): |
| | return "IFGSM(%s)" % (("" if self.epsilon is None else "w="+str(self.epsilon)+",") |
| | + ("" if self.k == 5 else "k="+str(self.k)+",") |
| | + ("" if self.should_end else "should_end=False")) |
| |
|
| | class NormalAdv(Point): |
| | def __init__(self, a="IFGSM()", w = None): |
| | self.a = (eval(a) if type(a) is str else a) |
| | self.epsilon = S.Const.initConst(w) |
| |
|
| | def box(self, original, w, *args, **kargs): |
| | epsilon = self.epsilon.getVal(c = w, shape = original.shape[:1], **kargs) |
| | assert (0 <= h.dten(epsilon)).all() |
| | epsilon = torch.randn(original.size()[0:1], device = h.device)[0] * epsilon |
| | return self.a.box(original, w = epsilon, *args, **kargs) |
| |
|
| | def __str__(self): |
| | return "NormalAdv(%s)" % ( str(self.a) + ("" if self.epsilon is None else ",w="+str(self.epsilon))) |
| |
|
| |
|
| | class InclusionSample(Point): |
| | def __init__(self, sub, a="Box()", normal = False, w = None, **kargs): |
| | self.sub = S.Const.initConst(sub) |
| | self.w = S.Const.initConst(w) |
| | self.normal = normal |
| | self.a = (eval(a) if type(a) is str else a) |
| |
|
| | def box(self, original, w, *args, **kargs): |
| | w = self.w.getVal(c = w, shape = original.shape[:1], **kargs) |
| | sub = self.sub.getVal(c = 1, shape = original.shape[:1], **kargs) |
| |
|
| | assert (0 <= h.dten(w)).all() |
| | assert (h.dten(sub) <= 1).all() |
| | assert (0 <= h.dten(sub)).all() |
| | if self.normal: |
| | inter = torch.randn_like(original, device = h.device) |
| | else: |
| | inter = (torch.rand_like(original, device = h.device) * 2 - 1) |
| |
|
| | inter = inter * w * (1 - sub) |
| | |
| | return self.a.box(original + inter, w = w * sub, *args, **kargs) |
| |
|
| | def boxBetween(self, o1, o2, *args, **kargs): |
| | w = (o2 - o1).abs() |
| | return self.box( (o2 + o1)/2 , w = w, *args, **kargs) |
| |
|
| | def __str__(self): |
| | return "InclusionSample(%s, %s)" % (str(self.sub), str(self.a) + ("" if self.epsilon is None else ",w="+str(self.epsilon))) |
| |
|
| | InSamp = InclusionSample |
| |
|
| |
|
| | class AdvInclusion(InclusionSample): |
| | def __init__(self, sub, a="IFGSM()", b="Box()", w = None, **kargs): |
| | self.sub = S.Const.initConst(sub) |
| | self.w = S.Const.initConst(w) |
| | self.a = (eval(a) if type(a) is str else a) |
| | self.b = (eval(b) if type(b) is str else b) |
| |
|
| | def box(self, original, w, *args, **kargs): |
| | w = self.w.getVal(c = w, shape = original.shape, **kargs) |
| | sub = self.sub.getVal(c = 1, shape = original.shape, **kargs) |
| |
|
| | assert (0 <= h.dten(w)).all() |
| | assert (h.dten(sub) <= 1).all() |
| | assert (0 <= h.dten(sub)).all() |
| |
|
| | if h.dten(w).sum().item() <= 0.0: |
| | inter = original |
| | else: |
| | inter = self.a.box(original, w = w * (1 - sub), *args, **kargs) |
| |
|
| | return self.b.box(inter, w = w * sub, *args, **kargs) |
| |
|
| | def __str__(self): |
| | return "AdvInclusion(%s, %s, %s)" % (str(self.sub), str(self.a), str(self.b) + ("" if self.epsilon is None else ",w="+str(self.epsilon))) |
| |
|
| |
|
| | class AdvDom(Point): |
| | def __init__(self, a="IFGSM()", b="Box()"): |
| | self.a = (eval(a) if type(a) is str else a) |
| | self.b = (eval(b) if type(b) is str else b) |
| |
|
| | def box(self, original,*args, **kargs): |
| | adv = self.a.box(original, *args, **kargs) |
| | return self.b.boxBetween(original, adv.ub(), *args, **kargs) |
| |
|
| | def boxBetween(self, o1, o2, *args, **kargs): |
| | original = (o1 + o2) / 2 |
| | adv = self.a.boxBetween(o1, o2, *args, **kargs) |
| | return self.b.boxBetween(original, adv.ub(), *args, **kargs) |
| |
|
| | def __str__(self): |
| | return "AdvDom(%s)" % (("" if self.width is None else "width="+str(self.width)+",") |
| | + str(self.a) + "," + str(self.b)) |
| |
|
| |
|
| |
|
| | class BiAdv(AdvDom): |
| | def box(self, original, **kargs): |
| | adv = self.a.box(original, **kargs) |
| | extreme = (adv.ub() - original).abs() |
| | return self.b.boxBetween(original - extreme, original + extreme, **kargs) |
| | |
| | def boxBetween(self, o1, o2, *args, **kargs): |
| | original = (o1 + o2) / 2 |
| | adv = self.a.boxBetween(o1, o2, *args, **kargs) |
| | extreme = (adv.ub() - original).abs() |
| | return self.b.boxBetween(original - extreme, original + extreme, *args, **kargs) |
| |
|
| | def __str__(self): |
| | return "BiAdv" + AdvDom.__str__(self)[6:] |
| |
|
| |
|
| | class HBox(object): |
| | Domain = ai.HybridZonotope |
| |
|
| | def domain(self, *args, **kargs): |
| | return ai.TaggedDomain(self.Domain(*args, **kargs), self) |
| |
|
| | def __init__(self, w = None, tot_weight = 1, width_weight = 0, pow_loss = None, log_loss = False, searchable = True, cross_loss = True, **kargs): |
| | self.w = S.Const.initConst(w) |
| | self.tot_weight = S.Const.initConst(tot_weight) |
| | self.width_weight = S.Const.initConst(width_weight) |
| | self.pow_loss = pow_loss |
| | self.searchable = searchable |
| | self.log_loss = log_loss |
| | self.cross_loss = cross_loss |
| |
|
| | def __str__(self): |
| | return "HBox(%s)" % ("" if self.w is None else "w="+str(self.w)) |
| |
|
| | def boxBetween(self, o1, o2, *args, **kargs): |
| | batches = o1.size()[0] |
| | num_elem = h.product(o1.size()[1:]) |
| | ei = h.getEi(batches, num_elem) |
| | |
| | if len(o1.size()) > 2: |
| | ei = ei.contiguous().view(num_elem, *o1.size()) |
| |
|
| | return self.domain((o1 + o2) / 2, None, ei * (o2 - o1).abs() / 2).checkSizes() |
| |
|
| | def box(self, original, w, **kargs): |
| | """ |
| | This version of it is slow, but keeps correlation down the line. |
| | """ |
| | radius = self.w.getVal(c = w, **kargs) |
| |
|
| | batches = original.size()[0] |
| | num_elem = h.product(original.size()[1:]) |
| | ei = h.getEi(batches,num_elem) |
| | |
| | if len(original.size()) > 2: |
| | ei = ei.contiguous().view(num_elem, *original.size()) |
| |
|
| | return self.domain(original, None, ei * radius).checkSizes() |
| |
|
| | def line(self, o1, o2, **kargs): |
| | w = self.w.getVal(c = 0, **kargs) |
| |
|
| | ln = ((o2 - o1) / 2).unsqueeze(0) |
| | if not w is None and w > 0.0: |
| | batches = o1.size()[0] |
| | num_elem = h.product(o1.size()[1:]) |
| | ei = h.getEi(batches,num_elem) |
| | if len(o1.size()) > 2: |
| | ei = ei.contiguous().view(num_elem, *o1.size()) |
| | ln = torch.cat([ln, ei * w]) |
| | return self.domain((o1 + o2) / 2, None, ln ).checkSizes() |
| |
|
| | def loss(self, dom, target, *args, **kargs): |
| | width_weight = self.width_weight.getVal(**kargs) |
| | tot_weight = self.tot_weight.getVal(**kargs) |
| | |
| | if self.cross_loss: |
| | r = dom.ub() |
| | inds = torch.arange(r.shape[0], device=h.device, dtype=h.ltype) |
| | r[inds,target] = dom.lb()[inds,target] |
| | tot = r.loss(target, *args, **kargs) |
| | else: |
| | tot = dom.loss(target, *args, **kargs) |
| |
|
| | if self.log_loss: |
| | tot = (tot + 1).log() |
| | if self.pow_loss is not None and self.pow_loss > 0 and self.pow_loss != 1: |
| | tot = tot.pow(self.pow_loss) |
| |
|
| | ls = tot * tot_weight |
| | if width_weight > 0: |
| | ls += dom.diameter() * width_weight |
| |
|
| | return ls / (width_weight + tot_weight) |
| |
|
| | class Box(HBox): |
| | def __str__(self): |
| | return "Box(%s)" % ("" if self.w is None else "w="+str(self.w)) |
| |
|
| | def box(self, original, w, **kargs): |
| | """ |
| | This version of it takes advantage of betas being uncorrelated. |
| | Unfortunately they stay uncorrelated forever. |
| | Counterintuitively, tests show more accuracy - this is because the other box |
| | creates lots of 0 errors which get accounted for by the calcultion of the newhead in relu |
| | which is apparently worse than not accounting for errors. |
| | """ |
| | radius = self.w.getVal(c = w, **kargs) |
| | return self.domain(original, h.ones(original.size()) * radius, None).checkSizes() |
| | |
| | def line(self, o1, o2, **kargs): |
| | w = self.w.getVal(c = 0, **kargs) |
| | return self.domain((o1 + o2) / 2, ((o2 - o1) / 2).abs() + h.ones(o2.size()) * w, None).checkSizes() |
| |
|
| | def boxBetween(self, o1, o2, *args, **kargs): |
| | return self.line(o1, o2, **kargs) |
| |
|
| | class ZBox(HBox): |
| |
|
| | def __str__(self): |
| | return "ZBox(%s)" % ("" if self.w is None else "w="+str(self.w)) |
| |
|
| | def Domain(self, *args, **kargs): |
| | return ai.Zonotope(*args, **kargs) |
| |
|
| | class HSwitch(HBox): |
| | def __str__(self): |
| | return "HSwitch(%s)" % ("" if self.w is None else "w="+str(self.w)) |
| | |
| | def Domain(self, *args, **kargs): |
| | return ai.HybridZonotope(*args, customRelu = ai.creluSwitch, **kargs) |
| | |
| | class ZSwitch(ZBox): |
| |
|
| | def __str__(self): |
| | return "ZSwitch(%s)" % ("" if self.w is None else "w="+str(self.w)) |
| | def Domain(self, *args, **kargs): |
| | return ai.Zonotope(*args, customRelu = ai.creluSwitch, **kargs) |
| |
|
| |
|
| | class ZNIPS(ZBox): |
| |
|
| | def __str__(self): |
| | return "ZSwitch(%s)" % ("" if self.w is None else "w="+str(self.w)) |
| |
|
| | def Domain(self, *args, **kargs): |
| | return ai.Zonotope(*args, customRelu = ai.creluNIPS, **kargs) |
| | |
| | class HSmooth(HBox): |
| | def __str__(self): |
| | return "HSmooth(%s)" % ("" if self.w is None else "w="+str(self.w)) |
| |
|
| | def Domain(self, *args, **kargs): |
| | return ai.HybridZonotope(*args, customRelu = ai.creluSmooth, **kargs) |
| | |
| | class HNIPS(HBox): |
| | def __str__(self): |
| | return "HSmooth(%s)" % ("" if self.w is None else "w="+str(self.w)) |
| |
|
| | def Domain(self, *args, **kargs): |
| | return ai.HybridZonotope(*args, customRelu = ai.creluNIPS, **kargs) |
| |
|
| | class ZSmooth(ZBox): |
| | def __str__(self): |
| | return "ZSmooth(%s)" % ("" if self.w is None else "w="+str(self.w)) |
| |
|
| | def Domain(self, *args, **kargs): |
| | return ai.Zonotope(*args, customRelu = ai.creluSmooth, **kargs) |
| |
|
| |
|
| |
|
| |
|
| |
|
| | |
| | class HRand(WrapDom): |
| | |
| | def __init__(self, num_correlated, a = "HSwitch()", **kargs): |
| | super(HRand, self).__init__(Box()) |
| | self.num_correlated = num_correlated |
| | self.dom = eval(a) if type(a) is str else a |
| | |
| | def Domain(self, d): |
| | with torch.no_grad(): |
| | out = d.abstractApplyLeaf('stochasticCorrelate', self.num_correlated) |
| | out = self.dom.Domain(out.head, out.beta, out.errors) |
| | return out |
| |
|
| | def __str__(self): |
| | return "HRand(%s, domain = %s)" % (str(self.num_correlated), str(self.a)) |
| |
|