Exemplo n.º 1
0
            break
        model_name = os.path.join(args.load, args.dataset + '_' + args.model + '_' + args.model +
                                  '_baseline_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            start_epoch = i + 1
            break
    if start_epoch == 0:
        assert False, "could not resume"

if args.ngpu > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

if args.ngpu > 0:
    net.cuda()
    torch.cuda.manual_seed(1)

cudnn.benchmark = True  # fire on all cylinders

net.eval()

concat = lambda x: np.concatenate(x, axis=0)
to_np = lambda x: x.data.to('cpu').numpy()

def evaluate(loader):
    confidence = []
    correct = []

    num_correct = 0
    with torch.no_grad():
Exemplo n.º 2
0
def main():

    train_transform = trn.Compose([
        trn.RandomHorizontalFlip(),
        trn.RandomCrop(32, padding=4),
        trn.ToTensor(),
        trn.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    test_transform = trn.Compose([
        trn.ToTensor(),
        trn.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    if args.dataset == 'cifar10':
        print("Using CIFAR 10")
        train_data_in = dset.CIFAR10('/data/sauravkadavath/cifar10-dataset',
                                     train=True,
                                     transform=train_transform)
        test_data = dset.CIFAR10('/data/sauravkadavath/cifar10-dataset',
                                 train=False,
                                 transform=test_transform)
        num_classes = 10
    else:
        print("Using CIFAR100")
        train_data_in = dset.CIFAR100('/data/sauravkadavath/cifar10-dataset',
                                      train=True,
                                      transform=train_transform)
        test_data = dset.CIFAR100('/data/sauravkadavath/cifar10-dataset',
                                  train=False,
                                  transform=test_transform)
        num_classes = 100

    train_loader_in = torch.utils.data.DataLoader(train_data_in,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=args.prefetch,
                                                  pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.prefetch,
                                              pin_memory=True)

    net = WideResNet(args.layers,
                     num_classes,
                     args.widen_factor,
                     dropRate=args.droprate)
    net.cuda()

    optimizer = torch.optim.SGD(net.parameters(),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    def cosine_annealing(step, total_steps, lr_max, lr_min):
        return lr_min + (lr_max - lr_min) * 0.5 * (
            1 + np.cos(step / total_steps * np.pi))

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: cosine_annealing(
            step,
            args.epochs * len(train_loader_in),
            1,  # since lr_lambda computes multiplicative factor
            1e-6 / args.learning_rate))

    # Make save directory
    if not os.path.exists(args.save):
        os.makedirs(args.save)
    if not os.path.isdir(args.save):
        raise Exception('%s is not a dir' % args.save)

        print('Beginning Training\n')

    with open(os.path.join(args.save, "training_log.csv"), 'w') as f:
        f.write()

    # Main loop
    for epoch in range(0, args.epochs):
        state['epoch'] = epoch

        begin_epoch = time.time()

        train(net, state, train_loader_in, optimizer, lr_scheduler)
        test(net, state, test_loader)

        # Save model
        torch.save(
            net.state_dict(),
            os.path.join(
                args.save,
                '{0}_{1}_layers_{2}_widenfactor_{3}_transform_epoch_{4}.pt'.
                format(args.dataset, args.model, str(args.layers),
                       str(args.widen_factor), str(epoch))))

        # Let us not waste space and delete the previous model
        prev_path = os.path.join(
            args.save,
            '{0}_{1}_layers_{2}_widenfactor_{3}_transform_epoch_{4}.pt'.format(
                args.dataset, args.model, str(args.layers),
                str(args.widen_factor), str(epoch - 1)))

        if os.path.exists(prev_path):
            os.remove(prev_path)

        # Show results
        print(
            'Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} | Test Error {4:.2f}'
            .format((epoch + 1), int(time.time() - begin_epoch),
                    state['train_loss'], state['test_loss'],
                    100 - 100. * state['test_accuracy']))