Ejemplo n.º 1
0
def main():
    train_covar = covar_gen(args.covar, args.n_train_points).astype(np.float32)
    train_data_clean = data_gen(args.data,
                                args.n_train_points)[0].astype(np.float32)

    # plt.scatter(train_data_clean[:, 0], train_data_clean[:, 1])

    train_data = np.zeros_like(train_data_clean)
    for i in range(args.n_train_points):
        train_data[i] = train_data_clean[i] + np.random.multivariate_normal(
            mean=np.zeros((2, )), cov=train_covar[i])

    # plt.scatter(train_data[:, 0], train_data[:, 1])
    # plt.show()

    train_covar = torch.from_numpy(train_covar)
    train_data = torch.from_numpy(train_data.astype(np.float32))

    train_dataset = DeconvDataset(train_data, train_covar)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True)

    test_data_clean = torch.from_numpy(
        data_gen(args.data, args.n_test_points)[0].astype(np.float32))

    eval_covar = covar_gen(args.covar, args.n_eval_points).astype(np.float32)
    eval_data_clean = data_gen(args.data,
                               args.n_eval_points)[0].astype(np.float32)

    eval_data = np.zeros_like(eval_data_clean)
    for i in range(args.n_eval_points):
        eval_data[i] = eval_data_clean[i] + np.random.multivariate_normal(
            mean=np.zeros((2, )), cov=eval_covar[i])

    eval_covar = torch.from_numpy(eval_covar)
    eval_data = torch.from_numpy(eval_data.astype(np.float32))

    eval_dataset = DeconvDataset(eval_data, eval_covar)
    eval_loader = DataLoader(eval_dataset,
                             batch_size=args.test_batch_size,
                             shuffle=False)

    if args.infer == 'true_data':
        model = SVIFlowToy(dimensions=2,
                           objective=args.objective,
                           posterior_context_size=args.posterior_context_size,
                           batch_size=args.batch_size,
                           device=device,
                           maf_steps_prior=args.flow_steps_prior,
                           maf_steps_posterior=args.flow_steps_posterior,
                           maf_features=args.maf_features,
                           maf_hidden_blocks=args.maf_hidden_blocks,
                           K=args.K)

    else:
        model = SVIFlowToyNoise(
            dimensions=2,
            objective=args.objective,
            posterior_context_size=args.posterior_context_size,
            batch_size=args.batch_size,
            device=device,
            maf_steps_prior=args.flow_steps_prior,
            maf_steps_posterior=args.flow_steps_posterior,
            maf_features=args.maf_features,
            maf_hidden_blocks=args.maf_hidden_blocks,
            K=args.K)

    message = 'Total number of parameters: %s' % (sum(
        p.numel() for p in model.parameters()))
    logger.info(message)

    optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)

    #training
    scheduler = list(map(int, args.eval_based_scheduler.split(',')))
    epoch = 0
    best_model = copy.deepcopy(model.state_dict())

    best_eval_loss = compute_eval_loss(model, eval_loader, device,
                                       args.n_eval_points)
    n_epochs_not_improved = 0

    model.train()
    while n_epochs_not_improved < scheduler[-1] and epoch < args.n_epochs:
        for batch_idx, data in enumerate(train_loader):
            data[0] = data[0].to(device)
            data[1] = data[1].to(device)

            loss = -model.score(data).mean()
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

        model.eval()
        eval_loss = compute_eval_loss(model, eval_loader, device,
                                      args.n_eval_points)

        if eval_loss < best_eval_loss:
            best_model = copy.deepcopy(model.state_dict())
            best_eval_loss = eval_loss
            n_epochs_not_improved = 0

        else:
            n_epochs_not_improved += 1

        lr_scheduler(n_epochs_not_improved, optimizer, scheduler, logger)

        if (epoch + 1) % args.test_freq == 0:
            if args.infer == 'true_data':
                test_loss_clean = -model.model._prior.log_prob(
                    test_data_clean.to(device)).mean()

            else:
                test_loss_clean = -model.model._likelihood.log_prob(
                    test_data_clean.to(device)).mean()

            message = 'Epoch %s:' % (
                epoch + 1
            ), 'train loss = %.5f' % loss, 'eval loss = %.5f' % eval_loss, 'test loss (clean) = %.5f' % test_loss_clean
            logger.info(message)

        else:
            message = 'Epoch %s:' % (
                epoch +
                1), 'train loss = %.5f' % loss, 'eval loss = %.5f' % eval_loss
            logger.info(message)

        if (epoch + 1) % args.viz_freq == 0:
            if args.infer == 'true_data':
                samples = model.model._prior.sample(
                    1000).detach().cpu().numpy()

            else:
                samples = model.model._likelihood.sample(
                    1000).detach().cpu().numpy()

            corner.hist2d(samples[:, 0], samples[:, 1])
            fig_filename = args.dir + 'out/' + name + '_corner_fig_' + str(
                epoch + 1) + '.png'
            plt.savefig(fig_filename)
            plt.close()

            plt.scatter(samples[:, 0], samples[:, 1])
            fig_filename = args.dir + 'out/' + name + '_scatter_fig_' + str(
                epoch + 1) + '.png'
            plt.savefig(fig_filename)
            plt.close()

        model.train()
        epoch += 1

    model.load_state_dict(best_model)
    model.eval()

    if args.infer == 'true_data':
        test_loss_clean = -model.model._prior.log_prob(
            test_data_clean.to(device)).mean()

    else:
        test_loss_clean = -model.model._likelihood.log_prob(
            test_data_clean.to(device)).mean()

    message = 'Final test loss (clean) = %.5f' % test_loss_clean
    logger.info(message)

    torch.save(model.state_dict(), args.dir + 'models/' + name + '.model')
    logger.info('Training has finished.')

    if args.data.split('_')[0] == 'mixture' or args.data.split(
            '_')[0] == 'gaussian':
        kl_points = data_gen(args.data, args.n_kl_points)[0].astype(np.float32)

        if args.infer == 'true_data':
            model_log_prob = model.model._prior.log_prob(
                torch.from_numpy(kl_points.astype(
                    np.float32)).to(device)).mean()

        else:
            model_log_prob = model.model._likelihood.log_prob(
                torch.from_numpy(kl_points.astype(
                    np.float32)).to(device)).mean()

        data_log_prob = compute_data_ll(args.data, kl_points).mean()

        approximate_KL = data_log_prob - model_log_prob
        message = 'KL div %.5f:' % approximate_KL
        logger.info(message)
