示例#1
0
    def forward3(self, x, k=1):
        
        self.B = x.size()[0]
        mu, logvar = self.encode(x)
        z, logpz, logqz = self.sample(mu, logvar, k=k)  #[P,B,Z]
        x_hat = self.decode(z)  #[PB,X]
        x_hat = x_hat.view(k, self.B, -1)
        # print x_hat.size()

        x = x.view(self.B, -1)

        # print (x_hat.size())
        # print (x.size())
        # fasdfd

        logpx = log_bernoulli(x_hat, x)  #[P,B]

        elbo = logpx #+ logpz - logqz  #[P,B]

        if k>1:
            max_ = torch.max(elbo, 0)[0] #[B]
            elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]

        elbo = torch.mean(elbo) #[1]

        #for printing
        logpx = torch.mean(logpx)
        logpz = torch.mean(logpz)
        logqz = torch.mean(logqz)
        self.x_hat_sigmoid = F.sigmoid(x_hat)

        return elbo, logpx, logpz, logqz, self.x_hat_sigmoid
示例#2
0
    def forward(self, x, policy, k=1):
        # x: [B,2,84,84]
        
        self.B = x.size()[0]

        mu, logvar = self.encode(x)
        # print (mu.size())
        # print (logvar.size())
        # mu = mu.unsqueeze(0)
        # logvar = logvar.unsqueeze(0)
        z, logpz, logqz = self.sample(mu, logvar, k=k)  #[P,B,Z]
        x_hat = self.decode(z)  #[PB,X]


        x_hat_sigmoid = F.sigmoid(x_hat)

        # print (torch.FloatTensor(x_hat_sigmoid.size()).random_(0,10) )

        # print (torch.FloatTensor(x_hat_sigmoid.size()).uniform_(0,.3) )

        # fsad

        #Add Noise
        # noisy_x_hat_sigmoid = x_hat_sigmoid + Variable(torch.FloatTensor(x_hat_sigmoid.size()).uniform_(0,.3)).cuda()
        # x_hat_sigmoid = noisy_x_hat_sigmoid
        
        # dist_recon = policy.action_dist(F.sigmoid(x_hat)*255.)
        # log_dist_recon = policy.action_logdist(F.sigmoid(x_hat)*255.)
        log_dist_recon = policy.action_logdist(x_hat_sigmoid)



        # print (torch.sum(torch.autograd.grad(torch.sum(torch.log(dist_recon)), self.deconv3.weight)[0]))
        # print (torch.sum(torch.autograd.grad(torch.sum(torch.log(dist_recon)), self.deconv3.weight)[0]))
        # print (torch.sum(torch.autograd.grad(torch.sum(x_hat*10), self.deconv3.weight)[0]))
        # fsadf

        # dist_true = policy.action_dist(x*255.)
        # log_dist_true = policy.action_logdist(x*255.)
        log_dist_true = policy.action_logdist(x)



        # print (dist_true)
        # print (dist_recon)

        # fasdf
        




        # # print (x_hat.size())
        # # print (dist_recon)
        # # print (dist_true)

        # # fads

        flat_x_hat = x_hat.view(k, self.B, -1)
        # # print x_hat.size()

        flat_x = x.view(self.B, -1)

        # # print (x_hat.size())
        # # print (x.size())
        # # fasdfd

        logpx = log_bernoulli(flat_x_hat, flat_x)  #[P,B]
        # print (logpx.size())
        
        # elbo = logpx + logpz - logqz  #[P,B]


        # action_dif = torch.mean((dist_recon-dist_true)**2)

        # action_dif = torch.sum((torch.log(dist_true) - torch.log(dist_recon))*dist_true)
        action_dist_kl = torch.sum((log_dist_true - log_dist_recon)*torch.exp(log_dist_true), dim=1) #[B]

        # print (action_dif.size())
        # fasdf



        # neg_action_dif = - action_dif 



        # print (torch.sum(torch.autograd.grad(neg_action_dif, policy.conv1.weight)[0]))      # ZERO 
        # fdsa

        # print (torch.sum(torch.autograd.grad(neg_action_dif, self.deconv3.weight)[0]))
        # print (torch.sum(torch.autograd.grad(neg_action_dif, self.deconv1.weight)[0]))
        # print (torch.sum(torch.autograd.grad(neg_action_dif, self.conv1.weight)[0]))



        # print (torch.sum(torch.autograd.grad(torch.sum(torch.log(dist_recon)), self.deconv3.weight)[0]))
        # print (torch.sum(torch.autograd.grad(torch.sum(torch.log(dist_true)), self.deconv3.weight)[0]))

        # fadf

        # elbo = logpx + logpz - logqz 


        # if k>1:
        #     max_ = torch.max(elbo, 0)[0] #[B]
        #     elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]

        # elbo = torch.mean(elbo) #[1]


        # neg_action_dif = neg_action_dif
        # logpz = logpz*.01
        # logqz = logqz*.01
        # logpx = logpx*.01

        # elbo = torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz) - torch.mean(action_dist_kl)  #[1]

        #for printing
        action_dist_kl = torch.mean(action_dist_kl)
        weight = .01 # .00001
        logpx = torch.mean(logpx) * weight * 0.
        logpz = torch.mean(logpz) * weight 
        logqz = torch.mean(logqz) * weight

        elbo = torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz) - torch.mean(action_dist_kl)  #[1]

        # self.x_hat_sigmoid = F.sigmoid(x_hat)

        # return elbo, logpx, logpz, logqz

        # print (logpx != logpx)

        # if (logpx != logpx).data.cpu().numpy():

        #     print( 'NAN')
        #     fasd


        return elbo, logpx, logpz, logqz, action_dist_kl
