Ejemplo n.º 1
0
    def _get_action(self,
                    actor_network,
                    time_step: TimeStep,
                    state: SarsaState,
                    epsilon_greedy=1.0):
        action_distribution, actor_state = actor_network(time_step.observation,
                                                         state=state.actor)
        if actor_network.is_distribution_output:
            if epsilon_greedy == 1.0:
                action = dist_utils.rsample_action_distribution(
                    action_distribution)
            else:
                action = dist_utils.epsilon_greedy_sample(
                    action_distribution, epsilon_greedy)
            noise_state = ()
        else:

            def _sample(a, noise):
                if epsilon_greedy >= 1.0:
                    return a + noise
                else:
                    choose_random_action = (torch.rand(a.shape[:1]) <
                                            epsilon_greedy)
                    return torch.where(
                        common.expand_dims_as(choose_random_action, a),
                        a + noise, a)

            noise, noise_state = self._noise_process(state.noise)
            action = nest_map(_sample, action_distribution, noise)
        return action_distribution, action, actor_state, noise_state
Ejemplo n.º 2
0
    def test_action_sampling_transformed_normal(self):
        def _get_transformed_normal(means, stds):
            normal_dist = td.Independent(td.Normal(loc=means, scale=stds), 1)
            transforms = [
                dist_utils.StableTanh(),
                dist_utils.AffineTransform(loc=torch.tensor(0.),
                                           scale=torch.tensor(5.0))
            ]
            squashed_dist = td.TransformedDistribution(
                base_distribution=normal_dist, transforms=transforms)
            return squashed_dist, transforms

        means = torch.Tensor([0.3, 0.7])
        dist, transforms = _get_transformed_normal(means=means,
                                                   stds=torch.Tensor(
                                                       [1.0, 1.0]))

        mode = dist_utils.get_mode(dist)

        transformed_mode = means
        for transform in transforms:
            transformed_mode = transform(transformed_mode)

        self.assertTrue((transformed_mode == mode).all())

        epsilon = 0.0
        action_obtained = dist_utils.epsilon_greedy_sample(dist, epsilon)
        self.assertTrue((transformed_mode == action_obtained).all())
Ejemplo n.º 3
0
 def test_action_sampling_normal(self):
     m = torch.distributions.normal.Normal(torch.Tensor([0.3, 0.7]),
                                           torch.Tensor([1.0, 1.0]))
     M = m.expand([10, 2])
     epsilon = 0.0
     action_expected = torch.Tensor([0.3, 0.7]).repeat(10, 1)
     action_obtained = dist_utils.epsilon_greedy_sample(M, epsilon)
     self.assertTrue((action_expected == action_obtained).all())
Ejemplo n.º 4
0
 def test_action_sampling_categorical(self):
     m = torch.distributions.categorical.Categorical(
         torch.Tensor([0.25, 0.75]))
     M = m.expand([10])
     epsilon = 0.0
     action_expected = torch.Tensor([1]).repeat(10)
     action_obtained = dist_utils.epsilon_greedy_sample(M, epsilon)
     self.assertTrue((action_expected == action_obtained).all())
Ejemplo n.º 5
0
    def predict_step(self, time_step: TimeStep, state: ActorCriticState,
                     epsilon_greedy):
        """Predict for one step."""
        action_dist, actor_state = self._actor_network(time_step.observation,
                                                       state=state.actor)

        action = dist_utils.epsilon_greedy_sample(action_dist, epsilon_greedy)
        return AlgStep(output=action,
                       state=ActorCriticState(actor=actor_state),
                       info=ActorCriticInfo(action_distribution=action_dist))
Ejemplo n.º 6
0
    def test_action_sampling_transformed_categorical(self):
        def _get_transformed_categorical(probs):
            categorical_dist = td.Independent(td.Categorical(probs=probs), 1)
            return categorical_dist

        probs = torch.Tensor([[0.3, 0.5, 0.2], [0.6, 0.4, 0.0]])
        dist = _get_transformed_categorical(probs=probs)
        mode = dist_utils.get_mode(dist)
        expected_mode = torch.argmax(probs, dim=1)

        self.assertTensorEqual(expected_mode, mode)

        epsilon = 0.0
        action_obtained = dist_utils.epsilon_greedy_sample(dist, epsilon)
        self.assertTensorEqual(expected_mode, action_obtained)
Ejemplo n.º 7
0
    def _predict_action(self,
                        observation,
                        state: SacActionState,
                        epsilon_greedy=None,
                        eps_greedy_sampling=False):
        """The reason why we want to do action sampling inside this function
        instead of outside is that for the mixed case, once a continuous action
        is sampled here, we should pair it with the discrete action sampled from
        the Q value. If we just return two distributions and sample outside, then
        the actions will not match.
        """
        new_state = SacActionState()
        if self._act_type != ActionType.Discrete:
            continuous_action_dist, actor_network_state = self._actor_network(
                observation, state=state.actor_network)
            new_state = new_state._replace(actor_network=actor_network_state)
            if eps_greedy_sampling:
                continuous_action = dist_utils.epsilon_greedy_sample(
                    continuous_action_dist, epsilon_greedy)
            else:
                continuous_action = dist_utils.rsample_action_distribution(
                    continuous_action_dist)

        critic_network_inputs = observation
        if self._act_type == ActionType.Mixed:
            critic_network_inputs = (observation, continuous_action)

        q_values = None
        if self._act_type != ActionType.Continuous:
            q_values, critic_state = self._critic_networks(
                critic_network_inputs, state=state.critic)
            new_state = new_state._replace(critic=critic_state)
            if self._act_type == ActionType.Discrete:
                alpha = torch.exp(self._log_alpha).detach()
            else:
                alpha = torch.exp(self._log_alpha[0]).detach()
            # p(a|s) = exp(Q(s,a)/alpha) / Z;
            q_values = q_values.min(dim=1)[0]
            logits = q_values / alpha
            discrete_action_dist = td.Categorical(logits=logits)
            if eps_greedy_sampling:
                discrete_action = dist_utils.epsilon_greedy_sample(
                    discrete_action_dist, epsilon_greedy)
            else:
                discrete_action = dist_utils.sample_action_distribution(
                    discrete_action_dist)

        if self._act_type == ActionType.Mixed:
            # Note that in this case ``action_dist`` is not the valid joint
            # action distribution because ``discrete_action_dist`` is conditioned
            # on a particular continuous action sampled above. So DO NOT use this
            # ``action_dist`` to directly sample an action pair with an arbitrary
            # continuous action anywhere else!
            # However, for computing the log probability of *this* sampled
            # ``action``, it's still valid. It can also be used for summary
            # purpose because of the expectation taken over the continuous action
            # when summarizing.
            action_dist = type(self._action_spec)(
                (discrete_action_dist, continuous_action_dist))
            action = type(self._action_spec)(
                (discrete_action, continuous_action))
        elif self._act_type == ActionType.Discrete:
            action_dist = discrete_action_dist
            action = discrete_action
        else:
            action_dist = continuous_action_dist
            action = continuous_action

        return action_dist, action, q_values, new_state
Ejemplo n.º 8
0
 def predict_step(self, time_step: TimeStep, state, epsilon_greedy):
     action_distribution, state = self._get_action(time_step.observation,
                                                   state)
     action = dist_utils.epsilon_greedy_sample(action_distribution,
                                               epsilon_greedy)
     return AlgStep(output=action, state=state, info=())