Exemple #1
0
    def sample_z(self, x, encoder, decoder, W):
        '''
        z: [P,B,Z]
        log_pz: [P,B]
        log_qz: [P,B]
        '''

        #Encode
        z_mean_logvar = encoder.feedforward(x)  #[B,Z*2]
        z_mean = tf.slice(z_mean_logvar, [0, 0],
                          [self.batch_size, self.z_size])  #[B,Z]
        z_logvar = tf.slice(z_mean_logvar, [0, self.z_size],
                            [self.batch_size, self.z_size])  #[B,Z]

        #Sample z  [P,B,Z]
        eps = tf.random_normal(
            (self.n_z_particles, self.batch_size, self.z_size),
            0,
            1,
            seed=self.rs)
        z = tf.add(z_mean, tf.multiply(tf.sqrt(tf.exp(z_logvar)),
                                       eps))  #broadcast, [P,B,Z]

        # Calc log probs [P,B]
        log_pz = log_norm(z, tf.zeros([self.batch_size, self.z_size]),
                          tf.log(tf.ones([self.batch_size, self.z_size])))
        log_qz = log_norm(z, z_mean, z_logvar)

        return z, log_pz, log_qz
Exemple #2
0
    def sample_z(self, x, encoder, decoder):
        '''
        z: [P,B,Z]
        log_pz: [P,B]
        log_qz: [P,B]
        '''

        #Encode
        z_mean_logvar = encoder.feedforward(x)  #[B,Z*2]
        z_mean, z_logvar = split_mean_logvar(z_mean_logvar)

        #Sample z  [P,B,Z]
        eps = tf.random_normal((self.k, self.batch_size, self.z_size),
                               0,
                               1,
                               seed=self.rs)
        z = tf.add(z_mean, tf.multiply(tf.sqrt(tf.exp(z_logvar)),
                                       eps))  #broadcast, [P,B,Z]

        # Calc log probs [P,B]
        log_pz = log_norm(z, tf.zeros([self.batch_size, self.z_size]),
                          tf.zeros([self.batch_size, self.z_size]))
        log_qz = log_norm(z, z_mean, z_logvar)

        return z, log_pz, log_qz
Exemple #3
0
    def sample_z(self, x, encoder, decoder, W):
        '''
        z: [B,Z]
        log_pz: [B]
        log_qz: [B]
        '''

        #Encode
        z_mean_logvar = encoder.feedforward(x)  #[B,Z*2]
        z_mean = tf.slice(z_mean_logvar, [0, 0],
                          [self.batch_size, self.z_size])  #[B,Z]
        z_logvar = tf.slice(z_mean_logvar, [0, self.z_size],
                            [self.batch_size, self.z_size])  #[B,Z]

        #Sample z0  [B,Z]
        eps = tf.random_normal((self.batch_size, self.z_size),
                               0,
                               1,
                               seed=self.rs)
        z0 = tf.add(z_mean, tf.multiply(tf.sqrt(tf.exp(z_logvar)),
                                        eps))  #[B,Z]

        #HVI
        T = 3
        zT, log_pvT, log_qv0 = HVI.transform(z0, decoder, W)

        # Calc log probs [B]
        log_pz = log_norm(zT, tf.zeros([self.batch_size, self.z_size]),
                          tf.log(tf.ones([self.batch_size, self.z_size])))
        log_qz = log_norm(z0, z_mean, z_logvar)

        return zT, log_pz, log_qz
Exemple #4
0
    def sample_z(self, x, encoder, decoder, k):
        '''
        x: [B,X]
        z: [P,B,Z]
        log_pz: [P,B]
        log_qz: [P,B]
        '''

        #Encode
        z_mean, z_logvar = split_mean_logvar(encoder.feedforward(x))  #[B,Z]

        #Sample z  [P,B,Z]
        eps = tf.random_normal((k, self.batch_size, self.z_size),
                               0,
                               1,
                               seed=self.rs)
        z0 = tf.add(z_mean, tf.multiply(tf.sqrt(tf.exp(z_logvar)),
                                        eps))  #broadcast, [P,B,Z]
        log_qz0 = log_norm(z0, z_mean, z_logvar)

        #[P,B,Z], [P,B]
        z, logdet = self.transform_sample(z0)

        # Calc log probs [P,B]
        log_pzT = log_norm(z, tf.zeros([self.batch_size, self.z_size]),
                           tf.log(tf.ones([self.batch_size, self.z_size])))

        log_pz = log_pzT + logdet

        log_qz = log_qz0

        return z, log_pz, log_qz
Exemple #5
0
    def log_probs(self, x, q_zlx, q_vlxz, r_vlxz, p_xlz):
        '''
        x: [B,X]
        '''

        # Sample z0 and calc log_qz0
        z0_mean, z0_logvar = split_mean_logvar(q_zlx.feedforward(x))  #[B,Z]
        z0 = sample_Gaussian(z0_mean, z0_logvar, self.n_z_particles)  #[P,B,Z]
        log_qz0 = log_norm(z0, z0_mean, z0_logvar)  #[P,B]

        # Sample v0 and calc log_qv0
        z0_reshaped = tf.reshape(
            z0, [self.n_z_particles * self.batch_size, self.z_size])  #[PB,Z]
        x_tiled = tf.tile(x, [self.n_z_particles, 1])  #[PB,X]
        xz = tf.concat([x_tiled, z0_reshaped], axis=1)  #[PB,X+Z]
        v0_mean, v0_logvar = split_mean_logvar(q_vlxz.feedforward(xz))  #[PB,Z]
        v0 = sample_Gaussian(v0_mean, v0_logvar, 1)  #[1,PB,Z]
        v0 = tf.reshape(
            v0, [self.n_z_particles * self.batch_size, self.z_size])  #[PB,Z]
        # v0 = tf.reshape(v0, [self.n_z_particles, self.batch_size, self.z_size]) #[P,B,Z]
        # v0_mean = tf.reshape(v0_mean, [self.n_z_particles, self.batch_size, self.z_size])  #[P,B,Z]
        # v0_logvar = tf.reshape(v0_logvar, [self.n_z_particles, self.batch_size, self.z_size]) #[P,B,Z]
        log_qv0 = log_norm2(v0, v0_mean, v0_logvar)  #[PB]
        log_qv0 = tf.reshape(log_qv0,
                             [self.n_z_particles, self.batch_size])  #[P,B]
        v0 = tf.reshape(
            v0, [self.n_z_particles, self.batch_size, self.z_size])  #[P,B,Z]

        # Transform [P,B,Z]
        zT, vT = self.leapfrogs(z0, v0, p_xlz, x)

        # Reverse model
        z_reshaped = tf.reshape(
            zT, [self.n_z_particles * self.batch_size, self.z_size])  #[PB,Z]
        xz = tf.concat([x_tiled, z_reshaped], axis=1)  #[PB,X+Z]
        vt_mean, vt_logvar = split_mean_logvar(r_vlxz.feedforward(xz))  #[PB,Z]
        vT = tf.reshape(
            vT, [self.n_z_particles * self.batch_size, self.z_size])  #[PB,Z]
        log_rv = log_norm2(vT, vt_mean, vt_logvar)  #[PB]
        log_rv = tf.reshape(log_rv,
                            [self.n_z_particles, self.batch_size])  #[PB]

        log_pz = log_norm(zT, tf.zeros([self.batch_size, self.z_size]),
                          tf.log(tf.ones([self.batch_size,
                                          self.z_size])))  #[P,B]

        # # Decode [P,B,X]
        # zT_reshaped = tf.reshape(zT, [self.n_z_particles*self.batch_size, self.z_size]) #[PB,Z]
        # x_mean = p_xlz.feedforward(zT_reshaped) #[PB,X]
        # x_mean = tf.reshape(x_mean, [self.n_z_particles, self.batch_size, self.z_size]) #[P,B,Z]
        # log_px = log_bern(x,x_mean) #[P,B]

        # [P,B]
        log_px = self._log_px(p_xlz, x, zT)

        # #[B]
        # elbo = log_px + log_pz + tf.reduce_sum(log_rzs,axis=0) - log_qz0 - tf.reduce_sum(log_qzs, axis=0)

        # return tf.reduce_mean(elbo) #over batch
        return [log_px, log_pz, log_qz0, log_qv0, log_rv]
