def main(relaxation=None,
         learn_prior=True,
         max_iters=None,
         batch_size=24,
         num_latents=200,
         model_type=None,
         lr=None,
         test_bias=False,
         train_dir=None,
         iwae_samples=100,
         dataset="mnist",
         logf=None,
         var_lr_scale=10.,
         Q_wd=.0001,
         Q_depth=-1,
         checkpoint_path=None):
    valid_batch_size = 100

    if model_type == "L1":
        num_layers = 1
        layer_type = linear_layer
    elif model_type == "L2":
        num_layers = 2
        layer_type = linear_layer
    elif model_type == "NL1":
        num_layers = 1
        layer_type = nonlinear_layer
    else:
        assert False, "bad model type {}".format(model_type)

    sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(
        allow_growth=True)))
    if dataset == "mnist":
        X_tr, X_va, X_te = datasets.load_mnist()
    elif dataset == "omni":
        X_tr, X_va, X_te = datasets.load_omniglot()
    else:
        assert False

    num_train = X_tr.shape[0]
    num_valid = X_va.shape[0]
    num_test = X_te.shape[0]
    train_mean = np.mean(X_tr, axis=0, keepdims=True)
    train_output_bias = -np.log(1. / np.clip(train_mean, 0.001, 0.999) -
                                1.).astype(np.float32)

    x = tf.placeholder(tf.float32, [None, 784])
    # x_im = tf.reshape(x, [-1, 28, 28, 1])
    # tf.summary.image("x_true", x_im)

    # make prior for top b
    p_prior = tf.Variable(
        tf.zeros([num_latents], dtype=tf.float32),
        trainable=learn_prior,
        name='p_prior',
    )
    # create rebar specific variables temperature and eta
    log_temperatures = [create_log_temp(1) for l in range(num_layers)]
    temperatures = [tf.exp(log_temp) for log_temp in log_temperatures]
    batch_temperatures = [tf.reshape(temp, [1, -1]) for temp in temperatures]
    etas = [create_eta(1) for l in range(num_layers)]
    batch_etas = [tf.reshape(eta, [1, -1]) for eta in etas]

    # random uniform samples
    u = [
        tf.random_uniform([tf.shape(x)[0], num_latents], dtype=tf.float32)
        for l in range(num_layers)
    ]
    # create binary sampler
    b_sampler = BSampler(u, "b_sampler")
    gen_b_sampler = BSampler(u, "gen_b_sampler")
    # generate hard forward pass
    encoder_name = "encoder"
    decoder_name = "decoder"
    inf_la_b, samples_b = inference_network(x, train_mean, layer_type,
                                            num_layers, num_latents,
                                            encoder_name, False, b_sampler)
    gen_la_b = generator_network(samples_b, train_output_bias, layer_type,
                                 num_layers, num_latents, decoder_name, False)
    log_image(gen_la_b[-1], "x_pred")
    # produce samples
    _samples_la_b = generator_network(None,
                                      train_output_bias,
                                      layer_type,
                                      num_layers,
                                      num_latents,
                                      decoder_name,
                                      True,
                                      sampler=gen_b_sampler,
                                      prior=p_prior)
    log_image(_samples_la_b[-1], "x_sample")

    # hard loss evaluation and log probs
    f_b, log_q_bs = neg_elbo(x,
                             samples_b,
                             inf_la_b,
                             gen_la_b,
                             p_prior,
                             log=True)
    batch_f_b = tf.expand_dims(f_b, 1)
    total_loss = tf.reduce_mean(f_b)
    # tf.summary.scalar("fb", total_loss)
    # optimizer for model parameters
    model_opt = tf.train.AdamOptimizer(lr, beta2=.99999)
    # optimizer for variance reducing parameters
    variance_opt = tf.train.AdamOptimizer(var_lr_scale * lr, beta2=.99999)
    # get encoder and decoder variables
    encoder_params = get_variables(encoder_name)
    decoder_params = get_variables(decoder_name)
    if learn_prior:
        decoder_params.append(p_prior)
    # compute and store gradients of hard loss with respect to encoder_parameters
    encoder_loss_grads = {}
    for g, v in model_opt.compute_gradients(total_loss,
                                            var_list=encoder_params):
        encoder_loss_grads[v.name] = g
    # get gradients for decoder parameters
    decoder_gradvars = model_opt.compute_gradients(total_loss,
                                                   var_list=decoder_params)
    # will hold all gradvars for the model (non-variance adjusting variables)
    model_gradvars = [gv for gv in decoder_gradvars]

    # conditional samples
    v = [v_from_u(_u, log_alpha) for _u, log_alpha in zip(u, inf_la_b)]
    # need to create soft samplers
    sig_z_sampler = SIGZSampler(u, batch_temperatures, "sig_z_sampler")
    sig_zt_sampler = SIGZSampler(v, batch_temperatures, "sig_zt_sampler")

    z_sampler = ZSampler(u, "z_sampler")
    zt_sampler = ZSampler(v, "zt_sampler")

    rebars = []
    reinforces = []
    variance_objectives = []
    # have to produce 2 forward passes for each layer for z and zt samples
    for l in range(num_layers):
        cur_la_b = inf_la_b[l]

        # if standard rebar or additive relaxation
        if relaxation == "rebar" or relaxation == "add":
            # compute soft samples and soft passes through model and soft elbos
            cur_z_sample = sig_z_sampler.sample(cur_la_b, l)
            prev_samples_z = samples_b[:l] + [cur_z_sample]

            cur_zt_sample = sig_zt_sampler.sample(cur_la_b, l)
            prev_samples_zt = samples_b[:l] + [cur_zt_sample]

            prev_log_alphas = inf_la_b[:l] + [cur_la_b]

            # soft forward passes
            inf_la_z, samples_z = inference_network(x,
                                                    train_mean,
                                                    layer_type,
                                                    num_layers,
                                                    num_latents,
                                                    encoder_name,
                                                    True,
                                                    sig_z_sampler,
                                                    samples=prev_samples_z,
                                                    log_alphas=prev_log_alphas)
            gen_la_z = generator_network(samples_z, train_output_bias,
                                         layer_type, num_layers, num_latents,
                                         decoder_name, True)
            inf_la_zt, samples_zt = inference_network(
                x,
                train_mean,
                layer_type,
                num_layers,
                num_latents,
                encoder_name,
                True,
                sig_zt_sampler,
                samples=prev_samples_zt,
                log_alphas=prev_log_alphas)
            gen_la_zt = generator_network(samples_zt, train_output_bias,
                                          layer_type, num_layers, num_latents,
                                          decoder_name, True)
            # soft loss evaluataions
            f_z, _ = neg_elbo(x, samples_z, inf_la_z, gen_la_z, p_prior)
            f_zt, _ = neg_elbo(x, samples_zt, inf_la_zt, gen_la_zt, p_prior)

        if relaxation == "add" or relaxation == "all":
            # sample z and zt
            prev_bs = samples_b[:l]
            cur_z_sample = z_sampler.sample(cur_la_b, l)
            cur_zt_sample = zt_sampler.sample(cur_la_b, l)

            q_z = Q_func(x,
                         train_mean,
                         cur_z_sample,
                         prev_bs,
                         Q_name(l),
                         False,
                         depth=Q_depth)
            q_zt = Q_func(x,
                          train_mean,
                          cur_zt_sample,
                          prev_bs,
                          Q_name(l),
                          True,
                          depth=Q_depth)
            # tf.summary.scalar("q_z_{}".format(l), tf.reduce_mean(q_z))
            # tf.summary.scalar("q_zt_{}".format(l), tf.reduce_mean(q_zt))
            if relaxation == "add":
                f_z = f_z + q_z
                f_zt = f_zt + q_zt
            elif relaxation == "all":
                f_z = q_z
                f_zt = q_zt
            else:
                assert False
        # tf.summary.scalar("f_z_{}".format(l), tf.reduce_mean(f_z))
        # tf.summary.scalar("f_zt_{}".format(l), tf.reduce_mean(f_zt))
        cur_samples_b = samples_b[l]
        # get gradient of sample log-likelihood wrt current parameter
        d_log_q_d_la = bernoulli_loglikelihood_derivitive(
            cur_samples_b, cur_la_b)
        # get gradient of soft-losses wrt current parameter
        d_f_z_d_la = tf.gradients(f_z, cur_la_b)[0]
        d_f_zt_d_la = tf.gradients(f_zt, cur_la_b)[0]
        batch_f_zt = tf.expand_dims(f_zt, 1)
        eta = batch_etas[l]
        # compute rebar and reinforce
        # tf.summary.histogram("der_diff_{}".format(l), d_f_z_d_la - d_f_zt_d_la)
        # tf.summary.histogram("d_log_q_d_la_{}".format(l), d_log_q_d_la)
        rebar = ((batch_f_b - eta * batch_f_zt) * d_log_q_d_la + eta *
                 (d_f_z_d_la - d_f_zt_d_la)) / batch_size
        reinforce = batch_f_b * d_log_q_d_la / batch_size
        rebars.append(rebar)
        reinforces.append(reinforce)
        # tf.summary.histogram("rebar_{}".format(l), rebar)
        # tf.summary.histogram("reinforce_{}".format(l), reinforce)
        # backpropogate rebar to individual layer parameters
        layer_params = get_variables(layer_name(l), arr=encoder_params)
        layer_rebar_grads = tf.gradients(cur_la_b, layer_params, grad_ys=rebar)
        # get direct loss grads for each parameter
        layer_loss_grads = [encoder_loss_grads[v.name] for v in layer_params]
        # each param's gradient should be rebar + the direct loss gradient
        layer_grads = [
            rg + lg for rg, lg in zip(layer_rebar_grads, layer_loss_grads)
        ]
        # for rg, lg, v in zip(layer_rebar_grads, layer_loss_grads, layer_params):
        #     tf.summary.histogram(v.name + "_grad_rebar", rg)
        #     tf.summary.histogram(v.name + "_grad_loss", lg)
        layer_gradvars = list(zip(layer_grads, layer_params))
        model_gradvars.extend(layer_gradvars)
        variance_objective = tf.reduce_mean(tf.square(rebar))
        variance_objectives.append(variance_objective)

    variance_objective = tf.add_n(variance_objectives)
    variance_vars = log_temperatures + etas
    if relaxation != "rebar":
        q_vars = get_variables("Q_")
        wd = tf.add_n([Q_wd * tf.nn.l2_loss(v) for v in q_vars])
        # tf.summary.scalar("Q_weight_decay", wd)
        # variance_vars = variance_vars + q_vars
    else:
        wd = 0.0
    variance_gradvars = variance_opt.compute_gradients(variance_objective + wd,
                                                       var_list=variance_vars)
    variance_train_op = variance_opt.apply_gradients(variance_gradvars)
    model_train_op = model_opt.apply_gradients(model_gradvars)
    with tf.control_dependencies([model_train_op, variance_train_op]):
        train_op = tf.no_op()

    # for g, v in model_gradvars + variance_gradvars:
    #     print(g, v.name)
    #     if g is not None:
    #         tf.summary.histogram(v.name, v)
    #         tf.summary.histogram(v.name + "_grad", g)

    val_loss = tf.Variable(1000,
                           trainable=False,
                           name="val_loss",
                           dtype=tf.float32)
    train_loss = tf.Variable(1000,
                             trainable=False,
                             name="train_loss",
                             dtype=tf.float32)
    # tf.summary.scalar("val_loss", val_loss)
    # tf.summary.scalar("train_loss", train_loss)
    # summ_op = tf.summary.merge_all()
    # summary_writer = tf.summary.FileWriter(train_dir)
    sess.run(tf.global_variables_initializer())

    # create savers
    train_saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    val_saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    iwae_elbo = -(tf.reduce_logsumexp(-f_b) - np.log(valid_batch_size))

    if checkpoint_path is None:
        iters_per_epoch = X_tr.shape[0] // batch_size
        print("Train set has {} examples".format(X_tr.shape[0]))
        if relaxation != "rebar":
            print("Pretraining Q network")
            for i in range(1000):
                if i % 100 == 0:
                    print(i)
                idx = np.random.randint(0, iters_per_epoch - 1)
                batch_xs = X_tr[idx * batch_size:(idx + 1) * batch_size]
                sess.run(variance_train_op, feed_dict={x: batch_xs})
        # t = time.time()
        best_val_loss = np.inf

        # results saving
        if relaxation == 'rebar':
            mode_out = relaxation
        else:
            mode_out = 'RELAX' + relaxation
        result_dir = './Results_MNIST_SBN'
        if not os.path.isdir(result_dir):
            os.mkdir(result_dir)
        shutil.copyfile(
            sys.argv[0], result_dir + '/training_script_' + dataset + '_' +
            mode_out + '_' + model_type + '.py')
        pathsave = result_dir + '/TF_SBN_' + dataset + '_' + mode_out + '_MB[%d]_' % batch_size + model_type + '_LR[%.2e].mat' % lr

        tr_loss_mb_set = []
        tr_timerun_mb_set = []
        tr_iter_mb_set = []

        tr_loss_set = []
        tr_timerun_set = []
        tr_iter_set = []

        val_loss_set = []
        val_timerun_set = []
        val_iter_set = []

        te_loss_set = []
        te_timerun_set = []
        te_iter_set = []

        for epoch in range(10000000):
            # train_losses = []
            for i in range(iters_per_epoch):
                cur_iter = epoch * iters_per_epoch + i

                if cur_iter == 0:
                    time_start = time.clock()

                if cur_iter > max_iters:
                    print("Training Completed")
                    return

                batch_xs = X_tr[i * batch_size:(i + 1) * batch_size]
                loss, _ = sess.run([total_loss, train_op],
                                   feed_dict={x: batch_xs})

                time_run = time.clock() - time_start

                tr_loss_mb_set.append(loss)
                tr_timerun_mb_set.append(time_run)
                tr_iter_mb_set.append(cur_iter + 1)

                if (cur_iter + 1) % 100 == 0:
                    print(
                        'Step: [{:6d}], Loss_mb: [{:10.4f}], time_run: [{:10.4f}]'
                        .format(cur_iter + 1, loss, time_run))

                TestInterval = 5000
                Train_num_mbs = num_train // batch_size
                Valid_num_mbs = num_valid // batch_size
                Test_num_mbs = num_test // batch_size

                # Testing
                if (cur_iter + 1) % TestInterval == 0:

                    # Training
                    loss_train1 = 0
                    for step_train in range(Train_num_mbs):
                        x_train = X_tr[step_train *
                                       batch_size:(step_train + 1) *
                                       batch_size]

                        feed_dict_train = {x: x_train}
                        loss_train_mb1 = sess.run(total_loss,
                                                  feed_dict=feed_dict_train)
                        loss_train1 += loss_train_mb1 * batch_size

                    loss_train1 = loss_train1 / (Train_num_mbs * batch_size)

                    tr_loss_set.append(loss_train1)
                    tr_timerun_set.append(time_run)
                    tr_iter_set.append(cur_iter + 1)

                    # Validation
                    loss_val1 = 0
                    for step_val in range(Valid_num_mbs):
                        x_valid = X_va[step_val * batch_size:(step_val + 1) *
                                       batch_size]

                        feed_dict_val = {x: x_valid}
                        loss_val_mb1 = sess.run(total_loss,
                                                feed_dict=feed_dict_val)
                        loss_val1 += loss_val_mb1 * batch_size

                    loss_val1 = loss_val1 / (Valid_num_mbs * batch_size)

                    val_loss_set.append(loss_val1)
                    val_timerun_set.append(time_run)
                    val_iter_set.append(cur_iter + 1)

                    # Test
                    loss_test1 = 0
                    for step_test in range(Test_num_mbs):
                        x_test = X_te[step_test * batch_size:(step_test + 1) *
                                      batch_size]

                        feed_dict_test = {x: x_test}
                        loss_test_mb1 = sess.run(total_loss,
                                                 feed_dict=feed_dict_test)
                        loss_test1 += loss_test_mb1 * batch_size

                    loss_test1 = loss_test1 / (Test_num_mbs * batch_size)

                    te_loss_set.append(loss_test1)
                    te_timerun_set.append(time_run)
                    te_iter_set.append(cur_iter + 1)

                    print(
                        '============TestInterval: [{:6d}], Loss_train: [{:10.4f}], Loss_val: [{:10.4f}], Loss_test: [{:10.4f}]'
                        .format(TestInterval, loss_train1, loss_val1,
                                loss_test1))

                # Saving
                if (cur_iter + 1) % TestInterval == 0:
                    sio.savemat(
                        pathsave, {
                            'tr_loss_mb_set': tr_loss_mb_set,
                            'tr_timerun_mb_set': tr_timerun_mb_set,
                            'tr_iter_mb_set': tr_iter_mb_set,
                            'tr_loss_set': tr_loss_set,
                            'tr_timerun_set': tr_timerun_set,
                            'tr_iter_set': tr_iter_set,
                            'val_loss_set': val_loss_set,
                            'val_timerun_set': val_timerun_set,
                            'val_iter_set': val_iter_set,
                            'te_loss_set': te_loss_set,
                            'te_timerun_set': te_timerun_set,
                            'te_iter_set': te_iter_set,
                        })
