Exemplo n.º 1
0
def fit_gaia_lim_sgd(datafile, output_prefix, K, batch_size, epochs, lr,
                     w_reg, k_means_iters, lr_step, lr_gamma,
                     use_cuda):
    data = np.load(datafile)

    if use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    train_data = DeconvDataset(
        torch.Tensor(data['X_train']),
        torch.Tensor(data['C_train'])
    )

    val_data = DeconvDataset(
        torch.Tensor(data['X_val']),
        torch.Tensor(data['C_val'])
    )

    gmm = SGDDeconvGMM(
        K,
        7,
        device=device,
        batch_size=batch_size,
        epochs=epochs,
        w=w_reg,
        k_means_iters=k_means_iters,
        lr=lr,
        lr_step=lr_step,
        lr_gamma=lr_gamma
    )
    start_time = time.time()
    gmm.fit(train_data, val_data=val_data, verbose=True)
    end_time = time.time()

    train_score = gmm.score_batch(train_data)
    val_score = gmm.score_batch(val_data)

    print('Training score: {}'.format(train_score))
    print('Val score: {}'.format(val_score))

    results = {
        'start_time': start_time,
        'end_time': end_time,
        'train_score': train_score,
        'val_score': val_score,
        'train_curve': gmm.train_loss_curve,
        'val_curve': gmm.val_loss_curve
    }
    json.dump(results, open(str(output_prefix) + '_results.json', mode='w'))
    torch.save(gmm.module.state_dict(), output_prefix + '_params.pkl')
def check_minibatch_k_means(plot=True):

    x = np.random.randn(200, 2)
    x[:100, :] += np.array([-5, 0])
    x[100:, :] += np.array([5, 0])

    noise_covars = np.zeros((200, 2, 2))

    data = DeconvDataset(torch.Tensor(x.astype(np.float32)),
                         torch.Tensor(noise_covars.astype(np.float32)))

    loader = data_utils.DataLoader(data,
                                   batch_size=20,
                                   num_workers=4,
                                   shuffle=True)

    counts, centroids = minibatch_k_means(loader, 2)

    print(centroids)
    print(counts)

    if plot:
        fig, ax = plt.subplots()
        ax.scatter(x[:, 0], x[:, 1])

        plt.show()
def fit_gaia_lim_sgd(datafile, use_cuda=False):
    data = np.load(datafile)

    if use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    train_data = DeconvDataset(torch.Tensor(data['X_train']),
                               torch.Tensor(data['L_train']))

    val_data = DeconvDataset(torch.Tensor(data['X_val']),
                             torch.Tensor(data['L_val']))

    svi = SVIFlow(7, 5, device=device, batch_size=512, epochs=40, lr=1e-4)
    svi.fit(train_data, val_data=val_data)

    val_log_prob = svi.score_batch(val_data, log_prob=True)

    print('Val log prob: {}'.format(val_log_prob / len(val_data)))
fig, ax = plt.subplots()
ax.scatter(X_noisy[:, 0], X_noisy[:, 1])
ax.scatter(X[:, 0], X[:, 1])
ax.set_xlim(-20, 20)
ax.set_ylim(-20, 20)
plt.show()

X_train = X_noisy[:(2 * N), :]
X_test = X_noisy[(2 * N):, :]

nc_train = S[:(2 * N), :, :]
nc_test = S[(2 * N):, :, :]

train_data = DeconvDataset(
    torch.Tensor(X_train.reshape(-1, D).astype(np.float32)),
    torch.Tensor(nc_train.reshape(-1, D, D).astype(np.float32)))

test_data = DeconvDataset(
    torch.Tensor(X_test.reshape(-1, D).astype(np.float32)),
    torch.Tensor(nc_test.reshape(-1, D, D).astype(np.float32)))

svi = SVIFlow(D, 5, device=device, batch_size=512, epochs=50, lr=1e-4)
svi.fit(train_data, val_data=None)

test_log_prob = svi.score_batch(test_data, log_prob=True)

print('Test log prob: {}'.format(test_log_prob / len(test_data)))

gmm = SGDDeconvGMM(K, D, device=device, batch_size=256, epochs=50, lr=1e-1)
gmm.fit(train_data, val_data=test_data, verbose=True)
Exemplo n.º 5
0
    iw_params.append(path)

svi = SVIFlow(2,
              5,
              device=torch.device('cuda'),
              batch_size=512,
              epochs=100,
              lr=1e-4,
              n_samples=50,
              use_iwae=False,
              context_size=64,
              hidden_features=128)

results = []

test_data = DeconvDataset(x_test.squeeze(), torch.cholesky(S.repeat(N, 1, 1)))

torch.set_default_tensor_type(torch.cuda.FloatTensor)

for p in pretrained_params:
    svi.model.load_state_dict(torch.load(p))
    with torch.no_grad():
        logv = svi.model._prior.log_prob(z_test[0].to(
            torch.device('cuda'))).mean().item()
    elbo = svi.score_batch(test_data, num_samples=100) / N
    logp = svi.score_batch(test_data, num_samples=100, log_prob=True) / N

    results.append({
        'i': 50,
        'model': 'pretrained',
        'elbo': elbo,
def main():
    if args.data == 'boston':
        data = np.load('data_small/boston_no_discrete.npy')
    elif args.data == 'white_wine':
        data = np.load('data_small/white_no_discrete_no_corr_0.98.npy')
    elif args.data == 'red_wine':
        data = np.load('data_small/red_no_discrete_no_corr_0.98.npy')
    n_features = data.shape[1]
    n_train = int(data.shape[0] * 0.9)
    train_data_clean = util_shuffle(data[:n_train])
    test_data = data[:n_train]
    kf = KFold(n_splits=10)

    covar = np.diag(args.covar * np.ones((n_features, )))

    train_data = train_data_clean + \
        np.random.multivariate_normal(mean=np.zeros(
            (n_features,)), cov=covar, size=n_train)

    # train_covars = np.repeat(
    #     covar[np.newaxis, :, :], n_train, axis=0)

    # train_dataset = DeconvDataset(train_data, train_covars)
    for i, (train_index, eval_index) in enumerate(kf.split(train_data)):
        X_train, X_eval = train_data[train_index], train_data[eval_index]
        train_covars = np.repeat(covar[np.newaxis, :, :],
                                 X_train.shape[0],
                                 axis=0)
        eval_covars = np.repeat(covar[np.newaxis, :, :],
                                X_eval.shape[0],
                                axis=0)

        train_dataset = DeconvDataset(X_train, train_covars)
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True)

        eval_dataset = DeconvDataset(X_eval, eval_covars)
        eval_loader = DataLoader(eval_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False)

        model = SVIFlowToy(dimensions=n_features,
                           objective=args.objective,
                           posterior_context_size=n_features,
                           batch_size=args.batch_size,
                           device=device,
                           maf_steps_prior=args.flow_steps_prior,
                           maf_steps_posterior=args.flow_steps_posterior,
                           maf_features=args.maf_features,
                           maf_hidden_blocks=args.maf_hidden_blocks,
                           K=args.K)

        message = 'Total number of parameters: %s' % (sum(
            p.numel() for p in model.parameters()))
        logger.info(message)

        optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)

        # training
        scheduler = [30]
        epoch = 0
        best_model = copy.deepcopy(model.state_dict())

        best_eval_loss = compute_eval_loss(model, eval_loader, device,
                                           len(eval_index))
        n_epochs_not_improved = 0

        model.train()
        while n_epochs_not_improved < scheduler[-1] and epoch < args.n_epochs:
            for batch_idx, data in enumerate(train_loader):
                data[0] = data[0].to(device)
                data[1] = data[1].to(device)

                for prior_params in model.model._prior.parameters():
                    prior_params.requires_grad = True

                for post_params in model.model._approximate_posterior.parameters(
                ):
                    post_params.requires_grad = False

                for i in range(args.prior_iter):

                    loss = -model.score(data).mean()
                    message = 'Loss prior %s: %f' % (i, loss)
                    logger.info(message)
                    optimizer.zero_grad()
                    loss.backward(retain_graph=True)
                    optimizer.step()

                for prior_params in model.model._prior.parameters():
                    prior_params.requires_grad = False

                for post_params in model.model._approximate_posterior.parameters(
                ):
                    post_params.requires_grad = True

                for i in range(args.posterior_iter):
                    loss = -model.score(data).mean()
                    message = 'Loss posterior %s: %f' % (i, loss)
                    logger.info(message)
                    optimizer.zero_grad()
                    loss.backward(retain_graph=True)
                    optimizer.step()

            model.eval()
            test_loss_clean = - \
                model.model._prior.log_prob(
                    torch.from_numpy(test_data).to(device)).mean()
            message = 'Test loss (clean) = %.5f' % test_loss_clean
            logger.info(message)
            eval_loss = compute_eval_loss(model, eval_loader, device,
                                          len(eval_index))

            if eval_loss < best_eval_loss:
                best_model = copy.deepcopy(model.state_dict())
                best_eval_loss = eval_loss
                n_epochs_not_improved = 0

            else:
                n_epochs_not_improved += 1

            model.train()
            epoch += 1
        break

        model = model.load_state_dict(best_model)
        test_loss_clean = - \
            model.model._prior.log_prob(
                torch.from_numpy(test_data).to(device)).mean()
        message = 'Final test loss (clean) = %.5f' % test_loss_clean
        logger.info(message)
Exemplo n.º 7
0
def check_online_deconv_gmm(D, K, N, plot=False, device=None, verbose=False):

    if not device:
        device = torch.device('cpu')

    data, params = generate_data(D, K, N)
    X_train, nc_train, X_test, nc_test = data
    means, covars = params

    train_data = DeconvDataset(
        torch.Tensor(X_train.reshape(-1, D).astype(np.float32)),
        torch.Tensor(nc_train.reshape(-1, D, D).astype(np.float32)))

    test_data = DeconvDataset(
        torch.Tensor(X_test.reshape(-1, D).astype(np.float32)),
        torch.Tensor(nc_test.reshape(-1, D, D).astype(np.float32)))

    gmm = OnlineDeconvGMM(K,
                          D,
                          device=device,
                          batch_size=500,
                          step_size=1e-1,
                          restarts=1,
                          epochs=20,
                          k_means_iters=20,
                          lr_step=10,
                          w=1e-3)
    gmm.fit(train_data, val_data=test_data, verbose=verbose)

    train_score = gmm.score_batch(train_data)
    test_score = gmm.score_batch(test_data)

    print('Training score: {}'.format(train_score))
    print('Test score: {}'.format(test_score))

    if plot:
        fig, ax = plt.subplots()

        ax.plot(gmm.train_ll_curve, label='Training LL')
        ax.plot(gmm.val_ll_curve, label='Validation LL')
        ax.legend()

        plt.show()

        fig, ax = plt.subplots()

        for i in range(K):
            sc = ax.scatter(X_train[:, i, 0],
                            X_train[:, i, 1],
                            alpha=0.2,
                            marker='x',
                            label='Cluster {}'.format(i))
            plot_covariance(means[i, :],
                            covars[i, :, :],
                            ax,
                            color=sc.get_facecolor()[0])

        sc = ax.scatter(gmm.means[:, 0],
                        gmm.means[:, 1],
                        marker='+',
                        label='Fitted Gaussians')

        for i in range(K):
            plot_covariance(gmm.means[i, :],
                            gmm.covars[i, :, :],
                            ax,
                            color=sc.get_facecolor()[0])

        ax.legend()
        plt.show()
