def test_log_prob_gradient(self): """ Same thing. Tanh term drops out since tanh has no params d/d mu log f_X(x) = - 2 (mu - x) d/d sigma log f_X(x) = 1/sigma^3 - 1/sigma :return: """ mean_var = ptu.from_numpy(np.array([0]), requires_grad=True) std_var = ptu.from_numpy(np.array([0.25]), requires_grad=True) tanh_normal = TanhNormal(mean_var, std_var) z = ptu.from_numpy(np.array([1])) x = torch.tanh(z) log_prob = tanh_normal.log_prob(x, pre_tanh_value=z) gradient = ptu.from_numpy(np.array([1])) log_prob.backward(gradient) self.assertNpArraysEqual( ptu.get_numpy(mean_var.grad), np.array([16]), ) self.assertNpArraysEqual( ptu.get_numpy(std_var.grad), np.array([4**3 - 4]), )
def forward( self, obs, reparameterize=True, deterministic=False, return_log_prob=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ h = obs # import pdb; pdb.set_trace() for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: if reparameterize is True: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: if reparameterize is True: action = tanh_normal.rsample() else: action = tanh_normal.sample() return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )
def forward( self, obs, reparameterize=False, deterministic=False, return_log_prob=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ t, b, _ = obs.size() h = obs h = self.inner_forward(h) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None expected_log_prob = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: if reparameterize: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: if reparameterize: action = tanh_normal.rsample() else: action = tanh_normal.sample() return ( action, mean, log_std, log_prob, expected_log_prob, std, mean_action_log_prob, pre_tanh_value, )
def test_log_prob_value(self): tanh_normal = TanhNormal(0, 1) z = np.array([1]) x_np = np.tanh(z) x = ptu.from_numpy(x_np) log_prob = tanh_normal.log_prob(x) log_prob_np = ptu.get_numpy(log_prob) log_prob_expected = ( np.log(np.array([1 / np.sqrt(2 * np.pi)])) - 0.5 # from Normal - np.log(1 - x_np**2)) self.assertNpArraysEqual( log_prob_expected, log_prob_np, )
def log_prob(self, obs, actions): raw_actions = atanh(actions) h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) mean = torch.clamp(mean, MEAN_MIN, MEAN_MAX) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std tanh_normal = TanhNormal(mean, std) log_prob = tanh_normal.log_prob(value=actions, pre_tanh_value=raw_actions) return log_prob.sum(-1)
def forward( self, obs, reparameterize=reparameterize, deterministic=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOGMIN, LOGMAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None expected_log_prob = None mean_action_log_prob = None pre_tanh_value = None tanh_normal = TanhNormal(mean, std) if reparameterize: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) return action.cuda(), log_prob.cuda(), pre_tanh_value.cuda( ), mean, log_std
def query_action_logpi(self, obs, action): h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) if self.layer_norm and i < len(self.fcs): h = self.layer_norms[i](h) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std pre_tanh_value = None tanh_normal = TanhNormal(mean, std) log_prob = tanh_normal.log_prob( action, pre_tanh_value=pre_tanh_value ) log_prob = log_prob.sum(dim=1, keepdim=True) return log_prob
def forward( self, obs, reparameterize=True, deterministic=False, return_log_prob=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) # actions * heads if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None log_stds = None log_probs = None if deterministic: means = mean.view(-1, self.action_dim) actions = torch.tanh(means) actions = actions.view(-1, self.heads, self.action_dim) else: all_actions = [] means = mean.view(-1, self.action_dim) stds = std.view(-1, self.action_dim) log_stds = log_std.view(-1, self.action_dim) tanh_normal = TanhNormal(means, stds) if return_log_prob: if reparameterize is True: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) log_probs = log_prob.view(-1, self.heads, self.action_dim) else: if reparameterize is True: action = tanh_normal.rsample() else: action = tanh_normal.sample() actions = action.view(-1, self.heads, self.action_dim) return ( actions, means, log_stds, log_probs, entropy, std, mean_action_log_prob, pre_tanh_value, )
def logprob(self, action, mean, std): tanh_normal = TanhNormal(mean, std) log_prob = tanh_normal.log_prob(action, ) log_prob = log_prob.sum(dim=1, keepdim=True) return log_prob
def forward( self, obs, reparameterize=True, deterministic=False, return_log_prob=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ # This is a bit messed up TODO clean it if obs.shape[0] == 1: # if obs is single image: flatten h = torch.flatten(obs) h = h.view(1, -1) else: # else if obs comes from replay buffer --> it is already flat and comes in a batch --> DO NOT flatten! h = obs h = super().forward(h, None, complete=False) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: if reparameterize is True: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: if reparameterize is True: action = tanh_normal.rsample() else: action = tanh_normal.sample() return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )
def forward( self, meta_size, batch_size, obs, reparameterize=False, deterministic=False, return_log_prob=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ h = obs for i, fc in enumerate(self.fcs): h = fc(h) if self.layer_norm and i < len(self.fcs) - 1: h = self.layer_norms[i](h) h = self.hidden_activation(h) if self.use_dropout and i < len(self.fcs) - 1: h = self.dropouts[i](h) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None expected_log_prob = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: if reparameterize: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True ) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True ) log_prob = tanh_normal.log_prob( action, pre_tanh_value=pre_tanh_value ) log_prob = log_prob.sum(dim=1, keepdim=True) else: if reparameterize: action = tanh_normal.rsample() else: action = tanh_normal.sample() return ( action, mean, log_std, log_prob, expected_log_prob, std, mean_action_log_prob, pre_tanh_value, )
def forward( self, obs, reparameterize=True, deterministic=False, return_log_prob=False, return_entropy=False, return_log_prob_of_mean=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability :param return_entropy: If True, return the true expected log prob. Will not need to be differentiated through, so this can be a number. :param return_log_prob_of_mean: If True, return the true expected log prob. Will not need to be differentiated through, so this can be a number. """ h = self.obs_processor(obs) h = self.mean_and_log_std_net(h) mean, log_std = torch.split(h, self.action_dim, dim=1) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: if reparameterize is True: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: if reparameterize is True: action = tanh_normal.rsample() else: action = tanh_normal.sample() if return_entropy: entropy = log_std + 0.5 + np.log(2 * np.pi) / 2 # I'm not sure how to compute the (differential) entropy for a # tanh(Gaussian) entropy = entropy.sum(dim=1, keepdim=True) raise NotImplementedError() if return_log_prob_of_mean: tanh_normal = TanhNormal(mean, std) mean_action_log_prob = tanh_normal.log_prob( torch.tanh(mean), pre_tanh_value=mean, ) mean_action_log_prob = mean_action_log_prob.sum(dim=1, keepdim=True) return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )
def test_log_prob_type(self): tanh_normal = TanhNormal(0, 1) x = ptu.from_numpy(np.array([0])) log_prob = tanh_normal.log_prob(x) self.assertIsInstance(log_prob, torch.autograd.Variable)
def forward( self, obs, deterministic=False, return_log_prob=False, return_entropy=False, return_log_prob_of_mean=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability :param return_entropy: If True, return the true expected log prob. Will not need to be differentiated through, so this can be a number. :param return_log_prob_of_mean: If True, return the true expected log prob. Will not need to be differentiated through, so this can be a number. """ obs, taus = split_tau(obs) batch_size = obs.size()[0] tau_vector = Variable( torch.zeros((batch_size, self.tau_vector_len)) + taus.data) h = obs h1 = self.hidden_activation(self.first_input(h)) h2 = self.hidden_activation(self.second_input(tau_vector)) h = torch.cat((h1, h2), dim=1) for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: action = tanh_normal.sample() if return_entropy: entropy = log_std + 0.5 + np.log(2 * np.pi) / 2 # Because tanh is invertible, the entropy of a Gaussian and the # entropy of the tanh of a Gaussian is the same. entropy = entropy.sum(dim=1, keepdim=True) if return_log_prob_of_mean: tanh_normal = TanhNormal(mean, std) mean_action_log_prob = tanh_normal.log_prob( torch.tanh(mean), pre_tanh_value=mean, ) mean_action_log_prob = mean_action_log_prob.sum(dim=1, keepdim=True) return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )