Example #1
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu'
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor()
    ])

    #trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train)
    #trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    #testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
    #testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    kwargs = {'num_workers':8,'pin_memory':False}
    #trainloader = torch.utils.data.DataLoader(datasets.MNIST('./data',train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),])),batch_size=args.batch_size,shuffle=True,**kwargs)
    #testloader = torch.utils.data.DataLoader(datasets.MNIST('./data',train=False,transform=transforms.Compose([transforms.ToTensor(),])),batch_size=args.batch_size,shuffle=True,**kwargs)

    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),])

    #dset = CustomImageFolder('data/CelebA', transform)
    #trainloader = torch.utils.data.DataLoader(dset,batch_size=args.batch_size,shuffle=True,num_workers=8,pin_memory=True,drop_last=True)
    #testloader = torch.utils.data.DataLoader(dset,batch_size=args.batch_size,shuffle=True,num_workers=8,pin_memory=True,drop_last=True)

    trainloader = torch.utils.data.DataLoader(datasets.CelebA('./data',split='train',download=True,transform=transform),batch_size=args.batch_size,shuffle=True,**kwargs)
    testloader = torch.utils.data.DataLoader(datasets.CelebA('./data',split='test',transform=transform),batch_size=args.batch_size,shuffle=True,**kwargs)

    # Model
    print('Building model..')
    net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples)
Example #2
0
 def test_invertibility_real_nvp(self):
     # ----------------------------------------------------------------------
     # Prepare some dummy data
     # ----------------------------------------------------------------------
     data = torch.rand(2, 8, 32, 32)
     # ----------------------------------------------------------------------
     # Prepare the layer with default init
     # ----------------------------------------------------------------------
     coupling = RealNVP(context_blocks=3,
                        input_channels=8,
                        hidden_channels=32,
                        quantization=65536)
     coupling.eval()
     # ----------------------------------------------------------------------
     # Assess the results are as expected
     # ----------------------------------------------------------------------
     out, log_det = coupling(data)
     back = coupling.reverse(out)
     error_reco = torch.mean(torch.abs(back - data)).item()
     # -------------
     self.assertLessEqual(error_reco, 1e-5)
     self.assertNotEqual(log_det.sum().item(), 0)
     # ----------------------------------------------------------------------
     # Apply and remove weight norm and check the results don't change
     # ----------------------------------------------------------------------
     coupling.apply_weight_norm()
     out2, log_det2 = coupling(data)
     coupling.remove_weight_norm()
     back2 = coupling.reverse(out)
     error_out = torch.mean(torch.abs(out2 - out)).item()
     error_back = torch.mean(torch.abs(back2 - back)).item()
     # -------------
     self.assertLessEqual(error_out, 5e-5)
     self.assertLessEqual(error_back, 1e-8)
Example #3
0
def main(args):
    base_path = util.make_directory(args.name, args.ow)
    device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu'
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    trainloader, testloader, in_channels = get_datasets(args)
    # Model
    print('Building model..')
    net = RealNVP(num_scales=args.num_scales, in_channels=in_channels, mid_channels=64, num_blocks=args.num_blocks)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark
        pass

    if args.resume:
        # Load checkpoint.
        ckpt_path = base_path / 'ckpts'
        best_path_ckpt = ckpt_path / 'best.pth.tar'
        print(f'Resuming from checkpoint at {best_path_ckpt}')
        checkpoint = torch.load(best_path_ckpt)
        net.load_state_dict(checkpoint['net'])
        global best_loss
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples, in_channels, base_path)
Example #4
0
 def test_real_nvp_save_load(self):
     # ----------------------------------------------------------------------
     # Data preparation and Run through dummy network
     # ----------------------------------------------------------------------
     data = torch.rand(2, 3, 32, 32)
     cwd = os.path.dirname(os.path.realpath(__file__))
     net = RealNVP(context_blocks=2,
                   input_channels=3,
                   hidden_channels=64,
                   quantization=65536)
     output, _ = net(data)
     # ----------------------------------------------------------------------
     # Save and load back model
     # ----------------------------------------------------------------------
     # Save the model configuration
     with open(os.path.join(cwd, "config.json"), 'w') as file:
         json.dump(net.config, file)
     # Save the model state dictionary
     filename = os.path.join(cwd, net.config["name"] + ".pt")
     torch.save(net.state_dict(), filename)
     # Load it back
     loaded_model, config = load_model(cwd)
     # ----------------------------------------------------------------------
     # Assert the output is as expected
     # ----------------------------------------------------------------------
     new_output, _ = loaded_model(data)
     error = torch.mean(torch.abs(new_output - output)).item()
     self.assertLessEqual(error, 5e-5)
     self.assertEqual(config, net.config)
Example #5
0
    def setup(self):
        self.elog.print("Config:")
        self.elog.print(self.config)

        # Prepare datasets
        transform = transforms.Compose(
            [transforms.ToTensor(), Dequantize(255)])
        self.dataset_train = datasets.MNIST(root="data/",
                                            download=True,
                                            transform=transform,
                                            train=True)
        self.dataset_test = datasets.MNIST(root="data/",
                                           download=True,
                                           transform=transform,
                                           train=False)

        try:
            self.dataset_train = torch.utils.data.Subset(
                self.dataset_train, np.arange(self.config.subset_size))
            self.dataset_test = torch.utils.data.Subset(
                self.dataset_test, np.arange(self.config.subset_size))
        except AttributeError:
            pass

        data_loader_kwargs = {
            'num_workers': 1,
            'pin_memory': True
        } if self.config.use_cuda else {}
        self.train_data_loader = DataLoader(self.dataset_train,
                                            batch_size=self.config.batch_size,
                                            shuffle=True,
                                            **data_loader_kwargs)
        self.test_data_loader = DataLoader(self.dataset_test,
                                           batch_size=self.config.batch_size,
                                           shuffle=True,
                                           **data_loader_kwargs)

        self.device = torch.device("cuda" if self.config.use_cuda else "cpu")

        self.model = RealNVP((1, 28, 28),
                             n_coupling=self.config.n_coupling,
                             n_filters=self.config.n_filters)
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.config.learning_rate,
                                    weight_decay=self.config.weight_decay)

        self.save_checkpoint(name="checkpoint_start")
        self.elog.print('Experiment set up.')
Example #6
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu'
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor()
    ])

    img_file = list(os.listdir('/Users/test/Desktop/Master_Thesis_Flow/real-nvp/data/GPR_Data/B12/B12_Pictures_LS'))
    print(type(img_file))
    
    trainset = GPRDataset(img_file, root_dir='data/GPR_Data/B12/B12_Pictures_LS', train=True, transform=transform_train)
    trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    testset = GPRDataset(img_file, root_dir='data/GPR_Data/B12/B12_Pictures_LS', train=False, transform=transform_test)
    testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # Model
    print('Building model..')
    net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples)
Example #7
0
 def test_training_real_nvp(self):
     # ----------------------------------------------------------------------
     # Prepare the layer with default init
     # ----------------------------------------------------------------------
     coupling = RealNVP(context_blocks=2,
                        input_channels=3,
                        hidden_channels=64,
                        quantization=65536)
     # ----------------------------------------------------------------------
     # Train for a couple of batches
     # ----------------------------------------------------------------------
     optimizer = torch.optim.Adam(coupling.parameters(), 0.0001)
     loss = NLLFlowLoss(sigma=1.0, quantization=65536, bits_per_dim=True)
     for _ in range(20):
         optimizer.zero_grad()
         data = torch.rand(2, 3, 32, 32)
         out, log_det = coupling(data)
         nll = loss(out, log_det)
         nll.backward()
         optimizer.step()
     # ----------------------------------------------------------------------
     # Assess the network is still invertible
     # ----------------------------------------------------------------------
     coupling.eval()
     data = torch.rand(2, 3, 32, 32)
     out, log_det = coupling(data)
     back = coupling.reverse(out)
     error_reco = torch.mean(torch.abs(back - data)).item()
     # -------------
     self.assertLessEqual(error_reco, 1e-5)
     self.assertNotEqual(log_det.sum().item(), 0)