Exemplo n.º 8
0
def main():
    train_covar = covar_gen(args.covar, args.n_train_points).astype(np.float32)
    train_data_clean = data_gen(args.data,
                                args.n_train_points)[0].astype(np.float32)

    # plt.scatter(train_data_clean[:, 0], train_data_clean[:, 1])

    train_data = np.zeros_like(train_data_clean)
    for i in range(args.n_train_points):
        train_data[i] = train_data_clean[i] + np.random.multivariate_normal(
            mean=np.zeros((2, )), cov=train_covar[i])

    # plt.scatter(train_data[:, 0], train_data[:, 1])
    # plt.show()

    train_covar = torch.from_numpy(train_covar)
    train_data = torch.from_numpy(train_data.astype(np.float32))

    train_dataset = DeconvDataset(train_data, train_covar)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)

    test_data_clean = torch.from_numpy(
        data_gen(args.data, args.n_test_points)[0].astype(np.float32))

    eval_covar = covar_gen(args.covar, args.n_eval_points).astype(np.float32)
    eval_data_clean = data_gen(args.data,
                               args.n_eval_points)[0].astype(np.float32)

    eval_data = np.zeros_like(eval_data_clean)
    for i in range(args.n_eval_points):
        eval_data[i] = eval_data_clean[i] + np.random.multivariate_normal(
            mean=np.zeros((2, )), cov=eval_covar[i])

    eval_covar = torch.from_numpy(eval_covar)
    eval_data = torch.from_numpy(eval_data.astype(np.float32))

    eval_dataset = DeconvDataset(eval_data, eval_covar)
    eval_loader = DataLoader(eval_dataset,
                             batch_size=args.test_batch_size,
                             shuffle=False)

    if args.infer == 'true_data':
        model = SVIFlowToy(dimensions=2,
                           objective=args.objective,
                           posterior_context_size=args.posterior_context_size,
                           batch_size=args.batch_size,
                           device=device,
                           maf_steps_prior=args.flow_steps_prior,
                           maf_steps_posterior=args.flow_steps_posterior,
                           maf_features=args.maf_features,
                           maf_hidden_blocks=args.maf_hidden_blocks,
                           K=args.K)

    else:
        model = SVIFlowToyNoise(
            dimensions=2,
            objective=args.objective,
            posterior_context_size=args.posterior_context_size,
            batch_size=args.batch_size,
            device=device,
            maf_steps_prior=args.flow_steps_prior,
            maf_steps_posterior=args.flow_steps_posterior,
            maf_features=args.maf_features,
            maf_hidden_blocks=args.maf_hidden_blocks,
            K=args.K)

    message = 'Total number of parameters: %s' % (sum(
        p.numel() for p in model.parameters()))
    logger.info(message)

    optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)

    #training
    scheduler = list(map(int, args.eval_based_scheduler.split(',')))
    epoch = 0
    best_model = copy.deepcopy(model.state_dict())

    best_eval_loss = compute_eval_loss(model, eval_loader, device,
                                       args.n_eval_points)
    n_epochs_not_improved = 0

    model.train()
    while n_epochs_not_improved < scheduler[-1] and epoch < args.n_epochs:
        for batch_idx, data in enumerate(train_loader):
            data[0] = data[0].to(device)
            data[1] = data[1].to(device)

            loss = -model.score(data).mean()
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

        model.eval()
        eval_loss = compute_eval_loss(model, eval_loader, device,
                                      args.n_eval_points)

        if eval_loss < best_eval_loss:
            best_model = copy.deepcopy(model.state_dict())
            best_eval_loss = eval_loss
            n_epochs_not_improved = 0

        else:
            n_epochs_not_improved += 1

        lr_scheduler(n_epochs_not_improved, optimizer, scheduler, logger)

        if (epoch + 1) % args.test_freq == 0:
            if args.infer == 'true_data':
                test_loss_clean = -model.model._prior.log_prob(
                    test_data_clean.to(device)).mean()

            else:
                test_loss_clean = -model.model._likelihood.log_prob(
                    test_data_clean.to(device)).mean()

            message = 'Epoch %s:' % (
                epoch + 1
            ), 'train loss = %.5f' % loss, 'eval loss = %.5f' % eval_loss, 'test loss (clean) = %.5f' % test_loss_clean
            logger.info(message)

        else:
            message = 'Epoch %s:' % (
                epoch +
                1), 'train loss = %.5f' % loss, 'eval loss = %.5f' % eval_loss
            logger.info(message)

        if (epoch + 1) % args.viz_freq == 0:
            if args.infer == 'true_data':
                samples = model.model._prior.sample(
                    1000).detach().cpu().numpy()

            else:
                samples = model.model._likelihood.sample(
                    1000).detach().cpu().numpy()

            corner.hist2d(samples[:, 0], samples[:, 1])
            fig_filename = args.dir + 'out/' + name + '_corner_fig_' + str(
                epoch + 1) + '.png'
            plt.savefig(fig_filename)
            plt.close()

            plt.scatter(samples[:, 0], samples[:, 1])
            fig_filename = args.dir + 'out/' + name + '_scatter_fig_' + str(
                epoch + 1) + '.png'
            plt.savefig(fig_filename)
            plt.close()

        model.train()
        epoch += 1

    model.load_state_dict(best_model)
    model.eval()

    if args.infer == 'true_data':
        test_loss_clean = -model.model._prior.log_prob(
            test_data_clean.to(device)).mean()

    else:
        test_loss_clean = -model.model._likelihood.log_prob(
            test_data_clean.to(device)).mean()

    message = 'Final test loss (clean) = %.5f' % test_loss_clean
    logger.info(message)

    torch.save(model.state_dict(), args.dir + 'models/' + name + '.model')
    logger.info('Training has finished.')

    if args.data.split('_')[0] == 'mixture' or args.data.split(
            '_')[0] == 'gaussian':
        kl_points = data_gen(args.data, args.n_kl_points)[0].astype(np.float32)

        if args.infer == 'true_data':
            model_log_prob = model.model._prior.log_prob(
                torch.from_numpy(kl_points.astype(
                    np.float32)).to(device)).mean()

        else:
            model_log_prob = model.model._likelihood.log_prob(
                torch.from_numpy(kl_points.astype(
                    np.float32)).to(device)).mean()

        data_log_prob = compute_data_ll(args.data, kl_points).mean()

        approximate_KL = data_log_prob - model_log_prob
        message = 'KL div %.5f:' % approximate_KL
        logger.info(message)
def main():
    data = np.load('data_small/boston_no_discrete.npy')
    n_features = data.shape[1]
    n_train = int(data.shape[0] * 0.9)
    train_data_clean = data[:n_train]

    covar = np.diag(args.covar * np.ones((11, )))

    train_data = train_data_clean + \
        np.random.multivariate_normal(mean=np.zeros(
            (11,)), cov=covar, size=n_train)

    train_covars = np.repeat(covar[np.newaxis, :, :], n_train, axis=0)

    train_dataset = DeconvDataset(train_data, train_covars)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)

    model = SVIFlowToy(dimensions=n_features,
                       objective=args.objective,
                       posterior_context_size=n_features,
                       batch_size=args.batch_size,
                       device=device,
                       maf_steps_prior=args.flow_steps_prior,
                       maf_steps_posterior=args.flow_steps_posterior,
                       maf_features=args.maf_features,
                       maf_hidden_blocks=args.maf_hidden_blocks,
                       K=args.K)

    message = 'Total number of parameters: %s' % (sum(
        p.numel() for p in model.parameters()))
    logger.info(message)

    optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)

    # training
    epoch = 0
    while epoch < args.n_epochs:
        train_loss = 0
        for batch_idx, data in enumerate(train_loader):
            data[0] = data[0].to(device)
            data[1] = data[1].to(device)

            log_prob = model.score(data)
            loss = -log_prob.mean()
            train_loss += -torch.sum(log_prob).item()
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

        train_loss /= n_train
        message = 'Train loss %.5f' % train_loss
        logger.info(message)

        if train_loss < 9.02486:  # boston housing
            break

    test_loss_clean = - \
        model.model._prior.log_prob(
            torch.from_numpy(data[n_train:]).to(device)).mean()

    message = 'Test loss (clean) = %.5f' % test_loss_clean
    logger.info(message)
