Exemplo n.º 1
0
Arquivo: rbm.py Projeto: xmax1/dvae
    def estimate_log_z(self, sess):
        """
        Estimate log_z using AIS implemented in QuPA
        Args:
            sess: tensorflow session
            
        Returns: 
            log_z: the estimate of log partition function 
        """
        log_z = sess.run(self.log_z_update)
        Print('Estimated log partition function with QuPA: %0.4f' % log_z)

        return log_z
Exemplo n.º 2
0
Arquivo: rbm.py Projeto: xmax1/dvae
    def estimate_log_z(self, sess):
        """
        Estimate log Z using AIS implemented in QuPA. Before log Z estimation, it makes sure that _lambda
        is still greater than the smallest eigenvalue.
        Args:
            sess: tensorflow session

        Returns: 
            log_z: the estimate of log partition function 
        """
        _lambda, eigen_values = sess.run([self._lambda, self.eigen_value_up])
        Print('min/max eigen value = %0.2f/%0.2f, lambda = %0.2f' %
              (np.amin(eigen_values), np.amax(eigen_values), _lambda))
        log_z = RBM.estimate_log_z(self, sess)
        return log_z
Exemplo n.º 3
0
    def estimate_log_z(self, sess):
        """
        Estimate log_z using AIS implemented in QuPA
        Args:
            sess: tensorflow session
            
        Returns: 
            log_z: the estimate of log partition function 
        """
        import time
        s = time.time()
        log_z = sess.run(self.log_z_update)
        total_time = time.time() - s
        Print(
            'Estimated log partition function with QuPA: %0.4f in %0.2f sec' %
            (log_z, total_time))

        return log_z
