def opmaxmin(cla, gan, eps, im_size=784, embed_feats=256, num_images=50, z_lr=5e-3, lambda_lr=1e-4,num_steps=1000, batch_num=None, ind=None): softmax = ch.nn.Softmax() logsoftmax = ch.nn.LogSoftmax() BATCH_SIZE = 1 batch1 = ch.zeros((num_images, 1,28,28)).cuda() batch2 = ch.zeros((num_images, 1,28,28)).cuda() is_valid = ch.zeros(num_images).cuda() count = 0 EPS = eps for i in range(num_images//BATCH_SIZE): z1 = ch.Tensor(ch.rand(BATCH_SIZE,embed_feats)).cuda() z1.requires_grad = True z2 = ch.Tensor(ch.rand(z1.shape)).cuda() z2.requires_grad_() ones = ch.ones(z1.shape[0]).cuda() lambda_ = 1e0*ch.ones(z1.shape[0],1).cuda() lambda_.requires_grad = True opt1 = YFOptimizer([{'params':z1},{'params':z2}], lr=z_lr, clip_thresh=None, adapt_clip=False) opt2 = YFOptimizer([{'params':lambda_}], lr=lambda_lr, clip_thresh=None, adapt_clip=False) for j in range(num_steps): x1 = gan(z1) x2 = gan(z2) distance_mat = ch.norm((x1-x2).view(x1.shape[0],-1),dim=-1,keepdim=False) - EPS*ones cla_res1 = cla(x1).argmax(dim=-1) cla_res2 = cla(x2).argmax(dim=-1) #print('Cross entropy:%f \t distance=%f \t lambda=%f'%(ce(cla(x1),cla(x2)),distance_mat,lambda_)) is_adv = 1 - (cla_res1==cla_res2).float() is_feasible = (distance_mat<=0).float() not_valid = 1- (is_adv*is_feasible) if ch.sum(is_adv*is_feasible) == BATCH_SIZE: # ind = (ch.abs(cla_res1 - cla_res2)*is_valid*is_feasible_mat).argmax(0) batch1[i*BATCH_SIZE:(i+1)*BATCH_SIZE,...] = x1 batch2[i*BATCH_SIZE:(i+1)*BATCH_SIZE,...] = x2 is_valid[i*BATCH_SIZE:(i+1)*BATCH_SIZE] = 1. break opt1.zero_grad() loss1 = (-1.* ch.sum(ce(cla(gan(z1)),cla(gan(z2)),reduction=None)*not_valid) + \ ch.sum(lambda_ * distance_mat*not_valid) + 1e-4*ch.sum(ch.norm(z1,dim=-1)*not_valid) +\ 1e-4*ch.sum(ch.norm(z2,dim=-1)*not_valid))/ch.sum(not_valid) loss1.backward(retain_graph=True) opt1.step() for k in range(1): opt2.zero_grad() loss2 = -1.*ch.mean(lambda_ * distance_mat*(not_valid)) loss2.backward() opt2.step() #lambda_ = lambda_.clamp(1e-3,1e5) batch1[i*BATCH_SIZE:(i+1)*BATCH_SIZE,...] = x1 batch2[i*BATCH_SIZE:(i+1)*BATCH_SIZE,...] = x2 is_valid[i*BATCH_SIZE:(i+1)*BATCH_SIZE] = is_adv * is_feasible count = ch.sum(is_valid) print('number of adversarial pairs found:%d\n'%(count)) return batch1.detach(), batch2.detach(), is_valid
def se(x1, x2, reduction='mean'): y = ch.norm((x1-x2).view(x1.shape[0],-1),dim=-1,p=2)**2 if reduction=='sum': return ch.sum(y) elif reduction=='mean': return ch.mean(y) else: return y best_adv_acc = 0. for ep in range(1,args.num_epochs+1): if args.op_attack and (ep-1) % args.op_iter ==0: net.train() opt.zero_grad() batch1, batch2, is_adv = opmaxmin(net,generator.decode,args.op_eps,num_images=50,\ num_steps=500,embed_feats=args.op_embed_feats,z_lr=1e-4,lambda_lr=1e-4,ind=ep) if ch.sum(is_adv) > 0: loss = args.op_weight*ch.sum(ce(net(batch1), net(batch2),reduction=None)*is_adv)/ch.sum(is_adv) loss.backward() opt.step() else: pass total_ims_seen = 0 val_num_correct = 0 val_num_total = 0 for i, (images, labels) in enumerate(trainloader): if args.dataset_size is not None and total_ims_seen > args.dataset_size: