def __init__(self, latent_dim, nets, **kwargs ): super().__init__() self.latent_dim = latent_dim self.task_enc, self.cnn_enc, self.policy, self.qf1, self.qf2, self.vf = nets self.target_vf = self.vf.copy() self.recurrent = kwargs['recurrent'] self.reparam = kwargs['reparameterize'] self.use_ib = kwargs['use_information_bottleneck'] self.tau = kwargs['soft_target_tau'] self.reward_scale = kwargs['reward_scale'] self.sparse_rewards = kwargs['sparse_rewards'] self.det_z = False self.obs_emb_dim = kwargs['obs_emb_dim'] self.q1_buff = [] self.n_updates = 0 # initialize task embedding to zero # (task, latent dim) self.register_buffer('z', torch.zeros(1, latent_dim)) # for incremental update, must keep track of number of datapoints accumulated self.register_buffer('num_z', torch.zeros(1)) # initialize posterior to the prior if self.use_ib: self.z_dists = [torch.distributions.Normal(ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim))]
def dump_mixed_latents(self, epoch): n = 8 batch, reconstructions = self.eval_data["test/last_batch"] x_t, env = batch["x_t"][:n], batch["env"][:n] z_comb = self.model.encode(x_t, env) z_pos = z_comb[:, :self.model.latent_sizes[0]] z_obj = z_comb[:, self.model.latent_sizes[0]:] grid = [] for i in range(n): for j in range(n): if i + j == 0: grid.append( ptu.zeros(1, self.input_channels, self.imsize, self.imsize)) elif i == 0: #grid.append(self.model.decode(torch.cat([z_pos[j], z_obj[i]], dim=1))) grid.append(x_t[j].reshape(1, self.input_channels, self.imsize, self.imsize)) elif j == 0: #grid.append(self.model.decode(torch.cat([z_pos[j], z_obj[i]], dim=1))) grid.append(env[i].reshape(1, self.input_channels, self.imsize, self.imsize)) else: z, z_c = z_pos[j].reshape(1, -1), z_obj[i].reshape(1, -1) grid.append(self.model.decode(torch.cat([z, z_c], dim=1))) samples = torch.cat(grid) save_dir = osp.join(self.log_dir, 'mixed_latents_%d.png' % epoch) save_image(samples.data.cpu().transpose(2, 3), save_dir, nrow=n)
def forward(self, obs, context=None, cal_rew=True): ''' given context, get statistics under the current policy of a set of observations ''' t, b, _ = obs.size() in_ = obs policy_outputs = self.policy(in_, reparameterize=True, return_log_prob=True) rew = None #in_=in_.view(t * b, -1) if cal_rew: encoder_output_next = self.context_encoder.forward_seq(context) z_mean_next = encoder_output_next[:, :, :self.latent_dim] z_var_next = F.softplus(encoder_output_next[:, :, self.latent_dim:]) var = ptu.ones(context.shape[0], 1, self.latent_dim) mean = ptu.zeros(context.shape[0], 1, self.latent_dim) z_mean = torch.cat([mean, z_mean_next], dim=1)[:, :-1, :] z_var = torch.cat([var, z_var_next], dim=1)[:, :-1, :] z_mean, z_var, z_mean_next, z_var_next = z_mean.contiguous( ), z_var.contiguous(), z_mean_next.contiguous( ), z_var_next.contiguous() z_mean, z_var, z_mean_next, z_var_next = z_mean.view( t * b, -1), z_var.view(t * b, -1), z_mean_next.view( t * b, -1), z_var_next.view(t * b, -1) rew = self.compute_kl_div_vime(z_mean, z_var, z_mean_next, z_var_next) rew = rew.detach() return policy_outputs, rew #, z_mean,z_var,z_mean_next,z_var_next
def from_vae_latents_to_lstm_latents(self, latents, lstm_hidden=None): batch_size, feature_size = latents.shape # print(latents.shape) lstm_input = latents lstm_input = lstm_input.view((1, batch_size, -1)) if lstm_hidden is None: lstm_hidden = (ptu.zeros(self.lstm_num_layers, batch_size, self.lstm_hidden_size), \ ptu.zeros(self.lstm_num_layers, batch_size, self.lstm_hidden_size)) h, hidden = self.lstm( lstm_input, lstm_hidden) # [seq_len, batch_size, lstm_hidden_size] lstm_latent = self.lstm_fc(h) lstm_latent = lstm_latent.view((batch_size, -1)) return lstm_latent
def compute_kl_div(self): ''' compute KL( q(z|c) || r(z) ) ''' prior = torch.distributions.Normal(ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim)) posteriors = [torch.distributions.Normal(mu, torch.sqrt(var)) for mu, var in zip(torch.unbind(self.z_means), torch.unbind(self.z_vars))] kl_divs = [torch.distributions.kl.kl_divergence(post, prior) for post in posteriors] kl_div_sum = torch.sum(torch.stack(kl_divs)) return kl_div_sum
def get_action( self, observation, use_raps_obs=False, use_true_actions=True, use_obs=True, ): """ :param observation: :return: action, debug_dictionary """ observation = ptu.from_numpy(np.array(observation)) if self.state: prev_state, action = self.state else: prev_state = self.world_model.initial(observation.shape[0]) action = ptu.zeros((observation.shape[0], self.action_dim)) embed = self.world_model.encode(observation) new_state, _ = self.world_model.obs_step(prev_state, action, embed) feat = self.world_model.get_features(new_state) dist = self.actor(feat) action = dist.mode() if self.exploration: action = self.actor.compute_exploration_action(action, self.expl_amount) self.state = (new_state, action) return ptu.get_numpy(action), {"state": new_state}
def compute_log_p_log_q_log_d( model, data, decoder_distribution='bernoulli', num_latents_to_sample=1, sampling_method='importance_sampling' ): assert data.dtype == np.float64, 'images should be normalized' imgs = ptu.from_numpy(data) latent_distribution_params = model.encode(imgs) batch_size = data.shape[0] representation_size = model.representation_size log_p, log_q, log_d = ptu.zeros((batch_size, num_latents_to_sample)), ptu.zeros( (batch_size, num_latents_to_sample)), ptu.zeros((batch_size, num_latents_to_sample)) true_prior = Normal(ptu.zeros((batch_size, representation_size)), ptu.ones((batch_size, representation_size))) mus, logvars = latent_distribution_params for i in range(num_latents_to_sample): if sampling_method == 'importance_sampling': latents = model.rsample(latent_distribution_params) elif sampling_method == 'biased_sampling': latents = model.rsample(latent_distribution_params) elif sampling_method == 'true_prior_sampling': latents = true_prior.rsample() else: raise EnvironmentError('Invalid Sampling Method Provided') stds = logvars.exp().pow(.5) vae_dist = Normal(mus, stds) log_p_z = true_prior.log_prob(latents).sum(dim=1) log_q_z_given_x = vae_dist.log_prob(latents).sum(dim=1) if decoder_distribution == 'bernoulli': decoded = model.decode(latents)[0] log_d_x_given_z = torch.log(imgs * decoded + (1 - imgs) * (1 - decoded) + 1e-8).sum(dim=1) elif decoder_distribution == 'gaussian_identity_variance': _, obs_distribution_params = model.decode(latents) dec_mu, dec_logvar = obs_distribution_params dec_var = dec_logvar.exp() decoder_dist = Normal(dec_mu, dec_var.pow(.5)) log_d_x_given_z = decoder_dist.log_prob(imgs).sum(dim=1) else: raise EnvironmentError('Invalid Decoder Distribution Provided') log_p[:, i] = log_p_z log_q[:, i] = log_q_z_given_x log_d[:, i] = log_d_x_given_z return log_p, log_q, log_d
def initial(self, batch_size): """ :param batch_size: int :return state: Dict mean: (batch_size, stoch_size) std: (batch_size, stoch_size) deter: (batch_size, deter_size) stoch: (batch_size, stoch_size) """ state = dict( mean=ptu.zeros([batch_size, self.stochastic_state_size]), std=ptu.zeros([batch_size, self.stochastic_state_size]), stoch=ptu.zeros([batch_size, self.stochastic_state_size]), deter=ptu.zeros([batch_size, self.deterministic_state_size]), ) return state
def clear_z(self, num_tasks=1): if self.use_ib: self.z_dists = [torch.distributions.Normal(ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim)) for _ in range(num_tasks)] z = [d.rsample() for d in self.z_dists] self.z = torch.stack(z) else: self.z = self.z.new_full((num_tasks, self.latent_dim), 0) self.task_enc.reset(num_tasks) # clear hidden state in recurrent case
def clear_z(self, num_tasks=1): ''' reset q(z|c) to the prior sample a new z from the prior ''' self.z = ptu.zeros(num_tasks, self.pie_hidden_dim) self.context = None self.pie_snail.reset(num_tasks)
def clear_z(self, num_tasks=1): ''' reset q(z|c) to the prior sample a new z from the prior ''' # reset distribution over z to the prior mu = ptu.zeros(num_tasks, self.latent_dim) if self.use_ib: var = ptu.ones(num_tasks, self.latent_dim) else: var = ptu.zeros(num_tasks, self.latent_dim) self.z_means = mu self.z_vars = var # sample a new z from the prior self.sample_z() # reset the context collected so far self.context = None
def rsample(self): z = (self.normal_means + self.normal_stds * MultivariateDiagonalNormal( ptu.zeros(self.normal_means.size()), ptu.ones(self.normal_stds.size())).sample()) z.requires_grad_() c = self.categorical.sample()[:, :, None] s = torch.gather(z, dim=2, index=c) return s[:, :, 0]
def rsample(self): z = (self.normal_means + self.normal_stds * MultivariateDiagonalNormal( ptu.zeros(self.normal_means.size()), ptu.ones(self.normal_stds.size())).sample()) z.requires_grad_() c = self.categorical.sample()[:, :, None] s = torch.matmul(z, c) return torch.squeeze(s, 2)
def __init__( self, env, policy, qf1, qf2, target_qf1, target_qf2, discount=0.99, reward_scale=1.0, policy_lr=1e-3, qf_lr=1e-3, optimizer_class=optim.Adam, soft_target_tau=1e-2, target_update_period=1, plotter=None, render_eval_paths=False, use_automatic_entropy_tuning=True, target_entropy=None, ): super().__init__() self.env = env self.policy = policy self.qf1 = qf1 self.qf2 = qf2 self.target_qf1 = target_qf1 self.target_qf2 = target_qf2 self.soft_target_tau = soft_target_tau self.target_update_period = target_update_period self.use_automatic_entropy_tuning = use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: self.target_entropy = -np.prod( self.env.action_space.shape).item( ) # heuristic value from Tuomas self.log_alpha = ptu.zeros(1, requires_grad=True) self.alpha_optimizer = optimizer_class([self.log_alpha], lr=policy_lr) self.plotter = plotter self.render_eval_paths = render_eval_paths self.qf_criterion = nn.MSELoss() self.vf_criterion = nn.MSELoss() self.policy_optimizer = optimizer_class(self.policy.parameters(), lr=policy_lr) self.qf1_optimizer = optimizer_class(self.qf1.parameters(), lr=qf_lr) self.qf2_optimizer = optimizer_class(self.qf2.parameters(), lr=qf_lr) self.discount = discount self.reward_scale = reward_scale self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True
def compute_kl_div(self): prior = torch.distributions.Normal(ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim)) kl_divs = [ torch.distributions.kl.kl_divergence(z_dist, prior) for z_dist in self.z_dists ] kl_div_sum = torch.sum(torch.stack(kl_divs)) return kl_div_sum
def rsample(self, return_pretanh_value=False): z = (self.normal_mean + self.normal_std * Variable( Normal(ptu.zeros(self.normal_mean.size()), ptu.ones(self.normal_std.size())).sample())) # z.requires_grad_() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z)
def get_full_edges(self, num_node): edges = ptu.zeros(2, num_node * num_node, dtype=int) edges[0, :] = torch.arange(num_node).repeat(num_node, 1).transpose( 0, 1).reshape(num_node * num_node) edges[1, :] = torch.arange(num_node).repeat(1, num_node).reshape( num_node * num_node) if not self.contain_self_loop: edges = pyg_utils.remove_self_loops(edges)[0] return edges
def clear_z(self, num_tasks=1): ''' reset q(z|c) to the prior sample a new z from the prior ''' # reset distribution over z to the prior mu = ptu.zeros(num_tasks, self.latent_dim) if self.use_ib: var = ptu.ones(num_tasks, self.latent_dim) else: var = ptu.zeros(num_tasks, self.latent_dim) self.z_means = mu self.z_vars = var # sample a new z from the prior # reset the context collected so far self.context = None # reset any hidden state in the encoder network (relevant for RNN) self.context_encoder.reset(num_tasks)
def rsample_with_pretanh(self): z = ( self.normal_mean + self.normal_std * MultivariateDiagonalNormal( ptu.zeros(self.normal_mean.size()), ptu.ones(self.normal_std.size()) ).sample() ) return torch.tanh(z), z
def clear_z(self, num_tasks=1): ''' reset q(z|c) to the prior sample a new z from the prior ''' # reset distribution over z to the prior mu = ptu.zeros(num_tasks, self.latent_dim) var = ptu.ones(num_tasks, self.latent_dim) self.z_means = mu self.z_vars = var
def mle_estimate(self): """Return the mean of the most likely component. This often computes the mode of the distribution, but not always. """ c = ptu.zeros(self.weights.shape[:2]) ind = torch.argmax(self.weights, dim=1) # [:, 0] c.scatter_(1, ind, 1) s = torch.matmul(self.normal_means, c[:, :, None]) return torch.squeeze(s, 2)
def get_tau(self, actions, fp=None): if self.tau_type == 'fix': presum_tau = ptu.zeros(len(actions), self.num_quantiles) + 1. / self.num_quantiles elif self.tau_type == 'iqn': # add 0.1 to prevent tau getting too close presum_tau = ptu.rand(len(actions), self.num_quantiles) + 0.1 presum_tau /= presum_tau.sum(dim=-1, keepdims=True) tau = torch.cumsum(presum_tau, dim=1) # (N, T), note that they are tau1...tauN in the paper with torch.no_grad(): tau_hat = ptu.zeros_like(tau) tau_hat[:, 0:1] = tau[:, 0:1] / 2. tau_hat[:, 1:] = (tau[:, 1:] + tau[:, :-1]) / 2. return tau, tau_hat, presum_tau
def rsample(self, return_pretanh_value=False): """ Sampling in the reparameterization case. """ z = (self.normal_mean + self.normal_std * Normal(ptu.zeros(self.normal_mean.size()), ptu.ones(self.normal_std.size())).sample()) z.requires_grad_() if return_pretanh_value: return torch.tanh(z), z else: return torch.tanh(z)
def clear_sequence_z(self, num_tasks=1, batch_size=1, traj_batch_size=1): assert self.recurrent_context_encoder != None if self.r_cat_dim > 0: self.seq_z_cat = ptu.ones(num_tasks * batch_size * self.r_n_cat, self.r_cat_dim) / self.r_cat_dim self.seq_z_next_cat = None if self.r_cont_dim > 0: self.seq_z_cont_mean = ptu.zeros(num_tasks * batch_size, self.r_cont_dim) self.seq_z_cont_var = ptu.ones(num_tasks * batch_size, self.r_cont_dim) self.seq_z_next_cont_mean = None self.seq_z_next_cont_var = None if self.r_dir_dim > 0: if self.r_constraint == 'logitnormal': self.seq_z_dir_mean = ptu.zeros(num_tasks * batch_size * self.r_n_dir, self.r_dir_dim) self.seq_z_dir_var = ptu.ones(num_tasks * batch_size * self.r_n_dir, self.r_dir_dim) * self.r_var self.seq_z_next_dir_mean = None self.seq_z_next_dir_var = None elif self.r_constraint == 'dirichlet': self.seq_z_dir = ptu.ones(num_tasks * batch_size * self.r_n_dir, self.r_dir_dim) * self.r_alpha self.seq_z_next_dir = None self.sample_sequence_z() self.recurrent_context_encoder.reset(num_tasks*traj_batch_size)
def __init__( self, env, dsp, policy, classifier, search_buffer, policy_lr=1e-3, classifier_lr=1e-3, optimizer_class=optim.Adam, use_automatic_entropy_tuning=True, target_entropy=None, ): super().__init__() self.env = env self.dsp = dsp self.policy = policy self.classifier = classifier self.search_buffer = search_buffer self.use_automatic_entropy_tuning = use_automatic_entropy_tuning if self.use_automatic_entropy_tuning: if target_entropy: self.target_entropy = target_entropy else: self.target_entropy = -np.prod(self.env.action_space.shape).item() self.log_alpha = ptu.zeros(1, requires_grad=True) self.alpha_optimizer = optimizer_class( [self.log_alpha], lr=policy_lr, ) self.classifier_criterion = nn.MSELoss() self.dsp_optimizer = optimizer_class( self.dsp.parameters(), lr=policy_lr, ) self.policy_optimizer = optimizer_class( self.policy.parameters(), lr=policy_lr, ) self.classisier_optimizer = optimizer_class( self.classifier.parameters(), lr=classifier_lr, ) self.eval_statistics = OrderedDict() self._n_train_steps_total = 0 self._need_to_update_eval_statistics = True
def forward(self, *inputs): x = super().forward(*inputs) mean = self.mean_layer(x) logstd = self.logstd_layer(x) logstd = torch.clamp(logstd, LOGMIN, LOGMAX) std = torch.exp(logstd) unit_normal = Normal(ptu.zeros(mean.size()), ptu.ones(std.size())) eps = unit_normal.sample() pre_tanh_z = mean.cpu() + std.cpu() * eps action = torch.tanh(pre_tanh_z) logp = unit_normal.log_prob(eps) # logp = logp.sum(dim=1, keepdim=True) # logsum = exp mult return action, pre_tanh_z, logp, mean, logstd
def compute_features(self, dataloader): self.prepare_for_inference() # features = np.zeros((self.num_trajectories, self.episode_length, self.feature_size), dtype=np.float32) features = ptu.zeros([self.num_trajectories, self.episode_length, self.feature_size], dtype=torch.float32) with torch.no_grad(): for i, (input_tensor,) in enumerate(dataloader): # feature_batch = self.model(input_tensor.cuda()).cpu().numpy() feature_batch = self.forward(input_tensor) if i < len(dataloader) - 1: features[i * self.batch_size_trajectory: (i + 1) * self.batch_size_trajectory] = feature_batch else: features[i * self.batch_size_trajectory:] = feature_batch return features
def __init__(self, latent_dim, context_encoder, policy, reward_predictor, use_next_obs_in_context=False, _debug_ignore_context=False, _debug_do_not_sqrt=False, _debug_use_ground_truth_context=False): super().__init__() self.latent_dim = latent_dim self.context_encoder = context_encoder self.policy = policy self.reward_predictor = reward_predictor self.deterministic_policy = MakeDeterministic(self.policy) self._debug_ignore_context = _debug_ignore_context self._debug_use_ground_truth_context = _debug_use_ground_truth_context # self.recurrent = kwargs['recurrent'] # self.use_ib = kwargs['use_information_bottleneck'] # self.sparse_rewards = kwargs['sparse_rewards'] self.use_next_obs_in_context = use_next_obs_in_context # initialize buffers for z dist and z # use buffers so latent context can be saved along with model weights self.register_buffer('z', torch.zeros(1, latent_dim)) self.register_buffer('z_means', torch.zeros(1, latent_dim)) self.register_buffer('z_vars', torch.zeros(1, latent_dim)) self.z_means = None self.z_vars = None self.context = None self.z = None # rp = reward predictor # TODO: add back in reward predictor code self.z_means_rp = None self.z_vars_rp = None self.z_rp = None self.context_encoder_rp = context_encoder self._use_context_encoder_snapshot_for_reward_pred = False self.latent_prior = torch.distributions.Normal( ptu.zeros(self.latent_dim), ptu.ones(self.latent_dim)) self._debug_do_not_sqrt = _debug_do_not_sqrt
def __init__(self, hidden_sizes, obs_dim, action_dim, std=None, init_w=1e-3, min_log_std=None, max_log_std=None, num_gaussians=1, std_architecture="shared", **kwargs): super().__init__( hidden_sizes, input_size=obs_dim, output_size=action_dim * num_gaussians, init_w=init_w, # output_activation=torch.tanh, **kwargs) self.action_dim = action_dim self.num_gaussians = num_gaussians self.min_log_std = min_log_std self.max_log_std = max_log_std self.log_std = None self.std = std self.std_architecture = std_architecture if std is None: last_hidden_size = obs_dim if len(hidden_sizes) > 0: last_hidden_size = hidden_sizes[-1] if self.std_architecture == "shared": self.last_fc_log_std = nn.Linear(last_hidden_size, action_dim * num_gaussians) self.last_fc_log_std.weight.data.uniform_(-init_w, init_w) self.last_fc_log_std.bias.data.uniform_(-init_w, init_w) elif self.std_architecture == "values": self.log_std_logits = nn.Parameter( ptu.zeros(action_dim * num_gaussians, requires_grad=True)) else: raise ValueError(self.std_architecture) else: self.log_std = np.log(std) assert LOG_SIG_MIN <= self.log_std <= LOG_SIG_MAX self.last_fc_weights = nn.Linear(last_hidden_size, action_dim * num_gaussians) self.last_fc_weights.weight.data.uniform_(-init_w, init_w) self.last_fc_weights.bias.data.uniform_(-init_w, init_w)
def compute_density(self, data): orig_data_length = len(data) data = np.vstack([data for _ in range(self.n_average)]) data = ptu.from_numpy(data) if self.mode == 'biased': latents, means, log_vars, stds = ( self.encoder.get_encoding_and_suff_stats(data)) importance_weights = ptu.ones(data.shape[0]) elif self.mode == 'prior': latents = ptu.randn(len(data), self.z_dim) importance_weights = ptu.ones(data.shape[0]) elif self.mode == 'importance_sampling': latents, means, log_vars, stds = ( self.encoder.get_encoding_and_suff_stats(data)) prior = Normal(ptu.zeros(1), ptu.ones(1)) prior_log_prob = prior.log_prob(latents).sum(dim=1) encoder_distrib = Normal(means, stds) encoder_log_prob = encoder_distrib.log_prob(latents).sum(dim=1) importance_weights = (prior_log_prob - encoder_log_prob).exp() else: raise NotImplementedError() unweighted_data_log_prob = self.compute_log_prob( data, self.decoder, latents).squeeze(1) unweighted_data_prob = unweighted_data_log_prob.exp() unnormalized_data_prob = unweighted_data_prob * importance_weights """ Average over `n_average` """ dp_split = torch.split(unnormalized_data_prob, orig_data_length, dim=0) # pre_avg.shape = ORIG_LEN x N_AVERAGE dp_stacked = torch.stack(dp_split, dim=1) # final.shape = ORIG_LEN unnormalized_dp = torch.sum(dp_stacked, dim=1, keepdim=False) """ Compute the importance weight denomintors. This requires summing across the `n_average` dimension. """ iw_split = torch.split(importance_weights, orig_data_length, dim=0) iw_stacked = torch.stack(iw_split, dim=1) iw_denominators = iw_stacked.sum(dim=1, keepdim=False) final = unnormalized_dp / iw_denominators return ptu.get_numpy(final)