Ejemplo n.º 1
0
 def intermediate_dist(t, z, mean, logvar, zeros, batch):
     logp1 = lognormal(z, mean, logvar)  #[P,B]
     log_prior = lognormal(z, zeros, zeros)  #[P,B]
     log_likelihood = log_bernoulli(model.decode(z), batch)
     logpT = log_prior + log_likelihood
     log_intermediate_2 = (1 - float(t)) * logp1 + float(t) * logpT
     return log_intermediate_2
Ejemplo n.º 2
0
 def intermediate_dist(t, z, mean, logvar, zeros, batch):
     logp1 = lognormal(z, mean, logvar)  #[P,B]
     log_prior = lognormal(z, zeros, zeros)  #[P,B]
     log_likelihood = log_bernoulli(model.decode(z), batch)
     logpT = log_prior + log_likelihood
     log_intermediate_2 = (1-float(t))*logp1 + float(t)*logpT
     return log_intermediate_2
Ejemplo n.º 3
0
 def sample(self, mu, logvar, k):
     eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_()) #[P,B,Z]
     z = eps.mul(torch.exp(.5*logvar)) + mu  #[P,B,Z]
     logpz = lognormal(z, Variable(torch.zeros(self.B, self.z_size)), 
                         Variable(torch.zeros(self.B, self.z_size)))  #[P,B]
     logqz = lognormal(z, mu, logvar)
     return z, logpz, logqz
Ejemplo n.º 4
0
 def sample(self, mu, logvar, k):
     eps = Variable(torch.FloatTensor(k, self.B,
                                      self.z_size).normal_())  #[P,B,Z]
     z = eps.mul(torch.exp(.5 * logvar)) + mu  #[P,B,Z]
     logpz = lognormal(z, Variable(torch.zeros(self.B, self.z_size)),
                       Variable(torch.zeros(self.B, self.z_size)))  #[P,B]
     logqz = lognormal(z, mu, logvar)
     return z, logpz, logqz
    def sample_z(self, mu, logvar, k):
        B = mu.size()[0]
        eps = Variable(torch.FloatTensor(k, B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        z = eps.mul(torch.exp(.5*logvar)) + mu  #[P,B,Z]
        logpz = lognormal(z, Variable(torch.zeros(B, self.z_size).type(self.dtype)), 
                            Variable(torch.zeros(B, self.z_size)).type(self.dtype))  #[P,B]

        logqz = lognormal(z, mu, logvar)
        return z, logpz, logqz
    def sample_z(self, mu, logvar, k):
        B = mu.size()[0]
        eps = Variable(torch.FloatTensor(k, B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        z = eps.mul(torch.exp(.5*logvar)) + mu  #[P,B,Z]
        logpz = lognormal(z, Variable(torch.zeros(B, self.z_size).type(self.dtype)), 
                            Variable(torch.zeros(B, self.z_size)).type(self.dtype))  #[P,B]

        logqz = lognormal(z, mu, logvar)
        return z, logpz, logqz
Ejemplo n.º 7
0
    def sample(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = Variable(torch.FloatTensor(std.size()).normal_())
        z = eps.mul(std).add_(mu)
        logpz = lognormal(z, Variable(torch.zeros(z.size())),
                          Variable(torch.zeros(z.size())))
        # logpz = self.lognormal(z, torch.zeros(z.size()), torch.zeros(z.size()))
        logqz = lognormal(z, mu, logvar)

        return z, logpz, logqz
Ejemplo n.º 8
0
    def sample_weights(self):

        Ws = []

        log_p_W_sum = 0
        log_q_W_sum = 0

        for layer_i in range(len(self.net) - 1):

            input_size_i = self.net[layer_i] + 1  #plus 1 for bias
            output_size_i = self.net[layer_i +
                                     1]  #plus 1 because we want layer i+1

            #Get vars [I,O]
            W_means = self.W_means[layer_i]
            W_logvars = self.W_logvars[layer_i]

            #Sample weights [IS,OS]*[IS,OS]=[IS,OS]
            eps = Variable(
                torch.randn(input_size_i, output_size_i).type(self.dtype))
            # print eps
            # print torch.sqrt(torch.exp(W_logvars))
            # W = torch.add(W_means, torch.sqrt(torch.exp(W_logvars)) * eps)
            W = (torch.sqrt(torch.exp(W_logvars)) * eps) + W_means

            # W = W_means

            #Compute probs of samples  [1]
            flat_w = W.view(input_size_i * output_size_i)  #[IS*OS]
            flat_W_means = W_means.view(input_size_i * output_size_i)  #[IS*OS]
            flat_W_logvars = W_logvars.view(input_size_i *
                                            output_size_i)  #[IS*OS]
            log_p_W_sum += lognormal(
                flat_w,
                Variable(
                    torch.zeros([input_size_i * output_size_i
                                 ]).type(self.dtype)),
                Variable(
                    torch.log(
                        torch.ones([input_size_i * output_size_i
                                    ]).type(self.dtype))))
            # log_p_W_sum += log_normal3(flat_w, tf.zeros([input_size_i*output_size_i]), tf.log(tf.ones([input_size_i*output_size_i])*100.))

            log_q_W_sum += lognormal(flat_w, flat_W_means, flat_W_logvars)

            Ws.append(W)

        return Ws, log_p_W_sum, log_q_W_sum
Ejemplo n.º 9
0
    def forward(self, k, x, logposterior):
        '''
        k: number of samples
        x: [B,X]
        logposterior(z) -> [P,B]
        '''

        self.B = x.size()[0]

        #Encode
        out = x
        for i in range(len(self.encoder_weights) - 1):
            out = self.act_func(self.encoder_weights[i](out))
            # out = self.act_func(self.layer_norms[i].forward(self.encoder_weights[i](out)))

        out = self.encoder_weights[-1](out)
        mean = out[:, :self.z_size]
        logvar = out[:, self.z_size:]

        #Sample
        eps = Variable(
            torch.FloatTensor(k, self.B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]
        z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar)  #[P,B]

        return z, logqz
Ejemplo n.º 10
0
 def return_current_state(self, x, a, k):
     
     self.B = x.size()[0]
     self.T = x.size()[1]
     self.k = k
     a = a.float()
     x = x.float()
     states = []
     prev_z = Variable(torch.zeros(k, self.B, self.z_size))
     # prev_z = torch.zeros(k, self.B, self.z_size)
     for t in range(self.T):
         current_x = x[:,t] #[B,X]
         current_a = a[:,t] #[B,A]
         #Encode
         mu, logvar = self.encode(current_x, current_a, prev_z) #[P,B,Z]
         #Sample
         z, logqz = self.sample(mu, logvar) #[P,B,Z], [P,B]
         #Decode
         x_hat = self.decode(z)  #[P,B,X]
         logpx = log_bernoulli(x_hat, current_x)  #[P,B]
         #Transition/Prior prob
         prior_mean, prior_log_var = self.transition_prior(prev_z, current_a) #[P,B,Z]
         logpz = lognormal(z, prior_mean, prior_log_var) #[P,B]
         prev_z = z
         states.append(z)
     return states
Ejemplo n.º 11
0
    def forward(self, k, x, logposterior):
        '''
        k: number of samples
        x: [B,X]
        logposterior(z) -> [P,B]
        '''

        self.B = x.size()[0]
        self.P = k

        #Encode
        out = self.act_func(self.fc1(x))
        out = self.act_func(self.fc2(out))
        out = self.fc3(out)
        mean = out[:, :self.z_size]
        logvar = out[:, self.z_size:]

        #Sample
        eps = Variable(
            torch.FloatTensor(k, self.B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]
        z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar)  #[P,B]

        logdetsum = 0.
        for i in range(self.n_flows):

            z, logdet = self.norm_flow(self.params[i], z)
            logdetsum += logdet

        return z, logqz - logdetsum
Ejemplo n.º 12
0
    def forward(self, x, k, warmup=1.):

        self.B = x.size()[0]  #batch size
        self.zeros = Variable(
            torch.zeros(self.B, self.z_size).type(self.dtype))

        self.logposterior = lambda aa: lognormal(
            aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x)

        z, logqz = self.q_dist.forward(k, x, self.logposterior)

        logpxz = self.logposterior(z)

        #Compute elbo
        elbo = logpxz - (warmup * logqz)  #[P,B]
        if k > 1:
            max_ = torch.max(elbo, 0)[0]  #[B]
            elbo = torch.log(torch.mean(torch.exp(elbo - max_),
                                        0)) + max_  #[B]

        elbo = torch.mean(elbo)  #[1]
        logpxz = torch.mean(logpxz)  #[1]
        logqz = torch.mean(logqz)

        return elbo, logpxz, logqz
Ejemplo n.º 13
0
    def forward(self, k, x, logposterior):
        '''
        k: number of samples
        x: [B,X]
        logposterior(z) -> [P,B]
        '''

        self.B = x.size()[0]
        self.P = k

        #Encode
        out = x
        for i in range(len(self.encoder_weights)-1):
            out = self.act_func(self.encoder_weights[i](out))
        out = self.encoder_weights[-1](out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        #Sample
        eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar) #[P,B]

        logdetsum = 0.
        for i in range(self.n_flows):

            z, logdet = self.norm_flow(self.params[i],z)
            logdetsum += logdet


        return z, logqz-logdetsum
Ejemplo n.º 14
0
    def sample(self, mu, logvar, k):
        B = mu.size()[0]


        eps = Variable(torch.FloatTensor(k, B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        z = eps.mul(torch.exp(.5*logvar)) + mu  #[P,B,Z]
        logqz = lognormal(z, mu, logvar) #[P,B]

        #[P,B,Z], [P,B]
        if self.flow_bool:
            z, logdet = self.q_dist.forward(z)
            logqz = logqz - logdet

        logpz = lognormal(z, Variable(torch.zeros(B, self.z_size).type(self.dtype)), 
                            Variable(torch.zeros(B, self.z_size)).type(self.dtype))  #[P,B]

        return z, logpz, logqz
Ejemplo n.º 15
0
    def logposterior_func(self, x, z):
        self.B = x.size()[0] #batch size
        self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype))

        # print (x)  #[B,X]
        # print(z)    #[P,Z]
        z = Variable(z).type(self.dtype)
        z = z.view(-1,self.B,self.z_size)
        return lognormal(z, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(z), x)
Ejemplo n.º 16
0
    def logprob(self, z, mean, logvar):

        # self.B = mean.size()[0]

        # eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        # z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar)  #[P,B]

        return logqz
Ejemplo n.º 17
0
    def sample(self, mean, logvar, k):

        self.B = mean.size()[0]

        eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar) #[P,B]

        return z, logqz
Ejemplo n.º 18
0
    def logposterior_func(self, x, z):
        self.B = x.size()[0] #batch size
        self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype))

        # print (x)  #[B,X]
        # print(z)    #[P,Z]
        z = Variable(z).type(self.dtype)
        z = z.view(-1,self.B,self.z_size)
        return lognormal(z, self.zeros, self.zeros) + log_bernoulli(self.decode(z), x)
Ejemplo n.º 19
0
    def logprob(self, z, mean, logvar):

        # self.B = mean.size()[0]

        # eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        # z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar) #[P,B]

        return logqz
Ejemplo n.º 20
0
    def mh_step(z0, v0, z, v, step_size, intermediate_dist_func):

        logpv0 = lognormal(v0, zeros, zeros)  #[P,B]
        hamil_0 = intermediate_dist_func(z0) + logpv0

        logpvT = lognormal(v, zeros, zeros)  #[P,B]
        hamil_T = intermediate_dist_func(z) + logpvT

        accept_prob = torch.exp(hamil_T - hamil_0)

        if torch.cuda.is_available():
            rand_uni = Variable(torch.FloatTensor(
                accept_prob.size()).uniform_(),
                                volatile=volatile_,
                                requires_grad=requires_grad).cuda()
        else:
            rand_uni = Variable(
                torch.FloatTensor(accept_prob.size()).uniform_())

        accept = accept_prob > rand_uni

        if torch.cuda.is_available():
            accept = accept.type(torch.FloatTensor).cuda()
        else:
            accept = accept.type(torch.FloatTensor)

        accept = accept.view(k, model.B, 1)

        z = (accept * z) + ((1 - accept) * z0)

        #Adapt step size
        avg_acceptance_rate = torch.mean(accept)

        if avg_acceptance_rate.cpu().data.numpy() > .7:
            step_size = 1.02 * step_size
        else:
            step_size = .98 * step_size

        if step_size < 0.0001:
            step_size = 0.0001
        if step_size > 0.5:
            step_size = 0.5

        return z, step_size
Ejemplo n.º 21
0
    def sample_q(self, x, k):

        self.B = x.size()[0] #batch size
        self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype))

        self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x)

        z, logqz = self.q_dist.forward(k=k, x=x, logposterior=self.logposterior)

        return z
Ejemplo n.º 22
0
    def sample(self, mu, logvar, k):
        
        # if torch.cuda.is_available():
        #     eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_()).cuda() #[P,B,Z]
        # else:
        eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]

        z = eps.mul(torch.exp(.5*logvar)) + mu  #[P,B,Z]

        # if torch.cuda.is_available():
        #     logpz = lognormal(z, Variable(torch.zeros(self.B, self.z_size).cuda()), 
        #                     Variable(torch.zeros(self.B, self.z_size)).cuda())  #[P,B]
        # else:
        logpz = lognormal(z, Variable(torch.zeros(self.B, self.z_size).type(self.dtype)), 
                            Variable(torch.zeros(self.B, self.z_size)).type(self.dtype))  #[P,B]


        logqz = lognormal(z, mu, logvar)
        return z, logpz, logqz
Ejemplo n.º 23
0
    def sample(self, mu, logvar, k):
        B = mu.size()[0]

        eps = Variable(
            torch.FloatTensor(k, B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]
        z = eps.mul(torch.exp(.5 * logvar)) + mu  #[P,B,Z]
        logqz = lognormal(z, mu, logvar)  #[P,B]

        #[P,B,Z], [P,B]
        if self.flow_bool:
            z, logdet = self.q_dist.forward(z)
            logqz = logqz - logdet

        logpz = lognormal(
            z, Variable(torch.zeros(B, self.z_size).type(self.dtype)),
            Variable(torch.zeros(B, self.z_size)).type(self.dtype))  #[P,B]

        return z, logpz, logqz
Ejemplo n.º 24
0
    def sample_q(self, x, k):

        self.B = x.size()[0] #batch size
        self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype))

        self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.generator.decode(aa), x)

        z, logqz = self.q_dist.forward(k=k, x=x, logposterior=self.logposterior)

        return z
Ejemplo n.º 25
0
    def sample(self, mean, logvar, k):

        self.B = mean.size()[0]

        eps = Variable(
            torch.FloatTensor(k, self.B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]
        z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar)  #[P,B]

        return z, logqz
Ejemplo n.º 26
0
    def mh_step(z0, v0, z, v, step_size, intermediate_dist_func):

        logpv0 = lognormal(v0, zeros, zeros) #[P,B]
        hamil_0 =  intermediate_dist_func(z0) + logpv0
        
        logpvT = lognormal(v, zeros, zeros) #[P,B]
        hamil_T = intermediate_dist_func(z) + logpvT

        accept_prob = torch.exp(hamil_T - hamil_0)

        if torch.cuda.is_available():
            rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_(), volatile=volatile_, requires_grad=requires_grad).cuda()
        else:
            rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_())

        accept = accept_prob > rand_uni

        if torch.cuda.is_available():
            accept = accept.type(torch.FloatTensor).cuda()
        else:
            accept = accept.type(torch.FloatTensor)
        
        accept = accept.view(k, model.B, 1)

        z = (accept * z) + ((1-accept) * z0)

        #Adapt step size
        avg_acceptance_rate = torch.mean(accept)

        if avg_acceptance_rate.cpu().data.numpy() > .65:
            step_size = 1.02 * step_size
        else:
            step_size = .98 * step_size

        if step_size < 0.0001:
            step_size = 0.0001
        if step_size > 0.5:
            step_size = 0.5

        return z, step_size
