def main():
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

        # os.mkdir(args.outdir)

    train_dataset = get_dataset(args.dataset, 'train')
    test_dataset = get_dataset(args.dataset, 'test')
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              batch_size=args.batch,
                              num_workers=args.workers,
                              pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset,
                             shuffle=False,
                             batch_size=args.batch,
                             num_workers=args.workers,
                             pin_memory=pin_memory)

    model = get_architecture(args.arch, args.dataset)

    logfilename = os.path.join(args.outdir, 'log.txt')
    init_logfile(logfilename,
                 "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")

    criterion = CrossEntropyLoss().cuda()
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.lr_step_size,
                       gamma=args.gamma)

    for epoch in range(args.epochs):
        scheduler.step(epoch)
        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, args.noise_sd)
        test_loss, test_acc = test(test_loader, model, criterion,
                                   args.noise_sd)
        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, str(datetime.timedelta(seconds=(after - before))),
                scheduler.get_lr()[0], train_loss, train_acc, test_loss,
                test_acc))

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.outdir, 'checkpoint.pth.tar'))
Ejemplo n.º 2
0
        def log_epoch(engine):
            scheduler.step()
            evaluator.run(test_loader)

            metrics = evaluator.state.metrics
            avg_accuracy = metrics['accuracy']
            avg_nll = metrics['nll']

            pbar.log_message(
                "Validation Results - Epoch: {}  Avg accuracy: {:.5f} Avg loss: {:.3f}"
                .format(engine.state.epoch, avg_accuracy, avg_nll))

            log(
                logfilename,
                "Validation  - Epoch: {}  Avg accuracy: {:.5f} Avg loss: {:.3f}"
                .format(engine.state.epoch, avg_accuracy, avg_nll))

            test_acc.append(avg_accuracy)

            if avg_accuracy >= max(test_acc):
                print("Saving the model at Epoch {:}".format(
                    engine.state.epoch))
                torch.save(
                    {
                        'arch': args.arch,
                        'state_dict': base_classifier.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, os.path.join(args.outdir, 'checkpoint.pth.tar'))

            if engine.state.epoch == args.epochs:
                test_acc_list.append(max(test_acc))
                log(logfilename,
                    "Finetuned Test Accuracy: {:.5f}".format(max(test_acc)))
                print("Finetuned Test Accuracy: ", max(test_acc))
Ejemplo n.º 3
0
def main():
    train_loader, test_loader, criterion, model, optimizer, scheduler, \
    starting_epoch, logfilename, model_path, device, writer = prologue(args)

    for epoch in range(starting_epoch, args.epochs):
        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, args.noise_sd, device,
                                      writer)
        test_loss, test_acc = test(test_loader, model, criterion, epoch,
                                   args.noise_sd, device, writer,
                                   args.print_freq)
        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, after - before,
                scheduler.get_lr()[0], train_loss, train_acc, test_loss,
                test_acc))

        # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`.
        # See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
        scheduler.step(epoch)

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, model_path)
Ejemplo n.º 4
0
def main():
    train_loader, test_loader, criterion, model, optimizer, scheduler, \
    starting_epoch, logfilename, model_path, device, writer = prologue(args)

    if args.attack == 'PGD':
        print('Attacker is PGD')
        attacker = PGD_L2(steps=args.num_steps,
                          device=device,
                          max_norm=args.epsilon)
    elif args.attack == 'DDN':
        print('Attacker is DDN')
        attacker = DDN(steps=args.num_steps,
                       device=device,
                       max_norm=args.epsilon,
                       init_norm=args.init_norm_DDN,
                       gamma=args.gamma_DDN)
    else:
        raise Exception('Unknown attack')

    for epoch in range(starting_epoch, args.epochs):
        attacker.max_norm = np.min(
            [args.epsilon, (epoch + 1) * args.epsilon / args.warmup])
        attacker.init_norm = np.min(
            [args.epsilon, (epoch + 1) * args.epsilon / args.warmup])

        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, args.noise_sd,
                                      attacker, device, writer)
        test_loss, test_acc = test(test_loader, model, criterion, epoch,
                                   args.noise_sd, device, writer,
                                   args.print_freq)
        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, after - before,
                scheduler.get_lr()[0], train_loss, train_acc, test_loss,
                test_acc))

        # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`.
        # See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
        scheduler.step(epoch)

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, model_path)
Ejemplo n.º 5
0
def run(args, model, train_loader, test_loader):
    if not os.path.exists(args.outdir):
        os.mkdir(args.outdir)

    logfilename = os.path.join(args.outdir, 'log.txt')
    init_logfile(logfilename,
                 "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")

    criterion = CrossEntropyLoss().cuda()
    # criterion = BCELoss().cuda()
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.lr_step_size,
                       gamma=args.gamma)
    # optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.99, 0.999), weight_decay=0.2)

    best_acc = 0.0
    # best_loss = sys.maxsize
    for epoch in range(args.epochs):
        scheduler.step(epoch)
        before = time.time()
        train_loss, train_acc = train(args, train_loader, model, criterion,
                                      optimizer, epoch, args.noise_sd)
        test_loss, test_acc = test(args, test_loader, model, criterion,
                                   args.noise_sd)
        after = time.time()

        # log(logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
        #     epoch, str(datetime.timedelta(seconds=(after - before))),
        #     scheduler.get_lr()[0], train_loss, train_acc, test_loss, test_acc))

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, str(datetime.timedelta(seconds=(after - before))), 0.0,
                train_loss, train_acc, test_loss, test_acc))

        if test_acc >= best_acc:
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
            }, os.path.join(args.outdir, 'best.pth.tar'))
            best_acc = test_acc

        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
        }, os.path.join(args.outdir, 'checkpoint.pth.tar'))
Ejemplo n.º 6
0
def main():

    #     writer = SummaryWriter()
    test_acc_list = []
    logfilename = os.path.join('.', 'log3.txt')
    init_logfile(logfilename,
                 "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")

    #     net = ResNet(BasicBlock, [3, 3, 3]).to(device)

    net = VGG_SNIP('D').to(device)
    #     criterion = nn.CrossEntropyLoss().to(device)
    criterion = nn.NLLLoss().to(device)
    optimizer = SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
    scheduler = MultiStepLR(optimizer, milestones=[80, 120], last_epoch=-1)
    train_loader, test_loader = get_cifar10_dataloaders(128, 128)

    keep_masks = SNIP(net, 0.05, train_loader, device)  # TODO: shuffle?
    apply_prune_mask(net, keep_masks)

    for epoch in range(160):
        before = time.time()
        train_loss, train_acc = train(train_loader,
                                      net,
                                      criterion,
                                      optimizer,
                                      epoch,
                                      device,
                                      100,
                                      display=True)
        test_loss, test_acc = test(test_loader,
                                   net,
                                   criterion,
                                   device,
                                   100,
                                   display=True)

        scheduler.step(epoch)
        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, str(datetime.timedelta(seconds=(after - before))),
                scheduler.get_lr()[0], train_loss, train_acc, test_loss,
                test_acc))

        print("test_acc: ", test_acc)
    test_acc_list.append(test_acc)
    log(logfilename, "This is the test accuracy list for args.round.")
    log(logfilename, str(test_acc_list))