Example #8
0
def test_flow():
    # ==== Basic functionality tests ====
    transform = transforms.Compose([transforms.ToTensor(), Dequantize(255)])
    dataset = datasets.MNIST(root="data/",
                             download=True,
                             transform=transform,
                             train=False)
    dataloader = DataLoader(dataset, batch_size=5)

    device = 'cpu'

    model = RealNVP((1, 28, 28), n_coupling=1, n_filters=100)

    d, target = next(iter(dataloader))

    # test flows
    with torch.no_grad():
        z, ld = model(d)
        recon = torch.sigmoid(model.inv_flow(z))
        abs_error = torch.abs(recon - d)
        rel_error = torch.abs(recon - d) / (torch.abs(d) + 1e-8)
    print(recon.allclose(d, atol=1e-6))
    print('Largest absolute error:')
    print(abs_error.max())
    print('between {} and reconstruction {}'.format(
        d.flatten()[abs_error.argmax()],
        recon.flatten()[abs_error.argmax()]))

    print('Largest relative error:')
    print(rel_error.max())
    print('between {} and reconstruction {}'.format(
        d.flatten()[rel_error.argmax()],
        recon.flatten()[rel_error.argmax()]))
    # for atol = 1e-6, this is true, but the relative error is still large. is this ok?

    return d, z, recon, ld
Example #9
0
    def setup(self):
        self.elog.print("Config:")
        self.elog.print(self.config)

        # Prepare datasets
        self.dataset_train = CelebLQ('hw2_q2.pkl',
                                     train=True,
                                     transform=Dequantize(3))
        self.dataset_test = CelebLQ('hw2_q2.pkl',
                                    train=False,
                                    transform=Dequantize(3))

        data_loader_kwargs = {
            'num_workers': 1,
            'pin_memory': True
        } if self.config.use_cuda else {}
        self.train_data_loader = DataLoader(self.dataset_train,
                                            batch_size=self.config.batch_size,
                                            shuffle=True,
                                            **data_loader_kwargs)
        self.test_data_loader = DataLoader(self.dataset_test,
                                           batch_size=self.config.batch_size,
                                           shuffle=True,
                                           **data_loader_kwargs)

        self.device = torch.device("cuda" if self.config.use_cuda else "cpu")

        self.model = RealNVP((3, 32, 32),
                             n_coupling=self.config.n_coupling,
                             n_filters=self.config.n_filters)
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.config.learning_rate)

        self.save_checkpoint(name="checkpoint_start")
        self.elog.print('Experiment set up.')
