def forward(self, x, target=None, mixup=False, mixup_hidden=False, args=None, grad=None, noise=None, adv_mask1=0, adv_mask2=0, mp=None): if mixup_hidden: layer_mix = random.randint(0, 2) elif mixup: layer_mix = 0 else: layer_mix = None out = x if target is not None: target_reweighted = to_one_hot(target, self.num_classes) if layer_mix == 0: out, target_reweighted = mixup_process(out, target_reweighted, args=args, grad=grad, noise=noise, adv_mask1=adv_mask1, adv_mask2=adv_mask2, mp=mp) out = self.conv1(out) out = self.layer1(out) if layer_mix == 1: out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True) out = self.layer2(out) if layer_mix == 2: out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True) out = self.layer3(out) if layer_mix == 3: out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True) out = act(self.bn1(out)) out = F.avg_pool2d(out, 8) out = out.reshape(out.size(0), -1) out = self.linear(out) if target is not None: return out, target_reweighted else: return out
def forward(self, x, target= None, mixup=False, mixup_hidden=False, args=None, grad=None, noise=None, adv_mask1=0, adv_mask2=0, profile=None): if mixup_hidden: layer_mix = random.randint(0,2) elif mixup: layer_mix = 0 else: layer_mix = None out = x if target is not None : target_reweighted = to_one_hot(target,self.num_classes) end =time.time() if layer_mix == 0: out, target_reweighted = mixup_process(out, target_reweighted, args=args, grad=grad, noise=noise, adv_mask1=adv_mask1, adv_mask2=adv_mask2) profile['gc'].append(time.time()-end) if len(profile['data'])%10 == 0: print('gc : ', np.mean(profile['gc'])) out = self.conv1(out) out = self.layer1(out) if layer_mix == 1: out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True) out = self.layer2(out) if layer_mix == 2: out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True) out = self.layer3(out) if layer_mix == 3: out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True) out = self.layer4(out) out = F.avg_pool2d(out, 4) out = out.view(out.size(0), -1) out = self.linear(out) if target is not None: return out, target_reweighted else: return out
def train(train_loader, model, optimizer, epoch, args, log, mpp=None): '''train given model and dataloader''' batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() mixing_avg = [] # switch to train mode model.train() end = time.time() for input, target in train_loader: data_time.update(time.time() - end) optimizer.zero_grad() input = input.cuda() target = target.long().cuda() sc = None # train with clean images if not args.comix: target_reweighted = to_one_hot(target, args.num_classes) output = model(input) loss = bce_loss(softmax(output), target_reweighted) # train with Co-Mixup images else: input_var = Variable(input, requires_grad=True) target_var = Variable(target) A_dist = None # Calculate saliency (unary) if args.clean_lam == 0: model.eval() output = model(input_var) loss_batch = criterion_batch(output, target_var) else: model.train() output = model(input_var) loss_batch = 2 * args.clean_lam * criterion_batch( output, target_var) / args.num_classes loss_batch_mean = torch.mean(loss_batch, dim=0) loss_batch_mean.backward(retain_graph=True) sc = torch.sqrt(torch.mean(input_var.grad**2, dim=1)) # Here, we calculate distance between most salient location (Compatibility) # We can try various measurements with torch.no_grad(): z = F.avg_pool2d(sc, kernel_size=8, stride=1) z_reshape = z.reshape(args.batch_size, -1) z_idx_1d = torch.argmax(z_reshape, dim=1) z_idx_2d = torch.zeros((args.batch_size, 2), device=z.device) z_idx_2d[:, 0] = z_idx_1d // z.shape[-1] z_idx_2d[:, 1] = z_idx_1d % z.shape[-1] A_dist = distance(z_idx_2d, dist_type='l1') if args.clean_lam == 0: model.train() optimizer.zero_grad() # Perform mixup and calculate loss target_reweighted = to_one_hot(target, args.num_classes) if args.parallel: device = input.device out, target_reweighted = mpp(input.cpu(), target_reweighted.cpu(), args=args, sc=sc.cpu(), A_dist=A_dist.cpu()) out = out.to(device) target_reweighted = target_reweighted.to(device) else: out, target_reweighted = mixup_process(input, target_reweighted, args=args, sc=sc, A_dist=A_dist) out = model(out) loss = bce_loss(softmax(out), target_reweighted) # measure accuracy and record loss prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() print_log( '**Train** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f}' .format(top1=top1, top5=top5, error1=100 - top1.avg), log) return top1.avg, top5.avg, losses.avg
def forward(self, x, target=None, mixup=False, mixup_hidden=False, args=None, grad=None, noise=None, adv_mask1=0, adv_mask2=0): if mixup_hidden: layer_mix = random.randint(0, 2) elif mixup: layer_mix = 0 else: layer_mix = None out = x if target is not None: target_reweighted = to_one_hot(target, self.num_classes) if layer_mix == 0: out, target_reweighted = mixup_process(out, target_reweighted, args=args, grad=grad, noise=noise, adv_mask1=adv_mask1, adv_mask2=adv_mask2) out = self.conv1(out) out = self.layer1(out) if layer_mix == 1: out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True) out = self.layer2(out) if layer_mix == 2: out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True) out = self.layer3(out) if layer_mix == 3: out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True) out = self.layer4(out) out = F.avg_pool2d(out, out.shape[-1]) out = out.view(out.size(0), -1) if self.weight_type is not None: out = F.normalize(out, p=2, dim=1) out = self.linear(out) if target is not None: return out, target_reweighted else: return out
for iter in tqdm(range(100), desc="parallel"): out, target_reweighted = mpp(out0, target_reweighted0, args=args, sc=sc, A_dist=A_dist, debug=True) print((d["out"].cpu() == out.cpu()).float().mean()) print((d["target_reweighted"].cpu() == target_reweighted.cpu() ).float().mean()) # Original run out0cuda = out0.cuda() target_reweighted0cuda = target_reweighted0.cuda() sccuda = sc.cuda() A_distcuda = A_dist.cuda() for iter in tqdm(range(100), desc="original"): out, target_reweighted = mixup_process(out0cuda, target_reweighted0cuda, args=args, sc=sccuda, A_dist=A_distcuda, debug=True) print((d["out"].cpu() == out.cpu()).float().mean()) print((d["target_reweighted"].cpu() == target_reweighted.cpu() ).float().mean()) print("end")