Example #1
0
    def forward(self,
                x,
                target=None,
                mixup=False,
                mixup_hidden=False,
                args=None,
                grad=None,
                noise=None,
                adv_mask1=0,
                adv_mask2=0,
                mp=None):
        if mixup_hidden:
            layer_mix = random.randint(0, 2)
        elif mixup:
            layer_mix = 0
        else:
            layer_mix = None

        out = x

        if target is not None:
            target_reweighted = to_one_hot(target, self.num_classes)

        if layer_mix == 0:
            out, target_reweighted = mixup_process(out,
                                                   target_reweighted,
                                                   args=args,
                                                   grad=grad,
                                                   noise=noise,
                                                   adv_mask1=adv_mask1,
                                                   adv_mask2=adv_mask2,
                                                   mp=mp)

        out = self.conv1(out)
        out = self.layer1(out)

        if layer_mix == 1:
            out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True)

        out = self.layer2(out)

        if layer_mix == 2:
            out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True)
        out = self.layer3(out)

        if layer_mix == 3:
            out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True)

        out = act(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.reshape(out.size(0), -1)
        out = self.linear(out)

        if target is not None:
            return out, target_reweighted
        else:
            return out
Example #2
0
    def forward(self, x, target= None, mixup=False, mixup_hidden=False, args=None, grad=None, 
                noise=None, adv_mask1=0, adv_mask2=0, profile=None):
            
        if mixup_hidden:
            layer_mix = random.randint(0,2)
        elif mixup:
            layer_mix = 0
        else:
            layer_mix = None   
        
        out = x
        
        if target is not None :
            target_reweighted = to_one_hot(target,self.num_classes)
        
        end =time.time()
        if layer_mix == 0: 
            out, target_reweighted = mixup_process(out, target_reweighted, args=args, grad=grad, noise=noise, adv_mask1=adv_mask1, adv_mask2=adv_mask2)
            profile['gc'].append(time.time()-end)
            if len(profile['data'])%10 == 0: print('gc  : ', np.mean(profile['gc']))
                    
        out = self.conv1(out)
        out = self.layer1(out)

        if layer_mix == 1:
            out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True)

        out = self.layer2(out)
        if layer_mix == 2:
            out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True)

        out = self.layer3(out)
        if  layer_mix == 3:
            out, target_reweighted = mixup_process(out, target_reweighted, args=args, hidden=True)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        
        if target is not None:
            return out, target_reweighted
        else: 
            return out
Example #3
0
def train(train_loader, model, optimizer, epoch, args, log, mpp=None):
    '''train given model and dataloader'''
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    mixing_avg = []

    # switch to train mode
    model.train()

    end = time.time()
    for input, target in train_loader:
        data_time.update(time.time() - end)
        optimizer.zero_grad()

        input = input.cuda()
        target = target.long().cuda()
        sc = None

        # train with clean images
        if not args.comix:
            target_reweighted = to_one_hot(target, args.num_classes)
            output = model(input)
            loss = bce_loss(softmax(output), target_reweighted)

        # train with Co-Mixup images
        else:
            input_var = Variable(input, requires_grad=True)
            target_var = Variable(target)
            A_dist = None

            # Calculate saliency (unary)
            if args.clean_lam == 0:
                model.eval()
                output = model(input_var)
                loss_batch = criterion_batch(output, target_var)
            else:
                model.train()
                output = model(input_var)
                loss_batch = 2 * args.clean_lam * criterion_batch(
                    output, target_var) / args.num_classes
            loss_batch_mean = torch.mean(loss_batch, dim=0)
            loss_batch_mean.backward(retain_graph=True)
            sc = torch.sqrt(torch.mean(input_var.grad**2, dim=1))

            # Here, we calculate distance between most salient location (Compatibility)
            # We can try various measurements
            with torch.no_grad():
                z = F.avg_pool2d(sc, kernel_size=8, stride=1)
                z_reshape = z.reshape(args.batch_size, -1)
                z_idx_1d = torch.argmax(z_reshape, dim=1)
                z_idx_2d = torch.zeros((args.batch_size, 2), device=z.device)
                z_idx_2d[:, 0] = z_idx_1d // z.shape[-1]
                z_idx_2d[:, 1] = z_idx_1d % z.shape[-1]
                A_dist = distance(z_idx_2d, dist_type='l1')

            if args.clean_lam == 0:
                model.train()
                optimizer.zero_grad()

            # Perform mixup and calculate loss
            target_reweighted = to_one_hot(target, args.num_classes)
            if args.parallel:
                device = input.device
                out, target_reweighted = mpp(input.cpu(),
                                             target_reweighted.cpu(),
                                             args=args,
                                             sc=sc.cpu(),
                                             A_dist=A_dist.cpu())
                out = out.to(device)
                target_reweighted = target_reweighted.to(device)

            else:
                out, target_reweighted = mixup_process(input,
                                                       target_reweighted,
                                                       args=args,
                                                       sc=sc,
                                                       A_dist=A_dist)

            out = model(out)
            loss = bce_loss(softmax(out), target_reweighted)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    print_log(
        '**Train** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f}'
        .format(top1=top1, top5=top5, error1=100 - top1.avg), log)
    return top1.avg, top5.avg, losses.avg
