def forward3(self, x, k=1): self.B = x.size()[0] mu, logvar = self.encode(x) z, logpz, logqz = self.sample(mu, logvar, k=k) #[P,B,Z] x_hat = self.decode(z) #[PB,X] x_hat = x_hat.view(k, self.B, -1) # print x_hat.size() x = x.view(self.B, -1) # print (x_hat.size()) # print (x.size()) # fasdfd logpx = log_bernoulli(x_hat, x) #[P,B] elbo = logpx #+ logpz - logqz #[P,B] if k>1: max_ = torch.max(elbo, 0)[0] #[B] elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] elbo = torch.mean(elbo) #[1] #for printing logpx = torch.mean(logpx) logpz = torch.mean(logpz) logqz = torch.mean(logqz) self.x_hat_sigmoid = F.sigmoid(x_hat) return elbo, logpx, logpz, logqz, self.x_hat_sigmoid
def forward(self, x, policy, k=1): # x: [B,2,84,84] self.B = x.size()[0] mu, logvar = self.encode(x) # print (mu.size()) # print (logvar.size()) # mu = mu.unsqueeze(0) # logvar = logvar.unsqueeze(0) z, logpz, logqz = self.sample(mu, logvar, k=k) #[P,B,Z] x_hat = self.decode(z) #[PB,X] x_hat_sigmoid = F.sigmoid(x_hat) # print (torch.FloatTensor(x_hat_sigmoid.size()).random_(0,10) ) # print (torch.FloatTensor(x_hat_sigmoid.size()).uniform_(0,.3) ) # fsad #Add Noise # noisy_x_hat_sigmoid = x_hat_sigmoid + Variable(torch.FloatTensor(x_hat_sigmoid.size()).uniform_(0,.3)).cuda() # x_hat_sigmoid = noisy_x_hat_sigmoid # dist_recon = policy.action_dist(F.sigmoid(x_hat)*255.) # log_dist_recon = policy.action_logdist(F.sigmoid(x_hat)*255.) log_dist_recon = policy.action_logdist(x_hat_sigmoid) # print (torch.sum(torch.autograd.grad(torch.sum(torch.log(dist_recon)), self.deconv3.weight)[0])) # print (torch.sum(torch.autograd.grad(torch.sum(torch.log(dist_recon)), self.deconv3.weight)[0])) # print (torch.sum(torch.autograd.grad(torch.sum(x_hat*10), self.deconv3.weight)[0])) # fsadf # dist_true = policy.action_dist(x*255.) # log_dist_true = policy.action_logdist(x*255.) log_dist_true = policy.action_logdist(x) # print (dist_true) # print (dist_recon) # fasdf # # print (x_hat.size()) # # print (dist_recon) # # print (dist_true) # # fads flat_x_hat = x_hat.view(k, self.B, -1) # # print x_hat.size() flat_x = x.view(self.B, -1) # # print (x_hat.size()) # # print (x.size()) # # fasdfd logpx = log_bernoulli(flat_x_hat, flat_x) #[P,B] # print (logpx.size()) # elbo = logpx + logpz - logqz #[P,B] # action_dif = torch.mean((dist_recon-dist_true)**2) # action_dif = torch.sum((torch.log(dist_true) - torch.log(dist_recon))*dist_true) action_dist_kl = torch.sum((log_dist_true - log_dist_recon)*torch.exp(log_dist_true), dim=1) #[B] # print (action_dif.size()) # fasdf # neg_action_dif = - action_dif # print (torch.sum(torch.autograd.grad(neg_action_dif, policy.conv1.weight)[0])) # ZERO # fdsa # print (torch.sum(torch.autograd.grad(neg_action_dif, self.deconv3.weight)[0])) # print (torch.sum(torch.autograd.grad(neg_action_dif, self.deconv1.weight)[0])) # print (torch.sum(torch.autograd.grad(neg_action_dif, self.conv1.weight)[0])) # print (torch.sum(torch.autograd.grad(torch.sum(torch.log(dist_recon)), self.deconv3.weight)[0])) # print (torch.sum(torch.autograd.grad(torch.sum(torch.log(dist_true)), self.deconv3.weight)[0])) # fadf # elbo = logpx + logpz - logqz # if k>1: # max_ = torch.max(elbo, 0)[0] #[B] # elbo = torch.log(torch.mean(torch.exp(elbo - max_), 0)) + max_ #[B] # elbo = torch.mean(elbo) #[1] # neg_action_dif = neg_action_dif # logpz = logpz*.01 # logqz = logqz*.01 # logpx = logpx*.01 # elbo = torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz) - torch.mean(action_dist_kl) #[1] #for printing action_dist_kl = torch.mean(action_dist_kl) weight = .01 # .00001 logpx = torch.mean(logpx) * weight * 0. logpz = torch.mean(logpz) * weight logqz = torch.mean(logqz) * weight elbo = torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz) - torch.mean(action_dist_kl) #[1] # self.x_hat_sigmoid = F.sigmoid(x_hat) # return elbo, logpx, logpz, logqz # print (logpx != logpx) # if (logpx != logpx).data.cpu().numpy(): # print( 'NAN') # fasd return elbo, logpx, logpz, logqz, action_dist_kl
def forward(self, x, policies, k=1): # x: [B,2,84,84] self.B = x.size()[0] mu, logvar = self.encode(x) z, logpz, logqz = self.sample(mu, logvar, k=k) #[P,B,Z] x_hat = self.decode(z) #[PB,X] x_hat_sigmoid = F.sigmoid(x_hat) kls = [] act_difs = [] for p in range(len(policies)): log_dist_recon = policies[p].action_logdist(x_hat_sigmoid) log_dist_true = policies[p].action_logdist(x) action_dist_kl = torch.sum((log_dist_true - log_dist_recon)*torch.exp(log_dist_true), dim=1) #[B] kls.append(action_dist_kl) act_recon = policies[p].get_intermediate_activation(x_hat_sigmoid) act_recon = act_recon.view(self.B, -1) act_true = policies[p].get_intermediate_activation(x) act_true = act_true.view(self.B, -1) # print (act_true) # fsadf act_dif = torch.mean((act_recon - act_true)**2, dim=1) act_difs.append(act_dif) #Average over polices kls = torch.stack(kls) #[policies, B] action_dist_kl = torch.mean(kls, dim=0) #[B] action_dist_kl = torch.mean(action_dist_kl) #[1] act_difs = torch.stack(act_difs) #[policies, B] act_dif = torch.mean(act_dif, dim=0) #[B] act_dif = torch.mean(act_dif) #[1] #Likelihood flat_x_hat = x_hat.view(k, self.B, -1) flat_x = x.view(self.B, -1) logpx = log_bernoulli(flat_x_hat, flat_x) #[P,B] # scale = action_dist_kl.data / (torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz)).data # scale = torch.clamp(scale, max=1.) # scale = Variable(scale) scale = .00001 logpx = torch.mean(logpx) * scale #* 0.1 logpz = torch.mean(logpz) * scale logqz = torch.mean(logqz) * scale elbo = torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz) - torch.mean(action_dist_kl) - act_dif #[1] return elbo, logpx, logpz, logqz, action_dist_kl, act_dif
def forward(self, x, policies, k=1): # x: [B,2,84,84] self.B = x.size()[0] mu, logvar = self.encode(x) # z, logpz, logqz = self.sample(mu, logvar, k=k) #[P,B,Z] # x_hat = self.decode(z) #[PB,X] x_hat = self.decode(mu) #[PB,X] x_hat_sigmoid = F.sigmoid(x_hat) kls = [] act_difs = [] grad_difs = [] for p in range(len(policies)): x = Variable(x.data, requires_grad=True, volatile=False) log_dist_recon = policies[p].action_logdist(x_hat_sigmoid) log_dist_true = policies[p].action_logdist(x) action_dist_kl = torch.sum( (log_dist_true - log_dist_recon) * torch.exp(log_dist_true), dim=1) #[B] kls.append(action_dist_kl) # act_recon = policies[p].get_intermediate_activation(x_hat_sigmoid) # act_recon = act_recon.view(self.B, -1) # act_true = policies[p].get_intermediate_activation(x) # act_true = act_true.view(self.B, -1) # # print (act_true) # # fsadf # act_dif = torch.mean((act_recon - act_true)**2, dim=1) # act_difs.append(act_dif) # ent_true = torch.mean(torch.sum(log_dist_true*torch.exp(log_dist_true),dim=1)) ent_true = torch.mean(log_dist_true[:, 3]) grad_true = torch.autograd.grad(ent_true, x, create_graph=True, retain_graph=True)[0] #[B,2,84,84] # print (grad_true) # ent_recon = torch.mean(torch.sum(log_dist_recon*torch.exp(log_dist_recon),dim=1)) ent_recon = torch.mean(log_dist_recon[:, 3]) grad_recon = torch.autograd.grad( ent_recon, x_hat_sigmoid, create_graph=True, retain_graph=True)[0] #[B,2,84,84] # print (grad_recon) # fasd # grad_dif = torch.mean((grad_recon-grad_true)**2) #[1] grad_dif = torch.sum((grad_recon - grad_true)**2) #[1] grad_difs.append(grad_dif) # #Average over polices kls = torch.stack(kls) #[policies, B] # action_dist_kl = torch.mean(kls, dim=0) #[B] action_dist_kl = torch.mean(action_dist_kl) #[1] # act_difs = torch.stack(act_difs) #[policies, B] # act_dif = torch.mean(act_dif, dim=0) #[B] # act_dif = torch.mean(act_dif) #[1] grad_difs = torch.stack(grad_difs) #[policies, B] grad_dif = torch.mean(grad_difs) #*100. #[1] # grad_dif = torch.sum(grad_difs) #*100. #[1] #Likelihood flat_x_hat = x_hat.view(k, self.B, -1) flat_x = x.view(self.B, -1) logpx = log_bernoulli(flat_x_hat, flat_x) #[P,B] # scale = action_dist_kl.data / (torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz)).data # scale = torch.clamp(scale, max=1.) # scale = Variable(scale) # scale = .00001 logpx = torch.mean(logpx) #* scale #* 0.1 # logpz = torch.mean(logpz) * scale # logqz = torch.mean(logqz) * scale # elbo = torch.mean(logpx) + torch.mean(logpz) - torch.mean(logqz) - torch.mean(action_dist_kl) - act_dif - grad_dif #[1] loss = grad_dif + action_dist_kl #[1] # return elbo, logpx, logpz, logqz, action_dist_kl, act_dif, grad_dif return loss, logpx, grad_dif, action_dist_kl