def main():
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # Copy code to output directory
    copy_code(args.outdir)

    train_dataset = get_dataset(args.dataset, 'train')
    test_dataset = get_dataset(args.dataset, 'test')
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              batch_size=args.batch,
                              num_workers=args.workers,
                              pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset,
                             shuffle=False,
                             batch_size=args.batch,
                             num_workers=args.workers,
                             pin_memory=pin_memory)
    ## This is used to test the performance of the denoiser attached to a cifar10 classifier
    cifar10_test_loader = DataLoader(get_dataset('cifar10', 'test'),
                                     shuffle=False,
                                     batch_size=args.batch,
                                     num_workers=args.workers,
                                     pin_memory=pin_memory)

    if args.pretrained_denoiser:
        checkpoint = torch.load(args.pretrained_denoiser)
        assert checkpoint['arch'] == args.arch
        denoiser = get_architecture(checkpoint['arch'], args.dataset)
        denoiser.load_state_dict(checkpoint['state_dict'])
    else:
        denoiser = get_architecture(args.arch, args.dataset)

    if args.optimizer == 'Adam':
        optimizer = Adam(denoiser.parameters(),
                         lr=args.lr,
                         weight_decay=args.weight_decay)
    elif args.optimizer == 'SGD':
        optimizer = SGD(denoiser.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)
    elif args.optimizer == 'AdamThenSGD':
        optimizer = Adam(denoiser.parameters(),
                         lr=args.lr,
                         weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.lr_step_size,
                       gamma=args.gamma)

    starting_epoch = 0
    logfilename = os.path.join(args.outdir, 'log.txt')

    ## Resume from checkpoint if exists and if resume flag is True
    denoiser_path = os.path.join(args.outdir, 'checkpoint.pth.tar')
    if args.resume and os.path.isfile(denoiser_path):
        print("=> loading checkpoint '{}'".format(denoiser_path))
        checkpoint = torch.load(denoiser_path,
                                map_location=lambda storage, loc: storage)
        assert checkpoint['arch'] == args.arch
        starting_epoch = checkpoint['epoch']
        denoiser.load_state_dict(checkpoint['state_dict'])
        if starting_epoch >= args.start_sgd_epoch and args.optimizer == 'AdamThenSGD ':  # Do adam for few steps thaen continue SGD
            print("-->[Switching from Adam to SGD.]")
            args.lr = args.start_sgd_lr
            optimizer = SGD(denoiser.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
            scheduler = StepLR(optimizer,
                               step_size=args.lr_step_size,
                               gamma=args.gamma)

        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            denoiser_path, checkpoint['epoch']))
    else:
        if args.resume:
            print("=> no checkpoint found at '{}'".format(args.outdir))
        init_logfile(logfilename,
                     "epoch\ttime\tlr\ttrainloss\ttestloss\ttestAcc")

    if args.objective == 'denoising':
        criterion = MSELoss(size_average=None, reduce=None,
                            reduction='mean').cuda()
        best_loss = 1e6

    elif args.objective in ['classification', 'stability']:
        assert args.classifier != '', "Please specify a path to the classifier you want to attach the denoiser to."

        if args.classifier in IMAGENET_CLASSIFIERS:
            assert args.dataset == 'imagenet'
            # loading pretrained imagenet architectures
            clf = get_architecture(args.classifier,
                                   args.dataset,
                                   pytorch_pretrained=True)
        else:
            checkpoint = torch.load(args.classifier)
            clf = get_architecture(checkpoint['arch'], 'cifar10')
            clf.load_state_dict(checkpoint['state_dict'])
        clf.cuda().eval()
        requires_grad_(clf, False)
        criterion = CrossEntropyLoss(size_average=None,
                                     reduce=None,
                                     reduction='mean').cuda()
        best_acc = 0

    for epoch in range(starting_epoch, args.epochs):
        before = time.time()
        if args.objective == 'denoising':
            train_loss = train(train_loader, denoiser, criterion, optimizer,
                               epoch, args.noise_sd)
            test_loss = test(test_loader, denoiser, criterion, args.noise_sd,
                             args.print_freq, args.outdir)
            test_acc = 'NA'
        elif args.objective in ['classification', 'stability']:
            train_loss = train(train_loader, denoiser, criterion, optimizer,
                               epoch, args.noise_sd, clf)
            if args.dataset == 'imagenet':
                test_loss, test_acc = test_with_classifier(
                    test_loader, denoiser, criterion, args.noise_sd,
                    args.print_freq, clf)
            else:
                # This is needed so that cifar10 denoisers trained using imagenet32 are still evaluated on the cifar10 testset
                test_loss, test_acc = test_with_classifier(
                    cifar10_test_loader, denoiser, criterion, args.noise_sd,
                    args.print_freq, clf)

        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, after - before, args.lr, train_loss, test_loss,
                test_acc))

        scheduler.step(epoch)
        args.lr = scheduler.get_lr()[0]

        # Switch from Adam to SGD
        if epoch == args.start_sgd_epoch and args.optimizer == 'AdamThenSGD ':  # Do adam for few steps thaen continue SGD
            print("-->[Switching from Adam to SGD.]")
            args.lr = args.start_sgd_lr
            optimizer = SGD(denoiser.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
            scheduler = StepLR(optimizer,
                               step_size=args.lr_step_size,
                               gamma=args.gamma)

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': denoiser.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.outdir, 'checkpoint.pth.tar'))

        if args.objective == 'denoising' and test_loss < best_loss:
            best_loss = test_loss
        elif args.objective in ['classification', 'stability'
                                ] and test_acc > best_acc:
            best_acc = test_acc
        else:
            continue

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': denoiser.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.outdir, 'best.pth.tar'))
def main():
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # args.outdir = os.path.join(os.getenv('PT_DATA_DIR', './'), args.outdir)
    
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # Copy code to output directory
    copy_code(args.outdir)
    
    train_dataset = get_dataset(args.dataset, 'train')
    test_dataset = get_dataset(args.dataset, 'test')
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch,
                              num_workers=args.workers, pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch,
                             num_workers=args.workers, pin_memory=pin_memory)
    ## This is used to test the performance of the denoiser attached to a cifar10 classifier
    cifar10_test_loader = DataLoader(get_dataset('cifar10', 'test'), shuffle=False, batch_size=args.batch,
                             num_workers=args.workers, pin_memory=pin_memory)

    if args.pretrained_denoiser:
        checkpoint = torch.load(args.pretrained_denoiser)
        assert checkpoint['arch'] == args.arch
        denoiser = get_architecture(checkpoint['arch'], args.dataset)
        denoiser.load_state_dict(checkpoint['state_dict'])
    else:
        denoiser = get_architecture(args.arch, args.dataset)

    if args.optimizer == 'Adam':
        optimizer = Adam(denoiser.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'SGD':
        optimizer = SGD(denoiser.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.optimizer == 'AdamThenSGD':
        optimizer = Adam(denoiser.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma)

    starting_epoch = 0
    logfilename = os.path.join(args.outdir, 'log.txt')

    ## Resume from checkpoint if exists and if resume flag is True
    denoiser_path = os.path.join(args.outdir, 'checkpoint.pth.tar')
    if args.resume and os.path.isfile(denoiser_path):
        print("=> loading checkpoint '{}'".format(denoiser_path))
        checkpoint = torch.load(denoiser_path,
                                map_location=lambda storage, loc: storage)
        assert checkpoint['arch'] == args.arch
        starting_epoch = checkpoint['epoch']
        denoiser.load_state_dict(checkpoint['state_dict'])
        if starting_epoch >= args.start_sgd_epoch and args.optimizer == 'AdamThenSGD ': # Do adam for few steps thaen continue SGD
            print("-->[Switching from Adam to SGD.]")
            args.lr = args.start_sgd_lr
            optimizer = SGD(denoiser.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
            scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma)
        
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
                        .format(denoiser_path, checkpoint['epoch']))
    else:
        if args.resume: print("=> no checkpoint found at '{}'".format(args.outdir))
        init_logfile(logfilename, "epoch\ttime\tlr\ttrainloss\ttestloss\ttestAcc")

    # the first model is the model we test on, the other 14 models are surrogate models that we use 
    # (see the paper for more detail)
    models = [
    "ResNet110", # fixed model
    #surrogate models
    "WRN", "WRN40", "VGG16", "VGG19", "ResNet18","PreActResNet18","ResNeXt29_2x64d",
    "MobileNet","MobileNetV2","SENet18","ShuffleNetV2","EfficientNetB0","GoogLeNet","DenseNet121"
    ]

    path = os.path.join(args.classifiers_path, '{}/noise_0.00/checkpoint.pth.tar')

    base_classifiers_paths = [
        path.format(model) for model in models
    ]
    base_classifiers = []
    if args.classifier_idx == -1:
        for base_classifier in base_classifiers_paths:
            # load the base classifier
            checkpoint = torch.load(base_classifier)
            base_classifier = get_architecture(checkpoint["arch"], args.dataset)
            base_classifier.load_state_dict(checkpoint['state_dict'])

            requires_grad_(base_classifier, False)
            base_classifiers.append(base_classifier.eval().cuda())
    else:
        if args.classifier_idx not in range(len(models)): raise Exception("Unknown model")
        print("Model namse: {}".format(base_classifiers_paths[args.classifier_idx]))
        checkpoint = torch.load(base_classifiers_paths[args.classifier_idx])
        base_classifier = get_architecture(checkpoint["arch"], args.dataset)
        base_classifier.load_state_dict(checkpoint['state_dict'])

        requires_grad_(base_classifier, False)
        base_classifiers.append(base_classifier.eval().cuda())

    criterion = CrossEntropyLoss(size_average=None, reduce=None, reduction = 'mean').cuda()
    best_acc = 0

    for epoch in range(starting_epoch, args.epochs):
        before = time.time()
        train_loss = train(train_loader, denoiser, criterion, optimizer, epoch, args.noise_sd, base_classifiers[1:])
        test_loss, test_acc = test_with_classifier(cifar10_test_loader, denoiser, criterion, args.noise_sd, args.print_freq, base_classifiers[0])

        after = time.time()

        log(logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
            epoch, after - before,
            args.lr, train_loss, test_loss, test_acc))

        scheduler.step(epoch)
        args.lr = scheduler.get_lr()[0]

        # Switch from Adam to SGD
        if epoch == args.start_sgd_epoch and args.optimizer == 'AdamThenSGD ': # Do adam for few steps thaen continue SGD
            print("-->[Switching from Adam to SGD.]")
            args.lr = args.start_sgd_lr
            optimizer = SGD(denoiser.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
            scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma)

        torch.save({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': denoiser.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, os.path.join(args.outdir, 'checkpoint.pth.tar'))

        if args.objective in ['classification', 'stability'] and test_acc > best_acc:
            best_acc = test_acc
        else:
            continue

        torch.save({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': denoiser.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, os.path.join(args.outdir, 'best.pth.tar'))
Ejemplo n.º 9
0
def main():
    if not os.path.exists(args.outdir):
        os.mkdir(args.outdir)

    device = torch.device("cuda")
    torch.cuda.set_device(args.gpu)

    logfilename = os.path.join(args.outdir, args.logname)

    log(logfilename, "Hyperparameter List")
    log(logfilename, "Epochs: {:}".format(args.epochs))
    log(logfilename, "Learning Rate: {:}".format(args.lr))
    log(logfilename, "Alpha: {:}".format(args.alpha))
    log(logfilename, "Keep ratio: {:}".format(args.keep_ratio))

    test_acc_list = []
    for _ in range(args.round):
        train_dataset = get_dataset(args.dataset, 'train')
        test_dataset = get_dataset(args.dataset, 'test')
        pin_memory = (args.dataset == "imagenet")
        train_loader = DataLoader(train_dataset,
                                  shuffle=True,
                                  batch_size=args.batch,
                                  num_workers=args.workers,
                                  pin_memory=pin_memory)
        test_loader = DataLoader(test_dataset,
                                 shuffle=False,
                                 batch_size=args.batch,
                                 num_workers=args.workers,
                                 pin_memory=pin_memory)

        # Loading the base_classifier
        base_classifier = get_architecture(args.arch, args.dataset, device)
        checkpoint = torch.load(args.savedir)
        base_classifier.load_state_dict(checkpoint['state_dict'])
        base_classifier.eval()
        print("Loaded the base_classifier")

        original_acc = model_inference(base_classifier,
                                       test_loader,
                                       device,
                                       display=True)

        log(logfilename,
            "Original Model Test Accuracy: {:.5}".format(original_acc))
        print("Original Model Test Accuracy, ", original_acc)

        # Creating a fresh copy of network not affecting the original network.
        net = copy.deepcopy(base_classifier)
        net = net.to(device)

        # Generating the mask 'm'
        for layer in net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))

                layer.weight.requires_grad = True
                layer.weight_mask.requires_grad = True

            # This is the monkey-patch overriding layer.forward to custom function.
            # layer.forward will pass nn.Linear with weights: 'w' and 'm' elementwised
            if isinstance(layer, nn.Linear):
                layer.forward = types.MethodType(mask_forward_linear, layer)

            if isinstance(layer, nn.Conv2d):
                layer.forward = types.MethodType(mask_forward_conv2d, layer)

        criterion = nn.NLLLoss().to(
            device)  # I added Log Softmax layer to all architecture.
        optimizer = SGD(
            net.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=0)  # weight_decay = 0 for training the mask.

        sparsity, total = 0, 0
        for layer in net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                boolean_list = layer.weight_mask.data > args.threshold
                sparsity += (boolean_list == 1).sum()
                total += layer.weight.numel()

        # Training the mask with the training set.
        # You can set the maximum number of loop in case the sparsity on auxiliary parameter
        # do not go below target sparsity.
        for epoch in range(300):
            if epoch % 5 == 0:
                print("Current epochs: ", epoch)
                print("Sparsity: {:}".format(sparsity))
            train_loss = mask_train(train_loader,
                                    net,
                                    criterion,
                                    optimizer,
                                    epoch,
                                    device,
                                    alpha=args.alpha,
                                    display=False)
            acc = model_inference(net, test_loader, device, display=False)
            log(logfilename,
                "Epoch {:}, Mask Update Test Acc: {:.5}".format(epoch, acc))

            sparsity, total = 0, 0
            for layer in net.modules():
                if isinstance(layer, nn.Linear) or isinstance(
                        layer, nn.Conv2d):
                    boolean_list = layer.weight_mask.data > args.threshold
                    sparsity += (boolean_list == 1).sum()
                    total += layer.weight.numel()

            if sparsity <= total * args.keep_ratio:
                print("Current epochs breaking loop at {:}".format(epoch))
                break

        mask_update_acc = model_inference(net,
                                          test_loader,
                                          device,
                                          display=True)
        log(logfilename,
            "Mask Update Test Accuracy: {:.5}".format(mask_update_acc))

        # This line allows to calculate the threshold to satisfy the keep_ratio.
        c_abs = []
        for layer in net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                c_abs.append(torch.abs(layer.weight_mask))

        all_scores = torch.cat([torch.flatten(x) for x in c_abs])
        num_params_to_keep = int(len(all_scores) * args.keep_ratio)
        threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
        threshold = threshold[-1]

        keep_masks = []
        for c in c_abs:
            keep_masks.append((c >= threshold).float())
        print(
            "Number of ones.",
            torch.sum(torch.cat([torch.flatten(x == 1) for x in keep_masks])))

        # Updating the weight with elementwise product of update c.
        for layer in net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                # We update the weight by elementwise multiplication between
                # weight 'w' and mask 'm'.
                layer.weight.data = layer.weight.data * layer.weight_mask.data
                layer.zeros = nn.Parameter(torch.zeros_like(
                    layer.weight))  # Dummy parameter.
                layer.ones = nn.Parameter(torch.ones_like(
                    layer.weight))  # Dummy parameter.
                layer.weight_mask.data = torch.where(
                    torch.abs(layer.weight_mask) <= threshold, layer.zeros,
                    layer.ones
                )  # Updated weight_mask becomes the mask with element
                # 0 and 1 again.

                # Temporarily disabling the backprop for both 'w' and 'm'.
                layer.weight.requires_grad = False
                layer.weight_mask.requires_grad = False

            if isinstance(layer, nn.Linear):
                layer.forward = types.MethodType(mask_forward_linear, layer)

            if isinstance(layer, nn.Conv2d):
                layer.forward = types.MethodType(mask_forward_conv2d, layer)

        weight_update_acc = model_inference(net,
                                            test_loader,
                                            device,
                                            display=True)
        log(logfilename,
            "Weight Update Test Accuracy: {:.5}".format(weight_update_acc))

        # Calculating the sparsity of the network.
        remain = 0
        total = 0
        for layer in net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                total += torch.norm(torch.ones_like(layer.weight),
                                    p=1)  # Counting total num parameter
                remain += torch.norm(layer.weight_mask.data,
                                     p=1)  # Counting ones in the mask.

                # Disabling backprop except weight 'w' for the finetuning.
                layer.zeros.requires_grad = False
                layer.ones.requires_grad = False
                layer.weight_mask.requires_grad = False
                layer.weight.requires_grad = True

            if isinstance(layer, nn.Linear):
                layer.forward = types.MethodType(mask_forward_linear, layer)

            if isinstance(layer, nn.Conv2d):
                layer.forward = types.MethodType(mask_forward_conv2d, layer)

        log(logfilename, "Sparsity: {:.3}".format(remain / total))
        print("Sparsity: ", remain / total)

        #        --------------------------------
        # We need to transfer the weight we learned from "net" to "base_classifier".
        for (layer1, layer2) in zip(base_classifier.modules(), net.modules()):
            if isinstance(layer1, (nn.Linear, nn.Conv2d)) or isinstance(
                    layer2, (nn.Linear, nn.Conv2d)):
                layer1.weight.data = layer2.weight.data
                if layer1.bias != None:
                    layer1.bias.data = layer2.bias.data
                    layer1.bias.requires_grad = True

                layer1.weight.requires_grad = True

        # Applying the mask to the base_classifier.
        apply_prune_mask(base_classifier, keep_masks)
        #        --------------------------------

        optimizer = SGD(base_classifier.parameters(),
                        lr=1e-3,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)
        loss = nn.NLLLoss()
        scheduler = StepLR(optimizer,
                           step_size=args.lr_step_size,
                           gamma=args.gamma)

        test_acc = []
        # Finetuning via ignite
        trainer = create_supervised_trainer(base_classifier, optimizer,
                                            nn.NLLLoss(), device)
        evaluator = create_supervised_evaluator(base_classifier, {
            'accuracy': Accuracy(),
            'nll': Loss(loss)
        }, device)

        pbar = ProgressBar()
        pbar.attach(trainer)

        @trainer.on(Events.ITERATION_COMPLETED)
        def log_training_loss(engine):
            iter_in_epoch = (engine.state.iteration -
                             1) % len(train_loader) + 1
            if engine.state.iteration % args.print_freq == 0:
                pbar.log_message("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
                                 "".format(engine.state.epoch, iter_in_epoch,
                                           len(train_loader),
                                           engine.state.output))

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_epoch(engine):
            scheduler.step()
            evaluator.run(test_loader)

            metrics = evaluator.state.metrics
            avg_accuracy = metrics['accuracy']
            avg_nll = metrics['nll']

            pbar.log_message(
                "Validation Results - Epoch: {}  Avg accuracy: {:.5f} Avg loss: {:.3f}"
                .format(engine.state.epoch, avg_accuracy, avg_nll))

            log(
                logfilename,
                "Validation  - Epoch: {}  Avg accuracy: {:.5f} Avg loss: {:.3f}"
                .format(engine.state.epoch, avg_accuracy, avg_nll))

            test_acc.append(avg_accuracy)

            if avg_accuracy >= max(test_acc):
                print("Saving the model at Epoch {:}".format(
                    engine.state.epoch))
                torch.save(
                    {
                        'arch': args.arch,
                        'state_dict': base_classifier.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, os.path.join(args.outdir, 'checkpoint.pth.tar'))

            if engine.state.epoch == args.epochs:
                test_acc_list.append(max(test_acc))
                log(logfilename,
                    "Finetuned Test Accuracy: {:.5f}".format(max(test_acc)))
                print("Finetuned Test Accuracy: ", max(test_acc))

        trainer.run(train_loader, args.epochs)

    log(logfilename, "This is the test accuracy list for args.round.")
    log(logfilename, str(test_acc_list))
Ejemplo n.º 10
0
def main():

    if not os.path.exists(args.outdir):
        os.mkdir(args.outdir)

    device = torch.device("cuda")
    torch.cuda.set_device(args.gpu)

    logfilename = os.path.join(args.outdir, args.logname)

    log(logfilename, "Hyperparameter List")
    log(logfilename, "Epochs: {:}".format(args.epochs))
    log(logfilename, "Learning Rate: {:}".format(args.lr))
    log(logfilename, "Alpha: {:}".format(args.alpha))
    log(logfilename, "Keep ratio: {:}".format(args.keep_ratio))
    log(logfilename, "Warmup Epochs: {:}".format(args.epochs_warmup))

    test_acc_list = []
    for _ in range(args.round):
        train_dataset = get_dataset(args.dataset, 'train')
        test_dataset = get_dataset(args.dataset, 'test')
        pin_memory = (args.dataset == "imagenet")
        train_loader = DataLoader(train_dataset,
                                  shuffle=True,
                                  batch_size=args.batch,
                                  num_workers=args.workers,
                                  pin_memory=pin_memory)
        test_loader = DataLoader(test_dataset,
                                 shuffle=False,
                                 batch_size=args.batch,
                                 num_workers=args.workers,
                                 pin_memory=pin_memory)

        # Loading the base_classifier
        base_classifier = get_architecture(args.arch, args.dataset, device)
        print("Loaded the base_classifier")

        criterion = nn.NLLLoss().to(device)
        optimizer = SGD(base_classifier.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)

        # Warmup training for the rewinding.
        for epoch in range(args.epochs_warmup):
            print("Warmup Training Epochs: {:}".format(epoch))
            train_loss, train_top1, train_top5 = utils.train(train_loader,
                                                             base_classifier,
                                                             criterion,
                                                             optimizer,
                                                             epoch,
                                                             device,
                                                             print_freq=100,
                                                             display=False)

        original_acc = model_inference(base_classifier,
                                       test_loader,
                                       device,
                                       display=True)
        log(logfilename,
            "Warmup Model Test Accuracy: {:.5}".format(original_acc))
        print("Warmup Model Test Accuracy, ", original_acc)

        # Creating a fresh copy of network not affecting the original network.
        # Goal is to find the supermask.

        net = copy.deepcopy(base_classifier)
        net = net.to(device)

        # Generating the mask 'm'
        for layer in net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))

                layer.weight.requires_grad = True
                layer.weight_mask.requires_grad = True

            # This is the monkey-patch overriding layer.forward to custom function.
            # layer.forward will pass nn.Linear with weights: 'w' and 'm' elementwised
            if isinstance(layer, nn.Linear):
                layer.forward = types.MethodType(mask_forward_linear, layer)

            if isinstance(layer, nn.Conv2d):
                layer.forward = types.MethodType(mask_forward_conv2d, layer)

        criterion = nn.NLLLoss().to(device)
        optimizer = SGD(net.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=0)
        # weight_decay = 0 for training the mask.

        sparsity, total = 0, 0
        for layer in net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                boolean_list = layer.weight_mask.data > args.threshold
                sparsity += (boolean_list == 1).sum()
                total += layer.weight.numel()

        # Training the mask with the training set.
        for epoch in range(300):
            if epoch % 5 == 0:
                print("Current epochs: ", epoch)
                print("Sparsity: {:}".format(sparsity))
            before = time.time()
            train_loss = mask_train(train_loader,
                                    net,
                                    criterion,
                                    optimizer,
                                    epoch,
                                    device,
                                    alpha=args.alpha,
                                    display=False)
            acc = model_inference(net, test_loader, device, display=False)
            log(logfilename,
                "Epoch {:}, Mask Update Test Acc: {:.5}".format(epoch, acc))

            sparsity = 0
            total = 0
            for layer in net.modules():
                if isinstance(layer, nn.Linear) or isinstance(
                        layer, nn.Conv2d):
                    boolean_list = layer.weight_mask.data > 1e-2
                    sparsity += (boolean_list == 1).sum()
                    total += layer.weight.numel()

            if sparsity <= total * args.keep_ratio:
                print("Current epochs breaking loop at {:}".format(epoch))
                break

        # This line allows to calculate the threshold to satisfy the keep_ratio.
        c_abs = []
        for layer in net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                c_abs.append(torch.abs(layer.weight_mask))

        all_scores = torch.cat([torch.flatten(x) for x in c_abs])
        num_params_to_keep = int(len(all_scores) * args.keep_ratio)
        threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
        threshold = threshold[-1]

        keep_masks = []
        for c in c_abs:
            keep_masks.append((c >= threshold).float())
        print(
            "Number of ones.",
            torch.sum(torch.cat([torch.flatten(x == 1) for x in keep_masks])))

        # Applying the mask to the original network.
        apply_prune_mask(base_classifier, keep_masks)

        mask_update_acc = model_inference(base_classifier,
                                          test_loader,
                                          device,
                                          display=True)
        log(logfilename,
            "Untrained Network Test Accuracy: {:.5}".format(mask_update_acc))
        print("Untrained Network Test Accuracy: {:.5}".format(mask_update_acc))

        optimizer = SGD(base_classifier.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)
        loss = nn.NLLLoss()
        scheduler = MultiStepLR(
            optimizer,
            milestones=[
                int(args.epochs * 0.5) - args.epochs_warmup,
                int(args.epochs * 0.75) - args.epochs_warmup
            ],
            last_epoch=-1)

        test_acc = []  # Collecting the test accuracy

        # Finetuning via ignite
        trainer = create_supervised_trainer(base_classifier, optimizer,
                                            nn.NLLLoss(), device)
        evaluator = create_supervised_evaluator(base_classifier, {
            'accuracy': Accuracy(),
            'nll': Loss(loss)
        }, device)

        pbar = ProgressBar()
        pbar.attach(trainer)

        @trainer.on(Events.ITERATION_COMPLETED)
        def log_training_loss(engine):
            iter_in_epoch = (engine.state.iteration -
                             1) % len(train_loader) + 1
            if engine.state.iteration % args.print_freq == 0:
                pbar.log_message("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
                                 "".format(engine.state.epoch, iter_in_epoch,
                                           len(train_loader),
                                           engine.state.output))

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_epoch(engine):
            scheduler.step()
            evaluator.run(test_loader)

            metrics = evaluator.state.metrics
            avg_accuracy = metrics['accuracy']
            avg_nll = metrics['nll']

            pbar.log_message(
                "Validation Results - Epoch: {}  Avg accuracy: {:.3f} Avg loss: {:.3f}"
                .format(engine.state.epoch + args.epochs_warmup, avg_accuracy,
                        avg_nll))

            log(
                logfilename,
                "Validation  - Epoch: {}  Avg accuracy: {:.3f} Avg loss: {:.3f}"
                .format(engine.state.epoch + args.epochs_warmup, avg_accuracy,
                        avg_nll))

            test_acc.append(avg_accuracy)

            if avg_accuracy >= max(test_acc):
                print(
                    "Saving the model at Epoch {:}".format(engine.state.epoch +
                                                           args.epochs_warmup))
                torch.save(
                    {
                        'arch': args.arch,
                        'state_dict': base_classifier.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, os.path.join(args.outdir, 'checkpoint.pth.tar'))

            if engine.state.epoch + args.epochs_warmup == args.epochs:
                test_acc_list.append(max(test_acc))
                log(logfilename,
                    "Finetuned Test Accuracy: {:.5f}".format(max(test_acc)))
                print("Finetuned Test Accuracy: ", max(test_acc))

        trainer.run(train_loader, args.epochs - args.epochs_warmup)

    log(logfilename, "This is the test accuracy list for args.round.")
    log(logfilename, str(test_acc_list))
Ejemplo n.º 11
0
checkpoint = torch.load(args.base_classifier)
model = get_architecture(checkpoint["arch"], args.dataset)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
criterion = CrossEntropyLoss().cuda()

test_dataset = get_dataset(args.dataset, 'test')
test_loader = DataLoader(test_dataset,
                         shuffle=False,
                         batch_size=args.batch,
                         num_workers=4,
                         pin_memory=False)

logfilename = os.path.join(args.outdir, 'log.txt')
log(logfilename, "{0}".format(args))

correct = 0
total = 0
torch.manual_seed(12345)
for i, (images, labels) in enumerate(test_loader):

    images = images.cuda()
    labels = labels.cuda()
    roa = ROA(model, 32)

    learning_rate = args.attlr
    iterations = args.attiters
    ROAwidth = args.ROAwidth
    ROAheight = args.ROAheight
    skip_in_x = args.skip_in_x
Ejemplo n.º 12
0
def main():
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    train_dataset = get_dataset(args.dataset, 'train')
    test_dataset = get_dataset(args.dataset, 'test')
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              batch_size=args.batch,
                              num_workers=args.workers,
                              pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset,
                             shuffle=False,
                             batch_size=args.batch,
                             num_workers=args.workers,
                             pin_memory=pin_memory)

    model = get_architecture(args.arch, args.dataset)

    if args.pretrain is not None:
        if args.pretrain == 'torchvision':
            # load pretrain model from torchvision
            if args.dataset == 'imagenet' and args.arch == 'resnet50':
                model = torchvision.models.resnet50(True).cuda()

                # fix
                normalize_layer = get_normalize_layer('imagenet').cuda()
                model = torch.nn.Sequential(normalize_layer, model)

                print('loaded from torchvision for imagenet resnet50')
            else:
                raise Exception(f'Unsupported pretrain arg {args.pretrain}')
        else:
            # load the base classifier
            checkpoint = torch.load(args.pretrain)
            model.load_state_dict(checkpoint['state_dict'])
            print(f'loaded from {args.pretrain}')

    logfilename = os.path.join(args.outdir, 'log.txt')
    init_logfile(logfilename,
                 "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")
    writer = SummaryWriter(args.outdir)

    canopy = None
    for (inputs, targets) in train_loader:
        canopy = inputs[0]
        break
    transformer = gen_transformer(args, canopy)

    criterion = CrossEntropyLoss().cuda()
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.lr_step_size,
                       gamma=args.gamma)

    for epoch in range(args.epochs):
        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, transformer, writer)
        test_loss, test_acc = test(test_loader, model, criterion, epoch,
                                   transformer, writer, args.print_freq)
        after = time.time()

        scheduler.step(epoch)

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, str(datetime.timedelta(seconds=(after - before))),
                scheduler.get_lr()[0], train_loss, train_acc, test_loss,
                test_acc))

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.outdir, 'checkpoint.pth.tar'))
Ejemplo n.º 13
0
def main():
    torch.manual_seed(0)
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not os.path.exists(args.outdir):
        os.mkdir(args.outdir)

    train_loader, test_loader = get_dataset(args.dataset,
                                            'train',
                                            args.batch,
                                            num_workers=args.workers)
    # test_dataset = get_dataset(args.dataset, 'valid')
    # pin_memory = (args.dataset == "imagenet")
    # train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch,
    #                           num_workers=args.workers, pin_memory=pin_memory)
    # test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch,
    #                          num_workers=args.workers, pin_memory=pin_memory)

    model = get_architecture(args.arch, args.dataset)

    logfilename = os.path.join(args.outdir, 'log.txt')
    init_logfile(logfilename,
                 "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")

    criterion = CrossEntropyLoss().cuda()
    if args.optimizer == 'momentum':
        optimizer = SGD(model.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)
    elif args.optimizer == 'nesterov':
        optimizer = SGD(model.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay,
                        nesterov=True)
    elif args.optimizer == 'amsgrad':
        optimizer = Adam(model.parameters(),
                         lr=args.lr,
                         weight_decay=args.weight_decay,
                         amsgrad=True)

    scheduler = StepLR(optimizer,
                       step_size=args.lr_step_size,
                       gamma=args.gamma)

    best_acc = 0.
    for epoch in range(args.epochs):
        scheduler.step(epoch)
        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, args)
        test_loss, test_acc = test(test_loader, model, criterion, epoch, args)
        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, str(datetime.timedelta(seconds=(after - before))),
                scheduler.get_lr()[0], train_loss, train_acc, test_loss,
                test_acc))

        if args.tune and best_acc < test_acc:
            best_acc = test_acc
            print('saving best model...')
            torch.save(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, os.path.join(args.outdir, 'checkpoint.best.tar'))

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.outdir, 'checkpoint.pth.tar'))
Ejemplo n.º 14
0
def main():
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # Copies files to the outdir to store complete script with each experiment
    copy_code(args.outdir)

    train_dataset = get_dataset(args.dataset, 'train')
    test_dataset = get_dataset(args.dataset, 'test')
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              batch_size=args.batch,
                              num_workers=args.workers,
                              pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset,
                             shuffle=False,
                             batch_size=args.batch,
                             num_workers=args.workers,
                             pin_memory=pin_memory)

    model = get_architecture(args.arch, args.dataset)
    if args.attack == 'PGD':
        print('Attacker is PGD')
        attacker = PGD_L2(steps=args.num_steps,
                          device='cuda',
                          max_norm=args.epsilon)
    elif args.attack == 'DDN':
        print('Attacker is DDN')
        attacker = DDN(steps=args.num_steps,
                       device='cuda',
                       max_norm=args.epsilon,
                       init_norm=args.init_norm_DDN,
                       gamma=args.gamma_DDN)
    else:
        raise Exception('Unknown attack')

    criterion = CrossEntropyLoss().cuda()
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.lr_step_size,
                       gamma=args.gamma)

    starting_epoch = 0
    logfilename = os.path.join(args.outdir, 'log.txt')

    # Load latest checkpoint if exists (to handle philly failures)
    model_path = os.path.join(args.outdir, 'checkpoint.pth.tar')
    if args.resume:
        if os.path.isfile(model_path):
            print("=> loading checkpoint '{}'".format(model_path))
            checkpoint = torch.load(model_path,
                                    map_location=lambda storage, loc: storage)
            starting_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                model_path, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(model_path))
            if args.adv_training:
                init_logfile(
                    logfilename,
                    "epoch\ttime\tlr\ttrainloss\ttestloss\ttrainacc\ttestacc\ttestaccNor"
                )
            else:
                init_logfile(
                    logfilename,
                    "epoch\ttime\tlr\ttrainloss\ttestloss\ttrainacc\ttestacc")
    else:
        if args.adv_training:
            init_logfile(
                logfilename,
                "epoch\ttime\tlr\ttrainloss\ttestloss\ttrainacc\ttestacc\ttestaccNor"
            )
        else:
            init_logfile(
                logfilename,
                "epoch\ttime\tlr\ttrainloss\ttestloss\ttrainacc\ttestacc")

    for epoch in range(starting_epoch, args.epochs):
        scheduler.step(epoch)
        attacker.max_norm = np.min(
            [args.epsilon, (epoch + 1) * args.epsilon / args.warmup])
        attacker.init_norm = np.min(
            [args.epsilon, (epoch + 1) * args.epsilon / args.warmup])

        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, args.noise_sd,
                                      attacker)
        test_loss, test_acc, test_acc_normal = test(test_loader, model,
                                                    criterion, args.noise_sd,
                                                    attacker)
        after = time.time()

        if args.adv_training:
            log(
                logfilename,
                "{}\t{:.2f}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                    epoch, after - before,
                    scheduler.get_lr()[0], train_loss, test_loss, train_acc,
                    test_acc, test_acc_normal))
        else:
            log(
                logfilename,
                "{}\t{:.2f}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                    epoch, after - before,
                    scheduler.get_lr()[0], train_loss, test_loss, train_acc,
                    test_acc))

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, model_path)
Ejemplo n.º 15
0
def main():
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # Copy code to output directory
    copy_code(args.outdir)

    train_dataset = get_dataset(args.dataset, 'train')
    test_dataset = get_dataset(args.dataset, 'test')
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              batch_size=args.batch,
                              num_workers=args.workers,
                              pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset,
                             shuffle=False,
                             batch_size=args.batch,
                             num_workers=args.workers,
                             pin_memory=pin_memory)

    model = get_architecture(args.arch, args.dataset)

    criterion = CrossEntropyLoss().cuda()
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.lr_step_size,
                       gamma=args.gamma)

    starting_epoch = 0
    logfilename = os.path.join(args.outdir, 'log.txt')

    ## Resume from checkpoint if exists and if resume flag is True
    model_path = os.path.join(args.outdir, 'checkpoint.pth.tar')
    if args.resume and os.path.isfile(model_path):
        print("=> loading checkpoint '{}'".format(model_path))
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        starting_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            model_path, checkpoint['epoch']))
    else:
        if args.resume:
            print("=> no checkpoint found at '{}'".format(args.outdir))
        init_logfile(
            logfilename,
            "epoch\ttime\tlr\ttrainloss\ttestloss\ttrainAcc\ttestAcc")

    for epoch in range(starting_epoch, args.epochs):
        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, args.noise_sd)
        test_loss, test_acc = test(test_loader, model, criterion,
                                   args.noise_sd)
        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, after - before,
                scheduler.get_lr()[0], train_loss, test_loss, train_acc,
                test_acc))

        scheduler.step(epoch)

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.outdir, 'checkpoint.pth.tar'))
Ejemplo n.º 16
0
def main():
    if not os.path.exists(args.outdir):
        os.mkdir(args.outdir)

    device = torch.device("cuda")
    torch.cuda.set_device(args.gpu)

    logfilename = os.path.join(args.outdir, args.logname)

    init_logfile(logfilename, "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")
    log(logfilename, "Hyperparameter List")
    log(logfilename, "Epochs: {:}".format(args.epochs))
    log(logfilename, "Learning Rate: {:}".format(args.lr))
    log(logfilename, "Alpha: {:}".format(args.alpha))
    log(logfilename, "Keep ratio: {:}".format(args.keep_ratio))

    test_acc_list = []
    for _ in range(args.round):
        traindir = os.path.join(args.data_train, 'train')
        valdir = os.path.join(args.data_val, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=args.batch, shuffle=(train_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=train_sampler)

        test_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(valdir, transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.batch, shuffle=False,
            num_workers=args.workers, pin_memory=True)


        base_classifier = models.__dict__[args.arch](pretrained=True).cuda()
        print("Loaded the base_classifier")

        original_acc = model_inference(base_classifier, test_loader,
                                       device, display=True)
        log(logfilename, "Original Model Test Accuracy: {:.5}".format(original_acc))
        print("Original Model Test Accuracy, ", original_acc)

        # Creating a fresh copy of network not affecting the original network.
        net = copy.deepcopy(base_classifier)
        net = net.to(device)


        # Generating the mask 'm'
        for layer in net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))

                layer.weight.requires_grad = True
                layer.weight_mask.requires_grad = True

            # This is the monkey-patch overriding layer.forward to custom function.
            # layer.forward will pass nn.Linear with weights: 'w' and 'm' elementwised
            if isinstance(layer, nn.Linear):
                layer.forward = types.MethodType(mask_forward_linear, layer)

            if isinstance(layer, nn.Conv2d):
                layer.forward = types.MethodType(mask_forward_conv2d, layer)


        criterion = nn.CrossEntropyLoss().to(device)    # I added Log Softmax layer to all architecture.
        optimizer = SGD(net.parameters(), lr=args.lr, momentum=args.momentum,
                        weight_decay=0) # weight_decay = 0 for training the mask.
 
        sparsity, total = 0, 0
        breakFlag = False
        net.train()
        # Training the mask with the training set.
        for epoch in range(100000):
#             if epoch % 5 == 0:
            print("Current epochs: ", epoch)
            print("Sparsity: {:}".format(sparsity))
            log(logfilename, "Current epochs: {}".format(epoch))
            log(logfilename, "Sparsity: {:}".format(sparsity))
            
                
            for i, (inputs, targets) in enumerate(train_loader):
                inputs = inputs.cuda()
                targets = targets.cuda()

                reg_loss = 0
                for layer in net.modules():
                    if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
                        reg_loss += torch.norm(layer.weight_mask, p=1)
                outputs = net(inputs)
                loss = criterion(outputs, targets) + args.alpha * reg_loss
                
                # Computing gradient and do SGD
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
#                 if i % 50000 == 0:
#                     print("Entered 50000 loop")
#                     log(logfilename, "Entered 50000 loop")

                sparsity, total = 0, 0
                for layer in net.modules():
                    if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                        boolean_list = layer.weight_mask.data > 1e-3
                        sparsity += (boolean_list == 1).sum()
                        total += layer.weight.numel()
                
                if i % 50 == 0:
                    print("Current Epochs: {}, Current i: {}, Current Sparsity: {}".format(epoch, i, sparsity))
                
                if sparsity <= total*args.keep_ratio:
                    print("Current epochs breaking loop at {:}".format(epoch))
                    log(logfilename, "Current epochs breaking loop at {:}".format(epoch))
                    breakFlag = True
                    break
#                 if breakFlag == True:
#                     break
            if breakFlag == True:
                break
            

        # This line allows to calculate the threshold to satisfy the keep_ratio.
        c_abs = []
        for layer in net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                c_abs.append(torch.abs(layer.weight_mask))
        
        all_scores = torch.cat([torch.flatten(x) for x in c_abs])
        num_params_to_keep = int(len(all_scores) * args.keep_ratio)
        threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
        threshold = threshold[-1]
        
        print("Threshold found: ", threshold)
        
        keep_masks = []
        for c in c_abs:
            keep_masks.append((c >= threshold).float())
        print("Number of ones.", torch.sum(torch.cat([torch.flatten(x == 1) for x in keep_masks])))
        
        # Updating the weight with elementwise product of update c.
        for layer in net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                # We update the weight by elementwise multiplication between
                # weight 'w' and mask 'm'.
                layer.weight.data = layer.weight.data * layer.weight_mask.data
                layer.zeros = nn.Parameter(torch.zeros_like(layer.weight))    # Dummy parameter.
                layer.ones = nn.Parameter(torch.ones_like(layer.weight))      # Dummy parameter.
                layer.weight_mask.data = torch.where(torch.abs(layer.weight_mask) <= threshold,
                                                layer.zeros,
                                                layer.ones)    # Updated weight_mask becomes the mask with element
                                                               # 0 and 1 again.

                # Temporarily disabling the backprop for both 'w' and 'm'.
                layer.weight.requires_grad = False
                layer.weight_mask.requires_grad = False

            if isinstance(layer, nn.Linear):
                layer.forward = types.MethodType(mask_forward_linear, layer)

            if isinstance(layer, nn.Conv2d):
                layer.forward = types.MethodType(mask_forward_conv2d, layer)

