Exemple #1
0
    def objective(self, x):
        '''
        Returns scalar to maximize
        '''


        encoder = NN(self.encoder_net, self.encoder_act_func, self.batch_size)
        decoder = BNN(self.decoder_net, self.decoder_act_func, self.batch_size)

        log_px_list = []
        log_pz_list = []
        log_qz_list = []
        log_pW_list = []
        log_qW_list = []

        for W_i in range(self.n_W_particles):

            # Sample decoder weights  __, [1], [1]
            W, log_pW, log_qW = decoder.sample_weights()

            # Sample z   [P,B,Z], [P,B], [P,B]
            z, log_pz, log_qz = self.sample_z(x, encoder, decoder, W)
            # z: [PB,Z]
            z = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size])

            # Decode [PB,X]
            y = decoder.feedforward(W, z)
            # y: [P,B,X]
            y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size])

            # Likelihood p(x|z)  [P,B]
            log_px = log_bern(x,y)

            #Store for later
            log_px_list.append(log_px)
            log_pz_list.append(log_pz)
            log_qz_list.append(log_qz)
            log_pW_list.append(log_pW)
            log_qW_list.append(log_qW)


        log_px = tf.stack(log_px_list) #[S,P,B]
        log_pz = tf.stack(log_pz_list) #[S,P,B]
        log_qz = tf.stack(log_qz_list) #[S,P,B]
        log_pW = tf.stack(log_pW_list) #[S]
        log_qW = tf.stack(log_qW_list) #[S]

        # Calculte log probs for printing
        self.log_px = tf.reduce_mean(log_px)
        self.log_pz = tf.reduce_mean(log_pz)
        self.log_qz = tf.reduce_mean(log_qz)
        self.log_pW = tf.reduce_mean(log_pW)
        self.log_qW = tf.reduce_mean(log_qW)
        self.z_elbo = self.log_px + self.log_pz - self.log_qz 


        #Calc elbo
        elbo = self.log_px + self.log_pz - self.log_qz + self.batch_frac*(self.log_pW - self.log_qW)

        return elbo
Exemple #2
0
    def log_probs2(self, x, encoder, decoder, sample_z):
        #no loop over W samples

        # Sample decoder weights  __, [1], [1]
        W, log_pW, log_qW = decoder.sample_weights(self.scale_log_probs)

        # Sample z   [P,B,Z], [P,B], [P,B]
        z, log_pz, log_qz = sample_z.sample_z(x, encoder, decoder, W)
        self.z = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size]) #[PB,Z]

        # Decode [PB,X]
        y = decoder.feedforward(W, self.z)
        

        # Likelihood p(x|z)  [P,B]
        if self.likelihood == 'Bernoulli':
            y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size]) #[P,B,X]
            log_px = log_bern(x,y)

        elif self.likelihood == 'Gaussian':
            self.x_mean, self.x_logvar = split_mean_logvar(y) #[PB,X]
            x_ = tf.tile(x, [self.n_z_particles, 1])
            self.log_px_ = log_norm4(x_, self.x_mean, self.x_logvar) #[PB]


        log_pz = tf.reshape(log_pz, [1, self.n_z_particles, self.batch_size])
        log_qz = tf.reshape(log_qz, [1, self.n_z_particles, self.batch_size])
        log_px = tf.reshape(self.log_px_, [1, self.n_z_particles, self.batch_size])

        if self.scale_log_probs:
            return [log_px/self.x_size, log_pz/self.z_size, log_qz/self.z_size, log_pW, log_qW] 
        else: 
            return [log_px, log_pz, log_qz, log_pW, log_qW] 
Exemple #3
0
        def foo(i, log_pxi, log_pzi, log_qzi, log_pWi, log_qWi):

            # Sample decoder weights  __, [1], [1]
            W, log_pW, log_qW = decoder.sample_weights()

            # Sample z   [P,B,Z], [P,B], [P,B]
            z, log_pz, log_qz = sample_z.sample_z(x, encoder, decoder, W)
            z = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size]) #[PB,Z]

            # Decode [PB,X]
            y = decoder.feedforward(W, z)
            y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size]) #[P,B,X]

            # Likelihood p(x|z)  [P,B]
            log_px = log_bern(x,y)


            # Reshape and concat results
            log_pW = tf.reshape(log_pW, [1])
            log_qW = tf.reshape(log_qW, [1])
            log_pz = tf.reshape(log_pz, [1, self.n_z_particles, self.batch_size])
            log_qz = tf.reshape(log_qz, [1, self.n_z_particles, self.batch_size])
            log_px = tf.reshape(log_px, [1, self.n_z_particles, self.batch_size])

            log_px = tf.concat([log_pxi, log_px], axis=0)
            log_pz = tf.concat([log_pzi, log_pz], axis=0)
            log_qz = tf.concat([log_qzi, log_qz], axis=0)
            log_pW = tf.concat([log_pWi, log_pW], axis=0)
            log_qW = tf.concat([log_qWi, log_qW], axis=0)

            return [i+1, log_px, log_pz, log_qz, log_pW, log_qW]
Exemple #4
0
        def foo(i, log_pxi, log_pzi, log_qzi, log_pWi, log_qWi):

            # Sample decoder weights  __, [1], [1]
            W, log_pW, log_qW = decoder.sample_weights()
            log_pW = tf.reshape(log_pW, [1])
            log_qW = tf.reshape(log_qW, [1])

            # Sample z   [P,B,Z], [P,B], [P,B]
            z, log_pz, log_qz = sample_z.sample_z(x, encoder, decoder, W)
            log_pz = tf.reshape(log_pz, [1, self.n_z_particles, self.batch_size])
            log_qz = tf.reshape(log_qz, [1, self.n_z_particles, self.batch_size])
            # z: [PB,Z]
            z = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size])

            # Decode [PB,X]
            y = decoder.feedforward(W, z)
            # y: [P,B,X]
            y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size])

            # Likelihood p(x|z)  [P,B]
            log_px = log_bern(x,y)
            log_px = tf.reshape(log_px, [1, self.n_z_particles, self.batch_size])

            log_px = tf.concat([log_pxi, log_px], axis=0)
            log_pz = tf.concat([log_pzi, log_pz], axis=0)
            log_qz = tf.concat([log_qzi, log_qz], axis=0)
            log_pW = tf.concat([log_pWi, log_pW], axis=0)
            log_qW = tf.concat([log_qWi, log_qW], axis=0)

            return [i+1, log_px, log_pz, log_qz, log_pW, log_qW]
Exemple #5
0
    def objective(self, x):
        '''
        Returns scalar to maximize
        '''

        encoder = NN(self.encoder_net, self.encoder_act_func, self.batch_size)
        decoder = BNN(self.decoder_net, self.decoder_act_func, self.batch_size)

        log_px_list = []
        log_pz_list = []
        log_qz_list = []
        log_pW_list = []
        log_qW_list = []

        for W_i in range(self.n_W_particles):

            # Sample decoder weights  __, [1], [1]
            W, log_pW, log_qW = decoder.sample_weights()

            # Sample z   [P,B,Z], [P,B], [P,B]
            z, log_pz, log_qz = self.sample_z(x, encoder, decoder, W)
            # z: [PB,Z]
            z = tf.reshape(z,
                           [self.n_z_particles * self.batch_size, self.z_size])

            # Decode [PB,X]
            y = decoder.feedforward(W, z)
            # y: [P,B,X]
            y = tf.reshape(y,
                           [self.n_z_particles, self.batch_size, self.x_size])

            # Likelihood p(x|z)  [P,B]
            log_px = log_bern(x, y)

            #Store for later
            log_px_list.append(log_px)
            log_pz_list.append(log_pz)
            log_qz_list.append(log_qz)
            log_pW_list.append(log_pW)
            log_qW_list.append(log_qW)

        log_px = tf.stack(log_px_list)  #[S,P,B]
        log_pz = tf.stack(log_pz_list)  #[S,P,B]
        log_qz = tf.stack(log_qz_list)  #[S,P,B]
        log_pW = tf.stack(log_pW_list)  #[S]
        log_qW = tf.stack(log_qW_list)  #[S]

        # Calculte log probs for printing
        self.log_px = tf.reduce_mean(log_px)
        self.log_pz = tf.reduce_mean(log_pz)
        self.log_qz = tf.reduce_mean(log_qz)
        self.log_pW = tf.reduce_mean(log_pW)
        self.log_qW = tf.reduce_mean(log_qW)
        self.z_elbo = self.log_px + self.log_pz - self.log_qz

        #Calc elbo
        elbo = self.log_px + self.log_pz - self.log_qz + self.batch_frac * (
            self.log_pW - self.log_qW)

        return elbo
Exemple #6
0
        def func_for_scan(prev_output, current_element):

            # Sample decoder weights  __, [1], [1]
            W, log_pW, log_qW = decoder.sample_weights()

            # Sample z   [P,B,Z], [P,B], [P,B]
            z, log_pz, log_qz = self.sample_z(x, encoder, decoder, W)
            # z: [PB,Z]
            z = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size])

            # Decode [PB,X]
            y = decoder.feedforward(W, z)
            # y: [P,B,X]
            y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size])

            # Likelihood p(x|z)  [P,B]
            log_px = log_bern(x,y)

            #Store for later
            # log_pW_list.append(tf.reduce_mean(log_pW))
            # log_qW_list.append(tf.reduce_mean(log_qW))
            # log_pz_list.append(tf.reduce_mean(log_pz))
            # log_qz_list.append(tf.reduce_mean(log_qz))
            # log_px_list.append(tf.reduce_mean(log_px))

            to_output = []
            to_output.append(tf.reduce_mean(log_px))
            to_output.append(tf.reduce_mean(log_pz))
            to_output.append(tf.reduce_mean(log_qz))   
            to_output.append(tf.reduce_mean(log_pW))
            to_output.append(tf.reduce_mean(log_qW))
                    
            return tf.stack(to_output)
Exemple #7
0
    def objective(self, x):
        '''
        Returns scalar to maximize
        '''

        encoder = NN(self.encoder_net, self.encoder_act_func, self.batch_size)
        decoder = BNN(self.decoder_net, self.decoder_act_func, self.batch_size)

        log_pW_list = []
        log_qW_list = []
        log_pz_list = []
        log_qz_list = []
        log_px_list = []

        for W_i in range(self.n_W_particles):

            # Sample decoder weights  __, [1], [1]
            W, log_pW, log_qW = decoder.sample_weights()

            # Sample z   [B,Z], [B], [B]
            z, log_pz, log_qz = self.sample_z(x, encoder, decoder, W)

            # Decode [B,X]
            y = decoder.feedforward(W, z)

            # Likelihood p(x|z)  [B]
            log_px = log_bern(x, y)

            #Store for later
            log_pW_list.append(tf.reduce_mean(log_pz))
            log_qW_list.append(tf.reduce_mean(log_pz))
            log_pz_list.append(tf.reduce_mean(log_pz))
            log_qz_list.append(tf.reduce_mean(log_pz))
            log_px_list.append(tf.reduce_mean(log_px))

        # Calculte log probs
        self.log_px = tf.reduce_mean(
            tf.stack(log_px))  #over batch + W_particles + z_particles
        self.log_pz = tf.reduce_mean(
            tf.stack(log_pz))  #over batch + z_particles
        self.log_qz = tf.reduce_mean(
            tf.stack(log_qz))  #over batch + z_particles
        self.log_pW = tf.reduce_mean(tf.stack(log_pW))  #W_particles
        self.log_qW = tf.reduce_mean(tf.stack(log_qW))  #W_particles

        self.z_elbo = self.log_px + self.log_pz - self.log_qz

        #Calc elbo
        elbo = self.log_px + self.log_pz - self.log_qz + self.batch_frac * (
            self.log_pW - self.log_qW)

        return elbo
Exemple #8
0
    def log_probs(self, x, encoder, decoder):

        # Sample z   [P,B,Z], [P,B], [P,B]
        z, log_pz, log_qz = self.sample_z(x, encoder, decoder)
        z = tf.reshape(z, [self.k * self.batch_size, self.z_size])  #[PB,Z]

        # Decode [PB,X]
        y = decoder.feedforward(z)
        y = tf.reshape(y, [self.k, self.batch_size, self.x_size])  #[P,B,X]

        # Likelihood p(x|z)  [P,B]
        log_px = log_bern(x, y)

        return log_px, log_pz, log_qz
Exemple #9
0
    def _log_px(self, p_xlz, x, z):
        '''
        x: [B,X]
        z: [P,B,Z]
        output: [P,B]
        '''

        z_reshaped = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size]) #[PB,Z]
        x_mean = p_xlz.feedforward(z_reshaped) #[PB,X]
        x_mean = tf.reshape(x_mean, [self.n_z_particles, self.batch_size, self.x_size]) #[P,B,Z]

        log_px = log_bern(x,x_mean) #[P,B]

        return log_px
Exemple #10
0
    def objective(self, x):
        '''
        Returns scalar to maximize
        '''

        encoder = NN(self.encoder_net, self.encoder_act_func, self.batch_size)
        decoder = BNN(self.decoder_net, self.decoder_act_func, self.batch_size)

        log_pW_list = []
        log_qW_list = []
        log_pz_list = []
        log_qz_list = []
        log_px_list = []

        for W_i in range(self.n_W_particles):

            # Sample decoder weights  __, [1], [1]
            W, log_pW, log_qW = decoder.sample_weights()

            # Sample z   [B,Z], [B], [B]
            z, log_pz, log_qz = self.sample_z(x, encoder, decoder, W)

            # Decode [B,X]
            y = decoder.feedforward(W, z)

            # Likelihood p(x|z)  [B]
            log_px = log_bern(x,y)

            #Store for later
            log_pW_list.append(tf.reduce_mean(log_pz))
            log_qW_list.append(tf.reduce_mean(log_pz))
            log_pz_list.append(tf.reduce_mean(log_pz))
            log_qz_list.append(tf.reduce_mean(log_pz))
            log_px_list.append(tf.reduce_mean(log_px))

        # Calculte log probs
        self.log_px = tf.reduce_mean(tf.stack(log_px)) #over batch + W_particles + z_particles
        self.log_pz = tf.reduce_mean(tf.stack(log_pz)) #over batch + z_particles
        self.log_qz = tf.reduce_mean(tf.stack(log_qz)) #over batch + z_particles
        self.log_pW = tf.reduce_mean(tf.stack(log_pW)) #W_particles
        self.log_qW = tf.reduce_mean(tf.stack(log_qW)) #W_particles

        self.z_elbo = self.log_px + self.log_pz - self.log_qz 


        #Calc elbo
        elbo = self.log_px + self.log_pz - self.log_qz + self.batch_frac*(self.log_pW - self.log_qW)

        return elbo
Exemple #11
0
    def _log_px(self, z, decoder, W, x):
        '''
        z: [P,B,Z]
        '''

        # z: [PB,Z]
        z = tf.reshape(z, [self.n_z_particles * self.batch_size, self.z_size])
        # Decode [PB,X]
        y = decoder.feedforward(W, z)
        # y: [P,B,X]
        y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size])
        # Likelihood p(x|z)  [P,B]
        log_px = log_bern(x, y)

        return log_px
Exemple #12
0
    def _log_px(self, z, decoder, W, x):
        '''
        z: [P,B,Z]
        '''

        # z: [PB,Z]
        z = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size])
        # Decode [PB,X]
        y = decoder.feedforward(W, z)
        # y: [P,B,X]
        y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size])
        # Likelihood p(x|z)  [P,B]
        log_px = log_bern(x,y)

        return log_px
Exemple #13
0
    def _log_px(self, p_xlz, x, z):
        '''
        x: [B,X]
        z: [P,B,Z]
        output: [P,B]
        '''

        z_reshaped = tf.reshape(
            z, [self.n_z_particles * self.batch_size, self.z_size])  #[PB,Z]
        x_mean = p_xlz.feedforward(z_reshaped)  #[PB,X]
        x_mean = tf.reshape(
            x_mean,
            [self.n_z_particles, self.batch_size, self.x_size])  #[P,B,Z]

        log_px = log_bern(x, x_mean)  #[P,B]

        return log_px