def main():
    if args.data == 'boston':
        data = np.load('data_small/boston_no_discrete.npy')
    elif args.data == 'white_wine':
        data = np.load('data_small/white_no_discrete_no_corr_0.98.npy')
    elif args.data == 'red_wine':
        data = np.load('data_small/red_no_discrete_no_corr_0.98.npy')
    n_features = data.shape[1]
    n_train = int(data.shape[0] * 0.9)
    train_data_clean = util_shuffle(data[:n_train])
    test_data = data[:n_train]
    kf = KFold(n_splits=10)

    covar = np.diag(args.covar * np.ones((n_features, )))

    train_data = train_data_clean + \
        np.random.multivariate_normal(mean=np.zeros(
            (n_features,)), cov=covar, size=n_train)

    # train_covars = np.repeat(
    #     covar[np.newaxis, :, :], n_train, axis=0)

    # train_dataset = DeconvDataset(train_data, train_covars)
    for i, (train_index, eval_index) in enumerate(kf.split(train_data)):
        X_train, X_eval = train_data[train_index], train_data[eval_index]
        train_covars = np.repeat(covar[np.newaxis, :, :],
                                 X_train.shape[0],
                                 axis=0)
        eval_covars = np.repeat(covar[np.newaxis, :, :],
                                X_eval.shape[0],
                                axis=0)

        train_dataset = DeconvDataset(X_train, train_covars)
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True)

        eval_dataset = DeconvDataset(X_eval, eval_covars)
        eval_loader = DataLoader(eval_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False)

        model = SVIFlowToy(dimensions=n_features,
                           objective=args.objective,
                           posterior_context_size=n_features,
                           batch_size=args.batch_size,
                           device=device,
                           maf_steps_prior=args.flow_steps_prior,
                           maf_steps_posterior=args.flow_steps_posterior,
                           maf_features=args.maf_features,
                           maf_hidden_blocks=args.maf_hidden_blocks,
                           K=args.K)

        message = 'Total number of parameters: %s' % (sum(
            p.numel() for p in model.parameters()))
        logger.info(message)

        optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)

        # training
        scheduler = [30]
        epoch = 0
        best_model = copy.deepcopy(model.state_dict())

        best_eval_loss = compute_eval_loss(model, eval_loader, device,
                                           len(eval_index))
        n_epochs_not_improved = 0

        model.train()
        while n_epochs_not_improved < scheduler[-1] and epoch < args.n_epochs:
            for batch_idx, data in enumerate(train_loader):
                data[0] = data[0].to(device)
                data[1] = data[1].to(device)

                for prior_params in model.model._prior.parameters():
                    prior_params.requires_grad = True

                for post_params in model.model._approximate_posterior.parameters(
                ):
                    post_params.requires_grad = False

                for i in range(args.prior_iter):

                    loss = -model.score(data).mean()
                    message = 'Loss prior %s: %f' % (i, loss)
                    logger.info(message)
                    optimizer.zero_grad()
                    loss.backward(retain_graph=True)
                    optimizer.step()

                for prior_params in model.model._prior.parameters():
                    prior_params.requires_grad = False

                for post_params in model.model._approximate_posterior.parameters(
                ):
                    post_params.requires_grad = True

                for i in range(args.posterior_iter):
                    loss = -model.score(data).mean()
                    message = 'Loss posterior %s: %f' % (i, loss)
                    logger.info(message)
                    optimizer.zero_grad()
                    loss.backward(retain_graph=True)
                    optimizer.step()

            model.eval()
            test_loss_clean = - \
                model.model._prior.log_prob(
                    torch.from_numpy(test_data).to(device)).mean()
            message = 'Test loss (clean) = %.5f' % test_loss_clean
            logger.info(message)
            eval_loss = compute_eval_loss(model, eval_loader, device,
                                          len(eval_index))

            if eval_loss < best_eval_loss:
                best_model = copy.deepcopy(model.state_dict())
                best_eval_loss = eval_loss
                n_epochs_not_improved = 0

            else:
                n_epochs_not_improved += 1

            model.train()
            epoch += 1
        break

        model = model.load_state_dict(best_model)
        test_loss_clean = - \
            model.model._prior.log_prob(
                torch.from_numpy(test_data).to(device)).mean()
        message = 'Final test loss (clean) = %.5f' % test_loss_clean
        logger.info(message)
