def train(epoch): if args.local_rank % ngpus_per_node == 0: print('\nEpoch: %d' % epoch) model.train() projector.train() train_sampler.set_epoch(epoch) total_loss = 0 reg_simloss = 0 reg_loss = 0 for batch_idx, (ori, inputs_1, inputs_2, label) in enumerate(trainloader): ori, inputs_1, inputs_2 = ori.cuda(), inputs_1.cuda() ,inputs_2.cuda() if args.attack_to=='original': attack_target = inputs_1 else: attack_target = inputs_2 if 'Rep' in args.advtrain_type : advinputs, adv_loss = Rep.get_loss(original_images=inputs_1, target = attack_target, optimizer=optimizer, weight= args.lamda, random_start=args.random_start) reg_loss += adv_loss.data if not (args.advtrain_type == 'None'): inputs = torch.cat((inputs_1, inputs_2, advinputs)) else: inputs = torch.cat((inputs_1, inputs_2)) outputs = projector(model(inputs)) similarity, gathered_outputs = pairwise_similarity(outputs, temperature=args.temperature, multi_gpu=multi_gpu, adv_type = args.advtrain_type) simloss = NT_xent(similarity, args.advtrain_type) if not (args.advtrain_type=='None'): loss = simloss + adv_loss else: loss = simloss optimizer.zero_grad() loss.backward() total_loss += loss.data reg_simloss += simloss.data optimizer.step() if (args.local_rank % ngpus_per_node == 0): if 'Rep' in args.advtrain_type: progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | SimLoss: %.3f | Adv: %.2f' % (total_loss / (batch_idx + 1), reg_simloss / (batch_idx + 1), reg_loss / (batch_idx + 1))) else: progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Adv: %.3f' % (total_loss/(batch_idx+1), reg_simloss/(batch_idx+1))) scheduler_warmup.step() return (total_loss/batch_idx, reg_simloss/batch_idx)
def linear_train(epoch, model, Linear, projector, loptim, attacker=None): Linear.train() if args.finetune: model.train() if args.ss: projector.train() else: model.eval() total_loss = 0 correct = 0 total = 0 for batch_idx, (ori, inputs, inputs_2, target) in enumerate(trainloader): ori, inputs_1, inputs_2, target = ori.cuda(), inputs.cuda( ), inputs_2.cuda(), target.cuda() input_flag = False if args.trans: inputs = inputs_1 else: inputs = ori if args.adv_img: advinputs = attacker.perturb(original_images=inputs, labels=target, random_start=args.random_start) if args.clean: total_inputs = inputs total_targets = target input_flag = True if args.ss: total_inputs = torch.cat((inputs, inputs_2)) total_targets = torch.cat((target, target)) if args.adv_img: if input_flag: total_inputs = torch.cat((total_inputs, advinputs)) total_targets = torch.cat((total_targets, target)) else: total_inputs = advinputs total_targets = target input_flag = True if not input_flag: assert ('choose the linear evaluation data type (clean, adv_img)') feat = model(total_inputs) if args.ss: output_p = projector(feat) B = ori.size(0) similarity, _ = pairwise_similarity(output_p[:2 * B, :2 * B], temperature=args.temperature, multi_gpu=False, adv_type='None') simloss = NT_xent(similarity, 'None') output = Linear(feat) _, predx = torch.max(output.data, 1) loss = criterion(output, total_targets) if args.ss: loss += simloss correct += predx.eq(total_targets.data).cpu().sum().item() total += total_targets.size(0) acc = 100. * correct / total total_loss += loss.data loptim.zero_grad() loss.backward() loptim.step() progress_bar( batch_idx, len(trainloader), 'Loss: {:.4f} | Acc: {:.2f}'.format(total_loss / (batch_idx + 1), acc)) print("Epoch: {}, train accuracy: {}".format(epoch, acc)) return acc, model, Linear, projector, loptim
def get_loss(self, original_images, target, optimizer, weight, random_start=True): if random_start: rand_perturb = torch.FloatTensor(original_images.shape).uniform_( -self.epsilon, self.epsilon) rand_perturb = rand_perturb.float().cuda() x = original_images.float().clone() + rand_perturb x = torch.clamp(x, self.min_val, self.max_val) else: x = original_images.clone() x.requires_grad = True self.model.eval() self.projector.eval() batch_size = len(x) with torch.enable_grad(): for _iter in range(self.max_iters): self.model.zero_grad() self.projector.zero_grad() if self.loss_type == 'mse': loss = F.mse_loss(self.projector(self.model(x)), self.projector(self.model(target))) elif self.loss_type == 'sim': inputs = torch.cat((x, target)) output = self.projector(self.model(inputs)) similarity, _ = pairwise_similarity(output, temperature=0.5, multi_gpu=False, adv_type='None') loss = NT_xent(similarity, 'None') elif self.loss_type == 'l1': loss = F.l1_loss(self.projector(self.model(x)), self.projector(self.model(target))) elif self.loss_type == 'cos': loss = F.cosine_similarity( self.projector(self.model(x)), self.projector(self.model(target))).mean() grads = torch.autograd.grad(loss, x, grad_outputs=None, only_inputs=True, retain_graph=False)[0] if self._type == 'linf': scaled_g = torch.sign(grads.data) x.data += self.alpha * scaled_g x = torch.clamp(x, self.min_val, self.max_val) x = project(x, original_images, self.epsilon, self._type) self.model.train() self.projector.train() optimizer.zero_grad() if self.loss_type == 'mse': loss = F.mse_loss(self.projector(self.model(x)), self.projector( self.model(target))) * (1.0 / batch_size) elif self.loss_type == 'sim': if self.regularize == 'original': inputs = torch.cat((x, original_images)) else: inputs = torch.cat((x, target)) output = self.projector(self.model(inputs)) similarity, _ = pairwise_similarity(output, temperature=0.5, multi_gpu=False, adv_type='None') loss = (1.0 / weight) * NT_xent(similarity, 'None') elif self.loss_type == 'l1': loss = F.l1_loss(self.projector(self.model(x)), self.projector( self.model(target))) * (1.0 / batch_size) elif self.loss_type == 'cos': loss = F.cosine_similarity( self.projector(self.model(x)), self.projector(self.model(target))).sum() * (1.0 / batch_size) return x.detach(), loss