Пример #1
0
    def rollout_step(self, time_step: TimeStep, state):
        """Train one step.

        Args:
            time_step (TimeStep): ``time_step.observation`` should be the latent
                vector.
            state (nested Tensor): state of the model
        """
        latent_vector = time_step.observation
        action_distribution, state = self._get_action(latent_vector, state)
        value, _ = self._value_net(latent_vector)
        action = dist_utils.sample_action_distribution(action_distribution)

        info = ActorCriticInfo(
            action_distribution=action_distribution, value=value)
        return AlgStep(output=action, state=state, info=info)
Пример #2
0
    def rollout_step(self, time_step: TimeStep, state: ActorCriticState):
        """Rollout for one step."""
        value, value_state = self._value_network(time_step.observation,
                                                 state=state.value)

        # We detach exp.observation here so that in the case that exp.observation
        # is calculated by some other trainable module, the training of that
        # module will not be affected by the gradient back-propagated from the
        # actor. However, the gradient from critic will still affect the training
        # of that module.
        action_distribution, actor_state = self._actor_network(
            common.detach(time_step.observation), state=state.actor)

        action = dist_utils.sample_action_distribution(action_distribution)
        return AlgStep(output=action,
                       state=ActorCriticState(actor=actor_state,
                                              value=value_state),
                       info=ActorCriticInfo(
                           value=value,
                           action_distribution=action_distribution))
Пример #3
0
    def _predict_action(self,
                        observation,
                        state: SacActionState,
                        epsilon_greedy=None,
                        eps_greedy_sampling=False):
        """The reason why we want to do action sampling inside this function
        instead of outside is that for the mixed case, once a continuous action
        is sampled here, we should pair it with the discrete action sampled from
        the Q value. If we just return two distributions and sample outside, then
        the actions will not match.
        """
        new_state = SacActionState()
        if self._act_type != ActionType.Discrete:
            continuous_action_dist, actor_network_state = self._actor_network(
                observation, state=state.actor_network)
            new_state = new_state._replace(actor_network=actor_network_state)
            if eps_greedy_sampling:
                continuous_action = dist_utils.epsilon_greedy_sample(
                    continuous_action_dist, epsilon_greedy)
            else:
                continuous_action = dist_utils.rsample_action_distribution(
                    continuous_action_dist)

        critic_network_inputs = observation
        if self._act_type == ActionType.Mixed:
            critic_network_inputs = (observation, continuous_action)

        q_values = None
        if self._act_type != ActionType.Continuous:
            q_values, critic_state = self._critic_networks(
                critic_network_inputs, state=state.critic)
            new_state = new_state._replace(critic=critic_state)
            if self._act_type == ActionType.Discrete:
                alpha = torch.exp(self._log_alpha).detach()
            else:
                alpha = torch.exp(self._log_alpha[0]).detach()
            # p(a|s) = exp(Q(s,a)/alpha) / Z;
            q_values = q_values.min(dim=1)[0]
            logits = q_values / alpha
            discrete_action_dist = td.Categorical(logits=logits)
            if eps_greedy_sampling:
                discrete_action = dist_utils.epsilon_greedy_sample(
                    discrete_action_dist, epsilon_greedy)
            else:
                discrete_action = dist_utils.sample_action_distribution(
                    discrete_action_dist)

        if self._act_type == ActionType.Mixed:
            # Note that in this case ``action_dist`` is not the valid joint
            # action distribution because ``discrete_action_dist`` is conditioned
            # on a particular continuous action sampled above. So DO NOT use this
            # ``action_dist`` to directly sample an action pair with an arbitrary
            # continuous action anywhere else!
            # However, for computing the log probability of *this* sampled
            # ``action``, it's still valid. It can also be used for summary
            # purpose because of the expectation taken over the continuous action
            # when summarizing.
            action_dist = type(self._action_spec)(
                (discrete_action_dist, continuous_action_dist))
            action = type(self._action_spec)(
                (discrete_action, continuous_action))
        elif self._act_type == ActionType.Discrete:
            action_dist = discrete_action_dist
            action = discrete_action
        else:
            action_dist = continuous_action_dist
            action = continuous_action

        return action_dist, action, q_values, new_state