def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='run approximation to LeNet on Mnist')
    parser.add_argument('--batch-size',
                        type=int,
                        default=256,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=100,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.0005)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--dropout-rate',
                        type=float,
                        default=0.5,
                        metavar='p_drop',
                        help='dropout rate')
    parser.add_argument(
        '--S',
        type=int,
        default=500,
        metavar='N',
        help='number of posterior samples from the Bayesian model')
    parser.add_argument(
        '--model-path',
        type=str,
        default='../saved_models/mnist_sgld/',
        metavar='N',
        help='number of posterior samples from the Bayesian model')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {}

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))])),
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    model = mnist_mlp(dropout=False).to(device)
    optimizer = SGLD(model.parameters(), lr=args.lr)

    import copy
    import pickle as pkl

    for epoch in range(1, args.epochs + 1):
        train_bayesian(args, model, device, train_loader, optimizer, epoch)
        print("epoch: {}".format(epoch))
        test(args, model, device, test_loader)

        # save models
        torch.save(model.state_dict(), args.model_path + 'sgld-mnist.pt')

    # save samples
    param_samples = []
    while (1):
        for idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            data = data.view(data.shape[0], -1)
            output = model(data)
            loss = F.nll_loss(F.log_softmax(output, dim=1), target)
            loss.backward()
            optimizer.step()
            param_samples.append(copy.deepcopy(model.state_dict()))
            if param_samples.__len__() >= args.S:
                print('1', len(param_samples))
                break
        if param_samples.__len__() >= args.S:
            print('2', len(param_samples))
            break
    with open(args.model_path + "sgld_samples.pkl", "wb") as f:
        print('3', len(param_samples))
        pkl.dump(param_samples, f)

    test(args, model, device, test_loader)

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))])),
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, ), (0.5, ))])),
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    # generate teacher train samples

    with torch.no_grad():
        # obtain ensemble outputs
        all_samples = []
        for i in range(500):
            samples_a_round = []
            model.load_state_dict(param_samples[i])
            for data, target in train_loader:
                data = data.to(device)
                data = data.view(data.shape[0], -1)
                output = F.softmax(model(data))
                samples_a_round.append(output)
            samples_a_round = torch.cat(samples_a_round).cpu()
            all_samples.append(samples_a_round)
        all_samples = torch.stack(all_samples).permute(1, 0, 2)

        torch.save(all_samples,
                   args.model_path + 'mnist-sgld-train-samples.pt')

    # generate teacher test  samples

    with torch.no_grad():
        # obtain ensemble outputs
        all_samples = []
        for i in range(500):
            samples_a_round = []
            model.load_state_dict(param_samples[i])
            for data, target in test_loader:
                data = data.to(device)
                data = data.view(data.shape[0], -1)
                output = F.softmax(model(data))
                samples_a_round.append(output)
            samples_a_round = torch.cat(samples_a_round).cpu()
            all_samples.append(samples_a_round)
        all_samples = torch.stack(all_samples).permute(1, 0, 2)

        torch.save(all_samples, args.model_path + 'mnist-sgld-test-samples.pt')

    # generate teacher omniglot samples

    ood_data = datasets.Omniglot(
        '../../data',
        download=True,
        transform=transforms.Compose([
            # transforms.ToPILImage(),
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, )),
        ]))

    ood_loader = torch.utils.data.DataLoader(ood_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             **kwargs)

    with torch.no_grad():
        # obtain ensemble outputs
        all_samples = []
        for i in range(500):
            samples_a_round = []
            model.load_state_dict(param_samples[i])
            for data, target in ood_loader:
                data = data.to(device)
                data = data.view(data.shape[0], -1)
                output = F.softmax(model(data))
                samples_a_round.append(output)
            samples_a_round = torch.cat(samples_a_round).cpu()
            all_samples.append(samples_a_round)
        all_samples = torch.stack(all_samples).permute(1, 0, 2)

        torch.save(all_samples,
                   args.model_path + 'mnist-sgld-omniglot-samples.pt')

    # generate teacher SEMEION samples

    ood_data = datasets.SEMEION(
        '../../data',
        download=True,
        transform=transforms.Compose([
            # transforms.ToPILImage(),
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, )),
        ]))

    ood_loader = torch.utils.data.DataLoader(ood_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             **kwargs)

    with torch.no_grad():
        # obtain ensemble outputs
        all_samples = []
        for i in range(500):
            samples_a_round = []
            model.load_state_dict(param_samples[i])
            for data, target in ood_loader:
                data = data.to(device)
                data = data.view(data.shape[0], -1)
                output = F.softmax(model(data))
                samples_a_round.append(output)
            samples_a_round = torch.cat(samples_a_round).cpu()
            all_samples.append(samples_a_round)
        all_samples = torch.stack(all_samples).permute(1, 0, 2)

        torch.save(all_samples,
                   args.model_path + 'mnist-sgld-SEMEION-samples.pt')