Example #4
0
    def forward(self,
                x,
                target=None,
                mixup=False,
                mixup_hidden=False,
                args=None,
                grad=None,
                noise=None,
                adv_mask1=0,
                adv_mask2=0):

        if mixup_hidden:
            layer_mix = random.randint(0, 2)
        elif mixup:
            layer_mix = 0
        else:
            layer_mix = None

        out = x

        if target is not None:
            target_reweighted = to_one_hot(target, self.num_classes)

        if layer_mix == 0:
            out, target_reweighted = mixup_process(out,
                                                   target_reweighted,
                                                   args=args,
                                                   grad=grad,
                                                   noise=noise,
                                                   adv_mask1=adv_mask1,
                                                   adv_mask2=adv_mask2)

        out = self.conv1(out)
        out = self.layer1(out)

        if layer_mix == 1:
            out, target_reweighted = mixup_process(out,
                                                   target_reweighted,
                                                   args=args,
                                                   hidden=True)

        out = self.layer2(out)
        if layer_mix == 2:
            out, target_reweighted = mixup_process(out,
                                                   target_reweighted,
                                                   args=args,
                                                   hidden=True)

        out = self.layer3(out)
        if layer_mix == 3:
            out, target_reweighted = mixup_process(out,
                                                   target_reweighted,
                                                   args=args,
                                                   hidden=True)
        out = self.layer4(out)
        out = F.avg_pool2d(out, out.shape[-1])
        out = out.view(out.size(0), -1)
        if self.weight_type is not None:
            out = F.normalize(out, p=2, dim=1)
        out = self.linear(out)

        if target is not None:
            return out, target_reweighted
        else:
            return out
Example #5
0
    for iter in tqdm(range(100), desc="parallel"):
        out, target_reweighted = mpp(out0,
                                     target_reweighted0,
                                     args=args,
                                     sc=sc,
                                     A_dist=A_dist,
                                     debug=True)

    print((d["out"].cpu() == out.cpu()).float().mean())
    print((d["target_reweighted"].cpu() == target_reweighted.cpu()
           ).float().mean())

    # Original run
    out0cuda = out0.cuda()
    target_reweighted0cuda = target_reweighted0.cuda()
    sccuda = sc.cuda()
    A_distcuda = A_dist.cuda()
    for iter in tqdm(range(100), desc="original"):
        out, target_reweighted = mixup_process(out0cuda,
                                               target_reweighted0cuda,
                                               args=args,
                                               sc=sccuda,
                                               A_dist=A_distcuda,
                                               debug=True)

    print((d["out"].cpu() == out.cpu()).float().mean())
    print((d["target_reweighted"].cpu() == target_reweighted.cpu()
           ).float().mean())

    print("end")