Example #10
0
def main(args):
    global best_loss
    global cnt_early_stop

    device = 'cuda' if torch.cuda.is_available() and len(args.gpu_ids) > 0 else 'cpu'
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        MyToTensor()
        # transforms.ToTensor()
    ])

    transform_test = transforms.Compose([
        MyToTensor()
        #transforms.ToTensor()
    ])

    assert DATASET in ['mnist', 'cifar10']
    if DATASET == 'mnist':
        dataset_picker = torchvision.datasets.MNIST
    else:
        dataset_picker = torchvision.datasets.CIFAR10

    trainset = dataset_picker(root='data', train=True, download=True, transform=transform_train)
    testset = dataset_picker(root='data', train=False, download=True, transform=transform_test)

    valset = copy.deepcopy(trainset)

    train_val_idx = np.random.choice(np.arange(trainset.data.shape[0]), size=N_TRAIN+N_VAL, replace=False)
    train_idx = train_val_idx[:N_TRAIN]
    val_idx = train_val_idx[N_TRAIN:]
    valset.data = valset.data[val_idx]
    trainset.data = trainset.data[train_idx]


    test_idx = np.random.choice(np.arange(testset.data.shape[0]), size=N_TEST, replace=False)
    testset.data = testset.data[test_idx]

    if DATASET == 'mnist':
        trainset.targets = trainset.targets[train_idx]
        valset.targets = valset.targets[val_idx]
        testset.targets = testset.targets[test_idx]
    else:
        trainset.targets = np.array(trainset.targets)[train_idx]
        valset.targets = np.array(valset.targets)[val_idx]
        testset.targets = np.array(testset.targets)[test_idx]

    # noisytestset = copy.deepcopy(testset)
    if DATASET == 'mnist':
        trainset.data, trainset.targets = get_noisy_data(trainset.data, trainset.targets)
        valset.data, valset.targets = get_noisy_data(valset.data, valset.targets)
        testset.data, testset.targets = get_noisy_data(testset.data, testset.targets)

    else:
        noisy_samples = np.random.rand(N_TEST * N_NOISY_SAMPLES_PER_TEST_SAMPLE, 32, 32, 3) - 0.5
        noisytestset.data = np.tile( noisytestset.data, (N_NOISY_SAMPLES_PER_TEST_SAMPLE, 1, 1, 1)) + noisy_samples
        noisytestset.targets = np.tile(noisytestset.targets, [N_NOISY_SAMPLES_PER_TEST_SAMPLE])

    trainloader = data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    testloader = data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    valloader = data.DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    # noisytestloader = data.DataLoader(noisytestset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # Model
    print('Building model..')
    if DATASET == 'mnist':
        net = RealNVP(num_scales=2, in_channels=1, mid_channels=64, num_blocks=8)
    else:
        net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        best_loss = checkpoint['val_loss']
        start_epoch = checkpoint['epoch']

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net, args.weight_decay, norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn, args.max_grad_norm)
        prev_best_loss = best_loss
        test(epoch, net, valloader, device, loss_fn, args.num_samples, 'val')
        if best_loss < prev_best_loss:
            cnt_early_stop = 0
        else:
            cnt_early_stop += 1
        if cnt_early_stop >= PATIENCE:
            break
        # test(epoch, net, testloader, device, loss_fn, args.num_samples, 'test')
        # test(epoch, net, noisytestloader, device, loss_fn, args.num_samples, 'noisytest')

    checkpoint = torch.load('ckpts/best.pth.tar')
    net.load_state_dict(checkpoint['net'])

    pixelwise_ll = -pixelwise_test(net, testloader, device, loss_fn, args.num_samples)
    pixelwise_ll = pixelwise_ll.reshape([-1, 28, 28])
    os.makedirs('pixelwise_loglikelihood', exist_ok=True)
    for i in range(len(pixelwise_ll)):
        tmp = np.exp( pixelwise_ll[i] )
        tmp = 255 * ( tmp / np.max(tmp) )
        im = Image.fromarray(tmp)
        im.convert('RGB').save('pixelwise_loglikelihood/' + str(i) + '.png')
    for i, (x,_) in enumerate(testloader):
        x_np = np.array(x.cpu(), dtype=np.float)
        for j in range(args.num_samples):
            im = Image.fromarray(255 * x_np[j].reshape([28,28]))
            im.convert('RGB').save('pixelwise_loglikelihood/' + str(j) + '-orig.png')
        break
    test(epoch, net, testloader, device, loss_fn, args.num_samples, 'test')
Example #11
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() and len(
        args.gpu_ids) > 0 else 'cpu'
    print(device)
    print(torch.cuda.device_count())
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
    print(torch.backends.cudnn.version())
    print(torch.version.cuda)
    start_epoch = 0

    # Note: No normalization applied, since RealNVP expects inputs in (0, 1).
    transform_train = transforms.Compose(
        [transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])

    transform_test = transforms.Compose([transforms.ToTensor()])
    print(os.getcwd())
    trainset = torchvision.datasets.CIFAR10(root='data',
                                            train=True,
                                            download=True,
                                            transform=transform_train)
    trainloader = data.DataLoader(trainset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)

    testset = torchvision.datasets.CIFAR10(root='data',
                                           train=False,
                                           download=True,
                                           transform=transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers)

    # Model
    print('Building model..')
    net = RealNVP(num_scales=2, in_channels=3, mid_channels=64, num_blocks=8)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net, args.gpu_ids)
        cudnn.benchmark = args.benchmark

    if args.resume:
        # Load checkpoint.
        print('Resuming from checkpoint at ckpts/best.pth.tar...')
        assert os.path.isdir('ckpts'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('ckpts/best.pth.tar')
        net.load_state_dict(checkpoint['net'])
        global best_loss
        best_loss = checkpoint['test_loss']
        start_epoch = checkpoint['epoch']

    total_num_parameters = 0
    for name, value in net.named_parameters():
        print(name, value.shape)
        total_num_parameters += np.prod(value.shape)
        print(name + ' has paramteres: ' + str(np.prod(value.shape)))

    print('Total number of parameters', total_num_parameters)

    loss_fn = RealNVPLoss()
    param_groups = util.get_param_groups(net,
                                         args.weight_decay,
                                         norm_suffix='weight_g')
    optimizer = optim.Adam(param_groups, lr=args.lr)

    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, trainloader, device, optimizer, loss_fn,
              args.max_grad_norm)
        test(epoch, net, testloader, device, loss_fn, args.num_samples)