Beispiel #2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Amortized approximation on MNIST')
    parser.add_argument('--batch-size', type=int, default=256, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--approx-epochs', type=int, default=200, metavar='N',
                        help='number of epochs to approx (default: 10)')
    parser.add_argument('--lr', type=float, default=1e-2, metavar='LR',
                        help='learning rate (default: 0.0005)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--S', type=int, default=100, metavar='N',
                        help='number of posterior samples from the Bayesian model')
    parser.add_argument('--model-path', type=str, default='../saved_models/mnist_sgld/', metavar='N',
                        help='number of posterior samples from the Bayesian model')
    parser.add_argument('--from-approx-model', type=int, default=1, metavar='N',
                        help='if our model is loaded or trained')
    parser.add_argument('--test-ood-from-disk', type=int, default=1,
                        help='generate test samples or load from disk')
    parser.add_argument('--ood-name', type=str, default='omniglot',
                        help='name of the used ood dataset')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {}

    tr_data = MNIST('../data', train=True, transform=transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))]), download=True)

    te_data = MNIST('../data', train=False, transform=transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))]), download=True)

    train_loader = torch.utils.data.DataLoader(
        tr_data,
        batch_size=args.batch_size, shuffle=False, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        te_data,
        batch_size=args.batch_size, shuffle=False,  **kwargs)

    if args.ood_name == 'omniglot':
        ood_data = datasets.Omniglot('../../data', download=True, transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]))
    elif args.ood_name == 'SEMEION':
        ood_data = datasets.SEMEION('../../data', download=True,  transform=transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]))

    ood_loader = torch.utils.data.DataLoader(
        ood_data,
        batch_size=args.batch_size, shuffle=False, **kwargs)

    model = mnist_mlp(dropout=False).to(device)

    model.load_state_dict(torch.load(args.model_path + 'sgld-mnist.pt'))

    test(args, model, device, test_loader)

    if args.from_approx_model == 0:
        output_samples = torch.load(args.model_path + 'mnist-sgld-train-samples.pt')

    # --------------- training approx ---------
    fmodel = mnist_mlp_h().to(device)
    gmodel = mnist_mlp_g().to(device)

    if args.from_approx_model == 0:
        g_optimizer = optim.SGD(gmodel.parameters(), lr=args.lr)
        f_optimizer = optim.SGD(fmodel.parameters(), lr=args.lr)
        best_acc = 0
        for epoch in range(1, args.approx_epochs + 1):
            train_approx(args, fmodel, gmodel, device, train_loader, f_optimizer, g_optimizer, output_samples, epoch)
            acc = test(args, fmodel, device, test_loader)
            # if (args.save_approx_model == 1):
            if acc > best_acc:
                torch.save(fmodel.state_dict(), args.model_path + 'sgld-mnist-mmd-mean.pt')
                torch.save(gmodel.state_dict(), args.model_path + 'sgld-mnist-mmd-conc.pt')
                best_acc = acc

    else:
        fmodel.load_state_dict(torch.load(args.model_path + 'sgld-mnist-mmd-mean.pt'))
        gmodel.load_state_dict(torch.load(args.model_path + 'sgld-mnist-mmd-conc.pt'))

    print('generating teacher particles for testing&ood data ...')
    # generate particles for test and ood dataset
    model.train()
    if args.test_ood_from_disk == 1:
        teacher_test_samples = torch.load(args.model_path + 'mnist-sgld-test-samples.pt')
    else:
        with torch.no_grad():
            # obtain ensemble outputs
            all_samples = []
            for i in range(500):
                samples_a_round = []
                for data, target in test_loader:
                    data = data.to(device)
                    data = data.view(data.shape[0], -1)
                    output = F.softmax(model(data))
                    samples_a_round.append(output)
                samples_a_round = torch.cat(samples_a_round).cpu()
                all_samples.append(samples_a_round)
            all_samples = torch.stack(all_samples).permute(1,0,2)

            torch.save(all_samples, args.model_path + 'mnist-sgld-test-samples.pt')
            teacher_test_samples = all_samples

    if args.test_ood_from_disk == 1:
        teacher_ood_samples = torch.load(args.model_path + 'mnist-sgld-' + args.ood_name + '-samples.pt')
    else:
        with torch.no_grad():
            # obtain ensemble outputs
            all_samples = []
            for i in range(500):
                samples_a_round = []
                for data, target in ood_loader:
                    data = data.to(device)
                    data = data.view(data.shape[0], -1)
                    output = F.softmax(model(data))
                    samples_a_round.append(output)
                samples_a_round = torch.cat(samples_a_round).cpu()
                all_samples.append(samples_a_round)
            all_samples = torch.stack(all_samples).permute(1,0,2)

            torch.save(all_samples, args.model_path + 'mnist-sgld-' + args.ood_name + '-samples.pt')
            teacher_ood_samples = all_samples

    eval_approx(args, fmodel, gmodel, device, test_loader, ood_loader, teacher_test_samples, teacher_ood_samples)
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description='run approximation to LeNet on Mnist')
    parser.add_argument('--batch-size',
                        type=int,
                        default=512,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=300,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.0005)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument(
        '--S',
        type=int,
        default=500,
        metavar='N',
        help='number of posterior samples from the Bayesian model')
    parser.add_argument(
        '--model-path',
        type=str,
        default='../saved_models/mnist_mcdp/',
        metavar='N',
        help='number of posterior samples from the Bayesian model')
    parser.add_argument('--from-model',
                        type=int,
                        default=0,
                        metavar='N',
                        help='if our model is loaded or trained')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 8, 'pin_memory': False} if use_cuda else {}

    tr_data = MNIST('../data',
                    train=True,
                    transform=transforms.Compose([
                        transforms.Resize((28, 28)),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, ), (0.5, ))
                    ]),
                    download=True)

    te_data = MNIST('../data',
                    train=False,
                    transform=transforms.Compose([
                        transforms.Resize((28, 28)),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, ), (0.5, ))
                    ]),
                    download=True)

    train_loader = torch.utils.data.DataLoader(tr_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)

    test_loader = torch.utils.data.DataLoader(te_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              **kwargs)

    model = mnist_mlp(dropout=True).to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum)

    # --------------- train or load teacher -----------
    if (args.from_model == 1):
        print('loading teacher model ...')
        model.load_state_dict(torch.load(args.model_path + 'mcdp-mnist.pt'))
    else:
        print('training teacher model ...')
        schedule = [50, 100, 150, 200, 250]
        best = 0
        for epoch in range(1, args.epochs + 1):
            if epoch in schedule:
                for g in optimizer.param_groups:
                    g['lr'] *= 0.5
            train_bayesian(args, model, device, train_loader, optimizer, epoch)
            print("teacher training epoch: {}".format(epoch))
            test_acc = test(args, model, device, test_loader)
            if test_acc > best:
                torch.save(model.state_dict(),
                           args.model_path + 'mcdp-mnist.pt')
                best = test_acc

    train_loader = torch.utils.data.DataLoader(tr_data,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               **kwargs)

    print('generating particles for training data ...')
    # for an easier training of amortized approximation,
    # instead of sampling param. during approx,
    # get particles on simplex and store them first.
    with torch.no_grad():
        # obtain ensemble outputs
        all_samples = []
        for i in range(500):
            samples_a_round = []
            for data, target in train_loader:
                data = data.to(device)
                data = data.view(data.shape[0], -1)
                output = F.softmax(model(data))
                samples_a_round.append(output)
            samples_a_round = torch.cat(samples_a_round).cpu()
            all_samples.append(samples_a_round)
        all_samples = torch.stack(all_samples).permute(1, 0, 2)

        torch.save(all_samples, args.model_path + 'mnist-mcdp-samples.pt')