Exemple #14
0
    def _log_px(self, p_xlz, x, z, k):
        '''
        x: [B,X]
        z: [P,B,Z]
        output: [P,B]
        '''

        z_reshaped = tf.reshape(z, [k * self.batch_size, self.z_size])  #[PB,Z]
        x_mean = p_xlz.feedforward(z_reshaped)  #[PB,X]
        x_mean = tf.reshape(x_mean,
                            [k, self.batch_size, self.x_size])  #[P,B,Z]

        log_px = log_bern(x, x_mean)  #[P,B]

        log_pz = log_norm(z, tf.zeros([self.batch_size, self.z_size]),
                          tf.log(tf.ones([self.batch_size, self.z_size])))

        return log_px + log_pz
Exemple #15
0
    def log_probs(self, x, encoder, decoder):

        log_px_list = []
        log_pz_list = []
        log_qz_list = []
        log_pW_list = []
        log_qW_list = []

        for W_i in range(self.n_W_particles):

            # Sample decoder weights  __, [1], [1]
            W, log_pW, log_qW = decoder.sample_weights()

            # Sample z   [P,B,Z], [P,B], [P,B]
            z, log_pz, log_qz = self.sample_z(x, encoder, decoder, W)
            # z: [PB,Z]
            z = tf.reshape(z,
                           [self.n_z_particles * self.batch_size, self.z_size])

            # Decode [PB,X]
            y = decoder.feedforward(W, z)
            # y: [P,B,X]
            y = tf.reshape(y,
                           [self.n_z_particles, self.batch_size, self.x_size])

            # Likelihood p(x|z)  [P,B]
            log_px = log_bern(x, y)

            #Store for later
            log_px_list.append(log_px)
            log_pz_list.append(log_pz)
            log_qz_list.append(log_qz)
            log_pW_list.append(log_pW)
            log_qW_list.append(log_qW)

        log_px = tf.stack(log_px_list)  #[S,P,B]
        log_pz = tf.stack(log_pz_list)  #[S,P,B]
        log_qz = tf.stack(log_qz_list)  #[S,P,B]
        log_pW = tf.stack(log_pW_list)  #[S]
        log_qW = tf.stack(log_qW_list)  #[S]

        return [log_px, log_pz, log_qz, log_pW, log_qW]
Exemple #16
0
    def log_probs(self, x, encoder, decoder):

        log_px_list = []
        log_pz_list = []
        log_qz_list = []
        log_pW_list = []
        log_qW_list = []

        for W_i in range(self.n_W_particles):

            # Sample decoder weights  __, [1], [1]
            W, log_pW, log_qW = decoder.sample_weights()

            # Sample z   [P,B,Z], [P,B], [P,B]
            z, log_pz, log_qz = self.sample_z(x, encoder, decoder, W)
            # z: [PB,Z]
            z = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size])

            # Decode [PB,X]
            y = decoder.feedforward(W, z)
            # y: [P,B,X]
            y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size])

            # Likelihood p(x|z)  [P,B]
            log_px = log_bern(x,y)

            #Store for later
            log_px_list.append(log_px)
            log_pz_list.append(log_pz)
            log_qz_list.append(log_qz)
            log_pW_list.append(log_pW)
            log_qW_list.append(log_qW)


        log_px = tf.stack(log_px_list) #[S,P,B]
        log_pz = tf.stack(log_pz_list) #[S,P,B]
        log_qz = tf.stack(log_qz_list) #[S,P,B]
        log_pW = tf.stack(log_pW_list) #[S]
        log_qW = tf.stack(log_qW_list) #[S]

        return [log_px, log_pz, log_qz, log_pW, log_qW]  
