Пример #1
0
def opmaxmin(cla, gan, eps, im_size=784, embed_feats=256, num_images=50, z_lr=5e-3, lambda_lr=1e-4,num_steps=1000, batch_num=None, ind=None):
    
    softmax = ch.nn.Softmax()
    logsoftmax = ch.nn.LogSoftmax()
    
    BATCH_SIZE = 1

    batch1 = ch.zeros((num_images, 1,28,28)).cuda()
    batch2 = ch.zeros((num_images, 1,28,28)).cuda()
    is_valid = ch.zeros(num_images).cuda()
    count = 0
    EPS = eps
    for i in range(num_images//BATCH_SIZE):

        z1 = ch.Tensor(ch.rand(BATCH_SIZE,embed_feats)).cuda() 
        z1.requires_grad = True
        z2 = ch.Tensor(ch.rand(z1.shape)).cuda()
        z2.requires_grad_()
        
        ones = ch.ones(z1.shape[0]).cuda()
        
        lambda_ = 1e0*ch.ones(z1.shape[0],1).cuda()
        lambda_.requires_grad = True

        opt1 = YFOptimizer([{'params':z1},{'params':z2}], lr=z_lr, clip_thresh=None, adapt_clip=False)
        opt2 = YFOptimizer([{'params':lambda_}], lr=lambda_lr, clip_thresh=None, adapt_clip=False)

        for j in range(num_steps):
                
            x1 = gan(z1)
            x2 = gan(z2)
            distance_mat = ch.norm((x1-x2).view(x1.shape[0],-1),dim=-1,keepdim=False) - EPS*ones
            
            cla_res1 = cla(x1).argmax(dim=-1)
            cla_res2 = cla(x2).argmax(dim=-1)
            
            #print('Cross entropy:%f \t distance=%f \t lambda=%f'%(ce(cla(x1),cla(x2)),distance_mat,lambda_))

            is_adv = 1 - (cla_res1==cla_res2).float()
            is_feasible = (distance_mat<=0).float() 
            not_valid = 1- (is_adv*is_feasible)
            if ch.sum(is_adv*is_feasible) == BATCH_SIZE:
#                 ind = (ch.abs(cla_res1 - cla_res2)*is_valid*is_feasible_mat).argmax(0)
                batch1[i*BATCH_SIZE:(i+1)*BATCH_SIZE,...] = x1
                batch2[i*BATCH_SIZE:(i+1)*BATCH_SIZE,...] = x2
                is_valid[i*BATCH_SIZE:(i+1)*BATCH_SIZE] = 1.
                break
            
            opt1.zero_grad()
            loss1 = (-1.* ch.sum(ce(cla(gan(z1)),cla(gan(z2)),reduction=None)*not_valid) + \
                     ch.sum(lambda_ * distance_mat*not_valid) + 1e-4*ch.sum(ch.norm(z1,dim=-1)*not_valid) +\
                     1e-4*ch.sum(ch.norm(z2,dim=-1)*not_valid))/ch.sum(not_valid)
            
            
            loss1.backward(retain_graph=True)
            opt1.step()
            
            for k in range(1):
                opt2.zero_grad()
                loss2 = -1.*ch.mean(lambda_ * distance_mat*(not_valid)) 
                loss2.backward()
                opt2.step()
                #lambda_ = lambda_.clamp(1e-3,1e5)
        batch1[i*BATCH_SIZE:(i+1)*BATCH_SIZE,...] = x1
        batch2[i*BATCH_SIZE:(i+1)*BATCH_SIZE,...] = x2
        is_valid[i*BATCH_SIZE:(i+1)*BATCH_SIZE] = is_adv * is_feasible
    
    count = ch.sum(is_valid)
    print('number of adversarial pairs found:%d\n'%(count))

    return batch1.detach(), batch2.detach(), is_valid
Пример #2
0

def se(x1, x2, reduction='mean'):
    y = ch.norm((x1-x2).view(x1.shape[0],-1),dim=-1,p=2)**2
    if reduction=='sum':
        return ch.sum(y)
    elif reduction=='mean':
        return ch.mean(y)
    else: 
        return y

best_adv_acc = 0.
for ep in range(1,args.num_epochs+1):
    if args.op_attack and (ep-1) % args.op_iter ==0:
        net.train()
        opt.zero_grad()
        batch1, batch2, is_adv = opmaxmin(net,generator.decode,args.op_eps,num_images=50,\
                num_steps=500,embed_feats=args.op_embed_feats,z_lr=1e-4,lambda_lr=1e-4,ind=ep)
        if ch.sum(is_adv) > 0:
            loss = args.op_weight*ch.sum(ce(net(batch1), net(batch2),reduction=None)*is_adv)/ch.sum(is_adv)
            loss.backward()
            opt.step()
        else:
            pass
    
    total_ims_seen = 0
    val_num_correct = 0
    val_num_total = 0
    for i, (images, labels) in enumerate(trainloader):

        if args.dataset_size is not None and total_ims_seen > args.dataset_size: