Ejemplo n.º 1
0
    def forward(self,
                obs,
                deterministic=False,
                return_log_prob=False,
                return_tanh_normal=False):
        goal = obs[:, -2:]
        obs = obs[:, :-2]

        extra_info = self.first_extra_layer(goal)
        hid = self.first_hid_layer(obs)

        for h_mod in self.mod_list:
            extra_info, hid = h_mod(extra_info, hid)

        hid_extra_concat = torch.cat([hid, extra_info], dim=-1)
        mean = self.last_fc_mean(hid_extra_concat)
        log_std = self.last_fc_log_std(hid_extra_concat)
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        std = torch.exp(log_std)

        log_prob = None
        expected_log_prob = None
        mean_action_log_prob = None
        pre_tanh_value = None
        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = ReparamTanhMultivariateNormal(mean, log_std)
            if return_log_prob:
                action, pre_tanh_value = tanh_normal.sample(
                    return_pretanh_value=True)
                log_prob = tanh_normal.log_prob(action,
                                                pre_tanh_value=pre_tanh_value)
            else:
                action = tanh_normal.sample()

        # I'm doing it like this for now for backwards compatibility, sorry!
        if return_tanh_normal:
            return (
                action,
                mean,
                log_std,
                log_prob,
                expected_log_prob,
                std,
                mean_action_log_prob,
                pre_tanh_value,
                tanh_normal,
            )
        return (
            action,
            mean,
            log_std,
            log_prob,
            expected_log_prob,
            std,
            mean_action_log_prob,
            pre_tanh_value,
        )
Ejemplo n.º 2
0
    def get_log_prob(self, obs, acts, return_normal_params=False):
        goal = obs[:, -2:]
        obs = obs[:, :-2]

        extra_info = self.first_extra_layer(goal)
        hid = self.first_hid_layer(obs)

        for h_mod in self.mod_list:
            extra_info, hid = h_mod(extra_info, hid)

        hid_extra_concat = torch.cat([hid, extra_info], dim=-1)
        mean = self.last_fc_mean(hid_extra_concat)
        log_std = self.last_fc_log_std(hid_extra_concat)
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        std = torch.exp(log_std)

        tanh_normal = ReparamTanhMultivariateNormal(mean, log_std)
        log_prob = tanh_normal.log_prob(acts)

        if return_normal_params:
            return log_prob, mean, log_std
        return log_prob
Ejemplo n.º 3
0
    def get_log_prob(self, obs, acts, return_normal_params=False):
        h = obs
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        mean = self.last_fc(h)
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std
        
        tanh_normal = ReparamTanhMultivariateNormal(mean, log_std)
        log_prob = tanh_normal.log_prob(acts)

        # print('\n\n\n\n\nGet log prob')
        # print(log_prob)
        # print(mean)
        # print(log_std)

        if return_normal_params:
            return log_prob, mean, log_std
        return log_prob
Ejemplo n.º 4
0
    def forward(
            self,
            obs,
            deterministic=False,
            return_log_prob=False,
            return_tanh_normal=False
    ):
        """
        :param obs: Observation
        :param deterministic: If True, do not sample
        :param return_log_prob: If True, return a sample and its log probability
        """

        h = obs
        middle_output = None
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        middle_output = h
        
        mean = self.last_fc(h)

        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std

        log_prob = None
        expected_log_prob = None
        mean_action_log_prob = None
        pre_tanh_value = None
        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = ReparamTanhMultivariateNormal(mean, log_std)
            # print('mean, std')
            # print(mean)
            # print(log_std)
            if return_log_prob:
                action, pre_tanh_value = tanh_normal.sample(
                    return_pretanh_value=True
                )
                log_prob = tanh_normal.log_prob(
                    action,
                    pre_tanh_value=pre_tanh_value
                )
            else:
                action = tanh_normal.sample()

        # I'm doing it like this for now for backwards compatibility, sorry!
        if return_tanh_normal:
            return (
                action, mean, log_std, log_prob, expected_log_prob, std,
                mean_action_log_prob, pre_tanh_value, tanh_normal, middle_output
            )
        return (
            action, mean, log_std, log_prob, expected_log_prob, std,
            mean_action_log_prob, pre_tanh_value, middle_output
        )