Exemple #17
0
        def foo(i, log_pxi, log_pzi, log_qzi, log_pWi, log_qWi):

            # Sample decoder weights  __, [1], [1]
            W, log_pW, log_qW = decoder.sample_weights()

            # Sample z   [P,B,Z], [P,B], [P,B]
            z, log_pz, log_qz = sample_z.sample_z(x, encoder, decoder, W)
            z = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size]) #[PB,Z]

            # Decode [PB,X]
            y = decoder.feedforward(W, z)
            

            # Likelihood p(x|z)  [P,B]
            if self.likelihood == 'Bernoulli':
                y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size]) #[P,B,X]
                log_px = log_bern(x,y)

            elif self.likelihood == 'Gaussian':
                mean, logvar = split_mean_logvar(y) #[PB,Z]
                x_ = tf.reshape(x, [1, self.batch_size, self.x_size])

                log_px = log_norm(x_, mean, logvar)


            # Reshape and concat results
            log_pW = tf.reshape(log_pW, [1])
            log_qW = tf.reshape(log_qW, [1])
            log_pz = tf.reshape(log_pz, [1, self.n_z_particles, self.batch_size])
            log_qz = tf.reshape(log_qz, [1, self.n_z_particles, self.batch_size])
            log_px = tf.reshape(log_px, [1, self.n_z_particles, self.batch_size])

            log_px = tf.concat([log_pxi, log_px], axis=0)
            log_pz = tf.concat([log_pzi, log_pz], axis=0)
            log_qz = tf.concat([log_qzi, log_qz], axis=0)
            log_pW = tf.concat([log_pWi, log_pW], axis=0)
            log_qW = tf.concat([log_qWi, log_qW], axis=0)

            return [i+1, log_px, log_pz, log_qz, log_pW, log_qW]
        def foo(i, log_pxi, log_pzi, log_qzi, log_pWi, log_qWi):

            # Sample decoder weights  __, [1], [1]
            W, log_pW, log_qW = decoder.sample_weights()
            log_pW = tf.reshape(log_pW, [1])
            log_qW = tf.reshape(log_qW, [1])

            # Sample z   [P,B,Z], [P,B], [P,B]
            z, log_pz, log_qz = sample_z.sample_z(x, encoder, decoder, W)
            log_pz = tf.reshape(log_pz, [1, self.n_z_particles, self.batch_size])
            log_qz = tf.reshape(log_qz, [1, self.n_z_particles, self.batch_size])
            # z: [PB,Z]
            z = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size])

            # Decode [PB,X]
            y = decoder.feedforward(W, z)
            # y: [P,B,X]
            y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size])

            # Likelihood p(x|z)  [P,B]
            log_px = log_bern(x,y)
            log_px = tf.reshape(log_px, [1, self.n_z_particles, self.batch_size])

            # #Store for later
            # log_px_list.append(log_px)
            # log_pz_list.append(log_pz)
            # log_qz_list.append(log_qz)
            # log_pW_list.append(log_pW)
            # log_qW_list.append(log_qW)

            log_px = tf.concat([log_pxi, log_px], axis=0)
            log_pz = tf.concat([log_pzi, log_pz], axis=0)
            log_qz = tf.concat([log_qzi, log_qz], axis=0)
            log_pW = tf.concat([log_pWi, log_pW], axis=0)
            log_qW = tf.concat([log_qWi, log_qW], axis=0)

            return [i+1, log_px, log_pz, log_qz, log_pW, log_qW]