Exemple #6
0
    def sample_z(self, x, encoder, decoder, W):
        '''
        z: [P,B,Z]
        log_pz: [P,B]
        log_qz: [P,B]
        '''

        for i in range(len(W)):

            if i == 0:
                flatten_W = tf.reshape(W[i], [-1])
                # print flatten_W
            else:
                flattt = tf.reshape(W[i], [-1])
                # print flattt
                flatten_W = tf.concat([flatten_W, flattt], axis=0)

        flatten_W = tf.reshape(flatten_W, [1, -1])
        tiled = tf.tile(flatten_W, [self.batch_size, 1])
        intput_ = tf.concat([x, tiled], axis=1)

        #Encode
        z_mean_logvar = encoder.feedforward(intput_)  #[B,Z*2]
        z_mean = tf.slice(z_mean_logvar, [0, 0],
                          [self.batch_size, self.z_size])  #[B,Z]
        z_logvar = tf.slice(z_mean_logvar, [0, self.z_size],
                            [self.batch_size, self.z_size])  #[B,Z]

        #Sample z  [P,B,Z]
        eps = tf.random_normal(
            (self.n_z_particles, self.batch_size, self.z_size),
            0,
            1,
            seed=self.rs)
        z0 = tf.add(z_mean, tf.multiply(tf.sqrt(tf.exp(z_logvar)),
                                        eps))  #broadcast, [P,B,Z]
        log_qz0 = log_norm(z0, z_mean, z_logvar)

        #[P,B,Z], [P,B]
        z, logdet = self.transform_sample(z0)

        # Calc log probs [P,B]
        log_pzT = log_norm(z, tf.zeros([self.batch_size, self.z_size]),
                           tf.log(tf.ones([self.batch_size, self.z_size])))

        log_pz = log_pzT + logdet

        log_qz = log_qz0

        # # Calc log probs [P,B]
        # log_pz = log_norm(z, tf.zeros([self.batch_size, self.z_size]),
        #                         tf.log(tf.ones([self.batch_size, self.z_size])))
        # log_qz = log_norm(z, z_mean, z_logvar)

        return z, log_pz, log_qz
Exemple #7
0
    def log_probs(self, x, q_zlx, q_vlxz, r_vlxz, p_xlz):
        '''
        x: [B,X]
        '''

        # Sample z0 and calc log_qz0
        z0_mean, z0_logvar = split_mean_logvar(q_zlx.feedforward(x)) #[B,Z]
        z0 = sample_Gaussian(z0_mean, z0_logvar, self.n_z_particles) #[P,B,Z]
        log_qz0 = log_norm(z0, z0_mean, z0_logvar) #[P,B]

        # Sample v0 and calc log_qv0
        z0_reshaped = tf.reshape(z0, [self.n_z_particles*self.batch_size, self.z_size]) #[PB,Z]
        x_tiled = tf.tile(x, [self.n_z_particles, 1]) #[PB,X]
        xz = tf.concat([x_tiled,z0_reshaped], axis=1) #[PB,X+Z]
        v0_mean, v0_logvar = split_mean_logvar(q_vlxz.feedforward(xz)) #[PB,Z]
        v0 = sample_Gaussian(v0_mean, v0_logvar, 1) #[1,PB,Z]
        v0 = tf.reshape(v0, [self.n_z_particles*self.batch_size, self.z_size]) #[PB,Z]
        # v0 = tf.reshape(v0, [self.n_z_particles, self.batch_size, self.z_size]) #[P,B,Z]
        # v0_mean = tf.reshape(v0_mean, [self.n_z_particles, self.batch_size, self.z_size])  #[P,B,Z]
        # v0_logvar = tf.reshape(v0_logvar, [self.n_z_particles, self.batch_size, self.z_size]) #[P,B,Z]
        log_qv0 = log_norm2(v0, v0_mean, v0_logvar) #[PB]
        log_qv0 = tf.reshape(log_qv0, [self.n_z_particles, self.batch_size]) #[P,B]
        v0 = tf.reshape(v0, [self.n_z_particles, self.batch_size, self.z_size]) #[P,B,Z]

        # Transform [P,B,Z]
        zT, vT = self.leapfrogs(z0,v0,p_xlz,x)

        # Reverse model
        z_reshaped = tf.reshape(zT, [self.n_z_particles*self.batch_size, self.z_size]) #[PB,Z]
        xz = tf.concat([x_tiled,z_reshaped], axis=1) #[PB,X+Z]
        vt_mean, vt_logvar = split_mean_logvar(r_vlxz.feedforward(xz)) #[PB,Z]
        vT = tf.reshape(vT, [self.n_z_particles*self.batch_size, self.z_size]) #[PB,Z]
        log_rv = log_norm2(vT, vt_mean, vt_logvar) #[PB]
        log_rv = tf.reshape(log_rv, [self.n_z_particles, self.batch_size]) #[PB]


        log_pz = log_norm(zT, tf.zeros([self.batch_size, self.z_size]), 
                                tf.log(tf.ones([self.batch_size, self.z_size]))) #[P,B]

        # # Decode [P,B,X]
        # zT_reshaped = tf.reshape(zT, [self.n_z_particles*self.batch_size, self.z_size]) #[PB,Z]
        # x_mean = p_xlz.feedforward(zT_reshaped) #[PB,X]
        # x_mean = tf.reshape(x_mean, [self.n_z_particles, self.batch_size, self.z_size]) #[P,B,Z]
        # log_px = log_bern(x,x_mean) #[P,B]

        # [P,B]
        log_px = self._log_px(p_xlz, x, zT)

        # #[B]
        # elbo = log_px + log_pz + tf.reduce_sum(log_rzs,axis=0) - log_qz0 - tf.reduce_sum(log_qzs, axis=0)

        # return tf.reduce_mean(elbo) #over batch 
        return [log_px, log_pz, log_qz0, log_qv0, log_rv]