Example #12
0
class MNISTExperiment(PytorchExperiment):
    def setup(self):
        self.elog.print("Config:")
        self.elog.print(self.config)

        # Prepare datasets
        transform = transforms.Compose(
            [transforms.ToTensor(), Dequantize(255)])
        self.dataset_train = datasets.MNIST(root="data/",
                                            download=True,
                                            transform=transform,
                                            train=True)
        self.dataset_test = datasets.MNIST(root="data/",
                                           download=True,
                                           transform=transform,
                                           train=False)

        try:
            self.dataset_train = torch.utils.data.Subset(
                self.dataset_train, np.arange(self.config.subset_size))
            self.dataset_test = torch.utils.data.Subset(
                self.dataset_test, np.arange(self.config.subset_size))
        except AttributeError:
            pass

        data_loader_kwargs = {
            'num_workers': 1,
            'pin_memory': True
        } if self.config.use_cuda else {}
        self.train_data_loader = DataLoader(self.dataset_train,
                                            batch_size=self.config.batch_size,
                                            shuffle=True,
                                            **data_loader_kwargs)
        self.test_data_loader = DataLoader(self.dataset_test,
                                           batch_size=self.config.batch_size,
                                           shuffle=True,
                                           **data_loader_kwargs)

        self.device = torch.device("cuda" if self.config.use_cuda else "cpu")

        self.model = RealNVP((1, 28, 28),
                             n_coupling=self.config.n_coupling,
                             n_filters=self.config.n_filters)
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.config.learning_rate,
                                    weight_decay=self.config.weight_decay)

        self.save_checkpoint(name="checkpoint_start")
        self.elog.print('Experiment set up.')

    def train(self, epoch):
        self.model.train()
        # loop over batches
        for batch_idx, (x, _) in enumerate(self.train_data_loader):
            self.optimizer.zero_grad()
            if self.config.use_cuda:
                x = x.cuda()
            _, ll = self.model(x)
            loss = -ll.mean()

            loss.backward()
            self.optimizer.step()

            if batch_idx % self.config.log_interval == 0:
                reported_loss = loss.item() / np.log(2) / np.prod(x.shape[1:])
                # plot train loss
                self.add_result(value=reported_loss,
                                name='Train_Loss',
                                counter=epoch +
                                batch_idx / len(self.train_data_loader),
                                tag='Loss')
                # log train batch loss and progress
                self.clog.show_text(
                    'Train Epoch: {} [{}/{} samples ({:.0f}%)]\t Batch Loss: {:.6f}'
                    .format(epoch, batch_idx * len(x),
                            len(self.train_data_loader.dataset),
                            100. * batch_idx / len(self.train_data_loader),
                            reported_loss),
                    name="log")

                self.clog.show_image_grid(
                    x,
                    name="training minibatch",
                    n_iter=epoch + batch_idx / len(self.train_data_loader),
                    iter_format="{:0.02f}")
                self.save_checkpoint(name="checkpoint", n_iter=batch_idx)

    def validate(self, epoch):
        self.model.eval()

        with torch.no_grad():
            validation_loss = 0
            n_data = 0
            for batch_idx, (x, _) in enumerate(self.test_data_loader):
                self.optimizer.zero_grad()
                if self.config.use_cuda:
                    x = x.cuda()
                _, ll = self.model(x)
                validation_loss += -ll.sum()
                n_data += len(x)
            validation_loss /= n_data

            # # get some samples
            # samples = self.model.sample(5, (28, 28), device=self.device)

        # if samples.ndim == 4:
        #     samples = np.transpose(samples, (0, 2, 3, 1))
        # self.clog.show_image_grid(
        #     samples, name="Samples", n_iter=epoch + batch_idx / len(self.train_data_loader),
        #     iter_format="{:0.02f}")

        reported_loss = validation_loss.item() / np.log(2) / np.prod(
            self.model.img_dim)
        # plot the test loss
        self.add_result(value=reported_loss,
                        name='Validation_Loss',
                        counter=epoch + 1,
                        tag='Loss')

        # log validation loss and accuracy
        self.elog.print(
            '\nValidation set: Average loss: {:.4f})\n'.format(reported_loss))

        self.save_checkpoint(name="checkpoint", n_iter=batch_idx)