示例#3
0
    def forward(self, x, policies, k=1):
        # x: [B,2,84,84]
        
        self.B = x.size()[0]

        mu, logvar = self.encode(x)

        z, logpz, logqz = self.sample(mu, logvar, k=k)  #[P,B,Z]
        x_hat = self.decode(z)  #[PB,X]

        x_hat_sigmoid = F.sigmoid(x_hat)

        kls = []
        act_difs = []
        for p in range(len(policies)):

            log_dist_recon = policies[p].action_logdist(x_hat_sigmoid)
            log_dist_true = policies[p].action_logdist(x)

            action_dist_kl = torch.sum((log_dist_true - log_dist_recon)*torch.exp(log_dist_true), dim=1) #[B]
            kls.append(action_dist_kl)


            act_recon = policies[p].get_intermediate_activation(x_hat_sigmoid)
            act_recon = act_recon.view(self.B, -1)
            act_true = policies[p].get_intermediate_activation(x)
            act_true = act_true.view(self.B, -1)
            # print (act_true)
            # fsadf
            act_dif = torch.mean((act_recon - act_true)**2, dim=1)
            act_difs.append(act_dif)



        #Average over polices
        kls = torch.stack(kls)  #[policies, B]
        action_dist_kl = torch.mean(kls, dim=0)  #[B]

        action_dist_kl = torch.mean(action_dist_kl) #[1]

        act_difs = torch.stack(act_difs)  #[policies, B]
        act_dif = torch.mean(act_dif, dim=0)  #[B]

        act_dif = torch.mean(act_dif) #[1]

        #Likelihood
        flat_x_hat = x_hat.view(k, self.B, -1)
        flat_x = x.view(self.B, -1)
        logpx = log_bernoulli(flat_x_hat, flat_x)  #[P,B]

        


        # scale = action_dist_kl.data / (torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz)).data 
        # scale = torch.clamp(scale, max=1.)
        # scale = Variable(scale)
        scale = .00001

        logpx = torch.mean(logpx) * scale #* 0.1
        logpz = torch.mean(logpz) * scale 
        logqz = torch.mean(logqz) * scale

        elbo = torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz) - torch.mean(action_dist_kl) -  act_dif #[1]

        return elbo, logpx, logpz, logqz, action_dist_kl, act_dif
示例#4
0
    def forward(self, x, policy, k=1):
        # x: [B,2,84,84]
        
        self.B = x.size()[0]

        mu, logvar = self.encode(x)
        # print (mu.size())
        # print (logvar.size())
        # mu = mu.unsqueeze(0)
        # logvar = logvar.unsqueeze(0)
        z, logpz, logqz = self.sample(mu, logvar, k=k)  #[P,B,Z]
        x_hat = self.decode(z)  #[PB,X]


        x_hat_sigmoid = F.sigmoid(x_hat)

        # print (torch.FloatTensor(x_hat_sigmoid.size()).random_(0,10) )

        # print (torch.FloatTensor(x_hat_sigmoid.size()).uniform_(0,.3) )

        # fsad

        #Add Noise
        # noisy_x_hat_sigmoid = x_hat_sigmoid + Variable(torch.FloatTensor(x_hat_sigmoid.size()).uniform_(0,.3)).cuda()
        # x_hat_sigmoid = noisy_x_hat_sigmoid
        
        # dist_recon = policy.action_dist(F.sigmoid(x_hat)*255.)
        # log_dist_recon = policy.action_logdist(F.sigmoid(x_hat)*255.)
        log_dist_recon = policy.action_logdist(x_hat_sigmoid)



        # print (torch.sum(torch.autograd.grad(torch.sum(torch.log(dist_recon)), self.deconv3.weight)[0]))
        # print (torch.sum(torch.autograd.grad(torch.sum(torch.log(dist_recon)), self.deconv3.weight)[0]))
        # print (torch.sum(torch.autograd.grad(torch.sum(x_hat*10), self.deconv3.weight)[0]))
        # fsadf

        # dist_true = policy.action_dist(x*255.)
        # log_dist_true = policy.action_logdist(x*255.)
        log_dist_true = policy.action_logdist(x)



        # print (dist_true)
        # print (dist_recon)

        # fasdf
        




        # # print (x_hat.size())
        # # print (dist_recon)
        # # print (dist_true)

        # # fads

        flat_x_hat = x_hat.view(k, self.B, -1)
        # # print x_hat.size()

        flat_x = x.view(self.B, -1)

        # # print (x_hat.size())
        # # print (x.size())
        # # fasdfd

        logpx = log_bernoulli(flat_x_hat, flat_x)  #[P,B]
        # print (logpx.size())
        
        # elbo = logpx + logpz - logqz  #[P,B]


        # action_dif = torch.mean((dist_recon-dist_true)**2)

        # action_dif = torch.sum((torch.log(dist_true) - torch.log(dist_recon))*dist_true)
        action_dist_kl = torch.sum((log_dist_true - log_dist_recon)*torch.exp(log_dist_true), dim=1) #[B]

        # print (action_dif.size())
        # fasdf



        # neg_action_dif = - action_dif 



        # print (torch.sum(torch.autograd.grad(neg_action_dif, policy.conv1.weight)[0]))      # ZERO 
        # fdsa

        # print (torch.sum(torch.autograd.grad(neg_action_dif, self.deconv3.weight)[0]))
        # print (torch.sum(torch.autograd.grad(neg_action_dif, self.deconv1.weight)[0]))
        # print (torch.sum(torch.autograd.grad(neg_action_dif, self.conv1.weight)[0]))



        # print (torch.sum(torch.autograd.grad(torch.sum(torch.log(dist_recon)), self.deconv3.weight)[0]))
        # print (torch.sum(torch.autograd.grad(torch.sum(torch.log(dist_true)), self.deconv3.weight)[0]))

        # fadf

        # elbo = logpx + logpz - logqz 


        # if k>1:
        #     max_ = torch.max(elbo, 0)[0] #[B]
        #     elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B]

        # elbo = torch.mean(elbo) #[1]


        # neg_action_dif = neg_action_dif
        # logpz = logpz*.01
        # logqz = logqz*.01
        # logpx = logpx*.01

        # elbo = torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz) - torch.mean(action_dist_kl)  #[1]

        #for printing
        action_dist_kl = torch.mean(action_dist_kl)
        weight = .01 # .00001
        logpx = torch.mean(logpx) * weight * 0.
        logpz = torch.mean(logpz) * weight 
        logqz = torch.mean(logqz) * weight

        elbo = torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz) - torch.mean(action_dist_kl)  #[1]

        # self.x_hat_sigmoid = F.sigmoid(x_hat)

        # return elbo, logpx, logpz, logqz

        # print (logpx != logpx)

        # if (logpx != logpx).data.cpu().numpy():

        #     print( 'NAN')
        #     fasd


        return elbo, logpx, logpz, logqz, action_dist_kl
    def forward(self, x, policies, k=1):
        # x: [B,2,84,84]

        self.B = x.size()[0]

        mu, logvar = self.encode(x)

        # z, logpz, logqz = self.sample(mu, logvar, k=k)  #[P,B,Z]
        # x_hat = self.decode(z)  #[PB,X]
        x_hat = self.decode(mu)  #[PB,X]

        x_hat_sigmoid = F.sigmoid(x_hat)

        kls = []
        act_difs = []
        grad_difs = []
        for p in range(len(policies)):

            x = Variable(x.data, requires_grad=True, volatile=False)

            log_dist_recon = policies[p].action_logdist(x_hat_sigmoid)
            log_dist_true = policies[p].action_logdist(x)

            action_dist_kl = torch.sum(
                (log_dist_true - log_dist_recon) * torch.exp(log_dist_true),
                dim=1)  #[B]
            kls.append(action_dist_kl)

            # act_recon = policies[p].get_intermediate_activation(x_hat_sigmoid)
            # act_recon = act_recon.view(self.B, -1)
            # act_true = policies[p].get_intermediate_activation(x)
            # act_true = act_true.view(self.B, -1)
            # # print (act_true)
            # # fsadf
            # act_dif = torch.mean((act_recon - act_true)**2, dim=1)
            # act_difs.append(act_dif)

            # ent_true = torch.mean(torch.sum(log_dist_true*torch.exp(log_dist_true),dim=1))
            ent_true = torch.mean(log_dist_true[:, 3])

            grad_true = torch.autograd.grad(ent_true,
                                            x,
                                            create_graph=True,
                                            retain_graph=True)[0]  #[B,2,84,84]
            # print (grad_true)

            # ent_recon = torch.mean(torch.sum(log_dist_recon*torch.exp(log_dist_recon),dim=1))
            ent_recon = torch.mean(log_dist_recon[:, 3])

            grad_recon = torch.autograd.grad(
                ent_recon, x_hat_sigmoid, create_graph=True,
                retain_graph=True)[0]  #[B,2,84,84]

            # print (grad_recon)

            # fasd

            # grad_dif = torch.mean((grad_recon-grad_true)**2)  #[1]
            grad_dif = torch.sum((grad_recon - grad_true)**2)  #[1]
            grad_difs.append(grad_dif)

        # #Average over polices
        kls = torch.stack(kls)  #[policies, B]
        # action_dist_kl = torch.mean(kls, dim=0)  #[B]
        action_dist_kl = torch.mean(action_dist_kl)  #[1]

        # act_difs = torch.stack(act_difs)  #[policies, B]
        # act_dif = torch.mean(act_dif, dim=0)  #[B]
        # act_dif = torch.mean(act_dif) #[1]

        grad_difs = torch.stack(grad_difs)  #[policies, B]
        grad_dif = torch.mean(grad_difs)  #*100. #[1]
        # grad_dif = torch.sum(grad_difs) #*100. #[1]

        #Likelihood
        flat_x_hat = x_hat.view(k, self.B, -1)
        flat_x = x.view(self.B, -1)
        logpx = log_bernoulli(flat_x_hat, flat_x)  #[P,B]

        # scale = action_dist_kl.data / (torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz)).data
        # scale = torch.clamp(scale, max=1.)
        # scale = Variable(scale)
        # scale = .00001

        logpx = torch.mean(logpx)  #* scale #* 0.1
        # logpz = torch.mean(logpz) * scale
        # logqz = torch.mean(logqz) * scale

        # elbo = torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz) - torch.mean(action_dist_kl) -  act_dif - grad_dif #[1]
        loss = grad_dif + action_dist_kl  #[1]

        # return elbo, logpx, logpz, logqz, action_dist_kl, act_dif, grad_dif
        return loss, logpx, grad_dif, action_dist_kl