Exemple #8
0
    def sample_z(self, x, encoder, decoder, W):
        '''
        z: [P,B,Z]
        log_pz: [P,B]
        log_qz: [P,B]
        '''

        for i in range(len(W)):

            if i ==0:
                flatten_W = tf.reshape(W[i], [-1])
                # print flatten_W
            else:
                flattt = tf.reshape(W[i], [-1])
                # print flattt
                flatten_W = tf.concat([flatten_W, flattt], axis=0)

        flatten_W = tf.reshape(flatten_W, [1,-1])
        tiled = tf.tile(flatten_W, [self.batch_size, 1])
        intput_ = tf.concat([x,tiled], axis=1)

        #Encode
        z_mean_logvar = encoder.feedforward(intput_) #[B,Z*2]
        z_mean = tf.slice(z_mean_logvar, [0,0], [self.batch_size, self.z_size]) #[B,Z] 
        z_logvar = tf.slice(z_mean_logvar, [0,self.z_size], [self.batch_size, self.z_size]) #[B,Z]

        #Sample z  [P,B,Z]
        eps = tf.random_normal((self.n_z_particles, self.batch_size, self.z_size), 0, 1, seed=self.rs) 
        z0 = tf.add(z_mean, tf.multiply(tf.sqrt(tf.exp(z_logvar)), eps)) #broadcast, [P,B,Z]
        log_qz0 = log_norm(z0, z_mean, z_logvar)


        #[P,B,Z], [P,B]
        z,logdet = self.transform_sample(z0)

        # Calc log probs [P,B]
        log_pzT = log_norm(z, tf.zeros([self.batch_size, self.z_size]), 
                                tf.log(tf.ones([self.batch_size, self.z_size])))
        
        log_pz = log_pzT  + logdet

        log_qz =  log_qz0 


        # # Calc log probs [P,B]
        # log_pz = log_norm(z, tf.zeros([self.batch_size, self.z_size]), 
        #                         tf.log(tf.ones([self.batch_size, self.z_size])))
        # log_qz = log_norm(z, z_mean, z_logvar)

        return z, log_pz, log_qz
Exemple #9
0
    def _log_pz(self, z):
        '''
        z:[P,B,Z]
        output:[P,B]
        '''

        return log_norm(z, tf.zeros([self.batch_size, self.z_size]), tf.log(tf.ones([self.batch_size, self.z_size])))
Exemple #10
0
    def _log_pz(self, z):
        '''
        z:[P,B,Z]
        output:[P,B]
        '''

        return log_norm(z, tf.zeros([self.batch_size, self.z_size]),
                        tf.log(tf.ones([self.batch_size, self.z_size])))