예제 #2
0
    def __init__(self, conf):
        self.conf = conf

        # determine and create result dir
        i = 1
        log_path = conf.result_path + 'run0'
        while os.path.exists(log_path):
            log_path = '{}run{}'.format(conf.result_path, i)
            i += 1
        os.makedirs(log_path)
        self.log_path = log_path

        if not os.path.exists(conf.checkpoint_dir):
            os.makedirs(conf.checkpoint_dir)

        self.checkpoint_file = os.path.join(self.conf.checkpoint_dir,
                                            "model.ckpt")
        input_shape = [
            conf.batch_size, conf.scene_width, conf.scene_height, conf.channels
        ]
        # build model
        with tf.device(conf.device):
            self.mdl = model.Supair(conf)
            self.in_ph = tf.placeholder(tf.float32, input_shape)
            self.elbo = self.mdl.elbo(self.in_ph)

            self.mdl.num_parameters()

            self.optimizer = tf.train.AdamOptimizer()
            self.train_op = self.optimizer.minimize(-1 * self.elbo)

        self.sess = tf.Session()

        self.saver = tf.train.Saver()
        if self.conf.load_params:
            self.saver.restore(self.sess, self.checkpoint_file)
        else:
            self.sess.run(tf.global_variables_initializer())
            self.sess.run(tf.local_variables_initializer())

        # load data
        bboxes = None
        if conf.dataset == 'MNIST':
            (x, counts, y,
             bboxes), (x_test, c_test, _,
                       _) = datasets.load_mnist(conf.scene_width,
                                                max_digits=2,
                                                path=conf.data_path)
            visualize.store_images(x[0:10], log_path + '/img_raw')
            if conf.noise:
                x = datasets.add_noise(x)
                x_test = datasets.add_noise(x_test)
                visualize.store_images(x[0:10], log_path + '/img_noisy')
            if conf.structured_noise:
                x = datasets.add_structured_noise(x)
                x_test = datasets.add_structured_noise(x_test)
                visualize.store_images(x[0:10], log_path + '/img_struc_noisy')
            x_color = np.squeeze(x)
        elif conf.dataset == 'sprites':
            (x_color, counts,
             _), (x_test, c_test,
                  _) = datasets.make_sprites(50000, path=conf.data_path)
            if conf.noise:
                x_color = datasets.add_noise(x_color)
            x = visualize.rgb2gray(x_color)
            x = np.clip(x, 0.0, 1.0)
            x_test = visualize.rgb2gray(x_test)
            x_test = np.clip(x_test, 0.0, 1.0)
            if conf.noise:
                x = datasets.add_noise(x)
                x_test = datasets.add_noise(x_test)
                x_color = datasets.add_noise(x_color)
        elif conf.dataset == 'omniglot':
            x = 1 - datasets.load_omniglot(path=conf.data_path)
            counts = np.ones(x.shape[0], dtype=np.int32)
            x_color = np.squeeze(x)
        elif conf.dataset == 'svhn':
            x, counts, objects, bgs = datasets.load_svhn(path=conf.data_path)
            self.pretrain(x, objects, bgs)
            x_color = np.squeeze(x)
        else:
            raise ValueError('unknown dataset', conf.dataset)

        self.x, self.x_color, self.counts = x, x_color, counts
        self.x_test, self.c_test = x_test, c_test
        self.bboxes = bboxes

        print('Built model')
        self.obj_reconstructor = SpnReconstructor(self.mdl.obj_spn)
        self.bg_reconstructor = SpnReconstructor(self.mdl.bg_spn)

        tfgraph = tf.get_default_graph()
        self.tensors_of_interest = {
            'z_where': tfgraph.get_tensor_by_name('z_where:0'),
            'z_pres': tfgraph.get_tensor_by_name('z_pres:0'),
            'bg_score': tfgraph.get_tensor_by_name('bg_score:0'),
            'y': tfgraph.get_tensor_by_name('y:0'),
            'obj_vis': tfgraph.get_tensor_by_name('obj_vis:0'),
            'bg_maps': tfgraph.get_tensor_by_name('bg_maps:0')
        }
