def correlate( self, cc_indx_batch_beta): # given in terms of the flattened matrix. num_correlate = h.product(cc_indx_batch_beta.shape[1:]) beta = h.zeros( self.head.shape).to_dtype() if self.beta is None else self.beta errors = h.zeros([0] + list(self.head.shape)).to_dtype( ) if self.errors is None else self.errors batch_size = beta.shape[0] new_errors = h.zeros([num_correlate] + list(self.head.shape)).to_dtype() inds_i = torch.arange(batch_size, device=h.device).unsqueeze(1).long() nc = torch.arange(num_correlate, device=h.device).unsqueeze(1).long() new_errors = new_errors.permute( 1, 0, *list(range(len(new_errors.shape)))[2:]).contiguous().view( batch_size, num_correlate, -1) new_errors[inds_i, nc.unsqueeze(0).expand([batch_size] + list(nc.shape)).squeeze(2), cc_indx_batch_beta] = \ beta.view(batch_size, -1)[inds_i, cc_indx_batch_beta] new_errors = new_errors.permute( 1, 0, *list(range(len(new_errors.shape)))[2:]).contiguous().view( num_correlate, batch_size, *beta.shape[1:]) errors = torch.cat((errors, new_errors), dim=0) beta.view(batch_size, -1)[inds_i, cc_indx_batch_beta] = 0 return self.new(self.head, beta, errors)
def decorrelate(self, cc_indx_batch_err): # keep these errors if self.errors is None: return self batch_size = self.head.shape[0] num_error_terms = self.errors.shape[0] beta = h.zeros( self.head.shape).to_dtype() if self.beta is None else self.beta errors = h.zeros([0] + list(self.head.shape)).to_dtype( ) if self.errors is None else self.errors inds_i = torch.arange(self.head.shape[0], device=h.device).unsqueeze(1).long() errors = errors.to_dtype().permute( 1, 0, *list(range(len(self.errors.shape)))[2:]) sm = errors.clone() sm[inds_i, cc_indx_batch_err] = 0 beta = beta.to_dtype() + sm.abs().sum(dim=1) errors = errors[inds_i, cc_indx_batch_err] errors = errors.permute(1, 0, *list(range(len( self.errors.shape)))[2:]).contiguous() return self.new(self.head, beta, errors)
def creluNIPS(dom): if dom.errors is None: if dom.beta is None: return dom.new(F.relu(dom.head), None, None) er = dom.beta mx = F.relu(dom.head + er) mn = F.relu(dom.head - er) return dom.new((mn + mx) / 2, (mx - mn) / 2, None) sm = torch.sum(torch.abs(dom.errors), 0) if not dom.beta is None: sm += dom.beta mn = dom.head - sm mx = dom.head + sm mngz = mn >= 0.0 zs = h.zeros(dom.head.shape) diff = mx - mn lam = torch.where((mx > 0) & (diff > 0.0), mx / diff, zs) mu = lam * mn * (-0.5) betaz = zs if dom.beta is None else dom.beta newhead = torch.where(mngz, dom.head, lam * dom.head + mu) mngz += diff <= 0.0 newbeta = torch.where(mngz, betaz, lam * betaz + mu) # mu is always positive on this side newerr = torch.where(mngz, dom.errors, lam * dom.errors) return dom.new(newhead, newbeta, newerr)
def doop(er1, er2): erS, erL = (er1, er2) sS, sL = (erS.size()[0], erL.size()[0]) if sS == sL: # TODO: here we know we used transformers on either side which didnt introduce new error terms (this is a hack for hybrid zonotopes and doesn't work with adaptive error term adding). return op(erS, erL) if ref_errs is not None: sz = ref_errs.size()[0] else: sz = min(sS, sL) p1 = op(erS[:sz], erL[:sz]) erSrem = erS[sz:] erLrem = erS[sz:] p2 = op(erSrem, h.zeros(erSrem.shape)) p3 = op(h.zeros(erLrem.shape), erLrem) return torch.cat((p1, p2, p3), dim=0)
def slidingMax(a): # using maxpool k = a.shape[1] ml = a.min(dim=1)[0].unsqueeze(1) inp = torch.cat((h.zeros([batch_size, k]), a - ml), dim=1) mpl = F.max_pool1d(inp.unsqueeze(1), kernel_size=k, stride=1, padding=0, return_indices=False).squeeze(1) return mpl[:, :-1] + ml
def attack(self, model, xo, untargeted, target, w, loss_function=ai.stdLoss, **kargs): w = self.epsilon.getVal(c=w, **kargs) x = nn.Parameter(xo.clone(), requires_grad=True) gradorg = h.zeros(x.shape) is_eq = 1 w = h.ones(x.shape) * w for i in range(self.k): if self.restart is not None and i % int( self.k / self.restart) == 0: x = is_eq * (torch.rand_like(xo) * w + xo) + (1 - is_eq) * x x = nn.Parameter(x, requires_grad=True) model.optimizer.zero_grad() out = model(x).vanillaTensorPart() loss = loss_function(out, target) loss.sum().backward(retain_graph=True) with torch.no_grad(): oth = x.grad / torch.norm(x.grad, p=1) gradorg *= self.mu gradorg += oth grad = (self.r * w / self.k) * ai.mysign(gradorg) if self.should_end: is_eq = ai.mulIfEq(grad, out, target) x = (x + grad * is_eq) if untargeted else (x - grad * is_eq) x = xo + torch.min(torch.max(x - xo, -w), w) x.requires_grad_() model.optimizer.zero_grad() return x
def softplus(self): if self.errors is None: if self.beta is None: return self.new(F.softplus(self.head), None, None) tp = F.softplus(self.head + self.beta) bt = F.softplus(self.head - self.beta) return self.new((tp + bt) / 2, (tp - bt) / 2, None) errors = self.concreteErrors() o = h.ones(self.head.size()) def sp(hd): return F.softplus( hd) # torch.log(o + torch.exp(hd)) # not very stable def spp(hd): ehd = torch.exp(hd) return ehd.div(ehd + o) def sppp(hd): ehd = torch.exp(hd) md = ehd + o return ehd.div(md.mul(md)) fa = sp(self.head) fpa = spp(self.head) a = self.head k = torch.sum(errors.abs(), 0) def evalG(r): return r.mul(r).mul(sppp(a + r)) m = torch.max(evalG(h.zeros(k.size())), torch.max(evalG(k), evalG(-k))) m = h.ifThenElse(a.abs().lt(k), torch.max(m, torch.max(evalG(a), evalG(-a))), m) m /= 2 return self.new(fa, m if self.beta is None else m + self.beta.mul(fpa), None if self.errors is None else self.errors.mul(fpa))
def train_epoch(epoch, model, victim_model, attack, args, train_loader): vargs = vars(args) model.train() print(("Cur ratio: {}").format(S.TrainInfo.cur_ratio)) assert isinstance(model.ty, goals.DList) and len(model.ty.al) == 2 for (i, a) in enumerate(model.ty.al): if not isinstance(a[0], goals.Point): model.ty.al[i] = (a[0], S.Const(args.train_lambda * S.TrainInfo.cur_ratio)) else: model.ty.al[i] = ( a[0], S.Const(1 - args.train_lambda * S.TrainInfo.cur_ratio)) for batch_idx, (data, target) in enumerate(train_loader): S.TrainInfo.total_batches_seen += 1 time = float(S.TrainInfo.total_batches_seen) / len(train_loader) data, target = data.to(h.device), target.to(h.device) model.global_num += data.size()[0] lossy = 0 adv_time = sys_time.time() if args.adv_train_num > 0: data, target = adv_batch(victim_model, attack, data, target, args.adv_train_num) adv_time = sys_time.time() - adv_time timer = Timer( "train a sample from " + model.name + " with " + model.ty.name, data.size()[0], False) with timer: for s in model.boxSpec(data.to_dtype(), target, time=time): model.optimizer.zero_grad() loss = model.aiLoss(*s, time=time, **vargs).mean(dim=0) lossy += loss.detach().item() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 5) for p in model.parameters(): if not p.requires_grad: continue if p is not None and torch.isnan(p).any(): print("Such nan in vals") if p is not None and p.grad is not None and torch.isnan( p.grad).any(): print("Such nan in postmagic") stdv = 1 / math.sqrt(h.product(p.data.shape)) p.grad = torch.where( torch.isnan(p.grad), torch.normal(mean=h.zeros(p.grad.shape), std=stdv), p.grad) model.optimizer.step() for p in model.parameters(): if not p.requires_grad: continue if p is not None and torch.isnan(p).any(): print("Such nan in vals after grad") stdv = 1 / math.sqrt(h.product(p.data.shape)) p.data = torch.where( torch.isnan(p.data), torch.normal(mean=h.zeros(p.data.shape), std=stdv), p.data) if args.clip_norm: model.clip_norm() for p in model.parameters(): if not p.requires_grad: continue if p is not None and torch.isnan(p).any(): raise Exception("Such nan in vals after clip") model.addSpeed(timer.getUnitTime() + adv_time / len(data)) if batch_idx % args.log_interval == 0: print(( 'Train Epoch {:12} Mix(a=Point(),b=Box(),aw=1,bw=0) {:3} [{:7}/{} ({:.0f}%)] \tAvg sec/ex {:1.8f}\tLoss: {:.6f}' ).format(model.name, epoch, batch_idx * len(data) // (args.adv_train_num + 1), len(train_loader.dataset), 100. * batch_idx / len(train_loader), model.speed, lossy))