def main(args):

    print('setting up distribution')

    def load_data(folder, test_frac=0.2):
        X = np.load(folder + '/data.npy')
        y = np.load(folder + '/labels.npy')
        N, D = X.shape
        data = np.concatenate([X, y], axis=1)
        np.random.shuffle(data)
        train_data = data[:int(N * (1 - test_frac))]
        test_data = data[int(N * (1 - test_frac)):]
        return train_data[:, :-1], train_data[:,
                                              -1], test_data[:, :
                                                             -1], test_data[:,
                                                                            -1]

    X_train, y_train, X_test, y_test = load_data(args.data)
    data_dim = X_train.shape[1] + 1
    dist = BayesLogRegPost(X_train, y_train, X_test, y_test, args.prior)

    print('setting up sampler')
    with tf.variable_scope('sampler', reuse=tf.AUTO_REUSE):
        sampler = VS(dist.log_prob_func(), data_dim, args.aux_dim,
                     args.hidden_units, args.num_layers, args.train_samples,
                     args.num_chains, args.activation, args.num_mix,
                     args.perturb)

        print('setting up and running experiment')
        exp = Experiment(log_dir=args.logdir,
                         sampler=sampler,
                         params=vars(args),
                         dist=dist)

        exp.run()
Example #2
0
def main(args):

    print('setting up distribution')
    dist = MixtureOfGaussians(means=args.means, stds=args.stds, pis=[0.5, 0.5])

    print('setting up sampler')
    with tf.variable_scope('sampler', reuse=tf.AUTO_REUSE):
        sampler = VS(dist.log_prob_func(), args.data_dim, args.aux_dim,
                     args.hidden_units, args.num_layers, args.train_samples,
                     args.num_chains, args.activation, args.num_mix,
                     args.perturb)

        print('setting up and running experiment')
        exp = Experiment(log_dir=args.logdir,
                         sampler=sampler,
                         params=vars(args),
                         dist=dist)

        exp.run()