def test_log_prob_gradient(self): """ Same thing. Tanh term drops out since tanh has no params d/d mu log f_X(x) = - 2 (mu - x) d/d sigma log f_X(x) = 1/sigma^3 - 1/sigma :return: """ mean_var = ptu.from_numpy(np.array([0]), requires_grad=True) std_var = ptu.from_numpy(np.array([0.25]), requires_grad=True) tanh_normal = TanhNormal(mean_var, std_var) z = ptu.from_numpy(np.array([1])) x = torch.tanh(z) log_prob = tanh_normal.log_prob(x, pre_tanh_value=z) gradient = ptu.from_numpy(np.array([1])) log_prob.backward(gradient) self.assertNpArraysEqual( ptu.get_numpy(mean_var.grad), np.array([16]), ) self.assertNpArraysEqual( ptu.get_numpy(std_var.grad), np.array([4**3 - 4]), )
def forward( self, obs, deterministic=False, return_log_prob=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None expected_log_prob = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: action = tanh_normal.sample() return ( action, mean, log_std, log_prob, expected_log_prob, std, mean_action_log_prob, pre_tanh_value, )
def test_log_prob_value(self): tanh_normal = TanhNormal(0, 1) z = np.array([1]) x_np = np.tanh(z) x = ptu.from_numpy(x_np) log_prob = tanh_normal.log_prob(x) log_prob_np = ptu.get_numpy(log_prob) log_prob_expected = ( np.log(np.array([1 / np.sqrt(2 * np.pi)])) - 0.5 # from Normal - np.log(1 - x_np**2)) self.assertNpArraysEqual( log_prob_expected, log_prob_np, )
def forward(self, obs): h = self.obs_processor(obs) h = self.mean_and_log_std_net(h) mean, log_std = torch.split(h, self.action_dim, dim=1) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) tanh_normal = TanhNormal(mean, std) return tanh_normal
def log_prob(self, obs, actions): raw_actions = atanh(actions) h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) mean = torch.clamp(mean, MEAN_MIN, MEAN_MAX) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std tanh_normal = TanhNormal(mean, std) log_prob = tanh_normal.log_prob(value=actions, pre_tanh_value=raw_actions) return log_prob.sum(-1)
def forward( self, obs, reparameterize=True, deterministic=False, return_log_prob=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ h = obs h = super().forward(obs, return_features=True) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: if reparameterize is True: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: if reparameterize is True: action = tanh_normal.rsample() else: action = tanh_normal.sample() return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )
def forward(self, obs): h, h_aux = super().forward(obs, return_last_main_activations=True) mean = self.last_fc_main(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std tanh_normal = TanhNormal(mean, std) return tanh_normal, h_aux
def forward(self, data): x = self.gat_net(data) mean = self.last_mean_layer(x) if self.std is None: log_std = self.last_log_std(x) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = (torch.from_numpy(np.array([ self.std, ])).float().to(ptu.device)) return TanhNormal(mean, std)
def query_action_logpi(self, obs, action): h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) if self.layer_norm and i < len(self.fcs): h = self.layer_norms[i](h) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std pre_tanh_value = None tanh_normal = TanhNormal(mean, std) log_prob = tanh_normal.log_prob( action, pre_tanh_value=pre_tanh_value ) log_prob = log_prob.sum(dim=1, keepdim=True) return log_prob
def forward(self, obs): h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = (torch.from_numpy(np.array([ self.std, ])).float().to(ptu.device)) return TanhNormal(mean, std)
def forward( self, obs, reparameterize=True, deterministic=False, return_log_prob=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ # This is a bit messed up TODO clean it if obs.shape[0] == 1: # if obs is single image: flatten h = torch.flatten(obs) h = h.view(1, -1) else: # else if obs comes from replay buffer --> it is already flat and comes in a batch --> DO NOT flatten! h = obs h = super().forward(h, None, complete=False) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: if reparameterize is True: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: if reparameterize is True: action = tanh_normal.rsample() else: action = tanh_normal.sample() return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )
def forward( self, obs, reparameterize=True, deterministic=False, return_log_prob=True, **kwargs ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ # gt.stamp("Tanhnormal_forward") h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) if self.layer_norm and i < len(self.fcs): h = self.layer_norms[i](h) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) # gt.stamp("tanhnormal_pre") if return_log_prob: if reparameterize is True: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True ) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True ) log_prob = tanh_normal.log_prob( action, pre_tanh_value=pre_tanh_value ) log_prob = log_prob.sum(dim=1, keepdim=True) else: if reparameterize is True: action = tanh_normal.rsample() else: action = tanh_normal.sample() # gt.stamp("tanhnormal_post") log_std = log_std.sum(dim=1, keepdim=True) return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )
def forward( self, obs, deterministic=False, return_log_prob=False, return_entropy=False, return_log_prob_of_mean=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability :param return_entropy: If True, return the true expected log prob. Will not need to be differentiated through, so this can be a number. :param return_log_prob_of_mean: If True, return the true expected log prob. Will not need to be differentiated through, so this can be a number. """ obs, taus = split_tau(obs) batch_size = obs.size()[0] tau_vector = Variable( torch.zeros((batch_size, self.tau_vector_len)) + taus.data) h = obs h1 = self.hidden_activation(self.first_input(h)) h2 = self.hidden_activation(self.second_input(tau_vector)) h = torch.cat((h1, h2), dim=1) for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: action = tanh_normal.sample() if return_entropy: entropy = log_std + 0.5 + np.log(2 * np.pi) / 2 # Because tanh is invertible, the entropy of a Gaussian and the # entropy of the tanh of a Gaussian is the same. entropy = entropy.sum(dim=1, keepdim=True) if return_log_prob_of_mean: tanh_normal = TanhNormal(mean, std) mean_action_log_prob = tanh_normal.log_prob( torch.tanh(mean), pre_tanh_value=mean, ) mean_action_log_prob = mean_action_log_prob.sum(dim=1, keepdim=True) return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )
def forward( self, obs, reparameterize=True, deterministic=False, return_log_prob=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) out = self.last_fc(h).view(-1, self.k, (2 * self.action_dim + 1)) log_w = out[..., 0] mean = out[..., 1:1 + self.action_dim] log_std = out[..., 1 + self.action_dim:] log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) log_w = torch.clamp(log_w, min=LOG_W_MIN) self.log_w = log_w self.mean = mean self.log_std = log_std std = torch.exp(log_std) log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None arange = torch.arange(out.shape[0]) if deterministic: ks = log_w.view(-1, self.k).argmax(1) action = torch.tanh(mean[arange, ks]) else: sample_ks = Categorical(logits=log_w.view(-1, self.k)).sample() tanh_normal = TanhNormal(mean[arange, sample_ks], std[arange, sample_ks]) if return_log_prob: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) # (NxKxA), (NxKxA), (Nx1xA) => (NxK) log_p_xz_t = log_gaussian(mean, log_std, pre_tanh_value[:, None, :].data) log_p_x_t = torch.logsumexp(log_p_xz_t + log_w, 1) - torch.logsumexp(log_w, 1) # squash correction log_prob = log_p_x_t - torch.log(1 - action**2 + 1e-6).sum(1) log_prob = log_prob[:, None] else: if reparameterize is True: action = tanh_normal.rsample() else: action = tanh_normal.sample() return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )
def test_log_prob_type(self): tanh_normal = TanhNormal(0, 1) x = ptu.from_numpy(np.array([0])) log_prob = tanh_normal.log_prob(x) self.assertIsInstance(log_prob, torch.autograd.Variable)
def forward( self, obs, reparameterize=True, deterministic=False, return_log_prob=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ h = obs[:, :self.obs_preprocess_size] for i, fc in enumerate(self.obs_preprocess_fcs): h = fc(h) h = self.hidden_activation(h) obs_processed = self.obs_preprocess_last_fc(h) h = torch.cat([obs[:, self.obs_preprocess_size:], obs_processed], dim=1) for i, fc in enumerate(self.fcs): h = fc(h) h = self.hidden_activation(h) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: if reparameterize is True: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: if reparameterize is True: action = tanh_normal.rsample() else: action = tanh_normal.sample() return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )
def logprob(self, action, mean, std): tanh_normal = TanhNormal(mean, std) log_prob = tanh_normal.log_prob(action, ) log_prob = log_prob.sum(dim=1, keepdim=True) return log_prob
def forward(self, x): mean, std, h_aux, _ = super().forward(x) tanh_normal = TanhNormal(mean, std) return tanh_normal, h_aux
def forward(self, *input): mean, log_std = super().forward(*input) std = log_std.exp() return TanhNormal(mean, std)
def forward( self, obs, reparameterize=True, deterministic=False, return_log_prob=False, return_entropy=False, return_log_prob_of_mean=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability :param return_entropy: If True, return the true expected log prob. Will not need to be differentiated through, so this can be a number. :param return_log_prob_of_mean: If True, return the true expected log prob. Will not need to be differentiated through, so this can be a number. """ h = self.obs_processor(obs) h = self.mean_and_log_std_net(h) mean, log_std = torch.split(h, self.action_dim, dim=1) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: if reparameterize is True: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: if reparameterize is True: action = tanh_normal.rsample() else: action = tanh_normal.sample() if return_entropy: entropy = log_std + 0.5 + np.log(2 * np.pi) / 2 # I'm not sure how to compute the (differential) entropy for a # tanh(Gaussian) entropy = entropy.sum(dim=1, keepdim=True) raise NotImplementedError() if return_log_prob_of_mean: tanh_normal = TanhNormal(mean, std) mean_action_log_prob = tanh_normal.log_prob( torch.tanh(mean), pre_tanh_value=mean, ) mean_action_log_prob = mean_action_log_prob.sum(dim=1, keepdim=True) return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )
def forward( self, obs, reparameterize=True, deterministic=False, return_log_prob=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) # actions * heads if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None log_stds = None log_probs = None if deterministic: means = mean.view(-1, self.action_dim) actions = torch.tanh(means) actions = actions.view(-1, self.heads, self.action_dim) else: all_actions = [] means = mean.view(-1, self.action_dim) stds = std.view(-1, self.action_dim) log_stds = log_std.view(-1, self.action_dim) tanh_normal = TanhNormal(means, stds) if return_log_prob: if reparameterize is True: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) log_probs = log_prob.view(-1, self.heads, self.action_dim) else: if reparameterize is True: action = tanh_normal.rsample() else: action = tanh_normal.sample() actions = action.view(-1, self.heads, self.action_dim) return ( actions, means, log_stds, log_probs, entropy, std, mean_action_log_prob, pre_tanh_value, )
def train_from_torch(self, batch): obs = batch['observations'] old_log_pi = batch['log_prob'] advantage = batch['advantage'] returns = batch['returns'] actions = batch['actions'] """ Policy Loss """ _, policy_mean, policy_log_std, _, _, policy_std, _, _ = self.policy( obs) new_log_pi = TanhNormal(policy_mean, policy_std).log_prob(actions).sum(1, keepdim=True) # Advantage Clip ratio = torch.exp(new_log_pi - old_log_pi) left = ratio * advantage right = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantage policy_loss = (-1 * torch.min(left, right)).mean() """ VF Loss """ v_pred = self.vf(obs) v_target = returns vf_loss = self.vf_criterion(v_pred, v_target) """ Update networks """ loss = policy_loss + vf_loss self.optimizer.zero_grad() loss.backward() self.optimizer.step() if self.last_approx_kl is None or not self._need_to_update_eval_statistics: self.last_approx_kl = (old_log_pi - new_log_pi).detach() approx_ent = -new_log_pi """ Save some statistics for eval """ if self._need_to_update_eval_statistics: policy_grads = torch.cat( [p.grad.flatten() for p in self.policy.parameters()]) value_grads = torch.cat( [p.grad.flatten() for p in self.vf.parameters()]) self._need_to_update_eval_statistics = False """ Eval should set this to None. This way, these statistics are only computed for one batch. """ self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss)) self.eval_statistics['Policy Loss'] = np.mean( ptu.get_numpy(policy_loss)) self.eval_statistics.update( create_stats_ordered_dict( 'V Predictions', ptu.get_numpy(v_pred), )) self.eval_statistics.update( create_stats_ordered_dict( 'V Target', ptu.get_numpy(v_target), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Gradients', ptu.get_numpy(policy_grads), )) self.eval_statistics.update( create_stats_ordered_dict( 'Value Gradients', ptu.get_numpy(value_grads), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy KL', ptu.get_numpy(self.last_approx_kl), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy Entropy', ptu.get_numpy(approx_ent), )) self.eval_statistics.update( create_stats_ordered_dict( 'New Log Pis', ptu.get_numpy(new_log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Old Log Pis', ptu.get_numpy(old_log_pi), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy mu', ptu.get_numpy(policy_mean), )) self.eval_statistics.update( create_stats_ordered_dict( 'Policy log std', ptu.get_numpy(policy_log_std), )) self._n_train_steps_total += 1