def logprob(self, input, sample_size=128, z=None): ''' input: positive samples ''' # init batch_size = input.size(0) input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) ''' get log q(z|x) ''' _, mu_qz, logvar_qz = self.encode(input) mu_qz = mu_qz.detach().repeat(1, sample_size).view( batch_size, sample_size, self.z_dim) logvar_qz = logvar_qz.detach().repeat(1, sample_size).view( batch_size, sample_size, self.z_dim) z = self.encode.sample(mu_qz, logvar_qz) logposterior = logprob_gaussian(mu_qz, logvar_qz, z, do_unsqueeze=False, do_mean=False) logposterior = torch.sum(logposterior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz ''' get log p(z) ''' # get prior (as unit normal dist) mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logprior = logprob_gaussian(mu_pz, logvar_pz, z, do_unsqueeze=False, do_mean=False) logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz ''' get log p(x|z) ''' # decode logit_x = [] #for i in range(sample_size): for i in range(batch_size): _, _logit_x = self.decode(z[i, :, :]) # ssz x zdim logit_x += [_logit_x.detach().unsqueeze(0)] logit_x = torch.cat(logit_x, dim=0) # bsz x ssz x input_dim _input = input.unsqueeze(1).expand( batch_size, sample_size, self.input_channels, self.input_height, self.input_height) # bsz x ssz x input_dim loglikelihood = -F.binary_cross_entropy_with_logits( logit_x, _input, reduction='none') loglikelihood = torch.sum(loglikelihood.view(batch_size, sample_size, -1), dim=2) # bsz x ssz ''' get log p(x|z)p(z)/q(z|x) ''' logprob = loglikelihood + logprior - logposterior # bsz x ssz logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) rprob = (logprob - logprob_max).exp() # relative prob logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 # return return logprob.mean()
def logprob_w_prior(self, input, sample_size=128, z=None): # init batch_size = input.size(0) input = input.view(batch_size, self.input_dim) ''' get z samples from p(z) ''' # get prior (as unit normal dist) if z is None: mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) z = sample_gaussian(mu_pz, logvar_pz) # sample z ''' get log p(x|z) ''' # decode _z = z.view(-1, self.z_dim) _, mu_x, logvar_x = self.decode(_z) # bsz*ssz x zdim mu_x = mu_x.view(batch_size, sample_size, self.input_dim) logvar_x = logvar_x.view(batch_size, sample_size, self.input_dim) _input = input.unsqueeze(1).expand( batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim loglikelihood = logprob_gaussian(mu_x, logvar_x, _input, do_unsqueeze=False, do_mean=False) loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz ''' get log p(x) ''' logprob = loglikelihood # bsz x ssz logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) rprob = (logprob - logprob_max).exp() # relative prob logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 # return return logprob.mean()
def logprob_w_cov_gaussian_posterior(self, input, sample_size=128, z=None, std=None): # init batch_size = input.size(0) input = input.view(batch_size, self.input_dim) assert sample_size >= 2*self.z_dim #assert int(math.sqrt(sample_size))**2 == sample_size ''' get z and pseudo log q(newz|x) ''' #z, newz = [], [] #logposterior = [] #inp = self.encode._forward_inp(input).detach() #for i in range(batch_size): # _inp = inp[i:i+1, :].expand(sample_size, inp.size(1)) # _nos = self.encode._forward_nos(sample_size, std=std, device=input.device).detach() # _z = self.encode._forward_all(_inp, _nos) # ssz x zdim # z += [_z.detach().unsqueeze(0)] #z = torch.cat(z, dim=0) # bsz x ssz x zdim #_nz = int(math.sqrt(sample_size)) _, _, _, _, z, _, _, _, _ = self.encode._forward(input, std=std, nz=sample_size) # bsz x ssz x zdim newz = [] logposterior = [] eye = torch.eye(self.z_dim, device=z.device) mu_qz = torch.mean(z, dim=1) # bsz x zdim for i in range(batch_size): _cov_qz = get_covmat(z[i, :, :]) + 1e-5*eye _rv_z = MultivariateNormal(mu_qz[i], _cov_qz) _newz = _rv_z.rsample(torch.Size([1, sample_size])) _logposterior = _rv_z.log_prob(_newz) newz += [_newz] logposterior += [_logposterior] newz = torch.cat(newz, dim=0) # bsz x ssz x zdim logposterior = torch.cat(logposterior, dim=0) # bsz x ssz ''' get log p(z) ''' # get prior (as unit normal dist) mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logprior = logprob_gaussian(mu_pz, logvar_pz, newz, do_unsqueeze=False, do_mean=False) logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz ''' get log p(x|z) ''' # decode mu_x, logvar_x = [], [] #for i in range(sample_size): for i in range(batch_size): _, _mu_x, _logvar_x = self.decode(newz[i, :, :]) mu_x += [_mu_x.detach().unsqueeze(0)] logvar_x += [_logvar_x.detach().unsqueeze(0)] mu_x = torch.cat(mu_x, dim=0) # bsz x ssz x input_dim logvar_x = torch.cat(logvar_x, dim=0) # bsz x ssz x input_dim _input = input.unsqueeze(1).expand(batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim loglikelihood = logprob_gaussian(mu_x, logvar_x, _input, do_unsqueeze=False, do_mean=False) loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz ''' get log p(x|z)p(z)/q(z|x) ''' logprob = loglikelihood + logprior - logposterior # bsz x ssz logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) rprob = (logprob - logprob_max).exp() # relative prob logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 # return return logprob.mean()
def logprob(self, input, sample_size=128, z=None): # init batch_size = input.size(0) input = input.view(batch_size, self.input_dim) ''' get log q(z|x) ''' _, mu_qz, logvar_qz = self.encode(input) mu_qz = mu_qz.detach().repeat(1, sample_size).view( batch_size, sample_size, self.z_dim) logvar_qz = logvar_qz.detach().repeat(1, sample_size).view( batch_size, sample_size, self.z_dim) z = self.encode.sample(mu_qz, logvar_qz) logposterior = logprob_gaussian(mu_qz, logvar_qz, z, do_unsqueeze=False, do_mean=False) logposterior = torch.sum(logposterior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz ''' get log p(z) ''' # get prior (as unit normal dist) mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logprior = logprob_gaussian(mu_pz, logvar_pz, z, do_unsqueeze=False, do_mean=False) logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz ''' get log p(x|z) ''' # decode #mu_x, logvar_x = [], [] #for i in range(batch_size): # _, _mu_x, _logvar_x = self.decode(z[i, :, :]) # ssz x zdim # mu_x += [_mu_x.detach().unsqueeze(0)] # logvar_x += [_logvar_x.detach().unsqueeze(0)] #mu_x = torch.cat(mu_x, dim=0) # bsz x ssz x input_dim #logvar_x = torch.cat(logvar_x, dim=0) # bsz x ssz x input_dim _z = z.view(-1, self.z_dim) _, mu_x, logvar_x = self.decode(_z) # bsz*ssz x zdim mu_x = mu_x.view(batch_size, sample_size, self.input_dim) logvar_x = logvar_x.view(batch_size, sample_size, self.input_dim) _input = input.unsqueeze(1).expand( batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim loglikelihood = logprob_gaussian(mu_x, logvar_x, _input, do_unsqueeze=False, do_mean=False) loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz ''' get log p(x|z)p(z)/q(z|x) ''' logprob = loglikelihood + logprior - logposterior # bsz x ssz logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) rprob = (logprob - logprob_max).exp() # relative prob logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 # return return logprob.mean()
def logprob_w_diag_gaussian_posterior(self, input, sample_size=128, z=None, std=None): # init batch_size = input.size(0) input = input.view(batch_size, self.input_dim) ''' get z ''' z = [] for i in range(sample_size): _z = self.encode(input, std=std) _z_flattened = _z.view(_z.size(1) * _z.size(2), -1) z += [_z_flattened.detach().unsqueeze(1)] z = torch.cat(z, dim=1) # bsz x ssz x zdim mu_qz = torch.mean(z, dim=1) logvar_qz = torch.log(torch.var(z, dim=1) + 1e-10) ''' get pseudo log q(z|x) ''' mu_qz = mu_qz.detach().repeat(1, sample_size).view( batch_size, sample_size, self.z_dim) logvar_qz = logvar_qz.detach().repeat(1, sample_size).view( batch_size, sample_size, self.z_dim) newz = sample_gaussian(mu_qz, logvar_qz) logposterior = logprob_gaussian(mu_qz, logvar_qz, newz, do_unsqueeze=False, do_mean=False) logposterior = torch.sum(logposterior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz ''' get log p(z) ''' # get prior (as unit normal dist) mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logprior = logprob_gaussian(mu_pz, logvar_pz, newz, do_unsqueeze=False, do_mean=False) logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz ''' get log p(x|z) ''' # decode logit_x = [] for i in range(sample_size): _, _logit_x = self.decode(newz[:, i, :]) logit_x += [_logit_x.detach().unsqueeze(1)] logit_x = torch.cat(logit_x, dim=1) # bsz x ssz x input_dim _input = input.unsqueeze(1).expand( batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim loglikelihood = -F.binary_cross_entropy_with_logits( logit_x, _input, reduction='none') loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz ''' get log p(x|z)p(z)/q(z|x) ''' logprob = loglikelihood + logprior - logposterior # bsz x ssz logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) rprob = (logprob - logprob_max).exp() # relative prob logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 # return return logprob.mean()
def logprob(self, input, sample_size=128, z=None): #assert int(math.sqrt(sample_size))**2 == sample_size # init batch_size = input.size(0) sample_size1 = sample_size #int(math.sqrt(sample_size)) sample_size2 = 1 #int(math.sqrt(sample_size)) input = input.view(batch_size, self.input_channels, self.input_height, self.input_height) ''' get - (log q(z|z0,x) + log q(z0|z) - log p(z0|z,x) - log p(z)) ''' ''' get log q(z0|x) ''' _, mu_qz0, logvar_qz0, _ = self.aux_encode(input) mu_qz0 = mu_qz0.unsqueeze(1).expand( batch_size, sample_size1, self.z0_dim).contiguous().view(batch_size * sample_size1, self.z0_dim) # bsz*ssz1 x z0_dim logvar_qz0 = logvar_qz0.unsqueeze(1).expand( batch_size, sample_size1, self.z0_dim).contiguous().view(batch_size * sample_size1, self.z0_dim) # bsz*ssz1 x z0_dim z0 = self.aux_encode.sample(mu_qz0, logvar_qz0) # bsz*ssz1 x z0_dim log_qz0 = logprob_gaussian(mu_qz0, logvar_qz0, z0, do_unsqueeze=False, do_mean=False) log_qz0 = torch.sum(log_qz0.view(batch_size, sample_size1, self.z0_dim), dim=2) # bsz x ssz1 log_qz0 = log_qz0.unsqueeze(2).expand( batch_size, sample_size1, sample_size2).contiguous().view( batch_size, sample_size1 * sample_size2) # bsz x ssz1*ssz2 ''' get log q(z|z0,x) ''' # forward _, mu_qz, logvar_qz, _ = self.encode( input, z0, nz=sample_size1) # bsz*ssz1 x z_dim mu_qz = mu_qz.detach().repeat(1, sample_size2).view( batch_size * sample_size1, sample_size2, self.z_dim) logvar_qz = logvar_qz.detach().repeat(1, sample_size2).view( batch_size * sample_size1, sample_size2, self.z_dim) z = self.encode.sample(mu_qz, logvar_qz) # bsz x ssz1 x ssz2 x z_dim log_qz = logprob_gaussian(mu_qz, logvar_qz, z, do_unsqueeze=False, do_mean=False) log_qz = torch.sum(log_qz.view(batch_size, sample_size1 * sample_size2, self.z_dim), dim=2) # bsz x ssz1*ssz2 ''' get log p(z0|z,x) ''' # encode _z0 = z0.unsqueeze(1).expand(batch_size * sample_size1, sample_size2, self.z0_dim).contiguous().view( batch_size, sample_size1, sample_size2, self.z0_dim).detach() _, mu_pz0, logvar_pz0 = self.aux_decode( input, z.view(-1, self.z_dim), nz=sample_size1 * sample_size2) # bsz*ssz1 x z_dim mu_pz0 = mu_pz0.view(batch_size, sample_size1, sample_size2, self.z0_dim) logvar_pz0 = logvar_pz0.view(batch_size, sample_size1, sample_size2, self.z0_dim) log_pz0 = logprob_gaussian(mu_pz0, logvar_pz0, _z0, do_unsqueeze=False, do_mean=False) # bsz x ssz1 x ssz2 xz0_dim log_pz0 = torch.sum(log_pz0.view(batch_size, sample_size1 * sample_size2, self.z0_dim), dim=2) # bsz x ssz1*ssz2 ''' get log p(z) ''' # get prior (as unit normal dist) mu_pz = input.new_zeros(batch_size * sample_size1, sample_size2, self.z_dim) logvar_pz = input.new_zeros(batch_size * sample_size1, sample_size2, self.z_dim) log_pz = logprob_gaussian(mu_pz, logvar_pz, z, do_unsqueeze=False, do_mean=False) log_pz = torch.sum(log_pz.view(batch_size, sample_size1 * sample_size2, self.z_dim), dim=2) # bsz x ssz1*ssz2 ''' get log p(x|z) ''' # decode _input = input.unsqueeze(1).unsqueeze(1).expand( batch_size, sample_size1, sample_size2, self.input_channels, self.input_height, self.input_height) # bsz x ssz1 x ssz2 x input_dim _z = z.view(-1, self.z_dim) #_, mu_x, logvar_x = self.decode(_z) # bsz*ssz1*ssz2 x zdim #mu_x = mu_x.view(batch_size, sample_size1, sample_size2, self.input_dim) #logvar_x = logvar_x.view(batch_size, sample_size1, sample_size2, self.input_dim) #loglikelihood = logprob_gaussian(mu_x, logvar_x, _input, do_unsqueeze=False, do_mean=False) _, logit_px = self.decode(_z) # bsz*ssz1*ssz2 x zdim logit_px = logit_px.view(batch_size, sample_size1, sample_size2, self.input_channels, self.input_height, self.input_height) loglikelihood = -F.binary_cross_entropy_with_logits( logit_px, _input, reduction='none') loglikelihood = torch.sum(loglikelihood.view( batch_size, sample_size1 * sample_size2, -1), dim=2) # bsz x ssz1*ssz2 ''' get log p(x|z)p(z)/q(z|x) ''' logprob = loglikelihood + log_pz + log_pz0 - log_qz - log_qz0 # bsz x ssz1*ssz2 logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) rprob = (logprob - logprob_max).exp() # relative prob logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 # return return logprob.mean()
def logprob_w_cov_gaussian_posterior(self, input, sample_size=128, z=None, std=None): # init batch_size = input.size(0) input = input.view(batch_size, self.input_dim) assert sample_size >= 2 * self.z_dim ''' get z and pseudo log q(newz|x) ''' z, newz = [], [] #cov_qz, rv_z = [], [] logposterior = [] inp = self.encode._forward_inp(input).detach() #for i in range(sample_size): for i in range(batch_size): _inp = inp[i:i + 1, :].expand(sample_size, inp.size(1)) _nos = self.encode._forward_nos(batch_size=sample_size, std=std, device=input.device).detach() _z = self.encode._forward_all(_inp, _nos) # ssz x zdim z += [_z.detach().unsqueeze(0)] z = torch.cat(z, dim=0) # bsz x ssz x zdim mu_qz = torch.mean(z, dim=1) # bsz x zdim for i in range(batch_size): _cov_qz = get_covmat(z[i, :, :]) _rv_z = MultivariateNormal(mu_qz[i], _cov_qz) _newz = _rv_z.rsample(torch.Size([1, sample_size])) _logposterior = _rv_z.log_prob(_newz) #cov_qz += [_cov_qz.unsqueeze(0)] #rv_z += [_rv_z] newz += [_newz] logposterior += [_logposterior] #cov_qz = torch.cat(cov_qz, dim=0) # bsz x zdim x zdim newz = torch.cat(newz, dim=0) # bsz x ssz x zdim logposterior = torch.cat(logposterior, dim=0) # bsz x ssz ''' get log p(z) ''' # get prior (as unit normal dist) mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logprior = logprob_gaussian(mu_pz, logvar_pz, newz, do_unsqueeze=False, do_mean=False) logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz ''' get log p(x|z) ''' # decode logit_x = [] #for i in range(sample_size): for i in range(batch_size): _, _logit_x = self.decode(newz[i, :, :]) # ssz x zdim logit_x += [_logit_x.detach().unsqueeze(0)] logit_x = torch.cat(logit_x, dim=0) # bsz x ssz x input_dim _input = input.unsqueeze(1).expand( batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim loglikelihood = -F.binary_cross_entropy_with_logits( logit_x, _input, reduction='none') loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz ''' get log p(x|z)p(z)/q(z|x) ''' logprob = loglikelihood + logprior - logposterior # bsz x ssz logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) rprob = (logprob - logprob_max).exp() # relative prob logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 # return return logprob.mean()
def logprob_w_kde_posterior(self, input, sample_size=128, z=None, std=None): # init batch_size = input.size(0) input = input.view(batch_size, self.input_dim) assert sample_size >= 2 * self.z_dim ''' get z and pseudo log q(newz|x) ''' z, newz = [], [] logposterior = [] inp = self.encode._forward_inp(input).detach() for i in range(batch_size): _inp = inp[i:i + 1, :].expand(sample_size, inp.size(1)) _nos = self.encode._forward_nos(sample_size, std=std, device=input.device).detach() _z = self.encode._forward_all(_inp, _nos) # ssz x zdim z += [_z.detach().unsqueeze(0)] z = torch.cat(z, dim=0) # bsz x ssz x zdim for i in range(batch_size): _z = z[i, :, :].cpu().numpy().T # zdim x ssz kernel = stats.gaussian_kde(_z) _newz = kernel.resample(sample_size) # zdim x ssz _logposterior = kernel.logpdf(_newz) # ssz _newz = torch.from_numpy(_newz.T).float().to( input.device) # ssz x zdim _logposterior = torch.from_numpy(_logposterior).float().to( input.device) # ssz newz += [_newz.unsqueeze(0)] logposterior += [_logposterior.unsqueeze(0)] newz = torch.cat(newz, dim=0) # bsz x ssz x zdim logposterior = torch.cat(logposterior, dim=0) # bsz x ssz ''' get log p(z) ''' # get prior (as unit normal dist) mu_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logvar_pz = input.new_zeros(batch_size, sample_size, self.z_dim) logprior = logprob_gaussian(mu_pz, logvar_pz, newz, do_unsqueeze=False, do_mean=False) logprior = torch.sum(logprior.view(batch_size, sample_size, self.z_dim), dim=2) # bsz x ssz ''' get log p(x|z) ''' # decode logit_x = [] #for i in range(sample_size): for i in range(batch_size): _, _logit_x = self.decode(newz[i, :, :]) # ssz x zdim logit_x += [_logit_x.detach().unsqueeze(0)] logit_x = torch.cat(logit_x, dim=0) # bsz x ssz x input_dim _input = input.unsqueeze(1).expand( batch_size, sample_size, self.input_dim) # bsz x ssz x input_dim loglikelihood = -F.binary_cross_entropy_with_logits( logit_x, _input, reduction='none') loglikelihood = torch.sum(loglikelihood, dim=2) # bsz x ssz ''' get log p(x|z)p(z)/q(z|x) ''' logprob = loglikelihood + logprior - logposterior # bsz x ssz logprob_max, _ = torch.max(logprob, dim=1, keepdim=True) rprob = (logprob - logprob_max).exp() # relative prob logprob = torch.log(torch.mean(rprob, dim=1, keepdim=True) + 1e-10) + logprob_max # bsz x 1 # return return logprob.mean()
def infogain(self, reps_context, context_sizes, reps_target, target_sizes, input_tuples, num_steps=None, beta=1.0, std=1.0): # init num_episodes = len(reps_context) loss_kl = 0 ''' forward posterior / prior ''' # init states states_p = self.rnn_p.init_state(num_episodes, [self.z_height, self.z_width]) states_q = self.rnn_q.init_state(num_episodes, [self.z_height, self.z_width]) hiddens_p = [state_p[0] for state_p in states_p] hiddens_q = [state_q[0] for state_q in states_q] latents = [] init_input_q = False init_input_p = False for i in range(num_steps if num_steps is not None else self.num_steps): # aggregate observations (posterior) if not init_input_q: reps_context = pad_sequence(reps_context, context_sizes) reps_context = torch.sum(reps_context, dim=1) reps_context = reps_context.view(-1, self.nc_context, self.z_height, self.z_width) reps_target = pad_sequence(reps_target, target_sizes) reps_target = torch.sum(reps_target, dim=1) reps_target = reps_target.view(-1, self.nc_context, self.z_height, self.z_width) input_q = torch.cat([reps_target, reps_context], dim=1) init_input_q = True # forward posterior means_q, logvars_q, hiddens_q, states_q = self.rnn_q( input_q, states_q, hiddens_p) # sample z from posterior zs = self.rnn_q.sample(means_q, logvars_q) # aggregate observations (prior) if not init_input_p: input_p = reps_context init_input_p = True # forward prior _, means_p, logvars_p, hiddens_p, states_p = self.rnn_p( input_p, states_p, latents_q=zs) # append z to latent latents += [torch.cat(zs, dim=1).unsqueeze(1) ] if len(zs) > 1 else [zs[0].unsqueeze(1)] # update accumulated KL for j in range(self.num_layers): #loss_kl += loss_kld_gaussian_vs_gaussian(means_q[j], logvars_q[j], means_p[j], logvars_p[j], do_sum=False) loss_kl += logprob_gaussian( means_q[j], #.view(num_episodes, -1), logvars_q[j], #.view(num_episodes, -1), zs[j], #.view(num_episodes, -1), do_sum=False) loss_kl += -logprob_gaussian( means_p[j], #.view(num_episodes, -1), logvars_p[j], #.view(num_episodes, -1), zs[j], #.view(num_episodes, -1), do_sum=False) ''' loss ''' # additional loss info info = {} info['kl'] = loss_kl.detach() # return #return img_mean_recon, hpt_mean_recon, None, loss, info #return mean_recons, latents, loss, info return None, latents, loss_kl.detach(), info
def predict(self, reps_context, context_sizes, reps_target, target_sizes, input_tuples, num_steps=None, beta=1.0, std=1.0, is_grayscale=False, use_uint8=True): # init num_episodes = len(reps_context) logprob_kl = 0 loss_kl = 0 ''' forward posterior / prior ''' # init states states_p = self.rnn_p.init_state(num_episodes, [self.z_height, self.z_width]) states_q = self.rnn_q.init_state(num_episodes, [self.z_height, self.z_width]) hiddens_p = [state_p[0] for state_p in states_p] hiddens_q = [state_q[0] for state_q in states_q] latents = [] init_input_q = False init_input_p = False for i in range(num_steps if num_steps is not None else self.num_steps): # aggregate observations (posterior) if not init_input_q: reps_context = pad_sequence(reps_context, context_sizes) reps_context = torch.sum(reps_context, dim=1) reps_context = reps_context.view(-1, self.nc_context, self.z_height, self.z_width) reps_target = pad_sequence(reps_target, target_sizes) reps_target = torch.sum(reps_target, dim=1) reps_target = reps_target.view(-1, self.nc_context, self.z_height, self.z_width) input_q = torch.cat([reps_target, reps_context], dim=1) init_input_q = True # forward posterior means_q, logvars_q, hiddens_q, states_q = self.rnn_q( input_q, states_q, hiddens_p) # sample z from posterior zs = self.rnn_q.sample(means_q, logvars_q) # aggregate observations (prior) if not init_input_p: input_p = reps_context init_input_p = True # forward prior _, means_p, logvars_p, hiddens_p, states_p = self.rnn_p( input_p, states_p, latents_q=zs) # append z to latent latents += [torch.cat(zs, dim=1).unsqueeze(1) ] if len(zs) > 1 else [zs[0].unsqueeze(1)] # update accumulated KL for j in range(self.num_layers): loss_kl += loss_kld_gaussian_vs_gaussian( means_q[j], logvars_q[j], means_p[j], logvars_p[j]) logprob_kl += logprob_gaussian( means_p[j], #.view(num_episodes, -1), logvars_p[j], #.view(num_episodes, -1), zs[j], #.view(num_episodes, -1), do_sum=False) logprob_kl += -logprob_gaussian( means_q[j], #.view(num_episodes, -1), logvars_q[j], #.view(num_episodes, -1), zs[j], #.view(num_episodes, -1), do_sum=False) ''' likelihood ''' info = {} info['logprob_mod_likelihoods'] = [] logprob_likelihood = 0 info['mod_likelihoods'] = [] loss_likelihood = 0 mean_recons = [] for idx, (dim, input_tuple) in enumerate(zip(self.dims, input_tuples)): channels, height, width, _, mtype = dim mod_target, mod_queries, mod_target_indices, mod_batch_sizes = input_tuple if len(mod_queries) > 0: # is not None: num_mod_data = len(mod_target) assert sum(mod_batch_sizes) == num_mod_data # run renderer (likelihood) mod_mean_recon = self._forward_renderer( idx, mod_queries, latents, num_episodes, mod_batch_sizes, mod_target_indices).detach() # convert to gray scale if mtype == 'image' and is_grayscale: mod_mean_recon = rgb2gray(mod_mean_recon) mod_target = rgb2gray(mod_target) if not use_uint8: mod_mean_recon = mod_mean_recon / 255 mod_target = mod_target / 255 elif mtype == 'image' and use_uint8: mod_mean_recon = 255 * mod_mean_recon mod_target = 255 * mod_target # estimate recon loss loss_mod_likelihood = loss_recon_gaussian_w_fixed_var( mod_mean_recon, mod_target, std=std, add_logvar=False).detach() logprob_mod_likelihood = logprob_gaussian_w_fixed_var( mod_mean_recon, #.view(num_episodes, -1), mod_target, #.view(num_episodes, -1), std=std, do_sum=False).detach() # estimate recon loss without std loss_mod_likelihood_nostd = loss_recon_gaussian_w_fixed_var( mod_mean_recon.detach(), mod_target).detach() #logprob_mod_likelihood_nostd = logprob_gaussian_w_fixed_var( # mod_mean_recon.detach(), #.view(num_episodes, -1), # mod_target, #.view(num_episodes, -1), # do_sum=False).detach() # sum per episode logprob_mod_likelihood = sum_tensor_per_episode( logprob_mod_likelihood, mod_batch_sizes, mod_target_indices, num_episodes) else: mod_mean_recon = reps_context.new_zeros( 0, channels, height, width) loss_mod_likelihood = None loss_mod_likelihood_nostd = None logprob_mod_likelihood = None # add to loss_likelihood if loss_mod_likelihood is not None: loss_likelihood += loss_mod_likelihood if logprob_mod_likelihood is not None: logprob_likelihood += logprob_mod_likelihood # append to list mean_recons += [mod_mean_recon] info['mod_likelihoods'] += [loss_mod_likelihood] info['logprob_mod_likelihoods'] += [logprob_mod_likelihood] ''' loss ''' # sum loss loss = loss_likelihood + beta * loss_kl logprob = logprob_likelihood + logprob_kl # additional loss info info['likelihood'] = loss_likelihood.detach() if type( loss_likelihood) is not int else 0 info['kl'] = loss_kl.detach() # return #return img_mean_recon, hpt_mean_recon, None, loss, info #return mean_recons, latents, loss, info return mean_recons, latents, logprob, info
def est_partition_func( self, sample_size=128, next_state_batch=None, mask_batch=None, memory=None, batch_size=None, ptflogvar=-2., ): if memory is not None: assert batch_size is not None # sample _, _, _, next_state_batch, mask_batch = memory.sample( batch_size=batch_size) next_state_batch = torch.FloatTensor(next_state_batch).to( self.device) mask_batch = torch.FloatTensor(mask_batch).to( self.device).unsqueeze(1) else: assert next_state_batch is not None assert mask_batch is not None batch_size = next_state_batch.size(0) # context _, nxt_preact_mean, nxt_hidden, _ = self.policy.evaluate( next_state_batch, eval=True) nxt_preact_mean = nxt_preact_mean.view(batch_size, 1, -1).detach() if self.dae_ctx_type == 'state': nxt_context = next_state_batch.view(batch_size, 1, -1).detach() elif self.dae_ctx_type == 'hidden': nxt_context = nxt_hidden.view(batch_size, 1, -1).detach() # sample _nxt_preact_mean = nxt_preact_mean.expand(batch_size, sample_size, self.num_actions) _nxt_preact_logvar = ptflogvar * nxt_preact_mean.new_ones( _nxt_preact_mean.size()) _newz = sample_gaussian(_nxt_preact_mean, _nxt_preact_logvar) # bsz x ssz x zdim # proposal distribution logproposal = logprob_gaussian( _nxt_preact_mean, _nxt_preact_logvar, _newz, do_unsqueeze=False, do_mean=False, ) # bsz x ssz x 1 logproposal = torch.sum(logproposal, dim=2, keepdim=True) \ - self.num_actions * math.log(self.std_scale) # bsz x ssz x 1 # unnormalized distribution newz = _newz - nxt_preact_mean scaled_newz = self.std_scale * newz stdmat = torch.zeros(batch_size, sample_size, 1, device=self.device).fill_(0) logp_ptfunc = (self.cdae.logprob( scaled_newz, nxt_context, std=stdmat, scale=self.std_scale).detach() - logproposal) logp_ptfunc_max, _ = torch.max(logp_ptfunc, dim=1, keepdim=True) rprob_ptfunc = (logp_ptfunc - logp_ptfunc_max).exp() # relative prob logp_ptfunc = torch.log( torch.mean(rprob_ptfunc, dim=1, keepdim=True) + 1e-12) + logp_ptfunc_max # bsz x 1 return logp_ptfunc.detach()