def main():
    if args.data == 'boston':
        data = np.load('data_small/boston_no_discrete.npy')
    elif args.data == 'white_wine':
        data = np.load('data_small/white_no_discrete_no_corr_0.98.npy')
    elif args.data == 'red_wine':
        data = np.load('data_small/red_no_discrete_no_corr_0.98.npy')
    elif args.data == 'ionosphere':
        data = np.load('data_small/ionosphere_no_discrete_no_corr_0.98.npy')

    n_features = data.shape[1]
    n_train = int(data.shape[0] * 0.9)
    train_data_clean = data[:n_train]

    covar = np.diag(args.covar * np.ones((11, )))

    train_data = train_data_clean + \
        np.random.multivariate_normal(mean=np.zeros(
            (11,)), cov=covar, size=n_train)

    kf = KFold(n_splits=5)

    # 54 combinations
    lr_list = [1e-2, 5e-3, 1e-3]
    K_list = [20, 50, 100, 200, 500]

    n_combs = 0
    for lr, K in product(lr_list, K_list):
        n_combs += 1

    best_eval = np.zeros((n_combs, 5))

    counter = 0
    for lr, K in product(lr_list, K_list):
        logger.info((lr, K))

        for i, (train_index, eval_index) in enumerate(kf.split(train_data)):
            X_train, X_eval = train_data[train_index], train_data[eval_index]
            train_covars = np.repeat(covar[np.newaxis, :, :],
                                     X_train.shape[0],
                                     axis=0)
            eval_covars = np.repeat(covar[np.newaxis, :, :],
                                    X_eval.shape[0],
                                    axis=0)

            train_dataset = DeconvDataset(X_train, train_covars)
            train_loader = DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True)

            eval_dataset = DeconvDataset(X_eval, eval_covars)
            eval_loader = DataLoader(eval_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=False)

            model = SGDDeconvGMM(K,
                                 n_features,
                                 device=device,
                                 batch_size=args.batch_size,
                                 epochs=args.n_epochs,
                                 lr=lr)

            message = 'Total number of parameters: %s' % (sum(
                p.numel() for p in model.module.parameters()))
            logger.info(message)

            best_eval_loss = model.fit(train_dataset,
                                       logger,
                                       val_data=eval_dataset,
                                       verbose=True)

            best_eval[counter, i] = best_eval_loss
            np.save(args.data + '_gmm_hypertuning_results_tmp', best_eval)

        counter += 1

    np.save(args.data + '_gmm_hypertuning_results', best_eval)
parser.add_argument('-c', '--grad_clip_norm', type=float)
parser.add_argument('-m', '--hidden-features', type=int)
parser.add_argument('output_prefix')

args = parser.parse_args()

K = 3
D = 2
N = 50000
N_val = int(0.25 * N)

ref_gmm, S, (z_train, x_train), (z_val, x_val), _ = generate_mixture_data()

if args.gmm:
    if args.svi_gmm:
        train_data = DeconvDataset(x_train.squeeze(), torch.cholesky(S.repeat(N, 1, 1)))
        val_data = DeconvDataset(x_val.squeeze(), torch.cholesky(S.repeat(N_val, 1, 1)))
        if args.svi_exact_gmm:
            svi_gmm = SVIGMMExact(
                2,
                5,
                device=torch.device('cuda'),
                batch_size=512,
                epochs=args.epochs,
                lr=args.learning_rate,
                n_samples=args.samples,
                use_iwae=args.use_iwae,
                context_size=64,
                hidden_features=args.hidden_features
            )
        else:
Exemplo n.º 12
0
def main():
    if args.data == 'boston':
        data = np.load('data_small/boston_no_discrete.npy')
    elif args.data == 'white_wine':
        data = np.load('data_small/white_no_discrete_no_corr_0.98.npy')
    elif args.data == 'red_wine':
        data = np.load('data_small/red_no_discrete_no_corr_0.98.npy')
    elif args.data == 'ionosphere':
        data = np.load('data_small/ionosphere_no_discrete_no_corr_0.98.npy')

    n_features = data.shape[1]
    n_train = int(data.shape[0] * 0.9)
    train_data_clean = data[:n_train]

    covar = np.diag(args.covar * np.ones((n_features, )))

    train_data = train_data_clean + \
        np.random.multivariate_normal(mean=np.zeros(
            (n_features,)), cov=covar, size=n_train)

    kf = KFold(n_splits=5)

    # 54 combinations
    lr_list = [1e-3, 5e-4, 1e-4]
    flow_steps_prior_list = [3, 4, 5]
    flow_steps_posterior_list = [3, 4, 5]
    maf_features_list = [64, 128]
    maf_hidden_blocks_list = [1, 2]

    n_combs = 0
    for lr, fspr, fspo, maf_f, maf_h in product(lr_list,
                                                flow_steps_posterior_list,
                                                flow_steps_posterior_list,
                                                maf_features_list,
                                                maf_hidden_blocks_list):
        n_combs += 1
        print(n_combs, (lr, fspr, fspo, maf_f, maf_h))

    best_eval = np.zeros((n_combs, 5))

    counter = 0
    for lr, fspr, fspo, maf_f, maf_h in product(lr_list,
                                                flow_steps_posterior_list,
                                                flow_steps_posterior_list,
                                                maf_features_list,
                                                maf_hidden_blocks_list):
        logger.info((lr, fspr, fspo, maf_f, maf_h))

        for i, (train_index, eval_index) in enumerate(kf.split(train_data)):
            X_train, X_eval = train_data[train_index], train_data[eval_index]
            train_covars = np.repeat(covar[np.newaxis, :, :],
                                     X_train.shape[0],
                                     axis=0)
            eval_covars = np.repeat(covar[np.newaxis, :, :],
                                    X_eval.shape[0],
                                    axis=0)

            train_dataset = DeconvDataset(X_train, train_covars)
            train_loader = DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True)

            eval_dataset = DeconvDataset(X_eval, eval_covars)
            eval_loader = DataLoader(eval_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=False)

            model = SVIFlowToy(dimensions=n_features,
                               objective=args.objective,
                               posterior_context_size=n_features,
                               batch_size=args.batch_size,
                               device=device,
                               maf_steps_prior=fspr,
                               maf_steps_posterior=fspo,
                               maf_features=maf_f,
                               maf_hidden_blocks=maf_h,
                               K=args.K)

            message = 'Total number of parameters: %s' % (sum(
                p.numel() for p in model.parameters()))
            logger.info(message)

            optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

            # training
            scheduler = [30]  # stop after 30 epochs of no improvement
            epoch = 0

            model.eval()
            with torch.no_grad():
                best_eval_loss = compute_eval_loss(model, eval_loader, device,
                                                   X_eval.shape[0])

                best_model = copy.deepcopy(model.state_dict())

            n_epochs_not_improved = 0

            model.train()

            while n_epochs_not_improved < scheduler[
                    -1] and epoch < args.n_epochs:
                for batch_idx, data in enumerate(train_loader):
                    data[0] = data[0].to(device)
                    data[1] = data[1].to(device)

                    loss = -model.score(data).mean()
                    optimizer.zero_grad()
                    loss.backward(retain_graph=True)
                    optimizer.step()

                model.eval()
                with torch.no_grad():
                    eval_loss = compute_eval_loss(model, eval_loader, device,
                                                  X_eval.shape[0])

                    if eval_loss < best_eval_loss:
                        best_model = copy.deepcopy(model.state_dict())
                        best_eval_loss = eval_loss
                        n_epochs_not_improved = 0

                    else:
                        n_epochs_not_improved += 1

                    message = 'Epoch %s:' % (
                        epoch + 1
                    ), 'train loss = %.5f' % loss, 'eval loss = %.5f' % eval_loss
                    logger.info(message)

                model.train()
                epoch += 1

            best_eval[counter, i] = best_eval_loss
            np.save(args.data + '_hypertuning_results_tmp', best_eval)

        counter += 1

    np.save(args.data + '_hypertuning_results', best_eval)