Ejemplo n.º 27
0
    def sample(self, mu, logvar, k):

        # if torch.cuda.is_available():
        #     eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_()).cuda() #[P,B,Z]
        # else:
        eps = Variable(
            torch.FloatTensor(k, self.B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]

        z = eps.mul(torch.exp(.5 * logvar)) + mu  #[P,B,Z]

        # if torch.cuda.is_available():
        #     logpz = lognormal(z, Variable(torch.zeros(self.B, self.z_size).cuda()),
        #                     Variable(torch.zeros(self.B, self.z_size)).cuda())  #[P,B]
        # else:
        logpz = lognormal(
            z, Variable(torch.zeros(self.B, self.z_size).type(self.dtype)),
            Variable(torch.zeros(self.B,
                                 self.z_size)).type(self.dtype))  #[P,B]

        logqz = lognormal(z, mu, logvar)
        return z, logpz, logqz
Ejemplo n.º 28
0
    def sample_weights(self):

        Ws = []

        log_p_W_sum = 0
        log_q_W_sum = 0

        for layer_i in range(len(self.net)-1):

            input_size_i = self.net[layer_i]+1 #plus 1 for bias
            output_size_i = self.net[layer_i+1] #plus 1 because we want layer i+1

            #Get vars [I,O]
            W_means = self.W_means[layer_i]
            W_logvars = self.W_logvars[layer_i]

            #Sample weights [IS,OS]*[IS,OS]=[IS,OS]
            eps = Variable(torch.randn(input_size_i, output_size_i).type(self.dtype))
            # print eps
            # print torch.sqrt(torch.exp(W_logvars))
            # W = torch.add(W_means, torch.sqrt(torch.exp(W_logvars)) * eps)
            W =  (torch.sqrt(torch.exp(W_logvars)) * eps) + W_means 


            # W = W_means

            #Compute probs of samples  [1]
            flat_w = W.view(input_size_i*output_size_i) #[IS*OS]
            flat_W_means = W_means.view(input_size_i*output_size_i) #[IS*OS]
            flat_W_logvars = W_logvars.view(input_size_i*output_size_i) #[IS*OS]
            log_p_W_sum += lognormal(flat_w, Variable(torch.zeros([input_size_i*output_size_i]).type(self.dtype)), Variable(torch.log(torch.ones([input_size_i*output_size_i]).type(self.dtype))))
            # log_p_W_sum += log_normal3(flat_w, tf.zeros([input_size_i*output_size_i]), tf.log(tf.ones([input_size_i*output_size_i])*100.))

            log_q_W_sum += lognormal(flat_w, flat_W_means, flat_W_logvars)

            Ws.append(W)

        return Ws, log_p_W_sum, log_q_W_sum
Ejemplo n.º 29
0
    def forward(self, x, k=1, warmup=1.):

        self.B = x.size()[0]  #batch size
        self.zeros = Variable(
            torch.zeros(self.B, self.z_size).type(self.dtype))  #[B,Z]

        self.logposterior = lambda aa: lognormal(
            aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x)

        z, logqz = self.q_dist.forward(k, x, self.logposterior)

        # [PB,Z]
        # z = z.view(-1,self.z_size)

        logpxz = self.logposterior(z)

        #Compute elbo
        elbo = logpxz - logqz  #[P,B]
        if k > 1:
            max_ = torch.max(elbo, 0)[0]  #[B]
            elbo = torch.log(torch.mean(torch.exp(elbo - max_),
                                        0)) + max_  #[B]

        elbo = torch.mean(elbo)  #[1]
        logpxz = torch.mean(logpxz)  #[1]
        logqz = torch.mean(logqz)

        # mu, logvar = self.encode(x)
        # z, logpz, logqz = self.sample(mu, logvar, k=k)  #[P,B,Z]
        # x_hat = self.decode(z)  #[PB,X]
        # x_hat = x_hat.view(k, self.B, -1)
        # logpx = log_bernoulli(x_hat, x)  #[P,B]

        # elbo = logpx +  warmup*(logpz - logqz)  #[P,B]

        # if k>1:
        #     max_ = torch.max(elbo, 0)[0] #[B]
        #     elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]

        # elbo = torch.mean(elbo) #[1]

        # #for printing
        # logpx = torch.mean(logpx)
        # logpz = torch.mean(logpz)
        # logqz = torch.mean(logqz)
        # self.x_hat_sigmoid = F.sigmoid(x_hat)

        # return elbo, logpx, logpz, logqz

        return elbo, logpxz, logqz
Ejemplo n.º 30
0
    def sample(k):

        P = k

        #Sample
        eps = Variable(torch.FloatTensor(k, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z]
        z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar) #[P,B]

        logdetsum = 0.
        for i in range(n_flows):

            z, logdet = norm_flow(params[i],z)
            logdetsum += logdet

        logq = logqz - logdetsum
        return z, logq
Ejemplo n.º 31
0
    def sample(k):

        P = k

        #Sample
        eps = Variable(torch.FloatTensor(k, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z]
        z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar) #[P,B]

        logdetsum = 0.
        for i in range(n_flows):

            z, logdet = norm_flow(params[i],z)
            logdetsum += logdet

        logq = logqz - logdetsum
        return z, logq
Ejemplo n.º 32
0
    def forward2(self, x, k):

        self.B = x.size()[0] #batch size
        self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype))

        self.logposterior = lambda aa: lognormal(aa, self.zeros, self.zeros) + log_bernoulli(self.decode(aa), x)

        z, logqz = self.q_dist.forward(k, x, self.logposterior)

        logpxz = self.logposterior(z)

        #Compute elbo
        elbo = logpxz - logqz #[P,B]
        # if k>1:
        #     max_ = torch.max(elbo, 0)[0] #[B]
        #     elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]
            
        elbo = torch.mean(elbo) #[1]
        logpxz = torch.mean(logpxz) #[1]
        logqz = torch.mean(logqz)

        return elbo, logpxz, logqz
Ejemplo n.º 33
0
    def forward(self, k, x, logposterior):
        '''
        k: number of samples
        x: [B,X]
        logposterior(z) -> [P,B]
        '''

        self.B = x.size()[0]

        #Encode
        out = self.act_func(self.fc1(x))
        out = self.act_func(self.fc2(out))
        out = self.fc3(out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        #Sample
        eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar) #[P,B]

        return z, logqz
Ejemplo n.º 34
0
    def forward(self, k, x, logposterior):
        '''
        k: number of samples
        x: [B,X]
        logposterior(z) -> [P,B]
        '''

        self.B = x.size()[0]

        # #Encode
        # out = x
        # for i in range(len(self.encoder_weights)-1):
        #     out = self.act_func(self.encoder_weights[i](out))
        # out = self.encoder_weights[-1](out)
        # mean = out[:,:self.z_size]
        # logvar = out[:,self.z_size:]

        x = x.view(-1, 3, 32, 32)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        # print (x)
        # x = x.view(-1, 1960)
        x = x.view(-1, 250)

        h1 = F.relu(self.fc1(x))
        h2 = self.fc2(h1)
        mean = h2[:, :self.z_size]
        logvar = h2[:, self.z_size:]

        #Sample
        eps = Variable(
            torch.FloatTensor(k, self.B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]
        z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar)  #[P,B]

        return z, logqz
Ejemplo n.º 35
0
    def forward(self, k, x, logposterior):
        '''
        k: number of samples
        x: [B,X]
        logposterior(z) -> [P,B]
        '''

        self.B = x.size()[0]

        #Encode
        out = self.act_func(self.fc1(x))
        out = self.act_func(self.fc2(out))
        out = self.fc3(out)
        mean = out[:, :self.z_size]
        logvar = out[:, self.z_size:]

        #Sample
        eps = Variable(
            torch.FloatTensor(k, self.B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]
        z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar)  #[P,B]

        return z, logqz
