def test_log_prob_gradient(self): """ Same thing. Tanh term drops out since tanh has no params d/d mu log f_X(x) = - 2 (mu - x) d/d sigma log f_X(x) = 1/sigma^3 - 1/sigma :return: """ mean_var = ptu.from_numpy(np.array([0]), requires_grad=True) std_var = ptu.from_numpy(np.array([0.25]), requires_grad=True) tanh_normal = TanhNormal(mean_var, std_var) z = ptu.from_numpy(np.array([1])) x = torch.tanh(z) log_prob = tanh_normal.log_prob(x, pre_tanh_value=z) gradient = ptu.from_numpy(np.array([1])) log_prob.backward(gradient) self.assertNpArraysEqual( ptu.get_numpy(mean_var.grad), np.array([16]), ) self.assertNpArraysEqual( ptu.get_numpy(std_var.grad), np.array([4**3 - 4]), )
def log_prob(self, observation, state, actions=None): raw_actions = self.atanh(actions) tensor = observation # normalize # tensor = torch.div(tensor, 255.) for layer in self.convs: tensor = layer(tensor) tensor = self.conv_activation(tensor) tensor = batch_flatten(tensor) if state is not None: tensor = torch.cat((tensor, state), 1) i = 0 for layer in self.fcs: h = tensor tensor = layer(tensor) if i == len(self.fcs) - 1: tensor = self.output_activation(tensor) else: tensor = self.fcs_activation(tensor) i += 1 mean = tensor mean = torch.clamp(mean, LOG_MEAN_MIN, LOG_MEAN_MAX) log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) tanh_normal = TanhNormal(mean, std) log_prob = tanh_normal.log_prob(value=actions, pre_tanh_value=raw_actions) return log_prob.sum(-1)
def test_log_prob_value(self): tanh_normal = TanhNormal(0, 1) z = np.array([1]) x_np = np.tanh(z) x = ptu.from_numpy(x_np) log_prob = tanh_normal.log_prob(x) log_prob_np = ptu.get_numpy(log_prob) log_prob_expected = ( np.log(np.array([1 / np.sqrt(2 * np.pi)])) - 0.5 # from Normal - np.log(1 - x_np**2)) self.assertNpArraysEqual( log_prob_expected, log_prob_np, )
def log_prob_aviral(self, obs, actions): def atanh(x): one_plus_x = (1 + x).clamp(min=1e-6) one_minus_x = (1 - x).clamp(min=1e-6) return 0.5 * torch.log(one_plus_x / one_minus_x) raw_actions = atanh(actions) h = self.obs_processor(obs) h = self.mean_and_log_std_net(h) mean, log_std = torch.split(h, self.action_dim, dim=1) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) tanh_normal = TanhNormal(mean, std) log_prob = tanh_normal.log_prob(value=actions, pre_tanh_value=raw_actions) return log_prob.sum(-1)
def log_prob_aviral(self, obs, actions): def atanh(x): one_plus_x = (1 + x).clamp(min=1e-6) one_minus_x = (1 - x).clamp(min=1e-6) return 0.5 * torch.log(one_plus_x / one_minus_x) raw_actions = atanh(actions) h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std tanh_normal = TanhNormal(mean, std) log_prob = tanh_normal.log_prob(value=actions, pre_tanh_value=raw_actions) return log_prob.sum(-1)
def forward( self, observation, state, actions=None, reparameterize=True, deterministic=False, return_log_prob=False, ): # import ipdb; ipdb.set_trace() tensor = observation # normalize # tensor = torch.div(tensor, 255.) for layer in self.convs: tensor = layer(tensor) tensor = self.conv_activation(tensor) tensor = batch_flatten(tensor) if state is not None: tensor = torch.cat((tensor, state), 1) i = 0 for layer in self.fcs: h = tensor tensor = layer(tensor) if i == len(self.fcs) - 1: tensor = self.output_activation(tensor) else: tensor = self.fcs_activation(tensor) i += 1 mean = tensor log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) mean = torch.clamp(mean, LOG_MEAN_MIN, LOG_MEAN_MAX) tanh_normal = TanhNormal(mean, std) log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: if reparameterize is True: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: if reparameterize is True: action = tanh_normal.rsample() else: action = tanh_normal.sample() return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )
def forward( self, obs, deterministic=False, return_log_prob=False, return_entropy=False, return_log_prob_of_mean=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability :param return_entropy: If True, return the true expected log prob. Will not need to be differentiated through, so this can be a number. :param return_log_prob_of_mean: If True, return the true expected log prob. Will not need to be differentiated through, so this can be a number. """ obs, taus = split_tau(obs) batch_size = obs.size()[0] tau_vector = Variable( torch.zeros((batch_size, self.tau_vector_len)) + taus.data) h = obs h1 = self.hidden_activation(self.first_input(h)) h2 = self.hidden_activation(self.second_input(tau_vector)) h = torch.cat((h1, h2), dim=1) for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: action = tanh_normal.sample() if return_entropy: entropy = log_std + 0.5 + np.log(2 * np.pi) / 2 # Because tanh is invertible, the entropy of a Gaussian and the # entropy of the tanh of a Gaussian is the same. entropy = entropy.sum(dim=1, keepdim=True) if return_log_prob_of_mean: tanh_normal = TanhNormal(mean, std) mean_action_log_prob = tanh_normal.log_prob( torch.tanh(mean), pre_tanh_value=mean, ) mean_action_log_prob = mean_action_log_prob.sum(dim=1, keepdim=True) return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )
def forward( self, obs, reparameterize=True, deterministic=False, return_log_prob=False, return_entropy=False, return_log_prob_of_mean=False, ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability :param return_entropy: If True, return the true expected log prob. Will not need to be differentiated through, so this can be a number. :param return_log_prob_of_mean: If True, return the true expected log prob. Will not need to be differentiated through, so this can be a number. """ h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None entropy = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = TanhNormal(mean, std) if return_log_prob: if reparameterize is True: action, pre_tanh_value = tanh_normal.rsample( return_pretanh_value=True) else: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) log_prob = log_prob.sum(dim=1, keepdim=True) else: if reparameterize is True: action = tanh_normal.rsample() else: action = tanh_normal.sample() if return_entropy: entropy = log_std + 0.5 + np.log(2 * np.pi) / 2 # I'm not sure how to compute the (differential) entropy for a # tanh(Gaussian) entropy = entropy.sum(dim=1, keepdim=True) raise NotImplementedError() if return_log_prob_of_mean: tanh_normal = TanhNormal(mean, std) mean_action_log_prob = tanh_normal.log_prob( torch.tanh(mean), pre_tanh_value=mean, ) mean_action_log_prob = mean_action_log_prob.sum(dim=1, keepdim=True) return ( action, mean, log_std, log_prob, entropy, std, mean_action_log_prob, pre_tanh_value, )
def logprob(self, action, mean, std): # import ipdb; ipdb.set_trace() tanh_normal = TanhNormal(mean, std) log_prob = tanh_normal.log_prob(action, ) log_prob = log_prob.sum(dim=1, keepdim=True) return log_prob
def test_log_prob_type(self): tanh_normal = TanhNormal(0, 1) x = ptu.from_numpy(np.array([0])) log_prob = tanh_normal.log_prob(x) self.assertIsInstance(log_prob, torch.autograd.Variable)