Exemple #11
0
    def sample_z(self, x, encoder, decoder, W):
        '''
        z: [P,B,Z]
        log_pz: [P,B]
        log_qz: [P,B]
        '''

        #Encode
        z_mean_logvar = encoder.feedforward(x) #[B,Z*2]
        z_mean = tf.slice(z_mean_logvar, [0,0], [self.batch_size, self.z_size]) #[B,Z] 
        z_logvar = tf.slice(z_mean_logvar, [0,self.z_size], [self.batch_size, self.z_size]) #[B,Z]

        #Sample z  [P,B,Z]
        eps = tf.random_normal((self.n_z_particles, self.batch_size, self.z_size), 0, 1, seed=self.rs) 
        z = tf.add(z_mean, tf.multiply(tf.sqrt(tf.exp(z_logvar)), eps)) #broadcast, [P,B,Z]

        # Calc log probs [P,B]
        log_pz = log_norm(z, tf.zeros([self.batch_size, self.z_size]), 
                                tf.log(tf.ones([self.batch_size, self.z_size])))
        log_qz = log_norm(z, z_mean, z_logvar)

        return z, log_pz, log_qz
Exemple #12
0
    def _log_px(self, p_xlz, x, z, k):
        '''
        x: [B,X]
        z: [P,B,Z]
        output: [P,B]
        '''

        z_reshaped = tf.reshape(z, [k * self.batch_size, self.z_size])  #[PB,Z]
        x_mean = p_xlz.feedforward(z_reshaped)  #[PB,X]
        x_mean = tf.reshape(x_mean,
                            [k, self.batch_size, self.x_size])  #[P,B,Z]

        log_px = log_bern(x, x_mean)  #[P,B]

        log_pz = log_norm(z, tf.zeros([self.batch_size, self.z_size]),
                          tf.log(tf.ones([self.batch_size, self.z_size])))

        return log_px + log_pz
Exemple #13
0
        def foo(i, log_pxi, log_pzi, log_qzi, log_pWi, log_qWi):

            # Sample decoder weights  __, [1], [1]
            W, log_pW, log_qW = decoder.sample_weights()

            # Sample z   [P,B,Z], [P,B], [P,B]
            z, log_pz, log_qz = sample_z.sample_z(x, encoder, decoder, W)
            z = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size]) #[PB,Z]

            # Decode [PB,X]
            y = decoder.feedforward(W, z)
            

            # Likelihood p(x|z)  [P,B]
            if self.likelihood == 'Bernoulli':
                y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size]) #[P,B,X]
                log_px = log_bern(x,y)

            elif self.likelihood == 'Gaussian':
                mean, logvar = split_mean_logvar(y) #[PB,Z]
                x_ = tf.reshape(x, [1, self.batch_size, self.x_size])

                log_px = log_norm(x_, mean, logvar)


            # Reshape and concat results
            log_pW = tf.reshape(log_pW, [1])
            log_qW = tf.reshape(log_qW, [1])
            log_pz = tf.reshape(log_pz, [1, self.n_z_particles, self.batch_size])
            log_qz = tf.reshape(log_qz, [1, self.n_z_particles, self.batch_size])
            log_px = tf.reshape(log_px, [1, self.n_z_particles, self.batch_size])

            log_px = tf.concat([log_pxi, log_px], axis=0)
            log_pz = tf.concat([log_pzi, log_pz], axis=0)
            log_qz = tf.concat([log_qzi, log_qz], axis=0)
            log_pW = tf.concat([log_pWi, log_pW], axis=0)
            log_qW = tf.concat([log_qWi, log_qW], axis=0)

            return [i+1, log_px, log_pz, log_qz, log_pW, log_qW]
