Exemplo n.º 1
0
       'te_acc_zero_mean': '.4f',
       'te_acc_perm_sigma_ens': '.4f',
       'te_acc_zero_mean_ens': '.4f',
       'te_nll_ens100': '.4f',
       'te_nll_stoch': '.4f',
       'te_nll_ens10': '.4f',
       'te_nll_perm_sigma': '.4f',
       'te_nll_zero_mean': '.4f',
       'te_nll_perm_sigma_ens': '.4f',
       'te_nll_zero_mean_ens': '.4f',
       'time': '.3f'}
logger = Logger("lenet5-VDO", fmt=fmt)

net = LeNet5()
net.cuda()
logger.print(net)

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=200,
                                          shuffle=True, num_workers=4, pin_memory=True)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=200,
                                         shuffle=False, num_workers=4, pin_memory=True)

criterion = metrics.SGVLB(net, 60000.).cuda()
optimizer = optim.Adam(net.parameters(), lr=0.001)

epochs = 200
Exemplo n.º 2
0

fmt = {'tr_loss': '3.1e',
       'tr_acc': '.4f',
       'te_acc_det': '.4f',
       'te_acc_stoch': '.4f',
       'te_acc_ens': '.4f',
       'te_nll_det': '.4f',
       'te_nll_stoch': '.4f',
       'te_nll_ens': '.4f',
       'time': '.3f'}
logger = Logger("lenet5-DO", fmt=fmt)

net = LeNet5()
net.cuda()
logger.print(net)

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=200,
                                          shuffle=True, num_workers=4, pin_memory=True)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=200,
                                         shuffle=False, num_workers=4, pin_memory=True)

criterion = metrics.SGVLB(net, 60000.).cuda()
optimizer = optim.Adam(net.parameters(), lr=0.001)

epochs = 200
Exemplo n.º 3
0
        'cuda:{}'.format(args.device) if torch.cuda.is_available() else 'cpu')

    model = LeNetVDO(args).to(args.device)

    args.batch_size, args.test_batch_size = 32, 32
    train_loader, test_loader = load_mnist(args)
    args.data_size = len(train_loader.dataset)

    for layer in model.children():
        i = 0
        if hasattr(layer, 'log_alpha'):
            fmt.update({'{}log_alpha'.format(i + 1): '3.3e'})
            i += 1

    logger = Logger('lenet-vdo', fmt=fmt)
    logger.print(args)
    logger.print(model)

    criterion = ClassificationLoss(model, args)
    optimizer = torch.optim.Adam(
        [p for p in model.parameters() if p.requires_grad], lr=args.lr)

    for epoch in range(args.epochs):
        t0 = time()

        model.train()
        model.set_flag('zero_mean', False)
        criterion.step()

        elbo, cat_mean, kls, accuracy = [], [], [], []
        for data, y_train in train_loader:
Exemplo n.º 4
0
def main():
    fmt = {
        'tr_loss': '3.1e',
        'tr_acc': '.4f',
        'te_acc_det': '.4f',
        'te_acc_stoch': '.4f',
        'te_acc_ens': '.4f',
        'te_acc_perm_sigma': '.4f',
        'te_acc_zero_mean': '.4f',
        'te_acc_perm_sigma_ens': '.4f',
        'te_acc_zero_mean_ens': '.4f',
        'te_nll_det': '.4f',
        'te_nll_stoch': '.4f',
        'te_nll_ens': '.4f',
        'te_nll_perm_sigma': '.4f',
        'te_nll_zero_mean': '.4f',
        'te_nll_perm_sigma_ens': '.4f',
        'te_nll_zero_mean_ens': '.4f',
        'time': '.3f'
    }
    fmt = {**fmt, **{'la%d' % i: '.4f' for i in range(4)}}
    args = get_args()
    logger = Logger("lenet5-VDO", fmt=fmt)

    trainset = torchvision.datasets.MNIST(root='./data',
                                          train=True,
                                          download=True,
                                          transform=transforms.ToTensor())
    train_sampler = torch.utils.data.BatchSampler(
        torch.utils.data.RandomSampler(trainset),
        batch_size=args.batch_size,
        drop_last=False)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_sampler=train_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True)

    testset = torchvision.datasets.MNIST(root='./data',
                                         train=False,
                                         download=True,
                                         transform=transforms.ToTensor())
    test_sampler = torch.utils.data.BatchSampler(
        torch.utils.data.SequentialSampler(testset),
        batch_size=args.batch_size,
        drop_last=False)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_sampler=test_sampler,
                                             num_workers=args.workers,
                                             pin_memory=True)

    net = LeNet5()
    net = net.to(device=args.device, dtype=args.dtype)
    if args.print_model:
        logger.print(net)
    criterion = metrics.SGVLB(net, len(trainset)).to(device=args.device,
                                                     dtype=args.dtype)
    optimizer = optim.Adam(net.parameters(), lr=args.learning_rate)

    epochs = args.epochs
    lr_start = args.learning_rate
    for epoch in trange(epochs):  # loop over the dataset multiple times
        t0 = time()
        utils.adjust_learning_rate(
            optimizer, metrics.lr_linear(epoch, 0, epochs, lr_start))
        net.train()
        training_loss = 0
        accs = []
        steps = 0
        for i, (inputs, labels) in enumerate(tqdm(trainloader), 0):
            steps += 1
            inputs, labels = inputs.to(
                device=args.device,
                dtype=args.dtype), labels.to(device=args.device)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            accs.append(metrics.logit2acc(
                outputs.data,
                labels))  # probably a bad way to calculate accuracy
            training_loss += loss.item()

        logger.add(epoch, tr_loss=training_loss / steps, tr_acc=np.mean(accs))

        # Deterministic test
        net.eval()
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=1)
        logger.add(epoch, te_nll_det=nll, te_acc_det=acc)

        # Stochastic test
        net.train()
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=1)
        logger.add(epoch, te_nll_stoch=nll, te_acc_stoch=acc)

        # Test-time averaging
        net.train()
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=20)
        logger.add(epoch, te_nll_ens=nll, te_acc_ens=acc)

        # Zero-mean
        net.train()
        net.dense1.set_flag('zero_mean', True)
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=1)
        net.dense1.set_flag('zero_mean', False)
        logger.add(epoch, te_nll_zero_mean=nll, te_acc_zero_mean=acc)

        # Permuted sigmas
        net.train()
        net.dense1.set_flag('permute_sigma', True)
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=1)
        net.dense1.set_flag('permute_sigma', False)
        logger.add(epoch, te_nll_perm_sigma=nll, te_acc_perm_sigma=acc)

        # Zero-mean test-time averaging
        net.train()
        net.dense1.set_flag('zero_mean', True)
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=20)
        net.dense1.set_flag('zero_mean', False)
        logger.add(epoch, te_nll_zero_mean_ens=nll, te_acc_zero_mean_ens=acc)

        # Permuted sigmas test-time averaging
        net.train()
        net.dense1.set_flag('permute_sigma', True)
        acc, nll = utils.evaluate(net,
                                  testloader,
                                  device=args.device,
                                  num_ens=20)
        net.dense1.set_flag('permute_sigma', False)
        logger.add(epoch, te_nll_perm_sigma_ens=nll, te_acc_perm_sigma_ens=acc)

        logger.add(epoch, time=time() - t0)
        las = [
            np.mean(net.conv1.log_alpha.data.cpu().numpy()),
            np.mean(net.conv2.log_alpha.data.cpu().numpy()),
            np.mean(net.dense1.log_alpha.data.cpu().numpy()),
            np.mean(net.dense2.log_alpha.data.cpu().numpy())
        ]

        logger.add(epoch, **{'la%d' % i: las[i] for i in range(4)})
        logger.iter_info()
        logger.save(silent=True)
        torch.save(net.state_dict(), logger.checkpoint)

    logger.save()