#        --------------------------------
        # We need to transfer the weight we learned from "net" to "base_classifier".
        for (layer1, layer2) in zip(base_classifier.modules(), net.modules()):
            if isinstance(layer1, (nn.Linear, nn.Conv2d)) or isinstance(layer2, (nn.Linear, nn.Conv2d)):
                layer1.weight.data = layer2.weight.data
                if layer1.bias != None:
                    layer1.bias.data = layer2.bias.data
                    layer1.bias.requires_grad = True

                layer1.weight.requires_grad = True
                
        

        torch.save(base_classifier.state_dict(), os.path.join(args.outdir, args.save_model))
        base_classifier_acc = model_inference(base_classifier, test_loader, device, display=True)
        log(logfilename, "Weight Update Test Accuracy: {:.5}".format(base_classifier_acc))
        print("Saved the finetune model.")
        for masks in keep_masks:
            masks = masks.data
            
        torch.save(keep_masks, os.path.join(args.outdir, args.keep_mask))
        print("Saved the masking function.")
        log(logfilename, "Finished finding the mask. (FINETUNE)")
def main():

    if not os.path.exists(args.outdir):
        os.mkdir(args.outdir)

    device = torch.device("cuda")
    torch.cuda.set_device(args.gpu)

    logfilename = os.path.join(args.outdir, args.logname)

    init_logfile(logfilename,
                 "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")
    log(logfilename, "Hyperparameter List")
    log(logfilename, "Epochs: {:}".format(args.epochs))
    log(logfilename, "Learning Rate: {:}".format(args.lr))
    log(logfilename, "Alpha: {:}".format(args.alpha))
    log(logfilename, "Keep ratio: {:}".format(args.keep_ratio))
    log(logfilename, "Warmup Epochs: {:}".format(args.epochs_warmup))

    test_acc_list = []
    for _ in range(args.round):
        traindir = os.path.join(args.data_train, 'train')
        valdir = os.path.join(args.data_val, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler)

        test_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                  batch_size=args.batch,
                                                  shuffle=False,
                                                  num_workers=args.workers,
                                                  pin_memory=True)

        base_classifier = models.__dict__[args.arch](pretrained=False).cuda()
        print("Loaded the base_classifier")

        criterion = nn.CrossEntropyLoss().to(device)
        optimizer = SGD(base_classifier.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)

        # Warmup training for the rewinding.
        for epoch in range(args.epochs_warmup):
            print("Warmup Training Epochs: {:}".format(epoch))
            log(logfilename, "Warmup current epochs: {}".format(epoch))
            train_loss, train_top1, train_top5 = utils.train(train_loader,
                                                             base_classifier,
                                                             criterion,
                                                             optimizer,
                                                             epoch,
                                                             device,
                                                             print_freq=100,
                                                             display=True)

        original_acc = model_inference(base_classifier,
                                       test_loader,
                                       device,
                                       display=True)
        log(logfilename,
            "Warmup Model Test Accuracy: {:.5}".format(original_acc))
        print("Warmup Model Test Accuracy, ", original_acc)

        # Creating a fresh copy of network not affecting the original network.
        # Goal is to find the supermask.

        net = copy.deepcopy(base_classifier)
        net = net.to(device)

        # Generating the mask 'm'
        for layer in net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight))

                layer.weight.requires_grad = True
                layer.weight_mask.requires_grad = True

            # This is the monkey-patch overriding layer.forward to custom function.
            # layer.forward will pass nn.Linear with weights: 'w' and 'm' elementwised
            if isinstance(layer, nn.Linear):
                layer.forward = types.MethodType(mask_forward_linear, layer)

            if isinstance(layer, nn.Conv2d):
                layer.forward = types.MethodType(mask_forward_conv2d, layer)

        criterion = nn.CrossEntropyLoss().to(
            device)  # Criterion for training the mask.
        optimizer = SGD(net.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=0)
        # weight_decay = 0 for training the mask.
        #         warm_scheduler = StepLR(optimizer, step_size=args.epochs_mask-10, gamma=0.2)

        sparsity, total = 0, 0
        breakFlag = False
        net.train()
        # Training the mask with the training set.
        for epoch in range(100000):
            #             if epoch % 5 == 0:
            print("Current epochs: ", epoch)
            print("Sparsity: {:}".format(sparsity))
            log(logfilename, "Current epochs: {}".format(epoch))
            log(logfilename, "Sparsity: {:}".format(sparsity))

            for i, (inputs, targets) in enumerate(train_loader):
                inputs = inputs.cuda()
                targets = targets.cuda()

                reg_loss = 0
                for layer in net.modules():
                    if isinstance(layer, nn.Conv2d) or isinstance(
                            layer, nn.Linear):
                        reg_loss += torch.norm(layer.weight_mask, p=1)
                outputs = net(inputs)
                loss = criterion(outputs, targets) + args.alpha * reg_loss

                # Computing gradient and do SGD
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                sparsity, total = 0, 0
                for layer in net.modules():
                    if isinstance(layer, nn.Linear) or isinstance(
                            layer, nn.Conv2d):
                        boolean_list = layer.weight_mask.data > 1e-3
                        sparsity += (boolean_list == 1).sum()
                        total += layer.weight.numel()

                if i % 50 == 0:
                    print(
                        "Current Epochs: {}, Current i: {}, Current Sparsity: {}"
                        .format(epoch, i, sparsity))

                if sparsity <= total * args.keep_ratio:
                    print("Current epochs breaking loop at {:}".format(epoch))
                    log(logfilename,
                        "Current epochs breaking loop at {:}".format(epoch))
                    breakFlag = True
                    break