Exemple #14
0
    def __init__(self, n_classes):
        
        tf.reset_default_graph()

        #Model hyperparameters
        # self.act_func = tf.nn.softplus #tf.tanh
        self.learning_rate = .0001
        self.rs = 0
        self.input_size = 784
        self.output_size = n_classes
        # self.net = network_architecture

        #Placeholders - Inputs/Targets [B,X]
        # self.batch_size = tf.placeholder(tf.int32, None)
        # self.n_particles = tf.placeholder(tf.int32, None)
        self.one_over_N = tf.placeholder(tf.float32, None)
        self.x = tf.placeholder(tf.float32, [None, self.input_size])
        self.y = tf.placeholder(tf.float32, [None, self.output_size])


        # first_half_NN = NN([784, 100, 100, 2], [tf.nn.tanh,tf.nn.tanh, None])
        first_half_NN = NN([784, 100, 100, 2], [tf.nn.softplus,tf.nn.softplus, None])
        second_half_BNN = BNN([2,100, 100, n_classes], [tf.nn.softplus,tf.nn.softplus, None])


        Ws, log_p_W_sum, log_q_W_sum = second_half_BNN.sample_weights()

        #Feedforward [B,2]
        self.z = first_half_NN.feedforward(self.x)

        log_pz = tf.reduce_mean(log_norm(self.z, tf.zeros([2]), tf.log(tf.ones([2]))))



        self.y2 = second_half_BNN.feedforward(self.z, Ws)
        # y_hat, log_p_W, log_q_W = self.model(self.x)

        #Likelihood [B,P]
        # log_p_y_hat = self.log_likelihood(self.y, y_hat)
        softmax_error = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=self.y2))

        #Objective
        # self.elbo = self.objective(log_p_y_hat, log_p_W, log_q_W, self.batch_fraction_of_dataset)
        self.cost = softmax_error - log_pz/70. + self.one_over_N*(-log_p_W_sum + log_q_W_sum) + .00001*first_half_NN.weight_decay()

        # Minimize negative ELBO
        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, epsilon=1e-02).minimize(self.cost)






        #For evaluation
        self.prediction = tf.nn.softmax(self.y2) 

        #To init variables
        self.init_vars = tf.global_variables_initializer()

        #For loadind/saving variables
        self.saver = tf.train.Saver()

        #For debugging 
        # self.vars = tf.trainable_variables()
        # self.grads = tf.gradients(self.elbo, tf.trainable_variables())

        #to make sure im not adding nodes to the graph
        tf.get_default_graph().finalize()

        #Start session
        self.sess = tf.Session()
Exemple #15
0
    def log_probs(self, x, q_zlx, q_zlxz, r_zlxz, p_xlz):
        '''
        x: [B,X]
        '''

        # log_px_list = []
        # log_pz_list = []
        log_qz_list = []
        log_rz_list = []


        # Sample z0 and calc log_qz0
        z0_mean, z0_logvar = self.split_mean_logvar(q_zlx.feedforward(x)) #[B,Z]
        z0 = self.sample_Gaussian(z0_mean, z0_logvar) #[B,Z]
        log_qz0 = log_norm(z0, z0_mean, z0_logvar) #[B]

        for t in range(self.n_transitions):

            #reverse model
            zt_mean, zt_logvar = self.split_mean_logvar(r_zlxz.feedforward(x)) #[B,Z]
            log_rzt = log_norm(z_minus1, zt_mean, zt_logvar) #[B]
            log_rz_list.append(log_rzt)

            #new sample
            xz = tf.concat([x,z], axis=1) #[B,X+Z]
            z_mean, z_logvar = self.split_mean_logvar(q_zlxz.feedforward(xz)) #[B,Z]
            z = self.sample_Gaussian(z_mean, z_logvar) #[B,Z]
            log_qz = log_norm(z, z_mean, z_logvar) #[B]
            log_qz_list.append(log_qz)


        log_rzs = tf.stack(log_rz_list) #[T,B]
        log_qzs = tf.stack(log_qz_list) #[T,B]

        log_pz = log_norm(z, tf.zeros([self.batch_size, self.z_size]), 
                                tf.log(tf.ones([self.batch_size, self.z_size]))) #[B]

        # Sample z   [P,B,Z], [P,B], [P,B]
        # z, log_pz, log_qz = self.sample_z(x, encoder, decoder, W)
        # z: [PB,Z]
        # z = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size])

        # Decode [B,X]
        x_mean = p_xlz.feedforward(z)
        log_px = log_bern(x,x_mean)

        # y: [P,B,X]
        # y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size])

        # Likelihood p(x|z)  [B]
        

        #Store for later
        # log_px_list.append(log_px)
        # log_pz_list.append(log_pz)
        # log_qz_list.append(log_qz)

        # log_px = tf.stack(log_px_list) #[P,B]
        # log_pz = tf.stack(log_pz_list) #[P,B]
        # log_qz = tf.stack(log_qz_list) #[P,B]

        #[B]
        elbo = log_px + log_pz + tf.reduce_sum(log_rzs,axis=0) - log_qz0 - tf.reduce_sum(log_qzs, axis=0)

        return tf.reduce_mean(elbo) #over batch