def main():
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # args.outdir = os.path.join(os.getenv('PT_DATA_DIR', './'), args.outdir)
    
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # Copy code to output directory
    copy_code(args.outdir)
    
    train_dataset = get_dataset(args.dataset, 'train')
    test_dataset = get_dataset(args.dataset, 'test')
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch,
                              num_workers=args.workers, pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch,
                             num_workers=args.workers, pin_memory=pin_memory)
    ## This is used to test the performance of the denoiser attached to a cifar10 classifier
    cifar10_test_loader = DataLoader(get_dataset('cifar10', 'test'), shuffle=False, batch_size=args.batch,
                             num_workers=args.workers, pin_memory=pin_memory)

    if args.pretrained_denoiser:
        checkpoint = torch.load(args.pretrained_denoiser)
        assert checkpoint['arch'] == args.arch
        denoiser = get_architecture(checkpoint['arch'], args.dataset)
        denoiser.load_state_dict(checkpoint['state_dict'])
    else:
        denoiser = get_architecture(args.arch, args.dataset)

    if args.optimizer == 'Adam':
        optimizer = Adam(denoiser.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'SGD':
        optimizer = SGD(denoiser.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.optimizer == 'AdamThenSGD':
        optimizer = Adam(denoiser.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma)

    starting_epoch = 0
    logfilename = os.path.join(args.outdir, 'log.txt')

    ## Resume from checkpoint if exists and if resume flag is True
    denoiser_path = os.path.join(args.outdir, 'checkpoint.pth.tar')
    if args.resume and os.path.isfile(denoiser_path):
        print("=> loading checkpoint '{}'".format(denoiser_path))
        checkpoint = torch.load(denoiser_path,
                                map_location=lambda storage, loc: storage)
        assert checkpoint['arch'] == args.arch
        starting_epoch = checkpoint['epoch']
        denoiser.load_state_dict(checkpoint['state_dict'])
        if starting_epoch >= args.start_sgd_epoch and args.optimizer == 'AdamThenSGD ': # Do adam for few steps thaen continue SGD
            print("-->[Switching from Adam to SGD.]")
            args.lr = args.start_sgd_lr
            optimizer = SGD(denoiser.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
            scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma)
        
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
                        .format(denoiser_path, checkpoint['epoch']))
    else:
        if args.resume: print("=> no checkpoint found at '{}'".format(args.outdir))
        init_logfile(logfilename, "epoch\ttime\tlr\ttrainloss\ttestloss\ttestAcc")

    # the first model is the model we test on, the other 14 models are surrogate models that we use 
    # (see the paper for more detail)
    models = [
    "ResNet110", # fixed model
    #surrogate models
    "WRN", "WRN40", "VGG16", "VGG19", "ResNet18","PreActResNet18","ResNeXt29_2x64d",
    "MobileNet","MobileNetV2","SENet18","ShuffleNetV2","EfficientNetB0","GoogLeNet","DenseNet121"
    ]

    path = os.path.join(args.classifiers_path, '{}/noise_0.00/checkpoint.pth.tar')

    base_classifiers_paths = [
        path.format(model) for model in models
    ]
    base_classifiers = []
    if args.classifier_idx == -1:
        for base_classifier in base_classifiers_paths:
            # load the base classifier
            checkpoint = torch.load(base_classifier)
            base_classifier = get_architecture(checkpoint["arch"], args.dataset)
            base_classifier.load_state_dict(checkpoint['state_dict'])

            requires_grad_(base_classifier, False)
            base_classifiers.append(base_classifier.eval().cuda())
    else:
        if args.classifier_idx not in range(len(models)): raise Exception("Unknown model")
        print("Model namse: {}".format(base_classifiers_paths[args.classifier_idx]))
        checkpoint = torch.load(base_classifiers_paths[args.classifier_idx])
        base_classifier = get_architecture(checkpoint["arch"], args.dataset)
        base_classifier.load_state_dict(checkpoint['state_dict'])

        requires_grad_(base_classifier, False)
        base_classifiers.append(base_classifier.eval().cuda())

    criterion = CrossEntropyLoss(size_average=None, reduce=None, reduction = 'mean').cuda()
    best_acc = 0

    for epoch in range(starting_epoch, args.epochs):
        before = time.time()
        train_loss = train(train_loader, denoiser, criterion, optimizer, epoch, args.noise_sd, base_classifiers[1:])
        test_loss, test_acc = test_with_classifier(cifar10_test_loader, denoiser, criterion, args.noise_sd, args.print_freq, base_classifiers[0])

        after = time.time()

        log(logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
            epoch, after - before,
            args.lr, train_loss, test_loss, test_acc))

        scheduler.step(epoch)
        args.lr = scheduler.get_lr()[0]

        # Switch from Adam to SGD
        if epoch == args.start_sgd_epoch and args.optimizer == 'AdamThenSGD ': # Do adam for few steps thaen continue SGD
            print("-->[Switching from Adam to SGD.]")
            args.lr = args.start_sgd_lr
            optimizer = SGD(denoiser.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
            scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma)

        torch.save({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': denoiser.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, os.path.join(args.outdir, 'checkpoint.pth.tar'))

        if args.objective in ['classification', 'stability'] and test_acc > best_acc:
            best_acc = test_acc
        else:
            continue

        torch.save({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': denoiser.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, os.path.join(args.outdir, 'best.pth.tar'))
def main():
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # Copy code to output directory
    copy_code(args.outdir)

    train_dataset = get_dataset(args.dataset, 'train')
    test_dataset = get_dataset(args.dataset, 'test')
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              batch_size=args.batch,
                              num_workers=args.workers,
                              pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset,
                             shuffle=False,
                             batch_size=args.batch,
                             num_workers=args.workers,
                             pin_memory=pin_memory)
    ## This is used to test the performance of the denoiser attached to a cifar10 classifier
    cifar10_test_loader = DataLoader(get_dataset('cifar10', 'test'),
                                     shuffle=False,
                                     batch_size=args.batch,
                                     num_workers=args.workers,
                                     pin_memory=pin_memory)

    if args.pretrained_denoiser:
        checkpoint = torch.load(args.pretrained_denoiser)
        assert checkpoint['arch'] == args.arch
        denoiser = get_architecture(checkpoint['arch'], args.dataset)
        denoiser.load_state_dict(checkpoint['state_dict'])
    else:
        denoiser = get_architecture(args.arch, args.dataset)

    if args.optimizer == 'Adam':
        optimizer = Adam(denoiser.parameters(),
                         lr=args.lr,
                         weight_decay=args.weight_decay)
    elif args.optimizer == 'SGD':
        optimizer = SGD(denoiser.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)
    elif args.optimizer == 'AdamThenSGD':
        optimizer = Adam(denoiser.parameters(),
                         lr=args.lr,
                         weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.lr_step_size,
                       gamma=args.gamma)

    starting_epoch = 0
    logfilename = os.path.join(args.outdir, 'log.txt')

    ## Resume from checkpoint if exists and if resume flag is True
    denoiser_path = os.path.join(args.outdir, 'checkpoint.pth.tar')
    if args.resume and os.path.isfile(denoiser_path):
        print("=> loading checkpoint '{}'".format(denoiser_path))
        checkpoint = torch.load(denoiser_path,
                                map_location=lambda storage, loc: storage)
        assert checkpoint['arch'] == args.arch
        starting_epoch = checkpoint['epoch']
        denoiser.load_state_dict(checkpoint['state_dict'])
        if starting_epoch >= args.start_sgd_epoch and args.optimizer == 'AdamThenSGD ':  # Do adam for few steps thaen continue SGD
            print("-->[Switching from Adam to SGD.]")
            args.lr = args.start_sgd_lr
            optimizer = SGD(denoiser.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
            scheduler = StepLR(optimizer,
                               step_size=args.lr_step_size,
                               gamma=args.gamma)

        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            denoiser_path, checkpoint['epoch']))
    else:
        if args.resume:
            print("=> no checkpoint found at '{}'".format(args.outdir))
        init_logfile(logfilename,
                     "epoch\ttime\tlr\ttrainloss\ttestloss\ttestAcc")

    if args.objective == 'denoising':
        criterion = MSELoss(size_average=None, reduce=None,
                            reduction='mean').cuda()
        best_loss = 1e6

    elif args.objective in ['classification', 'stability']:
        assert args.classifier != '', "Please specify a path to the classifier you want to attach the denoiser to."

        if args.classifier in IMAGENET_CLASSIFIERS:
            assert args.dataset == 'imagenet'
            # loading pretrained imagenet architectures
            clf = get_architecture(args.classifier,
                                   args.dataset,
                                   pytorch_pretrained=True)
        else:
            checkpoint = torch.load(args.classifier)
            clf = get_architecture(checkpoint['arch'], 'cifar10')
            clf.load_state_dict(checkpoint['state_dict'])
        clf.cuda().eval()
        requires_grad_(clf, False)
        criterion = CrossEntropyLoss(size_average=None,
                                     reduce=None,
                                     reduction='mean').cuda()
        best_acc = 0

    for epoch in range(starting_epoch, args.epochs):
        before = time.time()
        if args.objective == 'denoising':
            train_loss = train(train_loader, denoiser, criterion, optimizer,
                               epoch, args.noise_sd)
            test_loss = test(test_loader, denoiser, criterion, args.noise_sd,
                             args.print_freq, args.outdir)
            test_acc = 'NA'
        elif args.objective in ['classification', 'stability']:
            train_loss = train(train_loader, denoiser, criterion, optimizer,
                               epoch, args.noise_sd, clf)
            if args.dataset == 'imagenet':
                test_loss, test_acc = test_with_classifier(
                    test_loader, denoiser, criterion, args.noise_sd,
                    args.print_freq, clf)
            else:
                # This is needed so that cifar10 denoisers trained using imagenet32 are still evaluated on the cifar10 testset
                test_loss, test_acc = test_with_classifier(
                    cifar10_test_loader, denoiser, criterion, args.noise_sd,
                    args.print_freq, clf)

        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, after - before, args.lr, train_loss, test_loss,
                test_acc))

        scheduler.step(epoch)
        args.lr = scheduler.get_lr()[0]

        # Switch from Adam to SGD
        if epoch == args.start_sgd_epoch and args.optimizer == 'AdamThenSGD ':  # Do adam for few steps thaen continue SGD
            print("-->[Switching from Adam to SGD.]")
            args.lr = args.start_sgd_lr
            optimizer = SGD(denoiser.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
            scheduler = StepLR(optimizer,
                               step_size=args.lr_step_size,
                               gamma=args.gamma)

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': denoiser.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.outdir, 'checkpoint.pth.tar'))

        if args.objective == 'denoising' and test_loss < best_loss:
            best_loss = test_loss
        elif args.objective in ['classification', 'stability'
                                ] and test_acc > best_acc:
            best_acc = test_acc
        else:
            continue

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': denoiser.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.outdir, 'best.pth.tar'))
Esempio n. 3
0
    args.epsilon /= 256.0
    args.init_norm_DDN /= 256.0
    if args.epsilon > 0:
        args.gamma_DDN = 1 - (3 / 510 / args.epsilon)**(1 / args.num_steps)

    # load the base classifier
    # checkpoint = torch.load(args.base_classifier)
    # base_classifier = get_architecture(checkpoint["arch"], args.dataset)
    # base_classifier.load_state_dict(checkpoint['state_dict'])
    base_classifier = torch.hub.load('pytorch/vision',
                                     'resnet50',
                                     pretrained=True)
    base_classifier.eval()

    requires_grad_(base_classifier, False)

    # create the smoothed classifier g
    smoothed_classifier = Smooth(base_classifier,
                                 get_num_classes(args.dataset), args.sigma)

    # prepare output file
    f = open(outfile, 'w')
    print("idx\tlabel\tpredict\tbasePredict\tcorrect\ttime",
          file=f,
          flush=True)

    if args.attack == 'PGD':
        print('Attacker is PGD')
        attacker = PGD_Linf(steps=args.num_steps,
                            device='cuda',
Esempio n. 4
0
def train(loader: DataLoader,
          model: torch.nn.Module,
          criterion,
          optimizer: Optimizer,
          epoch: int,
          noise_sd: float,
          attacker: Attacker,
          device: torch.device,
          writer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()
    requires_grad_(model, True)

    for i, batch in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        mini_batches = _chunk_minibatch(batch, args.num_noise_vec)
        for inputs, targets in mini_batches:
            inputs, targets = inputs.to(device), targets.to(device)
            inputs = inputs.repeat(
                (1, args.num_noise_vec, 1, 1)).reshape(-1, *batch[0].shape[1:])
            batch_size = inputs.size(0)

            # augment inputs with noise
            noise = torch.randn_like(inputs, device=device) * noise_sd

            requires_grad_(model, False)
            model.eval()
            inputs = attacker.attack(model,
                                     inputs,
                                     targets,
                                     noise=noise,
                                     num_noise_vectors=args.num_noise_vec,
                                     no_grad=args.no_grad_attack)
            model.train()
            requires_grad_(model, True)

            noisy_inputs = inputs + noise

            targets = targets.unsqueeze(1).repeat(1,
                                                  args.num_noise_vec).reshape(
                                                      -1, 1).squeeze()
            outputs = model(noisy_inputs)
            loss = criterion(outputs, targets)

            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), batch_size)
            top1.update(acc1.item(), batch_size)
            top5.update(acc5.item(), batch_size)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Data {data_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f}\t'
                  'Acc@1 {top1.avg:.3f}\t'
                  'Acc@5 {top5.avg:.3f}'.format(epoch,
                                                i,
                                                len(loader),
                                                batch_time=batch_time,
                                                data_time=data_time,
                                                loss=losses,
                                                top1=top1,
                                                top5=top5))

    if writer:
        writer.add_scalar('loss/train', losses.avg, epoch)
        writer.add_scalar('batch_time', batch_time.avg, epoch)
        writer.add_scalar('accuracy/train@1', top1.avg, epoch)
        writer.add_scalar('accuracy/train@5', top5.avg, epoch)

    return (losses.avg, top1.avg)