#                 if breakFlag == True:
#                     break
            if breakFlag == True:
                break


#                     print("W 1-norm: ", torch.norm(layer.weight_mask, p=1))

# Just checking the 1-norm of weights in each layer.
# Approximates how sparse the mask is..

# This line allows to calculate the threshold to satisfy the keep_ratio.
        c_abs = []
        for layer in net.modules():
            if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                c_abs.append(torch.abs(layer.weight_mask))

        all_scores = torch.cat([torch.flatten(x) for x in c_abs])
        num_params_to_keep = int(len(all_scores) * args.keep_ratio)
        threshold, _ = torch.topk(all_scores, num_params_to_keep, sorted=True)
        threshold = threshold[-1]

        print("Threshold found: ", threshold)

        keep_masks = []
        for c in c_abs:
            keep_masks.append((c >= threshold).float())
        print(
            "Number of ones.",
            torch.sum(torch.cat([torch.flatten(x == 1) for x in keep_masks])))

        torch.save(base_classifier.state_dict(),
                   os.path.join(args.outdir, args.save_model))
        base_classifier_acc = model_inference(base_classifier,
                                              test_loader,
                                              device,
                                              display=True)
        log(logfilename,
            "Weight Update Test Accuracy: {:.5}".format(base_classifier_acc))
        print("Saved the rewind model.")
        for masks in keep_masks:
            masks = masks.data

        torch.save(keep_masks, os.path.join(args.outdir, args.keep_mask))
        print("Saved the masking function.")
        log(logfilename, "Finished finding the mask. (REWIND)")
