def forward(self, obs, deterministic=False, return_log_prob=False, return_tanh_normal=False): goal = obs[:, -2:] obs = obs[:, :-2] extra_info = self.first_extra_layer(goal) hid = self.first_hid_layer(obs) for h_mod in self.mod_list: extra_info, hid = h_mod(extra_info, hid) hid_extra_concat = torch.cat([hid, extra_info], dim=-1) mean = self.last_fc_mean(hid_extra_concat) log_std = self.last_fc_log_std(hid_extra_concat) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) log_prob = None expected_log_prob = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = ReparamTanhMultivariateNormal(mean, log_std) if return_log_prob: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True) log_prob = tanh_normal.log_prob(action, pre_tanh_value=pre_tanh_value) else: action = tanh_normal.sample() # I'm doing it like this for now for backwards compatibility, sorry! if return_tanh_normal: return ( action, mean, log_std, log_prob, expected_log_prob, std, mean_action_log_prob, pre_tanh_value, tanh_normal, ) return ( action, mean, log_std, log_prob, expected_log_prob, std, mean_action_log_prob, pre_tanh_value, )
def get_log_prob(self, obs, acts, return_normal_params=False): goal = obs[:, -2:] obs = obs[:, :-2] extra_info = self.first_extra_layer(goal) hid = self.first_hid_layer(obs) for h_mod in self.mod_list: extra_info, hid = h_mod(extra_info, hid) hid_extra_concat = torch.cat([hid, extra_info], dim=-1) mean = self.last_fc_mean(hid_extra_concat) log_std = self.last_fc_log_std(hid_extra_concat) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) tanh_normal = ReparamTanhMultivariateNormal(mean, log_std) log_prob = tanh_normal.log_prob(acts) if return_normal_params: return log_prob, mean, log_std return log_prob
def get_log_prob(self, obs, acts, return_normal_params=False): h = obs for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std tanh_normal = ReparamTanhMultivariateNormal(mean, log_std) log_prob = tanh_normal.log_prob(acts) # print('\n\n\n\n\nGet log prob') # print(log_prob) # print(mean) # print(log_std) if return_normal_params: return log_prob, mean, log_std return log_prob
def forward( self, obs, deterministic=False, return_log_prob=False, return_tanh_normal=False ): """ :param obs: Observation :param deterministic: If True, do not sample :param return_log_prob: If True, return a sample and its log probability """ h = obs middle_output = None for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) middle_output = h mean = self.last_fc(h) if self.std is None: log_std = self.last_fc_log_std(h) log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX) std = torch.exp(log_std) else: std = self.std log_std = self.log_std log_prob = None expected_log_prob = None mean_action_log_prob = None pre_tanh_value = None if deterministic: action = torch.tanh(mean) else: tanh_normal = ReparamTanhMultivariateNormal(mean, log_std) # print('mean, std') # print(mean) # print(log_std) if return_log_prob: action, pre_tanh_value = tanh_normal.sample( return_pretanh_value=True ) log_prob = tanh_normal.log_prob( action, pre_tanh_value=pre_tanh_value ) else: action = tanh_normal.sample() # I'm doing it like this for now for backwards compatibility, sorry! if return_tanh_normal: return ( action, mean, log_std, log_prob, expected_log_prob, std, mean_action_log_prob, pre_tanh_value, tanh_normal, middle_output ) return ( action, mean, log_std, log_prob, expected_log_prob, std, mean_action_log_prob, pre_tanh_value, middle_output )