Exemple #19
0
    def iwae_objective(self, x):
        '''
        Returns scalar to maximize
        x: [B,X]
        '''

        encoder = NN(self.encoder_net, self.encoder_act_func, self.batch_size)
        decoder = BNN(self.decoder_net, self.decoder_act_func, self.batch_size)

        log_px_list = []
        log_pz_list = []
        log_qz_list = []
        log_pW_list = []
        log_qW_list = []

        for W_i in range(self.n_W_particles):

            # Sample decoder weights  __, [1], [1]
            W, log_pW, log_qW = decoder.sample_weights()

            # Sample z   [P,B,Z], [P,B], [P,B]
            z, log_pz, log_qz = self.sample_z(x, encoder, decoder, W)
            # z: [PB,Z]
            z = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size])

            # Decode [PB,X]
            y = decoder.feedforward(W, z)
            # y: [P,B,X]
            y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size])

            # Likelihood p(x|z)  [P,B]
            log_px = log_bern(x,y)

            #Store for later
            log_px_list.append(log_px)
            log_pz_list.append(log_pz)
            log_qz_list.append(log_qz)
            log_pW_list.append(log_pW)
            log_qW_list.append(log_qW)


        log_px = tf.stack(log_px_list) #[S,P,B]
        log_pz = tf.stack(log_pz_list) #[S,P,B]
        log_qz = tf.stack(log_qz_list) #[S,P,B]
        log_pW = tf.stack(log_pW_list) #[S]
        log_qW = tf.stack(log_qW_list) #[S]

        # Calculte log probs for printing
        self.log_px = tf.reduce_mean(log_px)
        self.log_pz = tf.reduce_mean(log_pz)
        self.log_qz = tf.reduce_mean(log_qz)
        self.log_pW = tf.reduce_mean(log_pW)
        self.log_qW = tf.reduce_mean(log_qW)
        self.z_elbo = self.log_px + self.log_pz - self.log_qz 

        # Log mean exp over S and P, mean over B
        temp_elbo = tf.reduce_mean(log_px + log_pz - log_qz, axis=2)   #[S,P]
        log_pW = tf.reshape(log_pW, [self.n_W_particles, 1]) #[S,1]
        log_qW = tf.reshape(log_qW, [self.n_W_particles, 1]) #[S,1]
        temp_elbo = temp_elbo + (self.batch_frac*(log_pW - log_qW)) #broadcast, [S,P]
        temp_elbo = tf.reshape(temp_elbo, [self.n_W_particles*self.n_z_particles]) #[SP]
        max_ = tf.reduce_max(temp_elbo, axis=0) #[1]
        iwae_elbo = tf.log(tf.reduce_mean(tf.exp(temp_elbo-max_))) + max_  #[1]

        return iwae_elbo
Exemple #20
0
    def log_probs(self, x, q_zlx, q_zlxz, r_zlxz, p_xlz):
        '''
        x: [B,X]
        '''

        # log_px_list = []
        # log_pz_list = []
        log_qz_list = []
        log_rz_list = []


        # Sample z0 and calc log_qz0
        z0_mean, z0_logvar = self.split_mean_logvar(q_zlx.feedforward(x)) #[B,Z]
        z0 = self.sample_Gaussian(z0_mean, z0_logvar) #[B,Z]
        log_qz0 = log_norm(z0, z0_mean, z0_logvar) #[B]

        for t in range(self.n_transitions):

            #reverse model
            zt_mean, zt_logvar = self.split_mean_logvar(r_zlxz.feedforward(x)) #[B,Z]
            log_rzt = log_norm(z_minus1, zt_mean, zt_logvar) #[B]
            log_rz_list.append(log_rzt)

            #new sample
            xz = tf.concat([x,z], axis=1) #[B,X+Z]
            z_mean, z_logvar = self.split_mean_logvar(q_zlxz.feedforward(xz)) #[B,Z]
            z = self.sample_Gaussian(z_mean, z_logvar) #[B,Z]
            log_qz = log_norm(z, z_mean, z_logvar) #[B]
            log_qz_list.append(log_qz)


        log_rzs = tf.stack(log_rz_list) #[T,B]
        log_qzs = tf.stack(log_qz_list) #[T,B]

        log_pz = log_norm(z, tf.zeros([self.batch_size, self.z_size]), 
                                tf.log(tf.ones([self.batch_size, self.z_size]))) #[B]

        # Sample z   [P,B,Z], [P,B], [P,B]
        # z, log_pz, log_qz = self.sample_z(x, encoder, decoder, W)
        # z: [PB,Z]
        # z = tf.reshape(z, [self.n_z_particles*self.batch_size, self.z_size])

        # Decode [B,X]
        x_mean = p_xlz.feedforward(z)
        log_px = log_bern(x,x_mean)

        # y: [P,B,X]
        # y = tf.reshape(y, [self.n_z_particles, self.batch_size, self.x_size])

        # Likelihood p(x|z)  [B]
        

        #Store for later
        # log_px_list.append(log_px)
        # log_pz_list.append(log_pz)
        # log_qz_list.append(log_qz)

        # log_px = tf.stack(log_px_list) #[P,B]
        # log_pz = tf.stack(log_pz_list) #[P,B]
        # log_qz = tf.stack(log_qz_list) #[P,B]

        #[B]
        elbo = log_px + log_pz + tf.reduce_sum(log_rzs,axis=0) - log_qz0 - tf.reduce_sum(log_qzs, axis=0)

        return tf.reduce_mean(elbo) #over batch