def train(loader: DataLoader,
          model: torch.nn.Module,
          criterion,
          optimizer: Optimizer,
          epoch: int,
          noise_sd: float,
          attacker: Attacker,
          device: torch.device,
          writer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_reg = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()
    requires_grad_(model, True)

    for i, batch in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        mini_batches = _chunk_minibatch(batch, args.num_noise_vec)
        for inputs, targets in mini_batches:
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size = inputs.size(0)

            noises = [
                torch.randn_like(inputs, device=device) * noise_sd
                for _ in range(args.num_noise_vec)
            ]

            if args.adv_training:
                requires_grad_(model, False)
                model.eval()
                inputs = attacker.attack(model, inputs, targets, noises=noises)
                model.train()
                requires_grad_(model, True)

            # augment inputs with noise
            inputs_c = torch.cat([inputs + noise for noise in noises], dim=0)
            targets_c = targets.repeat(args.num_noise_vec)

            logits = model(inputs_c)
            loss_xent = criterion(logits, targets_c)

            logits_chunk = torch.chunk(logits, args.num_noise_vec, dim=0)
            loss_con = consistency_loss(logits_chunk, args.lbd, args.eta)

            loss = loss_xent + loss_con

            acc1, acc5 = accuracy(logits, targets_c, topk=(1, 5))
            losses.update(loss_xent.item(), batch_size)
            losses_reg.update(loss_con.item(), batch_size)
            top1.update(acc1.item(), batch_size)
            top5.update(acc5.item(), batch_size)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Data {data_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f}\t'
                  'Acc@1 {top1.avg:.3f}\t'
                  'Acc@5 {top5.avg:.3f}'.format(epoch,
                                                i,
                                                len(loader),
                                                batch_time=batch_time,
                                                data_time=data_time,
                                                loss=losses,
                                                top1=top1,
                                                top5=top5))

    writer.add_scalar('loss/train', losses.avg, epoch)
    writer.add_scalar('loss/consistency', losses_reg.avg, epoch)
    writer.add_scalar('batch_time', batch_time.avg, epoch)
    writer.add_scalar('accuracy/train@1', top1.avg, epoch)
    writer.add_scalar('accuracy/train@5', top5.avg, epoch)

    return (losses.avg, top1.avg)