Ejemplo n.º 18
0
def main():
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not os.path.exists(args.outdir):
        os.mkdir(args.outdir)
    if (args.scale_down == 1 or args.dataset == "imagenet"):
        train_dataset = get_dataset(args.dataset, 'train')
        test_dataset = get_dataset(args.dataset, 'test')
    else:
        train_dataset = datasets.CIFAR10(
            "./dataset_cache",
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(int(32 / args.scale_down)),
                transforms.RandomCrop(int(32 / args.scale_down),
                                      padding=int(4 / args.scale_down)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ]))
        test_dataset = datasets.CIFAR10("./dataset_cache",
                                        train=False,
                                        download=True,
                                        transform=transforms.Compose([
                                            transforms.Resize(
                                                int(32 / args.scale_down)),
                                            transforms.ToTensor()
                                        ]))
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              batch_size=args.batch,
                              num_workers=args.workers,
                              pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset,
                             shuffle=False,
                             batch_size=args.batch,
                             num_workers=args.workers,
                             pin_memory=pin_memory)

    model = get_architecture(args.arch, args.dataset)
    #model = torch.nn.DataParallel(model)
    logfilename = os.path.join(args.outdir, 'log.txt')
    init_logfile(logfilename,
                 "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")

    criterion = CrossEntropyLoss()
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.lr_step_size,
                       gamma=args.gamma)

    for epoch in range(args.epochs):
        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, args.noise_sd)
        test_loss, test_acc = test(test_loader, model, criterion,
                                   args.noise_sd)
        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, str(datetime.timedelta(seconds=(after - before))),
                scheduler.get_lr()[0], train_loss, train_acc, test_loss,
                test_acc))
        scheduler.step(epoch)

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.outdir, 'checkpoint.pth.tar'))