예제 #3
0
LR = 1e-4
MBsize = 24
dim_var = [784, 200]
TestInterval = 5000
max_iters = 1000000

NonLinerNN = True
PreProcess = True
dataset = "mnist"
# dataset = "omni"

if dataset == "mnist":
    X_tr, X_va, X_te = datasets.load_mnist()
elif dataset == "omni":
    X_tr, X_va, X_te = datasets.load_omniglot()
else:
    assert False

num_train = X_tr.shape[0]
num_valid = X_va.shape[0]
num_test = X_te.shape[0]
train_mean = np.mean(X_tr, axis=0, keepdims=True)

tf.reset_default_graph()


def GOBernoulli(Prob):
    zsamp = tf.cast(tf.less_equal(tf.random_uniform(Prob.shape), Prob),
                    tf.float32)
    zout = Prob + tf.stop_gradient(zsamp - Prob)
예제 #4
0
def main(relaxation=None,
         learn_prior=True,
         max_iters=None,
         batch_size=24,
         num_latents=200,
         model_type=None,
         lr=None,
         test_bias=False,
         train_dir=None,
         iwae_samples=100,
         dataset="mnist",
         logf=None,
         var_lr_scale=10.,
         Q_wd=.0001,
         Q_depth=-1,
         checkpoint_path=None):

    valid_batch_size = 100

    if model_type == "L1":
        num_layers = 1
        layer_type = linear_layer
    elif model_type == "L2":
        num_layers = 2
        layer_type = linear_layer
    elif model_type == "NL1":
        num_layers = 1
        layer_type = nonlinear_layer
    else:
        assert False, "bad model type {}".format(model_type)

    sess = tf.Session()
    if dataset == "mnist":
        X_tr, X_va, X_te = datasets.load_mnist()
    elif dataset == "omni":
        X_tr, X_va, X_te = datasets.load_omniglot()
    else:
        assert False
    train_mean = np.mean(X_tr, axis=0, keepdims=True)
    train_output_bias = -np.log(1. / np.clip(train_mean, 0.001, 0.999) -
                                1.).astype(np.float32)

    x = tf.placeholder(tf.float32, [None, 784])
    x_im = tf.reshape(x, [-1, 28, 28, 1])
    tf.summary.image("x_true", x_im)

    # make prior for top b
    p_prior = tf.Variable(
        tf.zeros([num_latents], dtype=tf.float32),
        trainable=learn_prior,
        name='p_prior',
    )
    # create rebar specific variables temperature and eta
    log_temperatures = [create_log_temp(1) for l in range(num_layers)]
    temperatures = [tf.exp(log_temp) for log_temp in log_temperatures]
    batch_temperatures = [tf.reshape(temp, [1, -1]) for temp in temperatures]
    etas = [create_eta(1) for l in range(num_layers)]
    batch_etas = [tf.reshape(eta, [1, -1]) for eta in etas]

    # random uniform samples
    u = [
        tf.random_uniform([tf.shape(x)[0], num_latents], dtype=tf.float32)
        for l in range(num_layers)
    ]
    # create binary sampler
    b_sampler = BSampler(u, "b_sampler")
    gen_b_sampler = BSampler(u, "gen_b_sampler")
    # generate hard forward pass
    encoder_name = "encoder"
    decoder_name = "decoder"
    inf_la_b, samples_b = inference_network(x, train_mean, layer_type,
                                            num_layers, num_latents,
                                            encoder_name, False, b_sampler)
    gen_la_b = generator_network(samples_b, train_output_bias, layer_type,
                                 num_layers, num_latents, decoder_name, False)
    log_image(gen_la_b[-1], "x_pred")
    # produce samples
    _samples_la_b = generator_network(None,
                                      train_output_bias,
                                      layer_type,
                                      num_layers,
                                      num_latents,
                                      decoder_name,
                                      True,
                                      sampler=gen_b_sampler,
                                      prior=p_prior)
    log_image(_samples_la_b[-1], "x_sample")

    # hard loss evaluation and log probs
    f_b, log_q_bs = neg_elbo(x,
                             samples_b,
                             inf_la_b,
                             gen_la_b,
                             p_prior,
                             log=True)
    batch_f_b = tf.expand_dims(f_b, 1)
    total_loss = tf.reduce_mean(f_b)
    tf.summary.scalar("fb", total_loss)
    # optimizer for model parameters
    model_opt = tf.train.AdamOptimizer(lr, beta2=.99999)
    # optimizer for variance reducing parameters
    variance_opt = tf.train.AdamOptimizer(var_lr_scale * lr, beta2=.99999)
    # get encoder and decoder variables
    encoder_params = get_variables(encoder_name)
    decoder_params = get_variables(decoder_name)
    if learn_prior:
        decoder_params.append(p_prior)
    # compute and store gradients of hard loss with respect to encoder_parameters
    encoder_loss_grads = {}
    for g, v in model_opt.compute_gradients(total_loss,
                                            var_list=encoder_params):
        encoder_loss_grads[v.name] = g
    # get gradients for decoder parameters
    decoder_gradvars = model_opt.compute_gradients(total_loss,
                                                   var_list=decoder_params)
    # will hold all gradvars for the model (non-variance adjusting variables)
    model_gradvars = [gv for gv in decoder_gradvars]

    # conditional samples
    v = [v_from_u(_u, log_alpha) for _u, log_alpha in zip(u, inf_la_b)]
    # need to create soft samplers
    sig_z_sampler = SIGZSampler(u, batch_temperatures, "sig_z_sampler")
    sig_zt_sampler = SIGZSampler(v, batch_temperatures, "sig_zt_sampler")

    z_sampler = ZSampler(u, "z_sampler")
    zt_sampler = ZSampler(v, "zt_sampler")

    rebars = []
    reinforces = []
    variance_objectives = []
    # have to produce 2 forward passes for each layer for z and zt samples
    for l in range(num_layers):
        cur_la_b = inf_la_b[l]

        # if standard rebar or additive relaxation
        if relaxation == "rebar" or relaxation == "add":
            # compute soft samples and soft passes through model and soft elbos
            cur_z_sample = sig_z_sampler.sample(cur_la_b, l)
            prev_samples_z = samples_b[:l] + [cur_z_sample]

            cur_zt_sample = sig_zt_sampler.sample(cur_la_b, l)
            prev_samples_zt = samples_b[:l] + [cur_zt_sample]

            prev_log_alphas = inf_la_b[:l] + [cur_la_b]

            # soft forward passes
            inf_la_z, samples_z = inference_network(x,
                                                    train_mean,
                                                    layer_type,
                                                    num_layers,
                                                    num_latents,
                                                    encoder_name,
                                                    True,
                                                    sig_z_sampler,
                                                    samples=prev_samples_z,
                                                    log_alphas=prev_log_alphas)
            gen_la_z = generator_network(samples_z, train_output_bias,
                                         layer_type, num_layers, num_latents,
                                         decoder_name, True)
            inf_la_zt, samples_zt = inference_network(
                x,
                train_mean,
                layer_type,
                num_layers,
                num_latents,
                encoder_name,
                True,
                sig_zt_sampler,
                samples=prev_samples_zt,
                log_alphas=prev_log_alphas)
            gen_la_zt = generator_network(samples_zt, train_output_bias,
                                          layer_type, num_layers, num_latents,
                                          decoder_name, True)
            # soft loss evaluataions
            f_z, _ = neg_elbo(x, samples_z, inf_la_z, gen_la_z, p_prior)
            f_zt, _ = neg_elbo(x, samples_zt, inf_la_zt, gen_la_zt, p_prior)

        if relaxation == "add" or relaxation == "all":
            # sample z and zt
            prev_bs = samples_b[:l]
            cur_z_sample = z_sampler.sample(cur_la_b, l)
            cur_zt_sample = zt_sampler.sample(cur_la_b, l)

            q_z = Q_func(x,
                         train_mean,
                         cur_z_sample,
                         prev_bs,
                         Q_name(l),
                         False,
                         depth=Q_depth)
            q_zt = Q_func(x,
                          train_mean,
                          cur_zt_sample,
                          prev_bs,
                          Q_name(l),
                          True,
                          depth=Q_depth)
            tf.summary.scalar("q_z_{}".format(l), tf.reduce_mean(q_z))
            tf.summary.scalar("q_zt_{}".format(l), tf.reduce_mean(q_zt))
            if relaxation == "add":
                f_z = f_z + q_z
                f_zt = f_zt + q_zt
            elif relaxation == "all":
                f_z = q_z
                f_zt = q_zt
            else:
                assert False
        tf.summary.scalar("f_z_{}".format(l), tf.reduce_mean(f_z))
        tf.summary.scalar("f_zt_{}".format(l), tf.reduce_mean(f_zt))
        cur_samples_b = samples_b[l]
        # get gradient of sample log-likelihood wrt current parameter
        d_log_q_d_la = bernoulli_loglikelihood_derivitive(
            cur_samples_b, cur_la_b)
        # get gradient of soft-losses wrt current parameter
        d_f_z_d_la = tf.gradients(f_z, cur_la_b)[0]
        d_f_zt_d_la = tf.gradients(f_zt, cur_la_b)[0]
        batch_f_zt = tf.expand_dims(f_zt, 1)
        eta = batch_etas[l]
        # compute rebar and reinforce
        tf.summary.histogram("der_diff_{}".format(l), d_f_z_d_la - d_f_zt_d_la)
        tf.summary.histogram("d_log_q_d_la_{}".format(l), d_log_q_d_la)
        rebar = ((batch_f_b - eta * batch_f_zt) * d_log_q_d_la + eta *
                 (d_f_z_d_la - d_f_zt_d_la)) / batch_size
        reinforce = batch_f_b * d_log_q_d_la / batch_size
        rebars.append(rebar)
        reinforces.append(reinforce)
        tf.summary.histogram("rebar_{}".format(l), rebar)
        tf.summary.histogram("reinforce_{}".format(l), reinforce)
        # backpropogate rebar to individual layer parameters
        layer_params = get_variables(layer_name(l), arr=encoder_params)
        layer_rebar_grads = tf.gradients(cur_la_b, layer_params, grad_ys=rebar)
        # get direct loss grads for each parameter
        layer_loss_grads = [encoder_loss_grads[v.name] for v in layer_params]
        # each param's gradient should be rebar + the direct loss gradient
        layer_grads = [
            rg + lg for rg, lg in zip(layer_rebar_grads, layer_loss_grads)
        ]
        for rg, lg, v in zip(layer_rebar_grads, layer_loss_grads,
                             layer_params):
            tf.summary.histogram(v.name + "_grad_rebar", rg)
            tf.summary.histogram(v.name + "_grad_loss", lg)
        layer_gradvars = list(zip(layer_grads, layer_params))
        model_gradvars.extend(layer_gradvars)
        variance_objective = tf.reduce_mean(tf.square(rebar))
        variance_objectives.append(variance_objective)

    variance_objective = tf.add_n(variance_objectives)
    variance_vars = log_temperatures + etas
    if relaxation != "rebar":
        q_vars = get_variables("Q_")
        wd = tf.add_n([Q_wd * tf.nn.l2_loss(v) for v in q_vars])
        tf.summary.scalar("Q_weight_decay", wd)
        variance_vars = variance_vars + q_vars
    else:
        wd = 0.0
    variance_gradvars = variance_opt.compute_gradients(variance_objective + wd,
                                                       var_list=variance_vars)
    variance_train_op = variance_opt.apply_gradients(variance_gradvars)
    model_train_op = model_opt.apply_gradients(model_gradvars)
    with tf.control_dependencies([model_train_op, variance_train_op]):
        train_op = tf.no_op()

    for g, v in model_gradvars + variance_gradvars:
        print(g, v.name)
        if g is not None:
            tf.summary.histogram(v.name, v)
            tf.summary.histogram(v.name + "_grad", g)

    val_loss = tf.Variable(1000,
                           trainable=False,
                           name="val_loss",
                           dtype=tf.float32)
    train_loss = tf.Variable(1000,
                             trainable=False,
                             name="train_loss",
                             dtype=tf.float32)
    tf.summary.scalar("val_loss", val_loss)
    tf.summary.scalar("train_loss", train_loss)
    summ_op = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(train_dir)
    sess.run(tf.global_variables_initializer())

    # create savers
    train_saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    val_saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
    iwae_elbo = -(tf.reduce_logsumexp(-f_b) - np.log(valid_batch_size))

    if checkpoint_path is None:
        iters_per_epoch = X_tr.shape[0] // batch_size
        print("Train set has {} examples".format(X_tr.shape[0]))
        if relaxation != "rebar":
            print("Pretraining Q network")
            for i in range(1000):
                if i % 100 == 0:
                    print(i)
                idx = np.random.randint(0, iters_per_epoch - 1)
                batch_xs = X_tr[idx * batch_size:(idx + 1) * batch_size]
                sess.run(variance_train_op, feed_dict={x: batch_xs})
        t = time.time()
        best_val_loss = np.inf
        for epoch in range(10000000):
            train_losses = []
            for i in range(iters_per_epoch):
                cur_iter = epoch * iters_per_epoch + i
                if cur_iter > max_iters:
                    print("Training Completed")
                    return
                batch_xs = X_tr[i * batch_size:(i + 1) * batch_size]
                if i % 1000 == 0:
                    loss, _, = sess.run([total_loss, train_op],
                                        feed_dict={x: batch_xs})
                    #summary_writer.add_summary(sum_str, cur_iter)
                    time_taken = time.time() - t
                    t = time.time()
                    #print(cur_iter, loss, "{} / batch".format(time_taken / 1000))
                    if test_bias:
                        rebs = []
                        refs = []
                        for _i in range(100000):
                            if _i % 1000 == 0:
                                print(_i)
                            rb, re = sess.run([rebars[3], reinforces[3]],
                                              feed_dict={x: batch_xs})
                            rebs.append(rb[:5])
                            refs.append(re[:5])
                        rebs = np.array(rebs)
                        refs = np.array(refs)
                        re_var = np.log(refs.var(axis=0))
                        rb_var = np.log(rebs.var(axis=0))
                        print("rebar variance     = {}".format(rb_var))
                        print("reinforce variance = {}".format(re_var))
                        print("rebar     = {}".format(rebs.mean(axis=0)))
                        print("reinforce = {}\n".format(refs.mean(axis=0)))
                else:
                    loss, _ = sess.run([total_loss, train_op],
                                       feed_dict={x: batch_xs})

                train_losses.append(loss)

            # epoch over, run test data
            iwaes = []
            for x_va in X_va:
                x_va_batch = np.array([x_va for i in range(valid_batch_size)])
                iwae = sess.run(iwae_elbo, feed_dict={x: x_va_batch})
                iwaes.append(iwae)
            trl = np.mean(train_losses)
            val = np.mean(iwaes)
            print("({}) Epoch = {}, Val loss = {}, Train loss = {}".format(
                train_dir, epoch, val, trl))
            logf.write("{}: {} {}\n".format(epoch, val, trl))
            sess.run([val_loss.assign(val), train_loss.assign(trl)])
            if val < best_val_loss:
                print("saving best model")
                best_val_loss = val
                val_saver.save(sess,
                               '{}/best-model'.format(train_dir),
                               global_step=epoch)
            np.random.shuffle(X_tr)
            if epoch % 10 == 0:
                train_saver.save(sess,
                                 '{}/model'.format(train_dir),
                                 global_step=epoch)

    # run iwae elbo on test set
    else:
        val_saver.restore(sess, checkpoint_path)
        iwae_elbo = -(tf.reduce_logsumexp(-f_b) - np.log(valid_batch_size))
        iwaes = []
        elbos = []
        for x_te in X_te:
            x_te_batch = np.array([x_te for i in range(100)])
            iwae, elbo = sess.run([iwae_elbo, f_b], feed_dict={x: x_te_batch})
            iwaes.append(iwae)
            elbos.append(elbo)
        print("MEAN IWAE: {}".format(np.mean(iwaes)))
        print("MEAN ELBO: {}".format(np.mean(elbos)))