Ejemplo n.º 3
0
def main():
    if args.data == 'boston':
        data = np.load('data_small/boston_no_discrete.npy')
    elif args.data == 'white_wine':
        data = np.load('data_small/white_no_discrete_no_corr_0.98.npy')
    elif args.data == 'red_wine':
        data = np.load('data_small/red_no_discrete_no_corr_0.98.npy')
    elif args.data == 'ionosphere':
        data = np.load('data_small/ionosphere_no_discrete_no_corr_0.98.npy')

    n_features = data.shape[1]
    n_train = int(data.shape[0] * 0.9)
    train_data_clean = data[:n_train]

    covar = np.diag(args.covar * np.ones((n_features, )))

    train_data = train_data_clean + \
        np.random.multivariate_normal(mean=np.zeros(
            (n_features,)), cov=covar, size=n_train)

    kf = KFold(n_splits=5)

    # 54 combinations
    lr_list = [1e-3, 5e-4, 1e-4]
    flow_steps_prior_list = [3, 4, 5]
    flow_steps_posterior_list = [3, 4, 5]
    maf_features_list = [64, 128]
    maf_hidden_blocks_list = [1, 2]

    n_combs = 0
    for lr, fspr, fspo, maf_f, maf_h in product(lr_list,
                                                flow_steps_posterior_list,
                                                flow_steps_posterior_list,
                                                maf_features_list,
                                                maf_hidden_blocks_list):
        n_combs += 1
        print(n_combs, (lr, fspr, fspo, maf_f, maf_h))

    best_eval = np.zeros((n_combs, 5))

    counter = 0
    for lr, fspr, fspo, maf_f, maf_h in product(lr_list,
                                                flow_steps_posterior_list,
                                                flow_steps_posterior_list,
                                                maf_features_list,
                                                maf_hidden_blocks_list):
        logger.info((lr, fspr, fspo, maf_f, maf_h))

        for i, (train_index, eval_index) in enumerate(kf.split(train_data)):
            X_train, X_eval = train_data[train_index], train_data[eval_index]
            train_covars = np.repeat(covar[np.newaxis, :, :],
                                     X_train.shape[0],
                                     axis=0)
            eval_covars = np.repeat(covar[np.newaxis, :, :],
                                    X_eval.shape[0],
                                    axis=0)

            train_dataset = DeconvDataset(X_train, train_covars)
            train_loader = DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True)

            eval_dataset = DeconvDataset(X_eval, eval_covars)
            eval_loader = DataLoader(eval_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=False)

            model = SVIFlowToy(dimensions=n_features,
                               objective=args.objective,
                               posterior_context_size=n_features,
                               batch_size=args.batch_size,
                               device=device,
                               maf_steps_prior=fspr,
                               maf_steps_posterior=fspo,
                               maf_features=maf_f,
                               maf_hidden_blocks=maf_h,
                               K=args.K)

            message = 'Total number of parameters: %s' % (sum(
                p.numel() for p in model.parameters()))
            logger.info(message)

            optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

            # training
            scheduler = [30]  # stop after 30 epochs of no improvement
            epoch = 0

            model.eval()
            with torch.no_grad():
                best_eval_loss = compute_eval_loss(model, eval_loader, device,
                                                   X_eval.shape[0])

                best_model = copy.deepcopy(model.state_dict())

            n_epochs_not_improved = 0

            model.train()

            while n_epochs_not_improved < scheduler[
                    -1] and epoch < args.n_epochs:
                for batch_idx, data in enumerate(train_loader):
                    data[0] = data[0].to(device)
                    data[1] = data[1].to(device)

                    loss = -model.score(data).mean()
                    optimizer.zero_grad()
                    loss.backward(retain_graph=True)
                    optimizer.step()

                model.eval()
                with torch.no_grad():
                    eval_loss = compute_eval_loss(model, eval_loader, device,
                                                  X_eval.shape[0])

                    if eval_loss < best_eval_loss:
                        best_model = copy.deepcopy(model.state_dict())
                        best_eval_loss = eval_loss
                        n_epochs_not_improved = 0

                    else:
                        n_epochs_not_improved += 1

                    message = 'Epoch %s:' % (
                        epoch + 1
                    ), 'train loss = %.5f' % loss, 'eval loss = %.5f' % eval_loss
                    logger.info(message)

                model.train()
                epoch += 1

            best_eval[counter, i] = best_eval_loss
            np.save(args.data + '_hypertuning_results_tmp', best_eval)

        counter += 1

    np.save(args.data + '_hypertuning_results', best_eval)