def __init__(self, input_size, output_size=1, ): super().__init__() self.sigma = nn.Parameter(ptu.ones(input_size, requires_grad=True)) self.mu = nn.Parameter(ptu.ones(input_size, requires_grad=True))
def __init__( self, hidden_sizes, output_size, input_size, temperature, vrnn_latent, vrnn_constraint, r_alpha, r_var, ): self.save_init_params(locals()) super().__init__() self.input_size = input_size self.output_size = output_size self.hidden_sizes = hidden_sizes self.temperature = temperature self.r_cont_dim, self.r_n_cat, self.r_cat_dim, self.r_n_dir, self.r_dir_dim = read_dim( vrnn_latent) self.z_size = self.r_cont_dim + self.r_n_cat * self.r_cat_dim + self.r_n_dir * self.r_dir_dim self.vrnn_constraint = vrnn_constraint self.r_alpha = r_alpha self.r_var = r_var self.hidden_dim = self.hidden_sizes[-1] self.register_buffer('hn', torch.zeros(1, self.hidden_dim)) self.register_buffer('cn', torch.zeros(1, self.hidden_dim)) # input should be (task, seq, feat) and hidden should be (1, task, feat) self.rnn = nn.GRUCell(self.hidden_dim * 2, self.hidden_dim) self.prior = Mlp(hidden_sizes=[self.hidden_dim], input_size=self.hidden_dim, output_size=self.output_size) self.phi_z = Mlp(hidden_sizes=[self.hidden_dim], input_size=self.z_size, output_size=self.hidden_dim) self.phi_x = Mlp(hidden_sizes=[self.hidden_dim], input_size=self.input_size, output_size=self.hidden_dim) self.encoder = Mlp(hidden_sizes=[self.hidden_dim], input_size=self.hidden_dim * 2, output_size=self.output_size) if self.r_cat_dim > 0: self.z_cat_prior_dist = torch.distributions.Categorical( ptu.ones(self.r_cat_dim) / self.r_cat_dim) if self.r_dir_dim > 0: if self.vrnn_constraint == 'logitnormal': self.z_dir_prior_dist = torch.distributions.Normal( ptu.zeros(self.r_dir_dim), ptu.ones(self.r_dir_dim) * np.sqrt(self.r_var)) elif self.vrnn_constraint == 'dirichlet': self.z_dir_prior_dist = torch.distributions.Dirichlet( ptu.ones(self.r_dir_dim) * self.r_alpha) if self.r_cont_dim > 0: self.z_cont_prior_dist = torch.distributions.Normal( ptu.zeros(self.r_cont_dim), ptu.ones(self.r_cont_dim))
def __init__(self, latent_dim, nets, **kwargs ): super().__init__() self.latent_dim = latent_dim self.task_enc, self.cnn_enc, self.policy, self.qf1, self.qf2, self.vf = nets self.target_vf = self.vf.copy() self.recurrent = kwargs['recurrent'] self.reparam = kwargs['reparameterize'] self.use_ib = kwargs['use_information_bottleneck'] self.tau = kwargs['soft_target_tau'] self.reward_scale = kwargs['reward_scale'] self.sparse_rewards = kwargs['sparse_rewards'] self.det_z = False self.obs_emb_dim = kwargs['obs_emb_dim'] self.q1_buff = [] self.n_updates = 0 # initialize task embedding to zero # (task, latent dim) self.register_buffer('z', torch.zeros(1, latent_dim)) # for incremental update, must keep track of number of datapoints accumulated self.register_buffer('num_z', torch.zeros(1)) # initialize posterior to the prior if self.use_ib: self.z_dists = [torch.distributions.Normal(ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim))]
def decode(self, latents): broadcast_ones = ptu.ones((latents.shape[0], latents.shape[1], self.decoder_imsize, self.decoder_imsize)) decoded = self.decoder(latents, broadcast_ones) x_hat = decoded[:, :3] m_hat_logits = decoded[:, 3] return x_hat, torch.ones_like(x_hat), m_hat_logits
def forward(self, obs, context=None,cal_rew=True): ''' given context, get statistics under the current policy of a set of observations ''' t, b, _ = obs.size() in_ = obs policy_outputs = self.policy(in_, reparameterize=True, return_log_prob=True) rew=None #in_=in_.view(t * b, -1) if cal_rew: encoder_output_next = self.context_encoder.forward_seq(context) z_mean_next = encoder_output_next[:, :, :self.latent_dim] z_var_next = F.softplus(encoder_output_next[:, :, self.latent_dim:]) var = ptu.ones(context.shape[0], 1, self.latent_dim) mean = ptu.zeros(context.shape[0], 1, self.latent_dim) z_mean = torch.cat([mean, z_mean_next], dim=1)[:, :-1, :] z_var = torch.cat([var, z_var_next], dim=1)[:, :-1, :] z_mean, z_var, z_mean_next, z_var_next = z_mean.contiguous(), z_var.contiguous(), z_mean_next.contiguous(), z_var_next.contiguous() z_mean, z_var, z_mean_next, z_var_next = z_mean.view(t * b, -1), z_var.view(t * b, -1), z_mean_next.view( t * b, -1), z_var_next.view(t * b, -1) rew = self.compute_kl_div_vime(z_mean, z_var, z_mean_next, z_var_next) rew = rew.detach() return policy_outputs, rew#, z_mean,z_var,z_mean_next,z_var_next
def compute_kl_div(self): ''' compute KL( q(z|c) || r(z) ) ''' prior = torch.distributions.Normal(ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim)) posteriors = [torch.distributions.Normal(mu, torch.sqrt(var)) for mu, var in zip(torch.unbind(self.z_means), torch.unbind(self.z_vars))] kl_divs = [torch.distributions.kl.kl_divergence(post, prior) for post in posteriors] kl_div_sum = torch.sum(torch.stack(kl_divs)) return kl_div_sum
def rsample(self): z = (self.normal_means + self.normal_stds * MultivariateDiagonalNormal( ptu.zeros(self.normal_means.size()), ptu.ones(self.normal_stds.size())).sample()) z.requires_grad_() c = self.categorical.sample()[:, :, None] s = torch.matmul(z, c) return torch.squeeze(s, 2)
def compute_density(self, data): orig_data_length = len(data) data = np.vstack([data for _ in range(self.n_average)]) data = ptu.from_numpy(data) if self.mode == 'biased': latents, means, log_vars, stds = ( self.encoder.get_encoding_and_suff_stats(data)) importance_weights = ptu.ones(data.shape[0]) elif self.mode == 'prior': latents = ptu.randn(len(data), self.z_dim) importance_weights = ptu.ones(data.shape[0]) elif self.mode == 'importance_sampling': latents, means, log_vars, stds = ( self.encoder.get_encoding_and_suff_stats(data)) prior = Normal(ptu.zeros(1), ptu.ones(1)) prior_log_prob = prior.log_prob(latents).sum(dim=1) encoder_distrib = Normal(means, stds) encoder_log_prob = encoder_distrib.log_prob(latents).sum(dim=1) importance_weights = (prior_log_prob - encoder_log_prob).exp() else: raise NotImplementedError() unweighted_data_log_prob = self.compute_log_prob( data, self.decoder, latents).squeeze(1) unweighted_data_prob = unweighted_data_log_prob.exp() unnormalized_data_prob = unweighted_data_prob * importance_weights """ Average over `n_average` """ dp_split = torch.split(unnormalized_data_prob, orig_data_length, dim=0) # pre_avg.shape = ORIG_LEN x N_AVERAGE dp_stacked = torch.stack(dp_split, dim=1) # final.shape = ORIG_LEN unnormalized_dp = torch.sum(dp_stacked, dim=1, keepdim=False) """ Compute the importance weight denomintors. This requires summing across the `n_average` dimension. """ iw_split = torch.split(importance_weights, orig_data_length, dim=0) iw_stacked = torch.stack(iw_split, dim=1) iw_denominators = iw_stacked.sum(dim=1, keepdim=False) final = unnormalized_dp / iw_denominators return ptu.get_numpy(final)
def rsample(self): z = (self.normal_means + self.normal_stds * MultivariateDiagonalNormal( ptu.zeros(self.normal_means.size()), ptu.ones(self.normal_stds.size())).sample()) z.requires_grad_() c = self.categorical.sample()[:, :, None] s = torch.gather(z, dim=2, index=c) return s[:, :, 0]
def clear_z(self, num_tasks=1): if self.use_ib: self.z_dists = [torch.distributions.Normal(ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim)) for _ in range(num_tasks)] z = [d.rsample() for d in self.z_dists] self.z = torch.stack(z) else: self.z = self.z.new_full((num_tasks, self.latent_dim), 0) self.task_enc.reset(num_tasks) # clear hidden state in recurrent case
def compute_kl_div(self): prior = torch.distributions.Normal(ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim)) kl_divs = [ torch.distributions.kl.kl_divergence(z_dist, prior) for z_dist in self.z_dists ] kl_div_sum = torch.sum(torch.stack(kl_divs)) return kl_div_sum
def rsample(self, return_pretanh_value=False): z = (self.normal_mean + self.normal_std * Variable( Normal(ptu.zeros(self.normal_mean.size()), ptu.ones(self.normal_std.size())).sample())) # z.requires_grad_() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z)
def clear_z(self, num_tasks=1): ''' reset q(z|c) to the prior sample a new z from the prior ''' # reset distribution over z to the prior mu = ptu.zeros(num_tasks, self.latent_dim) var = ptu.ones(num_tasks, self.latent_dim) self.z_means = mu self.z_vars = var
def rsample_with_pretanh(self): z = ( self.normal_mean + self.normal_std * MultivariateDiagonalNormal( ptu.zeros(self.normal_mean.size()), ptu.ones(self.normal_std.size()) ).sample() ) return torch.tanh(z), z
def compute_log_p_log_q_log_d(model, batch, decoder_distribution='bernoulli', num_latents_to_sample=1, sampling_method='importance_sampling'): x_0 = ptu.from_numpy(batch["x_0"]) data = batch["x_t"] imgs = ptu.from_numpy(data) latent_distribution_params = model.encode(imgs, x_0) r1 = model.latent_sizes[0] batch_size = data.shape[0] log_p, log_q, log_d = ptu.zeros( (batch_size, num_latents_to_sample)), ptu.zeros( (batch_size, num_latents_to_sample)), ptu.zeros( (batch_size, num_latents_to_sample)) true_prior = Normal(ptu.zeros((batch_size, r1)), ptu.ones( (batch_size, r1))) mus, logvars = latent_distribution_params[:2] for i in range(num_latents_to_sample): if sampling_method == 'importance_sampling': latents = model.rsample(latent_distribution_params[:2]) elif sampling_method == 'biased_sampling': latents = model.rsample(latent_distribution_params[:2]) elif sampling_method == 'true_prior_sampling': latents = true_prior.rsample() else: raise EnvironmentError('Invalid Sampling Method Provided') stds = logvars.exp().pow(.5) vae_dist = Normal(mus, stds) log_p_z = true_prior.log_prob(latents).sum(dim=1) log_q_z_given_x = vae_dist.log_prob(latents).sum(dim=1) if len(latent_distribution_params) == 3: # add conditioning for CVAEs latents = torch.cat((latents, latent_distribution_params[2]), dim=1) if decoder_distribution == 'bernoulli': decoded = model.decode(latents)[0] log_d_x_given_z = torch.log(imgs * decoded + (1 - imgs) * (1 - decoded) + 1e-8).sum(dim=1) elif decoder_distribution == 'gaussian_identity_variance': _, obs_distribution_params = model.decode(latents) dec_mu, dec_logvar = obs_distribution_params dec_var = dec_logvar.exp() decoder_dist = Normal(dec_mu, dec_var.pow(.5)) log_d_x_given_z = decoder_dist.log_prob(imgs).sum(dim=1) else: raise EnvironmentError('Invalid Decoder Distribution Provided') log_p[:, i] = log_p_z log_q[:, i] = log_q_z_given_x log_d[:, i] = log_d_x_given_z return log_p, log_q, log_d
def compute_log_p_log_q_log_d( model, data, decoder_distribution="bernoulli", num_latents_to_sample=1, sampling_method="importance_sampling", ): assert data.dtype == np.float64, "images should be normalized" imgs = ptu.from_numpy(data) latent_distribution_params = model.encode(imgs) batch_size = data.shape[0] representation_size = model.representation_size log_p, log_q, log_d = ( ptu.zeros((batch_size, num_latents_to_sample)), ptu.zeros((batch_size, num_latents_to_sample)), ptu.zeros((batch_size, num_latents_to_sample)), ) true_prior = Normal( ptu.zeros((batch_size, representation_size)), ptu.ones((batch_size, representation_size)), ) mus, logvars = latent_distribution_params for i in range(num_latents_to_sample): if sampling_method == "importance_sampling": latents = model.rsample(latent_distribution_params) elif sampling_method == "biased_sampling": latents = model.rsample(latent_distribution_params) elif sampling_method == "true_prior_sampling": latents = true_prior.rsample() else: raise EnvironmentError("Invalid Sampling Method Provided") stds = logvars.exp().pow(0.5) vae_dist = Normal(mus, stds) log_p_z = true_prior.log_prob(latents).sum(dim=1) log_q_z_given_x = vae_dist.log_prob(latents).sum(dim=1) if decoder_distribution == "bernoulli": decoded = model.decode(latents)[0] log_d_x_given_z = torch.log( imgs * decoded + (1 - imgs) * (1 - decoded) + 1e-8 ).sum(dim=1) elif decoder_distribution == "gaussian_identity_variance": _, obs_distribution_params = model.decode(latents) dec_mu, dec_logvar = obs_distribution_params dec_var = dec_logvar.exp() decoder_dist = Normal(dec_mu, dec_var.pow(0.5)) log_d_x_given_z = decoder_dist.log_prob(imgs).sum(dim=1) else: raise EnvironmentError("Invalid Decoder Distribution Provided") log_p[:, i] = log_p_z log_q[:, i] = log_q_z_given_x log_d[:, i] = log_d_x_given_z return log_p, log_q, log_d
def forward(self, obs, context): ''' given context, get statistics under the current policy of a set of observations ''' t, b, _ = obs.size() encoder_output_next = self.context_encoder.forward_seq(context) z_mean_next = encoder_output_next[:, :, :self.latent_dim] z_var_next = F.softplus(encoder_output_next[:, :, self.latent_dim:]) var = ptu.ones(context.shape[0], 1, self.latent_dim) mean = ptu.ones(context.shape[0], 1, self.latent_dim) z_mean = torch.cat([mean, z_mean_next], dim=1)[:, :-1, :] z_var = torch.cat([var, z_var_next], dim=1)[:, :-1, :] in_ = torch.cat([obs, z_mean.detach(), z_var.detach()], dim=2) in_ = in_.view(t * b, -1) policy_outputs = self.policy(in_, reparameterize=True, return_log_prob=True) rew = torch.mean(torch.log(z_var), dim=2) - torch.mean( torch.log(z_var_next), dim=2) rew = rew.detach() return policy_outputs, z_mean, z_var, z_mean_next, z_var_next, rew
def rsample(self, return_pretanh_value=False): """ Sampling in the reparameterization case. """ z = (self.normal_mean + self.normal_std * Normal(ptu.zeros(self.normal_mean.size()), ptu.ones(self.normal_std.size())).sample()) z.requires_grad_() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z)
def clear_sequence_z(self, num_tasks=1, batch_size=1, traj_batch_size=1): assert self.recurrent_context_encoder != None if self.r_cat_dim > 0: self.seq_z_cat = ptu.ones(num_tasks * batch_size * self.r_n_cat, self.r_cat_dim) / self.r_cat_dim self.seq_z_next_cat = None if self.r_cont_dim > 0: self.seq_z_cont_mean = ptu.zeros(num_tasks * batch_size, self.r_cont_dim) self.seq_z_cont_var = ptu.ones(num_tasks * batch_size, self.r_cont_dim) self.seq_z_next_cont_mean = None self.seq_z_next_cont_var = None if self.r_dir_dim > 0: if self.r_constraint == 'logitnormal': self.seq_z_dir_mean = ptu.zeros(num_tasks * batch_size * self.r_n_dir, self.r_dir_dim) self.seq_z_dir_var = ptu.ones(num_tasks * batch_size * self.r_n_dir, self.r_dir_dim) * self.r_var self.seq_z_next_dir_mean = None self.seq_z_next_dir_var = None elif self.r_constraint == 'dirichlet': self.seq_z_dir = ptu.ones(num_tasks * batch_size * self.r_n_dir, self.r_dir_dim) * self.r_alpha self.seq_z_next_dir = None self.sample_sequence_z() self.recurrent_context_encoder.reset(num_tasks*traj_batch_size)
def clear_z(self, num_tasks=1, batch_size=1, traj_batch_size=1): ''' reset q(z|c) to the prior sample a new z from the prior ''' if self.glob: if self.g_cat_dim > 0: self.z_means = ptu.ones(num_tasks * self.g_n_cat, self.g_cat_dim)/self.g_cat_dim if self.g_cont_dim > 0: self.z_c_means = ptu.zeros(num_tasks, self.g_cont_dim) self.z_c_vars = ptu.ones(num_tasks, self.g_cont_dim) if self.g_dir_dim > 0: if self.g_constraint == 'logitnormal': self.z_d_means = ptu.zeros(num_tasks * self.g_n_dir, self.g_dir_dim) self.z_d_vars = ptu.ones(num_tasks * self.g_n_dir, self.g_dir_dim)*self.g_var else: self.z_d_means = ptu.ones(num_tasks * self.g_n_dir, self.g_dir_dim)*self.g_alpha self.sample_z() if self.recurrent: if self.r_cat_dim > 0: self.seq_z_cat = ptu.ones(num_tasks * batch_size * self.r_n_cat, self.r_cat_dim) / self.r_cat_dim self.seq_z_next_cat = None if self.r_cont_dim > 0: self.seq_z_cont_mean = ptu.zeros(num_tasks * batch_size, self.r_cont_dim) self.seq_z_cont_var = ptu.ones(num_tasks * batch_size, self.r_cont_dim) self.seq_z_next_cont_mean = None self.seq_z_next_cont_var = None if self.r_dir_dim > 0: if self.r_constraint == 'logitnormal': self.seq_z_dir_mean = ptu.zeros(num_tasks * batch_size * self.r_n_dir, self.r_dir_dim) self.seq_z_dir_var = ptu.ones(num_tasks * batch_size * self.r_n_dir, self.r_dir_dim) * self.r_var self.seq_z_next_dir_mean = None self.seq_z_next_dir_var = None elif self.r_constraint == 'dirichlet': self.seq_z_dir = ptu.ones(num_tasks * batch_size * self.r_n_dir, self.r_dir_dim) * self.r_alpha self.seq_z_next_dir = None self.sample_sequence_z() # reset the context collected so far self.context = None # reset any hidden state in the encoder network (relevant for RNN) if self.global_context_encoder != None: self.global_context_encoder.reset(num_tasks) if self.recurrent_context_encoder != None: self.recurrent_context_encoder.reset(num_tasks*traj_batch_size)
def forward(self, *inputs): x = super().forward(*inputs) mean = self.mean_layer(x) logstd = self.logstd_layer(x) logstd = torch.clamp(logstd, LOGMIN, LOGMAX) std = torch.exp(logstd) unit_normal = Normal(ptu.zeros(mean.size()), ptu.ones(std.size())) eps = unit_normal.sample() pre_tanh_z = mean.cpu() + std.cpu() * eps action = torch.tanh(pre_tanh_z) logp = unit_normal.log_prob(eps) # logp = logp.sum(dim=1, keepdim=True) # logsum = exp mult return action, pre_tanh_z, logp, mean, logstd
def __init__(self, latent_dim, context_encoder, policy, reward_predictor, use_next_obs_in_context=False, _debug_ignore_context=False, _debug_do_not_sqrt=False, _debug_use_ground_truth_context=False): super().__init__() self.latent_dim = latent_dim self.context_encoder = context_encoder self.policy = policy self.reward_predictor = reward_predictor self.deterministic_policy = MakeDeterministic(self.policy) self._debug_ignore_context = _debug_ignore_context self._debug_use_ground_truth_context = _debug_use_ground_truth_context # self.recurrent = kwargs['recurrent'] # self.use_ib = kwargs['use_information_bottleneck'] # self.sparse_rewards = kwargs['sparse_rewards'] self.use_next_obs_in_context = use_next_obs_in_context # initialize buffers for z dist and z # use buffers so latent context can be saved along with model weights self.register_buffer('z', torch.zeros(1, latent_dim)) self.register_buffer('z_means', torch.zeros(1, latent_dim)) self.register_buffer('z_vars', torch.zeros(1, latent_dim)) self.z_means = None self.z_vars = None self.context = None self.z = None # rp = reward predictor # TODO: add back in reward predictor code self.z_means_rp = None self.z_vars_rp = None self.z_rp = None self.context_encoder_rp = context_encoder self._use_context_encoder_snapshot_for_reward_pred = False self.latent_prior = torch.distributions.Normal( ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim)) self._debug_do_not_sqrt = _debug_do_not_sqrt
def clear_z(self, num_tasks=1): ''' reset q(z|c) to the prior sample a new z from the prior ''' # reset distribution over z to the prior mu = ptu.zeros(num_tasks, self.latent_dim) if self.use_ib: var = ptu.ones(num_tasks, self.latent_dim) else: var = ptu.zeros(num_tasks, self.latent_dim) self.z_means = mu self.z_vars = var # sample a new z from the prior self.sample_z() # reset the context collected so far self.context = None
def compute_loss(self, batch, epoch=-1, test=False): prefix = "test/" if test else "train/" real_data = batch['x_t'].reshape(-1, self.input_channels, self.imsize, self.imsize) cond = batch['env'].reshape(-1, self.input_channels, self.imsize, self.imsize) batch_size = real_data.size(0) real_label = ptu.ones(batch_size) fake_label = ptu.zeros(batch_size) real_latent, _, _, _ = self.model.netE(real_data, cond) fake_latent = self.fixed_noise(batch_size, real_latent) fake_data = self.model.netG(fake_latent) real_noise = 0 fake_noise = 0 cond_noise = 0 if self.discriminator_noise: real_noise = self.noise(real_data.size(), self.num_epochs, epoch) fake_noise = self.noise(real_data.size(), self.num_epochs, epoch) cond_noise = self.noise(real_data.size(), self.num_epochs, epoch) real_pred, _ = self.model.netD(real_data + real_noise, cond + cond_noise, real_latent) fake_pred, _ = self.model.netD(fake_data + fake_noise, cond + cond_noise, fake_latent) errD = self.criterion(real_pred, real_label) + self.criterion( fake_pred, fake_label) errG = self.criterion(fake_pred, real_label) + self.criterion( real_pred, fake_label) recon = self.model.netG(real_latent) recon_error = F.mse_loss(recon, real_data) self.eval_statistics['epoch'] = epoch self.eval_statistics[prefix + "errD"].append(errD.item()) self.eval_statistics[prefix + "errG"].append(errG.item()) self.eval_statistics[prefix + "Recon Error"].append(recon_error.item()) self.eval_data[prefix + "last_batch"] = (batch, recon.reshape(batch_size, -1)) return errD, errG
def clear_z(self, num_tasks=1): ''' reset q(z|c) to the prior sample a new z from the prior ''' # reset distribution over z to the prior mu = ptu.zeros(num_tasks, self.latent_dim) if self.use_ib: var = ptu.ones(num_tasks, self.latent_dim) else: var = ptu.zeros(num_tasks, self.latent_dim) self.z_means = mu self.z_vars = var # sample a new z from the prior # reset the context collected so far self.context = None # reset any hidden state in the encoder network (relevant for RNN) self.context_encoder.reset(num_tasks)
def compute_kl_div(self): ''' compute KL( q(z|c) || r(z) ) ''' # prior <- N(0, 1) prior = torch.distributions.Normal(ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim)) # posteriors <- [N(z_mean1, sqrt(z_var1)), N(z_mean2, sqrt(z_var2)), ...] posteriors = [ torch.distributions.Normal(mu, torch.sqrt(var)) for mu, var in zip( torch.unbind(self.z_means), torch.unbind(self.z_vars)) ] # KL(N(zmu, zsigma), N(0,1)) kl_divs = [ torch.distributions.kl.kl_divergence(post, prior) for post in posteriors ] # sum(KLs) kl_div_sum = torch.sum(torch.stack(kl_divs)) # return the sums return kl_div_sum
def compute_loss(self, batch, epoch=-1, test=False): prefix = "test/" if test else "train/" real_data = batch[self.key_to_reconstruct].reshape( -1, self.input_channels, self.imsize, self.imsize) batch_size = real_data.size(0) fake_latent = self.fixed_noise(batch_size) noise1 = self.noise(real_data.size(), self.num_epochs, epoch) noise2 = self.noise(real_data.size(), self.num_epochs, epoch) real_label = ptu.ones(batch_size) fake_label = ptu.zeros(batch_size) fake_data = self.model.netG(fake_latent) real_latent, _, _, _ = self.model.netE(real_data) real_latent = real_latent.view(batch_size, self.representation_size, 1, 1) real_pred, _ = self.model.netD(real_data + noise1, real_latent) fake_pred, _ = self.model.netD(fake_data + noise2, fake_latent) errD = self.criterion(real_pred, real_label) + self.criterion( fake_pred, fake_label) errG = self.criterion(fake_pred, real_label) + self.criterion( real_pred, fake_label) recon = self.model.netG(real_latent) recon_error = F.mse_loss(recon, real_data) self.eval_statistics['epoch'] = epoch self.eval_statistics[prefix + "errD"].append(errD.item()) self.eval_statistics[prefix + "errG"].append(errG.item()) self.eval_statistics[prefix + "Recon Error"].append(recon_error.item()) self.eval_data[prefix + "last_batch"] = (real_data.reshape( batch_size, -1), recon.reshape(batch_size, -1)) return errD, errG
def __init__( self, representation_size, architecture, refinement_net, physics_net=None, decoder_class=DCNN, decoder_output_activation=identity, decoder_distribution='bernoulli', K=3, T=5, input_channels=1, imsize=48, init_w=1e-3, min_variance=1e-3, hidden_init=ptu.fanin_init, beta=5, dynamic=False, dataparallel=False, sigma=0.1, ): """ :param representation_size: :param conv_args: must be a dictionary specifying the following: kernel_sizes n_channels strides :param conv_kwargs: a dictionary specifying the following: hidden_sizes batch_norm :param deconv_args: must be a dictionary specifying the following: hidden_sizes deconv_input_width deconv_input_height deconv_input_channels deconv_output_kernel_size deconv_output_strides deconv_output_channels kernel_sizes n_channels strides :param deconv_kwargs: batch_norm :param encoder_class: :param decoder_class: :param decoder_output_activation: :param decoder_distribution: :param input_channels: :param imsize: :param init_w: :param min_variance: :param hidden_init: """ super().__init__(representation_size) if min_variance is None: self.log_min_variance = None else: self.log_min_variance = float(np.log(min_variance)) self.K = K self.T = T self.input_channels = input_channels self.imsize = imsize self.imlength = self.imsize * self.imsize * self.input_channels self.refinement_net = refinement_net self.beta = beta self.dynamic = dynamic self.physics_net = physics_net self.lstm_size = 256 deconv_args, deconv_kwargs = architecture['deconv_args'], architecture[ 'deconv_kwargs'] self.decoder_imsize = deconv_args['input_width'] self.decoder = decoder_class(**deconv_args, output_size=self.imlength, init_w=init_w, hidden_init=hidden_init, hidden_activation=nn.ELU(), **deconv_kwargs) self.action_encoder = Mlp((128, ), 128, 13, hidden_activation=nn.ELU()) self.action_lambda_encoder = Mlp((256, 256), representation_size, representation_size + 128, hidden_activation=nn.ELU()) l_norm_sizes = [7, 1, 1] self.layer_norms = nn.ModuleList( [LayerNorm2D(l) for l in l_norm_sizes]) if dataparallel: self.decoder = nn.DataParallel(self.decoder) #self.physics_net = nn.DataParallel(self.physics_net) self.refinement_net = nn.DataParallel(self.refinement_net) self.epoch = 0 self.decoder_distribution = decoder_distribution self.apply(ptu.init_weights) self.lambdas = nn.ParameterList([ Parameter(ptu.zeros((1, self.representation_size))), Parameter(ptu.ones((1, self.representation_size)) * 0.6) ]) #+ torch.exp(ptu.ones((1, self.representation_size)))))] self.sigma = from_numpy(np.array([sigma]))
def __init__(self, global_context_encoder, recurrent_context_encoder, global_latent, vrnn_latent, policy, temperature, unitkl, alpha, g_constraint, r_constraint, var, r_alpha, r_var, rnn, temp_res, rnn_sample, weighted_sample, **kwargs ): super().__init__() self.g_cont_dim, self.g_n_cat, self.g_cat_dim, self.g_n_dir, self.g_dir_dim = read_dim(global_latent) if recurrent_context_encoder != None: self.r_cont_dim, self.r_n_cat, self.r_cat_dim, self.r_n_dir, self.r_dir_dim = read_dim(vrnn_latent) self.global_context_encoder = global_context_encoder self.recurrent_context_encoder = recurrent_context_encoder self.policy = policy self.temperature = temperature self.unitkl = unitkl self.g_constraint = g_constraint # global dirichlet type self.r_constraint = r_constraint # local dirichlet type self.g_alpha = alpha self.g_var = var self.r_alpha = r_alpha self.r_var = r_var self.rnn = rnn self.weighted_sample = weighted_sample self.temp_res = temp_res self.rnn_sample = rnn_sample self.n_global, self.n_local, self.n_infer = 0, 0, 0 self.recurrent = kwargs['recurrent'] self.glob = kwargs['glob'] self.use_ib = kwargs['use_information_bottleneck'] self.sparse_rewards = kwargs['sparse_rewards'] self.use_next_obs = kwargs['use_next_obs'] # initialize buffers for z dist and z # use buffers so latent context can be saved along with model weights if self.glob: self.register_buffer('z', torch.zeros(1, self.g_cont_dim + self.g_cat_dim * self.g_n_cat + self.g_dir_dim * self.g_n_dir)) if self.g_cat_dim > 0: self.register_buffer('z_means', torch.zeros(1, self.g_cat_dim)) if self.g_cont_dim > 0: self.register_buffer('z_c_means', torch.zeros(1, self.g_cont_dim)) self.register_buffer('z_c_vars', torch.ones(1, self.g_cont_dim)) if self.g_dir_dim > 0: if self.g_constraint == 'logitnormal': self.register_buffer('z_d_means', torch.zeros(1, self.g_dir_dim)) self.register_buffer('z_d_vars', torch.ones(1, self.g_dir_dim)) elif self.g_constraint == 'dirichlet': self.register_buffer('z_d_means', torch.zeros(1, self.g_dir_dim)) if self.recurrent: self.register_buffer('seq_z', torch.zeros(1, self.r_cont_dim + self.r_cat_dim * self.r_n_cat + self.r_dir_dim * self.r_n_dir)) z_cat_prior, z_cont_prior, z_dir_prior = ptu.FloatTensor(), ptu.FloatTensor(), ptu.FloatTensor() if self.r_cat_dim > 0: self.register_buffer('seq_z_cat', torch.zeros(1, self.r_cat_dim)) self.seq_z_next_cat = None z_cat_prior = ptu.ones(self.r_cat_dim * self.r_n_cat) / self.r_cat_dim if self.r_dir_dim > 0: if self.r_constraint == 'logitnormal': self.register_buffer('seq_z_dir_mean', torch.zeros(1, self.r_dir_dim)) self.register_buffer('seq_z_dir_var', torch.ones(1, self.r_dir_dim)) self.seq_z_next_dir_mean = None self.seq_z_next_dir_var = None z_dir_prior_mean = ptu.zeros(self.r_n_dir * self.r_dir_dim) z_dir_prior_var = ptu.ones(self.r_n_dir * self.r_dir_dim) * self.r_var z_dir_prior = torch.cat([z_dir_prior_mean, z_dir_prior_var]) elif self.r_constraint == 'dirichlet': self.register_buffer('seq_z_dir', torch.zeros(1, self.r_dir_dim)) self.seq_z_next_dir = None z_dir_prior = ptu.ones(self.r_n_dir * self.r_dir_dim) * self.r_alpha if self.r_cont_dim > 0: self.register_buffer('seq_z_cont_mean', torch.zeros(1, self.r_cont_dim)) self.register_buffer('seq_z_cont_var', torch.zeros(1, self.r_cont_dim)) self.seq_z_next_cont_mean = None self.seq_z_next_cont_var = None z_cont_prior = torch.cat([ptu.zeros(self.r_cont_dim), ptu.ones(self.r_cont_dim)]) self.seq_z_prior = torch.cat([z_cat_prior, z_cont_prior, z_dir_prior]) self.clear_z()
def compute_kl_div(self): ''' compute KL( q(z|c) || r(z) ) ''' kl_div_cont, kl_div_disc, kl_div_dir = ptu.FloatTensor([0.]).mean(), ptu.FloatTensor([0.]).mean(), ptu.FloatTensor([0.]).mean() kl_div_seq_cont, kl_div_seq_disc, kl_div_seq_dir = ptu.FloatTensor([0.]).mean(), ptu.FloatTensor([0.]).mean(), ptu.FloatTensor([0.]).mean() if self.glob: if self.g_cat_dim > 0: if self.unitkl: kl_div_disc = torch.sum(self.z_means_all*torch.log((self.z_means_all+eps)*self.g_cat_dim)) else: kl_div_disc = torch.sum(self.z_means*torch.log((self.z_means+eps)*self.g_cat_dim)) if self.g_dir_dim > 0: if self.g_constraint == 'dirichlet': prior = torch.distributions.Dirichlet(ptu.ones(self.g_dir_dim)*self.g_alpha) if self.unitkl: posteriors = torch.distributions.Dirichlet(self.z_d_means_all) else: posteriors = torch.distributions.Dirichlet(self.z_d_means) kl_div_dir = torch.sum(torch.distributions.kl.kl_divergence(posteriors, prior)) elif self.g_constraint == 'logitnormal': prior = torch.distributions.Normal(ptu.zeros(self.g_dir_dim), ptu.ones(self.g_dir_dim)*np.sqrt(self.g_var)) if self.unitkl: posteriors = torch.distributions.Normal(self.z_d_means_all, torch.sqrt(self.z_d_vars_all)) else: posteriors = torch.distributions.Normal(self.z_d_means, torch.sqrt(self.z_d_vars)) kl_div_dir = torch.sum(torch.distributions.kl.kl_divergence(posteriors, prior)) if self.g_cont_dim > 0: if self.unitkl: kl_div_cont = torch.sum(0.5*(-torch.log(self.z_c_vars_all)+self.z_c_vars_all+self.z_c_means_all*self.z_c_means_all-1)) else: kl_div_cont = torch.sum(0.5*(-torch.log(self.z_c_vars)+self.z_c_vars+self.z_c_means*self.z_c_means-1)) if self.recurrent: if self.rnn == 'rnn': if self.r_cat_dim > 0: assert type(self.seq_z_next_cat) != type(None) kl_div_seq_disc = torch.sum(self.seq_z_cat * torch.log((self.seq_z_cat + eps) * self.r_cat_dim)) \ + torch.sum(self.seq_z_next_cat * torch.log((self.seq_z_next_cat + eps) * self.r_cat_dim)) if self.r_dir_dim > 0: if self.r_constraint == 'dirichlet': assert type(self.seq_z_next_dir) != type(None) prior = torch.distributions.Dirichlet(ptu.ones(self.r_dir_dim) * self.r_alpha) posteriors = torch.distributions.Dirichlet(self.seq_z_dir) posteriors_next = torch.distributions.Dirichlet(self.seq_z_next_dir) kl_div_seq_dir = torch.sum(torch.distributions.kl.kl_divergence(posteriors, prior)) \ + torch.sum(torch.distributions.kl.kl_divergence(posteriors_next, prior)) elif self.r_constraint == 'logitnormal': assert type(self.seq_z_next_dir_mean) != type(None) prior = torch.distributions.Normal(ptu.zeros(self.r_dir_dim), ptu.ones(self.r_dir_dim)*np.sqrt(self.r_var)) posteriors = torch.distributions.Normal(self.seq_z_dir_mean, torch.sqrt(self.seq_z_dir_var)) posteriors_next = torch.distributions.Normal(self.seq_z_next_dir_mean, torch.sqrt(self.seq_z_next_dir_var)) kl_div_seq_dir = torch.sum(torch.distributions.kl.kl_divergence(posteriors, prior)) \ + torch.sum(torch.distributions.kl.kl_divergence(posteriors_next, prior)) if self.r_cont_dim > 0: kl_div_seq_cont = torch.sum(0.5*(-torch.log(self.seq_z_cont_var)+self.seq_z_cont_var+self.seq_z_cont_mean*self.seq_z_cont_mean-1)) \ + torch.sum(0.5*(-torch.log(self.seq_z_next_cont_var)+self.seq_z_next_cont_var+self.seq_z_next_cont_mean*self.seq_z_next_cont_mean-1)) elif self.rnn == 'vrnn': kl_div_seq_disc, kl_div_seq_cont, kl_div_seq_dir = self.recurrent_context_encoder.compute_kl_div() return kl_div_disc, kl_div_cont, kl_div_dir, kl_div_seq_disc, kl_div_seq_cont, kl_div_seq_dir