def test(epoch, net): net.eval() test_loss = 0 correct = 0 total = 0 xce = 0. iterator = tqdm(testloader, ncols=0, leave=False) # x_adv = torch.load('x_adv.pt')['x_adv'] # print(x_adv.size()) # i = -1 for batch_idx, (inputs, targets) in enumerate(iterator): # i += 1 start_time = time.time() inputs, targets = inputs.to(device), targets.to(device) pert_inputs = inputs.detach() # pert_inputs, targets = x_adv[i*args.batch_size_test:np.minimum((i+1)*args.batch_size_test, 10000)].to(device), targets.to(device) outputs, _, _, pert_inputs, pert_i = net(pert_inputs, targets, batch_idx=batch_idx) xce_batch = torch.sum(-utils.one_hot_tensor(targets, 10, device) * F.log_softmax(outputs)).item() loss = criterion(outputs, targets) test_loss += loss.item() duration = time.time() - start_time _, predicted = outputs.max(1) batch_size = targets.size(0) total += batch_size correct_num = predicted.eq(targets).sum().item() correct += correct_num iterator.set_description( str(predicted.eq(targets).sum().item() / targets.size(0))) xce += xce_batch if batch_idx % args.log_step == 0: print( "step %d, duration %.2f, test acc %.2f, avg-acc %.2f, loss %.2f" % (batch_idx, duration, 100. * correct_num / batch_size, 100. * correct / total, test_loss / total)) acc = 100. * correct / total print('Val acc:', acc) xce = xce / total print('xce : ', xce) return acc
def attack_inversion(model, inputs, targets, config): step_size = config['step_size'] epsilon = config['epsilon'] num_steps = config['num_steps'] ls_factor = config['ls_factor'] model.eval() inv_idx = torch.arange(inputs.size(0) - 1, -1, -1).long() x = inputs.detach() x = x + torch.zeros_like(x).uniform_(-epsilon, epsilon) logits_pred_nat, fea_nat = model(inputs[inv_idx]) fea_nat = fea_nat.detach() num_classes = logits_pred_nat.size(-1) for i in range(num_steps): x.requires_grad_() zero_gradients(x) if x.grad is not None: x.grad.data.fill_(0) logits_pred, fea = model(x) #inver_loss = ot.cost_matrix_cos(fea, fea_nat) inver_loss = ot.pair_cos_dist(fea, fea_nat) #inver_loss = torch.div(torch.norm(fea - fea_nat, dim=1), torch.norm(fea_nat, dim=1)) adv_loss = inver_loss.mean() adv_loss.backward() x_adv = x.data - step_size * torch.sign(x.grad.data) x_adv = torch.min(torch.max(x_adv, inputs-epsilon), inputs+epsilon) x_adv = torch.clamp(x_adv, -1.0, 1.0) x = Variable(x_adv) targets_one_hot = one_hot_tensor(targets, num_classes, device) # if adapt label smooth targets_one_hot_inv = targets_one_hot[inv_idx] soft_targets = (1 - targets_one_hot_inv) / (num_classes - 1) soft_targets = (1 - ls_factor) * targets_one_hot + ls_factor * soft_targets # if not adapt label smooth # soft_targets = utils.label_smoothing(targets_one_hot, targets_one_hot.size(1), ls_factor) return x, soft_targets
def forward(self, inputs, targets, attack=True, targeted_label=-1, batch_idx=0): if not attack: outputs, _ = self.basic_net(inputs) return outputs, None if self.box_type == 'white': aux_net = pickle.loads(pickle.dumps(self.basic_net)) elif self.box_type == 'black': assert self.attack_net is not None, "should provide an additional net in black-box case" aux_net = pickle.loads(pickle.dumps(self.basic_net)) aux_net.eval() batch_size = inputs.size(0) m = batch_size n = batch_size # logits = aux_net(inputs)[0] # num_classes = logits.size(1) # outputs = aux_net(inputs)[0] # targets_prob = F.softmax(outputs.float(), dim=1) # y_tensor_adv = targets # step_sign = 1.0 x = inputs.detach() # x_org = x.detach() x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon) if self.train_flag: self.basic_net.train() else: self.basic_net.eval() logits_pred_nat, fea_nat = aux_net(inputs) num_classes = logits_pred_nat.size(1) y_gt = one_hot_tensor(targets, num_classes, device) loss_ce = softCrossEntropy() iter_num = self.num_steps for i in range(iter_num): x.requires_grad_() zero_gradients(x) if x.grad is not None: x.grad.data.fill_(0) logits_pred, fea = aux_net(x) ot_loss = ot.sinkhorn_loss_joint_IPOT(1, 0.00, logits_pred_nat, logits_pred, None, None, 0.01, m, n) aux_net.zero_grad() adv_loss = ot_loss adv_loss.backward(retain_graph=True) x_adv = x.data + self.step_size * torch.sign(x.grad.data) x_adv = torch.min(torch.max(x_adv, inputs - self.epsilon), inputs + self.epsilon) x_adv = torch.clamp(x_adv, -1.0, 1.0) x = Variable(x_adv) logits_pred, fea = self.basic_net(x) self.basic_net.zero_grad() y_sm = utils.label_smoothing(y_gt, y_gt.size(1), self.ls_factor) adv_loss = loss_ce(logits_pred, y_sm.detach()) return logits_pred, adv_loss
def train_fun(epoch, net): print('\nEpoch: %d' % epoch) net.train() train_loss = 0 correct = 0 total = 0 # update learning rate if epoch < args.decay_epoch1: lr = args.lr elif epoch < args.decay_epoch2: lr = args.lr * args.decay_rate else: lr = args.lr * args.decay_rate * args.decay_rate for param_group in optimizer.param_groups: param_group['lr'] = lr def get_acc(outputs, targets): _, predicted = outputs.max(1) total = targets.size(0) correct = predicted.eq(targets).sum().item() acc = 1.0 * correct / total return acc iterator = tqdm(trainloader, ncols=0, leave=False) # iterator = trainloader for batch_idx, (inputs, targets) in enumerate(iterator): # for tuples in enumerate(iterator): start_time = time.time() if args.dataset == 'cifar_aug': inputs_aug, targets_aug = next(iter(trainloader_aug)) indices = np.random.permutation(targets_aug.size()[0]) inputs_aug = inputs_aug[indices] inputs_orig, targets_orig = inputs.detach(), targets.detach() inputs[:args.batch_size_train // 5] = inputs_aug[:args.batch_size_train // 5] # targets = np.eye(args.batch_size_train)[targets] targets = one_hot_tensor(targets, 10, device) targets[:args.batch_size_train // 5, :] = 0.1 inputs, targets = inputs.to(device), targets.to(device) adv_acc = 0 optimizer.zero_grad() # forward feature_scatter if (args.adv_mode.lower() == 'feature_scatter' or args.adv_mode.lower() == 'lip_reg' or args.adv_mode.lower() == 'trades'): outputs, loss_fs, flag_out, _, diff_loss = net( inputs.detach(), targets) loss = loss_fs optimizer.zero_grad() elif args.adv_mode.lower() == 'madry': # forward madry outputs, _, _, pert_inputs, pert_i, y_train = net(inputs, targets) loss = soft_xent_loss(outputs, y_train) # loss = soft_xent_loss(outputs * 0.5, y_train) # temperturing #loss = F.cross_entropy(outputs, targets) optimizer.zero_grad() elif args.adv_mode.lower() == 'vertex': # forward vertex outputs, _, _, _, _, y_vertex = net(inputs, targets) # outputs, _, _, _, _, y_vertex = net(inputs, targets, epoch = (epoch+1) / args.max_epoch) loss = soft_xent_loss(outputs, y_vertex) optimizer.zero_grad() elif args.adv_mode.lower() == 'vertex_pgd': # forward vertex outputs, _, _, _, _, y_vertex = net(inputs, targets) loss = soft_xent_loss(outputs, y_vertex) optimizer.zero_grad() elif args.adv_mode.lower() == 'natural': # forward vertex outputs, _, _, _, _ = net(inputs, targets) # loss = F.cross_entropy(basic_net(inputs.detach())[0], targets) loss = F.cross_entropy(outputs, targets) optimizer.zero_grad() elif args.adv_mode.lower() == 'linear': # forward vertex outputs, _, _, x_train, _ = net(inputs, targets) # net(inputs, targets) # outputs = basic_net(inputs.detach())[0] # loss = F.cross_entropy(outputs, targets.detach()) outputs, loss_fs, flag_out, _, diff_loss = net( inputs.detach(), targets) loss = loss_fs optimizer.zero_grad() else: print('no adv_mode') loss = None loss.backward() optimizer.step() train_loss = loss.item() duration = time.time() - start_time if batch_idx % args.log_step == 0: if args.dataset == 'cifar_aug': inputs, targets = inputs_orig.to(device), targets_orig.to( device) if adv_acc == 0: adv_acc = get_acc(outputs, targets) iterator.set_description(str(adv_acc)) nat_outputs, _, _, _, _ = net(inputs, targets, attack=False) nat_acc = get_acc(nat_outputs, targets) print( "epoch %d, step %d, lr %.4f, duration %.2f, training nat acc %.2f, training adv acc %.2f, training adv loss %.4f" % (epoch, batch_idx, lr, duration, 100 * nat_acc, 100 * adv_acc, train_loss)) if epoch % 10 == 0: print('Saving..') f_path = os.path.join(args.model_dir, ('checkpoint-%s' % epoch)) state = { 'net': net.state_dict(), # 'optimizer': optimizer.state_dict() } if not os.path.isdir(args.model_dir): os.mkdir(args.model_dir) torch.save(state, f_path) if epoch >= 0: print('Saving latest @ epoch %s..' % (epoch)) f_path = os.path.join(args.model_dir, 'latest') state = { 'net': net.state_dict(), 'epoch': epoch, 'optimizer': optimizer.state_dict() } if not os.path.isdir(args.model_dir): os.mkdir(args.model_dir) torch.save(state, f_path) '''
def train_one_epoch(epoch, net): print('\n Training for Epoch: %d' % epoch) net.train() # learning rate schedule if epoch < args.decay_epoch1: lr = args.lr elif epoch < args.decay_epoch2: lr = args.lr * args.decay_rate else: lr = args.lr * args.decay_rate * args.decay_rate for param_group in optimizer.param_groups: param_group['lr'] = lr iterator = tqdm(trainloader, ncols=0, leave=False) for batch_idx, (inputs, targets) in enumerate(iterator): start_time = time.time() inputs, targets = inputs.to(device), targets.to(device) targets_onehot = one_hot_tensor(targets, args.num_classes, device) x_tilde, y_tilde = adv_interp(inputs, targets_onehot, net, args.num_classes, config_adv_interp['epsilon'], config_adv_interp['label_adv_delta'], config_adv_interp['v_min'], config_adv_interp['v_max']) outputs = net(x_tilde, mode='logits') loss = soft_xent_loss(outputs, y_tilde) optimizer.zero_grad() loss.backward() optimizer.step() train_loss = loss.detach().item() duration = time.time() - start_time if batch_idx % args.log_step == 0: adv_acc = utils.get_acc(outputs, targets) # natural net_cp = copy.deepcopy(net) nat_outputs = net_cp(inputs, mode='logits') nat_acc = utils.get_acc(nat_outputs, targets) print( "Epoch %d, Step %d, lr %.4f, Duration %.2f, Training nat acc %.2f, Training adv acc %.2f, Training adv loss %.4f" % (epoch, batch_idx, lr, duration, 100 * nat_acc, 100 * adv_acc, train_loss)) if epoch % args.save_epochs == 0 or epoch >= args.max_epoch - 2: print('Saving..') f_path = os.path.join(args.model_dir, ('checkpoint-%s' % epoch)) state = { 'net': net.state_dict(), 'epoch': epoch, #'optimizer': optimizer.state_dict() } if not os.path.isdir(args.model_dir): os.makedirs(args.model_dir) torch.save(state, f_path) if epoch >= 1: print('Saving latest model for epoch %s..' % (epoch)) f_path = os.path.join(args.model_dir, 'latest') state = { 'net': net.state_dict(), 'epoch': epoch, #'optimizer': optimizer.state_dict() } if not os.path.isdir(args.model_dir): os.mkdir(args.model_dir) torch.save(state, f_path)