Ejemplo n.º 36
0
    def forward(self, x, a, k=1, current_state=None):
        '''
        x: [B,T,X]
        a: [B,T,A]
        output: elbo scalar
        '''
        
        self.B = x.size()[0]
        self.T = x.size()[1]
        self.k = k

        a = a.float()
        x = x.float()

        # log_probs = [[] for i in range(k)]
        # log_probs = []
        logpxs = []
        logpzs = []
        logqzs = []


        weights = Variable(torch.ones(k, self.B)/k)
        # if current_state==None:
        prev_z = Variable(torch.zeros(k, self.B, self.z_size))
        # else:
        #     prev_z = current_state
        for t in range(self.T):
            current_x = x[:,t] #[B,X]
            current_a = a[:,t] #[B,A]

            #Encode
            mu, logvar = self.encode(current_x, current_a, prev_z) #[P,B,Z]
            #Sample
            z, logqz = self.sample(mu, logvar) #[P,B,Z], [P,B]
            #Decode
            x_hat = self.decode(z)  #[P,B,X]
            logpx = log_bernoulli(x_hat, current_x)  #[P,B]
            #Transition/Prior prob
            prior_mean, prior_log_var = self.transition_prior(prev_z, current_a) #[P,B,Z]
            logpz = lognormal(z, prior_mean, prior_log_var) #[P,B]






            log_alpha_t = logpx + logpz - logqz #[P,B]
            log_weights_tmp = torch.log(weights * torch.exp(log_alpha_t))

            max_ = torch.max(log_weights_tmp, 0)[0] #[B]
            log_p_hat = torch.log(torch.sum(torch.exp(log_weights_tmp - max_), 0)) + max_ #[B]

            # p_hat = torch.sum(alpha_t,0)  #[B]
            normalized_alpha_t = log_weights_tmp - log_p_hat  #[P,B]

            weights = torch.exp(normalized_alpha_t) #[P,B]

            #if resample
            if t%2==0:
                # print weights
                #[B,P] indices of the particles for each bactch
                sampled_indices = torch.multinomial(torch.t(weights), k, replacement=True).detach()
                new_z = []
                for b in range(self.B):
                    tmp = z[:,b] #[P,Z]
                    z_b = tmp[sampled_indices[b]] #[P,Z]
                    new_z.append(z_b)
                new_z = torch.stack(new_z, 1) #[P,B,Z]
                weights = Variable(torch.ones(k, self.B)/k)
                z = new_z

            logpxs.append(logpx)
            logpzs.append(logpz)
            logqzs.append(logqz)
            # log_probs.append(logpx + logpz - logqz)
            prev_z = z



        logpxs = torch.stack(logpxs) 
        logpzs = torch.stack(logpzs)
        logqzs = torch.stack(logqzs) #[T,P,B]

        logws = logpxs + logpzs - logqzs  #[T,P,B]
        logws = torch.mean(logws, 0)  #[P,B]

        # elbo = logpx + logpz - logqz  #[P,B]

        if k>1:
            max_ = torch.max(logws, 0)[0] #[B]
            elbo = torch.log(torch.mean(torch.exp(logws - max_), 0)) + max_ #[B]
            elbo = torch.mean(elbo) #over batch
        else:
            elbo = torch.mean(logws)

        # print log_probs[0]


        # #for printing
        logpx = torch.mean(logpxs)
        logpz = torch.mean(logpzs)
        logqz = torch.mean(logqzs)
        # self.x_hat_sigmoid = F.sigmoid(x_hat)

        # elbo = torch.mean(torch.stack(log_probs)) #[1]
        # elbo = logpx + logpz - logqz

        return elbo, logpx, logpz, logqz
Ejemplo n.º 37
0
def optimize_local_gaussian_mean_logvar2(logposterior, model, x):

    # B = x.shape[0]
    B = x.size()[0]  #batch size
    # input to log posterior is z, [P,B,Z]
    # I think B will be 1 for now

    mean = Variable(torch.zeros(B, model.z_size).type(model.dtype),
                    requires_grad=True)
    logvar = Variable(torch.zeros(B, model.z_size).type(model.dtype),
                      requires_grad=True)

    optimizer = optim.Adam([mean, logvar], lr=.001)
    # time_ = time.time()
    # n_data = len(train_x)
    # arr = np.array(range(n_data))

    P = 50

    last_100 = []
    best_last_100_avg = -1
    consecutive_worse = 0
    for epoch in range(1, 99999):  # 999999):

        if quick:
            # if 1:

            break

        #Sample
        eps = Variable(
            torch.FloatTensor(P, B, model.z_size).normal_().type(
                model.dtype))  #[P,B,Z]
        z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar)  #[P,B]
        logpx = logposterior(z)

        loss = -(torch.mean(1.5 * logpx - logqz))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_np = loss.data.cpu().numpy()
        last_100.append(loss_np)
        if epoch % 100 == 0:

            last_100_avg = np.mean(last_100)
            if last_100_avg < best_last_100_avg or best_last_100_avg == -1:
                consecutive_worse = 0
                best_last_100_avg = last_100_avg
            else:
                consecutive_worse += 1
                # print(consecutive_worse)
                if consecutive_worse > 10:
                    # print ('done')
                    break

            print(epoch, last_100_avg, consecutive_worse, mean)
            # print (torch.mean(logpx))

            last_100 = []

        if epoch % 1000 == 0:
            # print (logpx)
            # print (logqz)
            print(torch.mean(logpx))
            print(torch.mean(logqz))
            print(torch.std(logpx))
            print(torch.std(logqz))

    #Round 2

    last_100 = []
    best_last_100_avg = -1
    consecutive_worse = 0
    for epoch in range(1, 99999):  # 999999):

        if quick:
            # if 1:
            break

        #Sample
        eps = Variable(
            torch.FloatTensor(P, B, model.z_size).normal_().type(
                model.dtype))  #[P,B,Z]
        z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar)  #[P,B]
        logpx = logposterior(z)

        loss = -(torch.mean(logpx - logqz))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_np = loss.data.cpu().numpy()
        last_100.append(loss_np)
        if epoch % 100 == 0:

            last_100_avg = np.mean(last_100)
            if last_100_avg < best_last_100_avg or best_last_100_avg == -1:
                consecutive_worse = 0
                best_last_100_avg = last_100_avg
            else:
                consecutive_worse += 1
                # print(consecutive_worse)
                if consecutive_worse > 10:
                    # print ('done')
                    break

            print(epoch, last_100_avg, consecutive_worse, mean, '2')
            # print (torch.mean(logpx))

            last_100 = []

        if epoch % 1000 == 0:
            # print (logpx)
            # print (logqz)
            print(torch.mean(logpx))
            print(torch.mean(logqz))
            print(torch.std(logpx))
            print(torch.std(logqz))

    return mean, logvar, z
Ejemplo n.º 38
0
    def forward(self, k, x, logposterior):
        '''
        k: number of samples
        x: [B,X]
        logposterior(z) -> [P,B]
        '''

        self.B = x.size()[0]
        self.P = k

        # print (self.B, 'B')
        # print (k)
        # fsdaf



        #q(v|x)
        out = x
        for i in range(len(self.qv_weights)-1):
            out = self.act_func(self.qv_weights[i](out))
        out = self.qv_weights[-1](out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        #Sample v0
        eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        v = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        logqv0 = lognormal(v, mean, logvar) #[P,B]

        #[PB,Z]
        v = v.view(-1,self.z_size)
        #[PB,X]
        x_tiled = x.repeat(k,1)
        #[PB,X+Z]
        # print (x_tiled.size())
        # print (v.size())

        xv = torch.cat((x_tiled, v),1)

        #q(z|x,v)
        out = xv
        for i in range(len(self.qz_weights)-1):
            out = self.act_func(self.qz_weights[i](out))
        out = self.qz_weights[-1](out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        self.B = x.size()[0]
        # print (self.B, 'B')
        #Sample z0
        eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        # print (eps.size(),'eps')
        # print (mean.size(),'mean')
        # print (self.P, 'P')

        # print (mean)
        mean = mean.contiguous().view(self.P,self.B,self.z_size)
        logvar = logvar.contiguous().view(self.P,self.B,self.z_size)

        # print (mean)
        # mean = mean.contiguous().view(self.P,1,self.z_size)
        # logvar = logvar.contiguous().view(self.P,1,self.z_size)


        # print (mean.size(),'mean')

        z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        # print (z.size(),'z')

        # mean = mean.contiguous().view(self.P*self.B,self.z_size)
        # logvar = logvar.contiguous().view(self.P*self.B,self.z_size)

        logqz0 = lognormal333(z, mean, logvar) #[P,B]

        #[PB,Z]
        z = z.view(-1,self.z_size)

        # print (z.size())

        logdetsum = 0.
        for i in range(self.n_flows):

            z, v, logdet = self.norm_flow(self.params[i],z,v)
            logdetsum += logdet


        xz = torch.cat((x_tiled,z),1)

        #r(vT|x,zT)
        out = xz
        for i in range(len(self.rv_weights)-1):
            out = self.act_func(self.rv_weights[i](out))
        out = self.rv_weights[-1](out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        mean = mean.contiguous().view(self.P,self.B,self.z_size)
        logvar = logvar.contiguous().view(self.P,self.B,self.z_size)

        v = v.view(k,self.B,self.z_size)
        logrvT = lognormal333(v, mean, logvar) #[P,B]

        z = z.view(k,self.B,self.z_size)

        # print(logqz0.size(), 'here')
        # print(logqv0.size())
        # print(logdetsum.size())
        # print(logrvT.size())

        logdetsum = logdetsum.view(k,self.B)

        # print (logqz0+logqv0-logdetsum-logrvT)

        # fadfdsa

        return z, logqz0+logqv0-logdetsum-logrvT
Ejemplo n.º 39
0
def optimize_local_gaussian(logposterior, model, x):

    # print_ = 0

    # B = x.shape[0]
    B = x.size()[0] #batch size
    # input to log posterior is z, [P,B,Z]
    # I think B will be 1 for now



        # self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype))

    mean = Variable(torch.zeros(B, model.z_size).type(model.dtype), requires_grad=True)
    logvar = Variable(torch.zeros(B, model.z_size).type(model.dtype), requires_grad=True)

    optimizer = optim.Adam([mean, logvar], lr=.001)
    # time_ = time.time()
    # n_data = len(train_x)
    # arr = np.array(range(n_data))

    P = 50


    last_100 = []
    best_last_100_avg = -1
    consecutive_worse = 0
    for epoch in range(1, 999999):

        #Sample
        eps = Variable(torch.FloatTensor(P, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z]
        z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar) #[P,B]

        logpx = logposterior(z)

        # print (logpx)
        # print (logqz)

        # fsda

        # data_index= 0
        # for i in range(int(n_data/batch_size)):
            # batch = train_x[data_index:data_index+batch_size]
            # data_index += batch_size

            # batch = Variable(torch.from_numpy(batch)).type(self.dtype)
        optimizer.zero_grad()

        # elbo, logpxz, logqz = self.forward(batch, k=k)

        loss = -(torch.mean(logpx-logqz))

        loss_np = loss.data.cpu().numpy()
        # print (epoch, loss_np)
        # fasfaf

        loss.backward()
        optimizer.step()

        last_100.append(loss_np)
        if epoch % 100 ==0:

            last_100_avg = np.mean(last_100)
            if last_100_avg< best_last_100_avg or best_last_100_avg == -1:
                consecutive_worse=0
                best_last_100_avg = last_100_avg
            else:
                consecutive_worse +=1 
                # print(consecutive_worse)
                if consecutive_worse> 10:
                    # print ('done')
                    break

            if epoch % 2000 ==0:
                print (epoch, last_100_avg, consecutive_worse)#,mean)
            # print (torch.mean(logpx))

            last_100 = []

        # break


        # if epoch%display_epoch==0:
        #     print ('Train Epoch: {}/{}'.format(epoch, epochs),
        #         'LL:{:.3f}'.format(-loss.data[0]),
        #         'logpxz:{:.3f}'.format(logpxz.data[0]),
        #         # 'logpz:{:.3f}'.format(logpz.data[0]),
        #         'logqz:{:.3f}'.format(logqz.data[0]),
        #         'T:{:.2f}'.format(time.time()-time_),
        #         )

        #     time_ = time.time()


    # Compute VAE and IWAE bounds



    #Sample
    eps = Variable(torch.FloatTensor(1000, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z]
    z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
    logqz = lognormal(z, mean, logvar) #[P,B]

    # print (logqz)
    # fad
    logpx = logposterior(z)

    elbo = logpx-logqz #[P,B]
    vae = torch.mean(elbo)

    max_ = torch.max(elbo, 0)[0] #[B]
    elbo_ = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]
    iwae = torch.mean(elbo_)

    return vae, iwae
Ejemplo n.º 40
0
    def forward(self, x, a, k=1, current_state=None):
        '''
        x: [B,T,X]
        a: [B,T,A]
        output: elbo scalar
        '''
        
        self.B = x.size()[0]
        self.T = x.size()[1]
        self.k = k

        a = a.float()
        x = x.float()

        # log_probs = [[] for i in range(k)]
        # log_probs = []
        logpxs = []
        logpzs = []
        logqzs = []

        # if current_state==None:
        prev_z = Variable(torch.zeros(k, self.B, self.z_size))
        # else:
        #     prev_z = current_state
        for t in range(self.T):
            current_x = x[:,t] #[B,X]
            current_a = a[:,t] #[B,A]

            #Encode
            mu, logvar = self.encode(current_x, current_a, prev_z) #[P,B,Z]
            #Sample
            z, logqz = self.sample(mu, logvar) #[P,B,Z], [P,B]
            #Decode
            x_hat = self.decode(z)  #[P,B,X]
            logpx = log_bernoulli(x_hat, current_x)  #[P,B]
            #Transition/Prior prob
            prior_mean, prior_log_var = self.transition_prior(prev_z, current_a) #[P,B,Z]
            logpz = lognormal(z, prior_mean, prior_log_var) #[P,B]

            logpxs.append(logpx)
            logpzs.append(logpz)
            logqzs.append(logqz)
            # log_probs.append(logpx + logpz - logqz)
            prev_z = z



        logpxs = torch.stack(logpxs) 
        logpzs = torch.stack(logpzs)
        logqzs = torch.stack(logqzs) #[T,P,B]

        logws = logpxs + logpzs - logqzs  #[T,P,B]
        logws = torch.mean(logws, 0)  #[P,B]

        # elbo = logpx + logpz - logqz  #[P,B]

        if k>1:
            max_ = torch.max(logws, 0)[0] #[B]
            elbo = torch.log(torch.mean(torch.exp(logws - max_), 0)) + max_ #[B]
            elbo = torch.mean(elbo) #over batch
        else:
            elbo = torch.mean(logws)

        # print log_probs[0]


        # #for printing
        logpx = torch.mean(logpxs)
        logpz = torch.mean(logpzs)
        logqz = torch.mean(logqzs)
        # self.x_hat_sigmoid = F.sigmoid(x_hat)

        # elbo = torch.mean(torch.stack(log_probs)) #[1]
        # elbo = logpx + logpz - logqz

        return elbo, logpx, logpz, logqz
Ejemplo n.º 41
0
    def sample(k):

        P = k

        # #Sample
        # eps = Variable(torch.FloatTensor(P, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z]
        # z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        # logqz = lognormal(z, mean, logvar) #[P,B]

        # logpx = logposterior(z)
        # optimizer.zero_grad()


        #q(v|x)
        # out = x
        # for i in range(len(self.qv_weights)-1):
        #     out = self.act_func(self.qv_weights[i](out))
        # out = self.qv_weights[-1](out)
        # mean = out[:,:self.z_size]
        # logvar = out[:,self.z_size:]

        #Sample v0
        eps = Variable(torch.FloatTensor(k, B, z_size).normal_().type(model.dtype)) #[P,B,Z]
        # print (eps)
        v = eps.mul(torch.exp(.5*logvar_v)) + mean_v  #[P,B,Z]
        logqv0 = lognormal(v, mean_v, logvar_v) #[P,B]

        #[PB,Z]
        v = v.view(-1,model.z_size)
        # print (v)
        # fsaf

        # print(v)
        # fasd
        #[PB,X]
        # x_tiled = x.repeat(k,1)
        #[PB,X+Z]
        # xv = torch.cat((x_tiled, v),1)

        #q(z|x,v)
        out = v
        for i in range(len(qz_weights)-1):
            out = act_func(qz_weights[i](out))
        out = qz_weights[-1](out)
        mean = out[:,:z_size]
        logvar = out[:,z_size:] + 5.

        # print (mean)

        # B = x.size()[0]
        # print (self.B, 'B')
        #Sample z0
        eps = Variable(torch.FloatTensor(k, B, z_size).normal_().type(model.dtype)) #[P,B,Z]

        # print (mean)
        mean = mean.contiguous().view(P,B,model.z_size)
        logvar = logvar.contiguous().view(P,B,model.z_size)

        z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        # print (z.size(),'z')

        # mean = mean.contiguous().view(P*B,model.z_size)
        # logvar = logvar.contiguous().view(P*B,model.z_size)



        # print (z)
        # fad

        logqz0 = lognormal333(z, mean, logvar) #[P,B]

        #[PB,Z]
        z = z.view(-1,z_size)

        logdetsum = 0.
        for i in range(n_flows):

            z, v, logdet = norm_flow(params[i],z,v)
            logdetsum += logdet


        # xz = torch.cat((x_tiled,z),1)

        #r(vT|x,zT)
        out = z
        for i in range(len(rv_weights)-1):
            out = act_func(rv_weights[i](out))
        out = rv_weights[-1](out)
        mean = out[:,:model.z_size]
        logvar = out[:,model.z_size:]

        mean = mean.contiguous().view(P,B,model.z_size)
        logvar = logvar.contiguous().view(P,B,model.z_size)

        v = v.view(k,B,model.z_size)
        logrvT = lognormal333(v, mean, logvar) #[P,B]

        z = z.view(k,B,model.z_size)


        logq = logqz0+logqv0-logdetsum-logrvT

        # print (torch.mean(logqz0),torch.mean(logqv0),torch.mean(logdetsum),torch.mean(logrvT))

        return z, logq
Ejemplo n.º 42
0
    def sample(k):

        P = k

        # #Sample
        # eps = Variable(torch.FloatTensor(P, B, model.z_size).normal_().type(model.dtype)) #[P,B,Z]
        # z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        # logqz = lognormal(z, mean, logvar) #[P,B]

        # logpx = logposterior(z)
        # optimizer.zero_grad()

        #q(v|x)
        # out = x
        # for i in range(len(self.qv_weights)-1):
        #     out = self.act_func(self.qv_weights[i](out))
        # out = self.qv_weights[-1](out)
        # mean = out[:,:self.z_size]
        # logvar = out[:,self.z_size:]

        #Sample v0
        eps = Variable(
            torch.FloatTensor(k, B,
                              z_size).normal_().type(model.dtype))  #[P,B,Z]
        # print (eps)
        v = eps.mul(torch.exp(.5 * logvar_v)) + mean_v  #[P,B,Z]
        logqv0 = lognormal(v, mean_v, logvar_v)  #[P,B]

        #[PB,Z]
        v = v.view(-1, model.z_size)
        # print (v)
        # fsaf

        # print(v)
        # fasd
        #[PB,X]
        # x_tiled = x.repeat(k,1)
        #[PB,X+Z]
        # xv = torch.cat((x_tiled, v),1)

        #q(z|x,v)
        out = v
        for i in range(len(qz_weights) - 1):
            out = act_func(qz_weights[i](out))
        out = qz_weights[-1](out)
        mean = out[:, :z_size]
        logvar = out[:, z_size:] + 5.

        # print (mean)

        # B = x.size()[0]
        # print (self.B, 'B')
        #Sample z0
        eps = Variable(
            torch.FloatTensor(k, B,
                              z_size).normal_().type(model.dtype))  #[P,B,Z]

        # print (mean)
        mean = mean.contiguous().view(P, B, model.z_size)
        logvar = logvar.contiguous().view(P, B, model.z_size)

        z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        # print (z.size(),'z')

        # mean = mean.contiguous().view(P*B,model.z_size)
        # logvar = logvar.contiguous().view(P*B,model.z_size)

        # print (z)
        # fad

        logqz0 = lognormal333(z, mean, logvar)  #[P,B]

        #[PB,Z]
        z = z.view(-1, z_size)

        logdetsum = 0.
        for i in range(n_flows):

            z, v, logdet = norm_flow(params[i], z, v)
            logdetsum += logdet

        # xz = torch.cat((x_tiled,z),1)

        #r(vT|x,zT)
        out = z
        for i in range(len(rv_weights) - 1):
            out = act_func(rv_weights[i](out))
        out = rv_weights[-1](out)
        mean = out[:, :model.z_size]
        logvar = out[:, model.z_size:]

        mean = mean.contiguous().view(P, B, model.z_size)
        logvar = logvar.contiguous().view(P, B, model.z_size)

        v = v.view(k, B, model.z_size)
        logrvT = lognormal333(v, mean, logvar)  #[P,B]

        z = z.view(k, B, model.z_size)

        logq = logqz0 + logqv0 - logdetsum - logrvT

        # print (torch.mean(logqz0),torch.mean(logqv0),torch.mean(logdetsum),torch.mean(logrvT))

        return z, logq
Ejemplo n.º 43
0
def optimize_local_gaussian(logposterior, model, x):

    # B = x.shape[0]
    B = x.size()[0]  #batch size
    # input to log posterior is z, [P,B,Z]
    # I think B will be 1 for now

    # self.zeros = Variable(torch.zeros(self.B, self.z_size).type(self.dtype))

    mean = Variable(torch.zeros(B, model.z_size).type(model.dtype),
                    requires_grad=True)
    logvar = Variable(torch.zeros(B, model.z_size).type(model.dtype),
                      requires_grad=True)

    optimizer = optim.Adam([mean, logvar], lr=.001)
    # time_ = time.time()
    # n_data = len(train_x)
    # arr = np.array(range(n_data))

    P = 50

    last_100 = []
    best_last_100_avg = -1
    consecutive_worse = 0
    for epoch in range(1, 999999):

        #Sample
        eps = Variable(
            torch.FloatTensor(P, B, model.z_size).normal_().type(
                model.dtype))  #[P,B,Z]
        z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        logqz = lognormal(z, mean, logvar)  #[P,B]

        logpx = logposterior(z)

        # print (logpx)
        # print (logqz)

        # fsda

        # data_index= 0
        # for i in range(int(n_data/batch_size)):
        # batch = train_x[data_index:data_index+batch_size]
        # data_index += batch_size

        # batch = Variable(torch.from_numpy(batch)).type(self.dtype)
        optimizer.zero_grad()

        # elbo, logpxz, logqz = self.forward(batch, k=k)

        loss = -(torch.mean(logpx - logqz))

        loss_np = loss.data.cpu().numpy()
        # print (epoch, loss_np)
        # fasfaf

        loss.backward()
        optimizer.step()

        last_100.append(loss_np)
        if epoch % 100 == 0:

            last_100_avg = np.mean(last_100)
            if last_100_avg < best_last_100_avg or best_last_100_avg == -1:
                consecutive_worse = 0
                best_last_100_avg = last_100_avg
            else:
                consecutive_worse += 1
                # print(consecutive_worse)
                if consecutive_worse > 10:
                    # print ('done')
                    break

            print(epoch, last_100_avg, consecutive_worse)  #,mean)
            # print (torch.mean(logpx))

            last_100 = []

        # break

        # if epoch%display_epoch==0:
        #     print ('Train Epoch: {}/{}'.format(epoch, epochs),
        #         'LL:{:.3f}'.format(-loss.data[0]),
        #         'logpxz:{:.3f}'.format(logpxz.data[0]),
        #         # 'logpz:{:.3f}'.format(logpz.data[0]),
        #         'logqz:{:.3f}'.format(logqz.data[0]),
        #         'T:{:.2f}'.format(time.time()-time_),
        #         )

        #     time_ = time.time()

    # Compute VAE and IWAE bounds

    #Sample
    eps = Variable(
        torch.FloatTensor(1000, B,
                          model.z_size).normal_().type(model.dtype))  #[P,B,Z]
    z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
    logqz = lognormal(z, mean, logvar)  #[P,B]

    # print (logqz)
    # fad
    logpx = logposterior(z)

    elbo = logpx - logqz  #[P,B]
    vae = torch.mean(elbo)

    max_ = torch.max(elbo, 0)[0]  #[B]
    elbo_ = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_  #[B]
    iwae = torch.mean(elbo_)

    return vae, iwae
Ejemplo n.º 44
0
    def forward(self, k, x, logposterior):
        '''
        k: number of samples
        x: [B,X]
        logposterior(z) -> [P,B]
        '''

        self.B = x.size()[0]
        self.P = k

        #q(v|x)
        out = self.act_func(self.fc1(x))
        out = self.act_func(self.fc2(out))
        out = self.fc3(out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        #Sample v0
        eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        v = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        logqv0 = lognormal(v, mean, logvar) #[P,B]

        #[PB,Z]
        v = v.view(-1,self.z_size)
        #[PB,X]
        x_tiled = x.repeat(k,1)
        #[PB,X+Z]
        # print (x_tiled.size())
        # print (v.size())

        xv = torch.cat((x_tiled, v),1)

        #q(z|x,v)
        out = self.act_func(self.fc4(xv))
        out = self.act_func(self.fc5(out))
        out = self.fc6(out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        #Sample z0
        eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        logqz0 = lognormal(z, mean, logvar) #[P,B]

        #[PB,Z]
        z = z.view(-1,self.z_size)


        logdetsum = 0.
        for i in range(self.n_flows):

            z, v, logdet = self.norm_flow(self.params[i],z,v)
            logdetsum += logdet


        xz = torch.cat((x_tiled,z),1)
        #r(vT|x,zT)
        out = self.act_func(self.fc7(xz))
        out = self.act_func(self.fc8(out))
        out = self.fc9(out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        v = v.view(k,self.B,self.z_size)
        logrvT = lognormal(v, mean, logvar) #[P,B]

        z = z.view(k,self.B,self.z_size)

        return z, logqz0+logqv0-logdetsum-logrvT
Ejemplo n.º 45
0
    def forward(self, k, x, logposterior):
        '''
        k: number of samples
        x: [B,X]
        logposterior(z) -> [P,B]
        '''

        self.B = x.size()[0]
        self.P = k


        if torch.cuda.is_available():
            self.grad_outputs = torch.ones(k, self.B).cuda()
        else:
            self.grad_outputs = torch.ones(k, self.B)

        #q(v|x)
        out = x
        for i in range(len(self.qv_weights)-1):
            out = self.act_func(self.qv_weights[i](out))
        out = self.qv_weights[-1](out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        #Sample v0
        eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]
        v = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]
        logqv0 = lognormal(v, mean, logvar) #[P,B]

        #[PB,Z]
        v = v.view(-1,self.z_size)
        #[PB,X]
        x_tiled = x.repeat(k,1)
        #[PB,X+Z]
        # print (x_tiled.size())
        # print (v.size())
        xv = torch.cat((x_tiled, v),1)

        #q(z|x,v)
        out = xv
        for i in range(len(self.qz_weights)-1):
            out = self.act_func(self.qz_weights[i](out))
        out = self.qz_weights[-1](out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        #Sample z0
        eps = Variable(torch.FloatTensor(k, self.B, self.z_size).normal_().type(self.dtype)) #[P,B,Z]

        mean = mean.contiguous().view(self.P,self.B,self.z_size)
        logvar = logvar.contiguous().view(self.P,self.B,self.z_size)

        z = eps.mul(torch.exp(.5*logvar)) + mean  #[P,B,Z]

        mean = mean.contiguous().view(self.P*self.B,self.z_size)
        logvar = logvar.contiguous().view(self.P*self.B,self.z_size)


        logqz0 = lognormal(z, mean, logvar) #[P,B]

        #[PB,Z]
        z = z.view(-1,self.z_size)


        logdetsum = 0.
        for i in range(self.n_flows):

            z, v, logdet = self.norm_flow(self.params[i],z,v,logposterior)
            logdetsum += logdet


        xz = torch.cat((x_tiled,z),1)
        #r(vT|x,zT)
        out = xz
        for i in range(len(self.rv_weights)-1):
            out = self.act_func(self.rv_weights[i](out))
        out = self.rv_weights[-1](out)
        mean = out[:,:self.z_size]
        logvar = out[:,self.z_size:]

        v = v.view(k,self.B,self.z_size)
        logrvT = lognormal(v, mean, logvar) #[P,B]

        z = z.view(k,self.B,self.z_size)

        return z, logqz0+logqv0-logdetsum-logrvT
Ejemplo n.º 46
0
    def forward(self, k, x, logposterior):
        '''
        k: number of samples
        x: [B,X]
        logposterior(z) -> [P,B]
        '''

        self.B = x.size()[0]
        self.P = k

        if torch.cuda.is_available():
            self.grad_outputs = torch.ones(k, self.B).cuda()
        else:
            self.grad_outputs = torch.ones(k, self.B)

        #q(v|x)
        out = self.act_func(self.fc1(x))
        out = self.act_func(self.fc2(out))
        out = self.fc3(out)
        mean = out[:, :self.z_size]
        logvar = out[:, self.z_size:]

        #Sample v0
        eps = Variable(
            torch.FloatTensor(k, self.B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]
        v = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        logqv0 = lognormal(v, mean, logvar)  #[P,B]

        #[PB,Z]
        v = v.view(-1, self.z_size)
        #[PB,X]
        x_tiled = x.repeat(k, 1)
        #[PB,X+Z]
        # print (x_tiled.size())
        # print (v.size())
        xv = torch.cat((x_tiled, v), 1)

        #q(z|x,v)
        out = self.act_func(self.fc4(xv))
        out = self.act_func(self.fc5(out))
        out = self.fc6(out)
        mean = out[:, :self.z_size]
        logvar = out[:, self.z_size:]

        #Sample z0
        eps = Variable(
            torch.FloatTensor(k, self.B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]
        z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        logqz0 = lognormal(z, mean, logvar)  #[P,B]

        #[PB,Z]
        z = z.view(-1, self.z_size)

        logdetsum = 0.
        for i in range(self.n_flows):

            z, v, logdet = self.norm_flow(self.params[i], z, v, logposterior)
            logdetsum += logdet

        xz = torch.cat((x_tiled, z), 1)
        #r(vT|x,zT)
        out = self.act_func(self.fc7(xz))
        out = self.act_func(self.fc8(out))
        out = self.fc9(out)
        mean = out[:, :self.z_size]
        logvar = out[:, self.z_size:]

        v = v.view(k, self.B, self.z_size)
        logrvT = lognormal(v, mean, logvar)  #[P,B]

        z = z.view(k, self.B, self.z_size)

        return z, logqz0 + logqv0 - logdetsum - logrvT
Ejemplo n.º 47
0
    def forward(self, k, x, logposterior):
        '''
        k: number of samples
        x: [B,X]
        logposterior(z) -> [P,B]
        '''

        self.B = x.size()[0]
        self.P = k

        # print (self.B, 'B')
        # print (k)
        # fsdaf

        #q(v|x)
        out = x
        for i in range(len(self.qv_weights) - 1):
            out = self.act_func(self.qv_weights[i](out))
        out = self.qv_weights[-1](out)
        mean = out[:, :self.z_size]
        logvar = out[:, self.z_size:]

        #Sample v0
        eps = Variable(
            torch.FloatTensor(k, self.B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]
        v = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        logqv0 = lognormal(v, mean, logvar)  #[P,B]

        #[PB,Z]
        v = v.view(-1, self.z_size)
        #[PB,X]
        x_tiled = x.repeat(k, 1)
        #[PB,X+Z]
        # print (x_tiled.size())
        # print (v.size())

        xv = torch.cat((x_tiled, v), 1)

        #q(z|x,v)
        out = xv
        for i in range(len(self.qz_weights) - 1):
            out = self.act_func(self.qz_weights[i](out))
        out = self.qz_weights[-1](out)
        mean = out[:, :self.z_size]
        logvar = out[:, self.z_size:]

        self.B = x.size()[0]
        # print (self.B, 'B')
        #Sample z0
        eps = Variable(
            torch.FloatTensor(k, self.B, self.z_size).normal_().type(
                self.dtype))  #[P,B,Z]
        # print (eps.size(),'eps')
        # print (mean.size(),'mean')
        # print (self.P, 'P')

        # print (mean)
        mean = mean.contiguous().view(self.P, self.B, self.z_size)
        logvar = logvar.contiguous().view(self.P, self.B, self.z_size)

        # print (mean)
        # mean = mean.contiguous().view(self.P,1,self.z_size)
        # logvar = logvar.contiguous().view(self.P,1,self.z_size)

        # print (mean.size(),'mean')

        z = eps.mul(torch.exp(.5 * logvar)) + mean  #[P,B,Z]
        # print (z.size(),'z')

        # mean = mean.contiguous().view(self.P*self.B,self.z_size)
        # logvar = logvar.contiguous().view(self.P*self.B,self.z_size)

        logqz0 = lognormal333(z, mean, logvar)  #[P,B]

        #[PB,Z]
        z = z.view(-1, self.z_size)

        # print (z.size())

        logdetsum = 0.
        for i in range(self.n_flows):

            z, v, logdet = self.norm_flow(self.params[i], z, v)
            logdetsum += logdet

        xz = torch.cat((x_tiled, z), 1)

        #r(vT|x,zT)
        out = xz
        for i in range(len(self.rv_weights) - 1):
            out = self.act_func(self.rv_weights[i](out))
        out = self.rv_weights[-1](out)
        mean = out[:, :self.z_size]
        logvar = out[:, self.z_size:]

        mean = mean.contiguous().view(self.P, self.B, self.z_size)
        logvar = logvar.contiguous().view(self.P, self.B, self.z_size)

        v = v.view(k, self.B, self.z_size)
        logrvT = lognormal333(v, mean, logvar)  #[P,B]

        z = z.view(k, self.B, self.z_size)

        # print(logqz0.size(), 'here')
        # print(logqv0.size())
        # print(logdetsum.size())
        # print(logrvT.size())

        logdetsum = logdetsum.view(k, self.B)

        # print (logqz0+logqv0-logdetsum-logrvT)

        # fadfdsa

        return z, logqz0 + logqv0 - logdetsum - logrvT
Ejemplo n.º 48
0
def test_ais(model, data_x, path_to_load_variables='', batch_size=20, display_epoch=4, k=10):

    def intermediate_dist(t, z, mean, logvar, zeros, batch):

        logp1 = lognormal(z, mean, logvar)  #[P,B]
        log_prior = lognormal(z, zeros, zeros)  #[P,B]
        log_likelihood = log_bernoulli(model.decode(z), batch)
        logpT = log_prior + log_likelihood
        log_intermediate_2 = (1-float(t))*logp1 + float(t)*logpT
        return log_intermediate_2



    n_intermediate_dists = 25
    n_HMC_steps = 5
    step_size = .1

    retain_graph = False
    volatile_ = False
    requires_grad = False

    if path_to_load_variables != '':
        # model.load_state_dict(torch.load(path_to_load_variables))
        model.load_state_dict(torch.load(path_to_load_variables, map_location=lambda storage, loc: storage))
        print 'loaded variables ' + path_to_load_variables


    logws = []
    data_index= 0
    for i in range(len(data_x)/ batch_size):

        print i

        #AIS

        schedule = np.linspace(0.,1.,n_intermediate_dists)
        model.B = batch_size


        batch = data_x[data_index:data_index+batch_size]
        data_index += batch_size

        if torch.cuda.is_available():
            batch = Variable(batch, volatile=volatile_, requires_grad=requires_grad).cuda()
            zeros = Variable(torch.zeros(model.B, model.z_size), volatile=volatile_, requires_grad=requires_grad).cuda() # [B,Z]
            logw = Variable(torch.zeros(k, model.B), volatile=volatile_, requires_grad=requires_grad).cuda()
            grad_outputs = torch.ones(k, model.B).cuda()
        else:
            batch = Variable(batch)
            zeros = Variable(torch.zeros(model.B, model.z_size)) # [B,Z]
            logw = Variable(torch.zeros(k, model.B))
            grad_outputs = torch.ones(k, model.B)

        #Encode x
        mean, logvar = model.encode(batch) #[B,Z]
        # print mean.data.numpy().shape
        # fasdf

        #Init z
        z, logpz, logqz = model.sample(mean, logvar, k=k)  #[P,B,Z], [P,B], [P,B]
        # print logpz.data.numpy().shape
        # fasdf


        for (t0, t1) in zip(schedule[:-1], schedule[1:]):

            # gc.collect() 
            memReport()

            print t0

            #Compute intermediate distribution log prob
            # (1-t)*logp1(z) + (t)*logpT(z)
            logp1 = lognormal(z, mean, logvar)  #[P,B]
            # print z.size()
            # print zeros.size()
            log_prior = lognormal(z, zeros, zeros)  #[P,B]
            log_likelihood = log_bernoulli(model.decode(z), batch)
            logpT = log_prior + log_likelihood

            #log pt-1(zt-1)
            log_intermediate_1 = (1-float(t0))*logp1 + float(t0)*logpT
            #log pt(zt-1)
            log_intermediate_2 = (1-float(t1))*logp1 + float(t1)*logpT

            logw += log_intermediate_2 - log_intermediate_1



            #HMC


            if torch.cuda.is_available():
                v = Variable(torch.FloatTensor(z.size()).normal_(), volatile=volatile_, requires_grad=requires_grad).cuda()
            else:
                v = Variable(torch.FloatTensor(z.size()).normal_()) 

            v0 = v
            z0 = z


            gradients = torch.autograd.grad(outputs=log_intermediate_2, inputs=z,
                              grad_outputs=grad_outputs,
                              create_graph=True, retain_graph=retain_graph, only_inputs=True)[0]

            v = v + .5 *step_size*gradients
            z = z + step_size*v

            for LF_step in range(n_HMC_steps):
            # for LF_step in range(1):

                # print LF_step

                # logp1 = lognormal(z, mean, logvar)  #[P,B]
                # log_prior = lognormal(z, zeros, zeros)  #[P,B]
                # log_likelihood = log_bernoulli(model.decode(z), batch)
                # logpT = log_prior + log_likelihood
                # log_intermediate_2 = (1-float(t1))*logp1 + float(t1)*logpT
                log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch)

                gradients = torch.autograd.grad(outputs=log_intermediate_2, inputs=z,
                                  grad_outputs=grad_outputs,
                                  create_graph=True, retain_graph=retain_graph, only_inputs=True)[0]

                v = v + step_size*gradients
                z = z + step_size*v



            # logp1 = lognormal(z, mean, logvar)  #[P,B]
            # log_prior = lognormal(z, zeros, zeros)  #[P,B]
            # log_likelihood = log_bernoulli(model.decode(z), batch)
            # logpT = log_prior + log_likelihood
            # log_intermediate_2 = (1-float(t1))*logp1 + float(t1)*logpT
            log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch)

            gradients = torch.autograd.grad(outputs=log_intermediate_2, inputs=z,
                              grad_outputs=grad_outputs,
                              create_graph=True, retain_graph=retain_graph, only_inputs=True)[0]

            v = v + .5 *step_size*gradients


            #MH step
            # logp1 = lognormal(z0, mean, logvar)  #[P,B]
            # log_prior = lognormal(z0, zeros, zeros)  #[P,B]
            # log_likelihood = log_bernoulli(model.decode(z0), batch)
            # logpT = log_prior + log_likelihood
            # log_intermediate_2 = (1-float(t1))*logp1 + float(t1)*logpT
            log_intermediate_2 = intermediate_dist(t1, z0, mean, logvar, zeros, batch)

            logpv0 = lognormal(v0, zeros, zeros) #[P,B]
            hamil_0 =  log_intermediate_2 + logpv0

            # logp1 = lognormal(z, mean, logvar)  #[P,B]
            # log_prior = lognormal(z, zeros, zeros)  #[P,B]
            # log_likelihood = log_bernoulli(model.decode(z), batch)
            # logpT = log_prior + log_likelihood
            # log_intermediate_2 = (1-float(t1))*logp1 + float(t1)*logpT
            log_intermediate_2 = intermediate_dist(t1, z, mean, logvar, zeros, batch)
            
            logpvT = lognormal(v, zeros, zeros) #[P,B]

            hamil_T = log_intermediate_2 + logpvT
            # print hamil_T.data.numpy().shape

            accept_prob = torch.exp(hamil_T - hamil_0)

            if torch.cuda.is_available():
                rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_(), volatile=volatile_, requires_grad=requires_grad).cuda()
            else:
                rand_uni = Variable(torch.FloatTensor(accept_prob.size()).uniform_())


            accept = accept_prob > rand_uni

            if torch.cuda.is_available():
                accept = accept.type(torch.FloatTensor).cuda()
            else:
                accept = accept.type(torch.FloatTensor)


            
            accept = accept.view(k, model.B, 1)
            # print accept.data.numpy().shape

            # print torch.mean(accept)


            z = (accept * z) + ((1-accept) * z0)

            avg_acceptance_rate = torch.mean(accept)
            # print avg_acceptance_rate.data.numpy()

            # if avg_acceptance_rate.data.numpy() > .7:
            # if avg_acceptance_rate > .7:
            if avg_acceptance_rate.cpu().data.numpy() > .7:
                step_size = 1.02 * step_size
            else:
                step_size = .98 * step_size

            if step_size < 0.0001:
                step_size = 0.0001
            if step_size > 0.5:
                step_size = 0.5




        #lgo sum exp
        max_ = torch.max(logw,0)[0] #[B]
        logw = torch.log(torch.mean(torch.exp(logw - max_), 0)) + max_ #[B]

        logws.append(torch.mean(logw.cpu()).data.numpy())


        if i%display_epoch==0:
            print i,len(data_x)/ batch_size, np.mean(logws)

    return np.mean(logws)