| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.optim as optim |
| |
|
| | import helpers as h |
| | import domains |
| | from domains import * |
| | import math |
| |
|
| |
|
| | POINT_DOMAINS = [m for m in h.getMethods(domains) if h.hasMethod(m, "attack")] + [ torch.FloatTensor, torch.Tensor, torch.cuda.FloatTensor ] |
| | SYMETRIC_DOMAINS = [domains.Box] + POINT_DOMAINS |
| |
|
| | def domRes(outDom, target, **args): |
| | t = h.one_hot(target.data.long(), outDom.size()[1]).to_dense() |
| | tmat = t.unsqueeze(2).matmul(t.unsqueeze(1)) |
| | |
| | tl = t.unsqueeze(2).expand(-1, -1, tmat.size()[1]) |
| | |
| | inv_t = h.eye(tmat.size()[1]).expand(tmat.size()[0], -1, -1) |
| | inv_t = inv_t - tmat |
| | |
| | tl = tl.bmm(inv_t) |
| | |
| | fst = outDom.bmm(tl) |
| | snd = outDom.bmm(inv_t) |
| | diff = fst - snd |
| | return diff.lb() + t |
| |
|
| | def isSafeDom(outDom, target, **args): |
| | od,_ = torch.min(domRes(outDom, target, **args), 1) |
| | return od.gt(0.0).long().item() |
| |
|
| |
|
| | def isSafeBox(target, net, inp, eps, dom): |
| | atarg = target.argmax(1)[0].unsqueeze(0) |
| | if hasattr(dom, "attack"): |
| | x = dom.attack(net, eps, inp, target) |
| | pred = net(x).argmax(1)[0].unsqueeze(0) |
| | return pred.item() == atarg.item() |
| | else: |
| | outDom = net(dom.box(inp, eps)) |
| | return isSafeDom(outDom, atarg) |
| |
|