Esempio n. 6
0
def test(loader: DataLoader,
         model: torch.nn.Module,
         criterion,
         noise_sd: float,
         attacker: Attacker = None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    top1_normal = AverageMeter()
    end = time.time()

    # switch to eval mode
    model.eval()
    requires_grad_(model, False)

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(loader):
            # measure data loading time
            data_time.update(time.time() - end)

            inputs = inputs.cuda()
            targets = targets.cuda()

            # augment inputs with noise
            noise = torch.randn_like(inputs, device='cuda') * noise_sd
            noisy_inputs = inputs + noise

            # compute output
            if args.adv_training:
                normal_outputs = model(noisy_inputs)
                acc1_normal, _ = accuracy(normal_outputs, targets, topk=(1, 5))
                top1_normal.update(acc1_normal.item(), inputs.size(0))

                with torch.enable_grad():
                    inputs = attacker.attack(model,
                                             inputs,
                                             targets,
                                             noise=noise)
                # noise = torch.randn_like(inputs, device='cuda') * noise_sd
                noisy_inputs = inputs + noise

            outputs = model(noisy_inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(acc1.item(), inputs.size(0))
            top5.update(acc5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                          i,
                          len(loader),
                          batch_time=batch_time,
                          data_time=data_time,
                          loss=losses,
                          top1=top1,
                          top5=top5))

        if args.adv_training:
            return (losses.avg, top1.avg, top1_normal.avg)
        else:
            return (losses.avg, top1.avg, None)
Esempio n. 7
0
def train(loader: DataLoader,
          model: torch.nn.Module,
          criterion,
          optimizer: Optimizer,
          epoch: int,
          noise_sd: float,
          attacker: Attacker = None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()
    requires_grad_(model, True)

    for i, batch in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        mini_batches = get_minibatches(batch, args.num_noise_vec)
        noisy_inputs_list = []
        for inputs, targets in mini_batches:
            inputs = inputs.cuda()
            targets = targets.cuda()

            inputs = inputs.repeat(
                (1, args.num_noise_vec, 1, 1)).view(batch[0].shape)

            # augment inputs with noise
            noise = torch.randn_like(inputs, device='cuda') * noise_sd

            if args.adv_training:
                requires_grad_(model, False)
                model.eval()
                inputs = attacker.attack(model,
                                         inputs,
                                         targets,
                                         noise=noise,
                                         num_noise_vectors=args.num_noise_vec,
                                         no_grad=args.no_grad_attack)
                model.train()
                requires_grad_(model, True)

            if args.train_multi_noise:
                noisy_inputs = inputs + noise
                targets = targets.unsqueeze(1).repeat(
                    1, args.num_noise_vec).reshape(-1, 1).squeeze()
                outputs = model(noisy_inputs)
                loss = criterion(outputs, targets)

                acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
                losses.update(loss.item(), noisy_inputs.size(0))
                top1.update(acc1.item(), noisy_inputs.size(0))
                top5.update(acc5.item(), noisy_inputs.size(0))

                # compute gradient and do SGD step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            else:
                inputs = inputs[::args.num_noise_vec]  # subsample the samples
                noise = noise[::args.num_noise_vec]
                # noise = torch.randn_like(inputs, device='cuda') * noise_sd
                noisy_inputs_list.append(inputs + noise)

        if not args.train_multi_noise:
            noisy_inputs = torch.cat(noisy_inputs_list)
            targets = batch[1].cuda()
            assert len(targets) == len(noisy_inputs)

            outputs = model(noisy_inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), noisy_inputs.size(0))
            top1.update(acc1.item(), noisy_inputs.size(0))
            top5.update(acc5.item(), noisy_inputs.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      len(loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    return (losses.avg, top1.avg)