Exemplo n.º 4
0
def run_training(vae, cont_train, config_train, log_dir):
    """ The main function that will derive training of a vae.
    Args:
        vae: is an object from the class VAE. 
        cont_train: a boolean flag indicating whether train should continue from the checkpoint stored in the log_dir.
        config_train: a dictionary containing config. training (hyperparameters).
        log_dir: path to a directory that will used for storing both tensorboard files and checkpoints.

    Returns:
        test_neg_ll_value: the value of test log-likelihood.
    """
    use_iw = config_train['use_iw']
    Print('Starting training.')
    batch_size = config_train['batch_size']
    # Get the train, val, test sets of on MNIST.
    data_dir = config_train['data_dir']
    eval_batch_size = config_train['eval_batch_size']
    data_sets = input_data.read_data_set(data_dir,
                                         dataset=config_train['dataset'])

    # place holder for input.
    input_placeholder = tf.placeholder(tf.float32, shape=(None, vae.num_input))
    # define training graph.
    if use_iw:
        Print('using IW obj. function')
        iw_loss, neg_elbo, sigmoid_output, wd_loss, _ = \
            vae.neg_elbo(input_placeholder, is_training=True, k=config_train['k'], use_iw=use_iw)
        loss = iw_loss + wd_loss
        # create scalar summary for training loss.
        tf.summary.scalar('train/neg_iw_loss', iw_loss)
        sigmoid_output = tf.slice(sigmoid_output, [0, 0], [batch_size, -1])
    else:
        Print('using VAE obj. function')
        _, neg_elbo, sigmoid_output, wd_loss, _ = \
            vae.neg_elbo(input_placeholder, is_training=True, k=config_train['k'], use_iw=use_iw)
        loss = neg_elbo + wd_loss
        # create scalar summary for training loss.
        tf.summary.scalar('train/neg_elbo', neg_elbo)

    train_op = vae.training(loss)

    # create images for reconstruction.
    image = create_reconstruction_image(input_placeholder,
                                        sigmoid_output[:batch_size],
                                        batch_size)
    tf.summary.image('recon', image, max_outputs=1)

    # define graph to generate random samples from model.
    num_samples = 100
    random_samples = vae.generate_samples(num_samples)
    tiled_samples = tile_image_tf(random_samples,
                                  n=int(np.sqrt(num_samples)),
                                  m=int(np.sqrt(num_samples)),
                                  width=28,
                                  height=28)
    tf.summary.image('generated_sample', tiled_samples, max_outputs=1)

    # merge all summary for training graph
    train_summary_op = tf.summary.merge_all()

    # define a parallel graph for evaluation. Enable parameter sharing by setting is_training to False.
    _, neg_elbo_eval, _, _, log_iw_eval = vae.neg_elbo(input_placeholder,
                                                       is_training=False)

    # the following will create summaries that will be used in the evaluation graph.
    val_neg_elbo, test_neg_elbo = tf.placeholder(
        tf.float32, shape=()), tf.placeholder(tf.float32, shape=())
    val_neg_ll, test_neg_ll = tf.placeholder(
        tf.float32, shape=()), tf.placeholder(tf.float32, shape=())
    val_summary = tf.summary.scalar('val/neg_elbo', val_neg_elbo)
    test_summary = tf.summary.scalar('test/neg_elbo', test_neg_elbo)
    val_ll_summary = tf.summary.scalar('val/neg_ll', val_neg_ll)
    test_ll_summary = tf.summary.scalar('test/neg_ll', test_neg_ll)
    eval_summary_op = tf.summary.merge(
        [val_summary, test_summary, val_ll_summary, test_ll_summary])

    # start checkpoint saver.
    saver = tf.train.Saver(max_to_keep=1)
    sess = tf.Session()

    # Run the Op to initialize the variables.
    if cont_train:
        ckpt = tf.train.get_checkpoint_state(log_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            init_step = int(ckpt.model_checkpoint_path.split('-')[-1]) + 1
            Print('Initializing model from %s from step %d' %
                  (log_dir, init_step))
        else:
            raise ('No Checkpoint was fount at %s' % log_dir)
    else:
        init = tf.global_variables_initializer()
        sess.run(init)
        init_step = 0

    # Instantiate a SummaryWriter to output summaries and the Graph.
    # Create train/validation/test summary directories
    summary_writer = tf.summary.FileWriter(log_dir)

    # And then after everything is built, start the training loop.
    duration = 0.0
    best_val_neg_ll = np.finfo(float).max
    num_iter = config_train['num_iter']
    for step in xrange(init_step, num_iter):
        start_time = time.time()
        # perform one iteration of training.
        feed_dict = fill_feed_dict(data_sets.train, input_placeholder,
                                   batch_size)
        _, neg_elbo_value = sess.run([train_op, neg_elbo], feed_dict=feed_dict)
        duration += time.time() - start_time

        # Save a checkpoint and evaluate the model periodically.
        eval_iter = 20000 if num_iter > 1e5 else 10000
        if (step + 1) % eval_iter == 0 or (step + 1) == num_iter:
            # if vae has rbm in its prior we should update its log Z.
            if vae.should_compute_log_z():
                vae.prior.estimate_log_z(sess)

            # validate on the validation and test set
            val_neg_elbo_value, val_neg_ll_value = evaluate(
                sess,
                neg_elbo_eval,
                log_iw_eval,
                input_placeholder,
                data_sets.validation,
                batch_size=eval_batch_size,
                k_iw=100)
            test_neg_elbo_value, test_neg_ll_value = evaluate(
                sess,
                neg_elbo_eval,
                log_iw_eval,
                input_placeholder,
                data_sets.test,
                batch_size=eval_batch_size,
                k_iw=100)
            summary_str = sess.run(eval_summary_op,
                                   feed_dict={
                                       val_neg_elbo: val_neg_elbo_value,
                                       test_neg_elbo: test_neg_elbo_value,
                                       val_neg_ll: val_neg_ll_value,
                                       test_neg_ll: test_neg_ll_value
                                   })
            summary_writer.add_summary(summary_str, step)

            Print(
                'Step %d: val ELBO = %.2f test ELBO = %.2f, val NLL = %.2f, test NLL = %.2f'
                % (step, val_neg_elbo_value, test_neg_elbo_value,
                   val_neg_ll_value, test_neg_ll_value))
            # save model if it is better on validation set:
            if val_neg_ll_value < best_val_neg_ll:
                best_val_neg_ll = val_neg_ll_value
                saver.save(sess, log_dir + '/', global_step=step)

        # Write the summaries and print an overview fairly often.
        report_iter = 1000
        if step % report_iter == 0 and step > 500:
            # print status to stdout.
            Print('Step %d, %.3f sec per step' %
                  (step, duration / report_iter))
            duration = 0.0
            # Update the events file.
            summary_str = sess.run(train_summary_op, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str, step)

        # in the last iteration, we load the best model based on the validation performance, and evaluate it on test
        if (step + 1) == num_iter:
            Print('Final evaluation using the best saved model')
            # reload the best model this is good when a model overfits.
            ckpt = tf.train.get_checkpoint_state(log_dir)
            saver.restore(sess, ckpt.model_checkpoint_path)
            Print('Done restoring the model at step: %d' %
                  sess.run(get_global_step_var()))
            if vae.should_compute_log_z():
                vae.prior.estimate_log_z(sess)

            val_neg_elbo_value, val_neg_ll_value = evaluate(
                sess,
                neg_elbo_eval,
                log_iw_eval,
                input_placeholder,
                data_sets.validation,
                eval_batch_size,
                k_iw=100)
            test_neg_elbo_value, test_neg_ll_value = evaluate(
                sess,
                neg_elbo_eval,
                log_iw_eval,
                input_placeholder,
                data_sets.test,
                eval_batch_size,
                k_iw=config_train['k_iw'])
            summary_str = sess.run(eval_summary_op,
                                   feed_dict={
                                       val_neg_elbo: val_neg_elbo_value,
                                       test_neg_elbo: test_neg_elbo_value,
                                       val_neg_ll: val_neg_ll_value,
                                       test_neg_ll: test_neg_ll_value
                                   })
            Print(
                'Step %d: val ELBO = %.2f test ELBO = %.2f, val NLL = %.2f, test NLL = %.2f'
                % (step, val_neg_elbo_value, test_neg_elbo_value,
                   val_neg_ll_value, test_neg_ll_value))
            summary_writer.add_summary(summary_str, step + 1)
            summary_writer.flush()

            sess.close()
            tf.reset_default_graph()
            return test_neg_ll_value
Exemplo n.º 5
0
Arquivo: rbm.py Projeto: xmax1/dvae
    def __init__(self,
                 num_var1,
                 num_var2,
                 weight_decay,
                 name='RBM',
                 num_samples=100,
                 num_gibbs_iter=40,
                 kld_term=None,
                 use_qupa=False):
        """
        Initialize bias and weight parameters, create sampling operations (gibbs or QuPA). This class implements
        the KL divergence computation for DVAE and DVAE++.
        
        Args:
            num_var1:               number of vars of left side of RBM
            num_var2:               number of vars of right side of RBM
            weight_decay:           regularization for the weight matrix
            name:                   name
            num_samples:            number of RBM samples drawn in each iterations (used for computing log Z gradient)
            num_gibbs_iter:         number of gibbs step for pcd or mcmc sweeps for QuPA
            kld_term:               Use 'dvae_spike_exp' for DVAE, 'dvaepp_exp', 'dvaepp_power' for DVAE++, 
                                    'guassian_integral', 'marginal_type1' for DVAE#.
            use_qupa:               A boolean flag indicating whether QuPA will be used for sampling. Setting this
                                    variable to False will use PCD for sampling
        """
        assert kld_term in {'dvae_spike_exp', 'dvaepp_power', 'dvaepp_exp', 'guassian_integral', 'marginal_type1'}, \
            'kld_term defined by %s in argument is not defined.' % kld_term
        self.kld_term = kld_term
        self.num_var1 = num_var1
        self.num_var2 = num_var2
        self.num_var = num_var1 + num_var2
        self.weight_decay = weight_decay
        self.name = name

        # bias on the left side
        self.b1 = tf.Variable(tf.zeros(shape=[self.num_var1, 1],
                                       dtype=tf.float32),
                              name='bias1')
        # bias on the right side
        self.b2 = tf.Variable(tf.zeros(shape=[self.num_var2, 1],
                                       dtype=tf.float32),
                              name='bias2')
        # pairwise weight
        self.w = tf.Variable(tf.zeros(shape=[self.num_var1, self.num_var2],
                                      dtype=tf.float32),
                             name='pairwise')

        # sampling options
        self.num_samples = num_samples
        self.use_qupa = use_qupa

        # concat b
        b = tf.concat(values=[tf.squeeze(self.b1),
                              tf.squeeze(self.b2)],
                      axis=0)

        if not self.use_qupa:
            Print('Using PCD')
            # init pcd class implemented in QuPA
            self.sampler = PCD(left_size=self.num_var1,
                               right_size=self.num_var2,
                               num_samples=self.num_samples,
                               dtype=tf.float32)
        else:
            Print('Using QuPA')
            # init population annealing class in QuPA
            self.sampler = qupa.PopulationAnnealer(
                left_size=self.num_var1,
                right_size=self.num_var2,
                num_samples=self.num_samples,
                dtype=tf.float32)

        # This returns a scalar tensor with the gradient of log z. Don't trust its value.
        self.log_z_train = self.sampler.training_log_z(
            b, self.w, num_mcmc_sweeps=num_gibbs_iter)

        # This returns the internal log z variable in QuPA sampler. We wil use this variable in evaluation.
        self.log_z_value = self.sampler.log_z_var

        # get always the samples after updating train log z
        with tf.control_dependencies([self.log_z_train]):
            self.samples = self.sampler.samples()

        # Define inverse temperatures used for AIS. Increasing the # of betas improves the precision of log z estimates.
        betas = tf.linspace(tf.constant(0.), tf.constant(1.), num=1000)
        # Define log_z estimation for evaluation.
        eval_logz = qupa.ais.evaluation_log_z(b,
                                              self.w,
                                              init_biases=None,
                                              betas=betas,
                                              num_samples=1024)

        # Update QuPA internal log z variable with the eval_logz
        self.log_z_update = self.log_z_value.assign(eval_logz)
Exemplo n.º 6
0
def Src(n=5):
    return Iter(range(1, (n + 1)))


# Print info for each demo
def demo(description, ez):
    print("\n")
    print(description)
    print("\n")
    ez.printLayout()
    print(ez.graph())
    # ez.watch(0.1)
    ez.start().join()


demo("Sanity.", EZ(Src(), Print("Serial")))

demo("Generate constant values with Const", EZ(Const(0, 5), Add(1), Print()))

demo(
    "In-Series processing uses [] or Serial() or S().",
    EZ([Src(), Print("Also Serial")]),
)

demo(
    "Parallel Processing uses () or Parallel() or P().",
    EZ(Src(), (Print("A"), Print("B"))),
)

demo(
    "Broadcast Processing uses {} or Broadcast() or B().",
Exemplo n.º 7
0
    def __init__(self, num_input, config, config_recon, config_train):
        """  This function initializes an instance of the VAE class. 
        Args:
            num_input: the length of observed random variable (x).
            config: a dictionary containing config. for the (hierarchical) posterior distribution and prior over z. 
            config_recon: a dictionary containing config. for the reconstruct function in the decoder p(x | z).
            config_train: a dictionary containing config. training (hyperparameters).
        """
        np.set_printoptions(threshold=10)
        Print(str(config))
        Print(str(config_recon))
        Print(str(config_train))

        self.num_input = num_input
        self.config = config  # configuration dictionary for approx post and prior on z
        self.config_recon = config_recon  # configuration dictionary for reconstruct function p(x | z)
        self.config_train = config_train  # configuration dictionary for training hyper-parameters

        # bias term on the visible node
        self.train_bias = -np.log(
            1. / np.clip(self.config_train['mean_x'], 0.001, 0.999) -
            1.).astype(np.float32)

        self.dist_type = config[
            'dist_type']  # flag indicating whether we have rbm prior.
        tf.summary.scalar('beta', config['beta'])
        # define DistUtil classes that will be used in posterior and prior.
        if self.dist_type == "dvae_spike_exp":  # DVAE (spike-exp)
            dist_util = Spike_and_Exp
            dist_util_param = {'beta': self.config['beta']}
            tf.summary.scalar('posterior/beta', dist_util_param['beta'])
        elif self.dist_type == "dvaepp_exp":  # DVAE++ (exp)
            dist_util = MixtureGeneric
            self.smoothing_dist = Exponential(
                params={'beta': self.config['beta']})
            dist_util_param = {'smoothing_dist': self.smoothing_dist}
            tf.summary.scalar('posterior/beta', self.smoothing_dist.beta)
        elif self.dist_type == "dvaepp_power":  # DVAE++ (power)
            dist_util = MixtureGeneric
            self.smoothing_dist = PowerLaw(
                params={'beta': self.config['beta']})
            dist_util_param = {'smoothing_dist': self.smoothing_dist}
            tf.summary.scalar('posterior/lambda', self.smoothing_dist._lambda)
        elif self.dist_type == "dvaes_gi":  # DVAE# (Gaussian int)
            MixtureNormal.num_param = 2  # more parameters for
            dist_util = MixtureNormal
            dist_util_param = {'isotropic': False, 'delta_mu_scale': 0.5}
        elif self.dist_type == "dvaes_gauss":  # DVAE# (Gaussian)
            dist_util = MixtureGeneric
            self.smoothing_dist = Normal(params={'beta': self.config['beta']})
            dist_util_param = {'smoothing_dist': self.smoothing_dist}
            tf.summary.scalar('posterior/sigma', self.smoothing_dist.sigma)
        elif self.dist_type == "dvaes_exp":  # DVAE# (exp)
            dist_util = MixtureGeneric
            self.smoothing_dist = Exponential(
                params={'beta': self.config['beta']})
            dist_util_param = {'smoothing_dist': self.smoothing_dist}
            tf.summary.scalar('posterior/beta', self.smoothing_dist.beta)
        elif self.dist_type == "dvaes_unexp":  # DVAE# (uniform+exp)
            dist_util = MixtureGeneric
            self.smoothing_dist = ExponentialUniform(
                params={
                    'beta': self.config['beta'],
                    'eps': 0.05
                })
            dist_util_param = {'smoothing_dist': self.smoothing_dist}
            tf.summary.scalar('posterior/beta', self.smoothing_dist.beta)
            tf.summary.scalar('posterior/eps', self.smoothing_dist.eps)
        elif self.dist_type == "dvaes_power":  # DVAE# (power)
            dist_util = MixtureGeneric
            self.smoothing_dist = PowerLaw(
                params={'beta': self.config['beta']})
            dist_util_param = {'smoothing_dist': self.smoothing_dist}
            tf.summary.scalar('posterior/lambda', self.smoothing_dist._lambda)
        else:
            raise ValueError('self.dist_type=%s is unknown' % self.dist_type)

        # define p(z)
        self.prior = self.define_prior()

        # create encoder for the first level.
        self.encoder = SimpleEncoder(num_input=num_input,
                                     config=config,
                                     dist_util=dist_util,
                                     dist_util_param=dist_util_param)

        # create encoder and decoder for lower layers.
        num_latent_units = self.config['num_latent_units'] * self.config[
            'num_latent_layers']
        self.decoder = SimpleDecoder(num_latent_units=num_latent_units,
                                     num_output=num_input,
                                     config_recon=config_recon)
Exemplo n.º 8
0
    def __init__(self, num_input, config, config_recon, config_train):
        """  This function initializes an instance of the VAE class. 
        Args:
            num_input: the length of observed random variable (x).
            config: a dictionary containing config. for the (hierarchical) posterior distribution and prior over z. 
            config_recon: a dictionary containing config. for the reconstruct function in the decoder p(x | z).
            config_train: a dictionary containing config. training (hyperparameters).
        """
        np.set_printoptions(threshold=10)
        Print(str(config))
        Print(str(config_recon))
        Print(str(config_train))

        self.num_input = num_input
        self.config = config  # configuration dictionary for approx post and prior on z
        self.config_recon = config_recon  # configuration dictionary for reconstruct function p(x | z)
        self.config_train = config_train  # configuration dictionary for training hyper-parameters

        # bias term on the visible node
        self.train_bias = -np.log(
            1. / np.clip(self.config_train['mean_x'], 0.001, 0.999) -
            1.).astype(np.float32)
        self.entropy_lower_bound = 0.05

        self.dist_type = config[
            'dist_type']  # flag indicating whether we have rbm prior.
        tf.summary.scalar('beta', config['beta'])
        self.encoder_type = 'hierarchical'
        self.is_struct_pred = config_train['is_struct_pred']

        # define DistUtil classes that will be used in posterior and prior.
        if self.dist_type == "dvaes_power":  # DVAE# (power)
            dist_util = MixtureGeneric
            self.smoothing_dist = PowerLaw(
                params={'beta': self.config['beta']})
            dist_util_param = {'smoothing_dist': self.smoothing_dist}
            tf.summary.scalar('posterior/lambda', self.smoothing_dist._lambda)
        elif self.dist_type == "pwl_relax":  # PWL relaxtion
            dist_util = RelaxDist
            dist_util_param = {
                'beta': self.config['beta'],
                'smoothing_fun': sample_through_pwlinear
            }
            tf.summary.scalar('posterior/beta', dist_util_param['beta'])
        elif self.dist_type == "gsm_relax":  # Gumbel-Softmax relaxtion
            dist_util = RelaxDist
            dist_util_param = {
                'beta': self.config['beta'],
                'smoothing_fun': sample_concrete
            }
            tf.summary.scalar('posterior/beta', dist_util_param['beta'])
        elif self.dist_type == "dvaess_con":
            self.encoder_type = 'rbm'
            num_var = self.config['num_latent_units'] // 2
            self.posterior = ConcreteRBM(
                training_size=config_train['training_size'],
                num_var1=num_var,
                num_var2=num_var,
                num_gibbs_iter=10,
                beta=self.config['beta'],
                num_eval_k=self.config_train['k_iw'],
                num_train_k=self.config_train['k'])
        else:
            raise ValueError('self.dist_type=%s is unknown' % self.dist_type)

        # define p(z)
        self.prior = self.define_prior()

        if self.encoder_type == 'hierarchical':
            # create encoder for the first level.
            num_hidden_pre = [200] * 1 if self.is_struct_pred else [200] * 2
            self.pre_process_net = FeedForwardNetwork(
                num_input,
                num_hiddens=num_hidden_pre,
                num_output=None,
                name='pre_proc',
                weight_decay_coeff=1e-4,
                output_split=1,
                use_batch_norm=True,
                collections='q_collections')
            self.encoder = SimpleEncoder(num_input=200,
                                         config=config,
                                         dist_util=dist_util,
                                         dist_util_param=dist_util_param)
        else:
            self.pre_process_net = None
            self.encoder = RBMEncoder(num_input=num_input,
                                      config=config,
                                      posterior_rbm=self.posterior)

        # create encoder and decoder for lower layers.
        num_latent_units = self.config['num_latent_units'] * self.config[
            'num_latent_layers']
        self.decoder = SimpleDecoder(num_latent_units=num_latent_units,
                                     num_output=num_input,
                                     config_recon=config_recon)
Exemplo n.º 9
0
    def elbo_terms(self,
                   input,
                   posterior,
                   post_samples,
                   log_z,
                   k,
                   is_training,
                   batch_norm_update,
                   post_samples_mf=None):
        # create features for the likelihood p(x|z)
        output_activations = self.decoder.reconstruct(post_samples,
                                                      is_training,
                                                      batch_norm_update)
        # add data bias
        output_activations[0] = output_activations[0] + self.train_bias
        # form the output dist util.
        output_dist = FactorialBernoulliUtil(output_activations)
        # create the final output
        output = tf.nn.sigmoid(output_dist.logit_mu)
        output = self.mix_output_with_input(input, output)

        # concat all the samples
        post_samples_concat = tf.concat(axis=-1, values=post_samples)
        # post_samples_concat = post_samples_mf  # remove this, it uses MF instead of samples

        kl, log_q, log_p = 0., 0., 0.
        if self.config_train['use_iw'] and is_training and k > 1:
            Print('Using IW Obj.')
            if self.encoder_type == 'hierarchical':
                for samples, factorial in zip(post_samples, posterior):
                    log_q += factorial.log_prob(samples, stop_grad=True)

                log_p = self.prior.log_prob(post_samples_concat,
                                            stop_grad=False,
                                            is_training=is_training)
            else:
                log_q = self.posterior.log_prob(posterior,
                                                post_samples_concat,
                                                log_z,
                                                stop_grad=True)
                log_p = -self.prior.energy_tf(
                    post_samples_concat) - self.prior.log_z_train

            if self.is_struct_pred:
                kl = log_q
                log_p, log_q = 0., 0.
        else:
            Print('Using VAE Obj.')
            if self.is_struct_pred:  # add only the entropy loss to the objective function
                log_q = 0.
                if self.encoder_type == 'hierarchical':
                    for samples, factorial in zip(post_samples, posterior):
                        log_q += factorial.log_prob(samples,
                                                    stop_grad=True,
                                                    is_training=True)
                else:
                    log_q = self.posterior.log_prob(posterior,
                                                    post_samples_concat,
                                                    log_z,
                                                    stop_grad=True)
                kl = log_q
            else:
                # compute KL only for VAE case
                if self.encoder_type == 'hierarchical':
                    kl = self.prior.kl_dist_from(posterior, post_samples,
                                                 is_training)
                elif self.encoder_type == 'rbm':
                    kl = self.posterior.kl_from_this(posterior, self.prior,
                                                     post_samples_mf, log_z,
                                                     is_training)

        # expected log prob p(x| z)
        cost = -output_dist.log_prob_per_var(input, stop_grad=False)
        cost = self.process_decoder_target(cost)
        cost = tf.reduce_sum(cost, axis=1)

        return kl, cost, output, log_p, log_q