def _get_action(self, actor_network, time_step: TimeStep, state: SarsaState, epsilon_greedy=1.0): action_distribution, actor_state = actor_network(time_step.observation, state=state.actor) if actor_network.is_distribution_output: if epsilon_greedy == 1.0: action = dist_utils.rsample_action_distribution( action_distribution) else: action = dist_utils.epsilon_greedy_sample( action_distribution, epsilon_greedy) noise_state = () else: def _sample(a, noise): if epsilon_greedy >= 1.0: return a + noise else: choose_random_action = (torch.rand(a.shape[:1]) < epsilon_greedy) return torch.where( common.expand_dims_as(choose_random_action, a), a + noise, a) noise, noise_state = self._noise_process(state.noise) action = nest_map(_sample, action_distribution, noise) return action_distribution, action, actor_state, noise_state
def test_action_sampling_transformed_normal(self): def _get_transformed_normal(means, stds): normal_dist = td.Independent(td.Normal(loc=means, scale=stds), 1) transforms = [ dist_utils.StableTanh(), dist_utils.AffineTransform(loc=torch.tensor(0.), scale=torch.tensor(5.0)) ] squashed_dist = td.TransformedDistribution( base_distribution=normal_dist, transforms=transforms) return squashed_dist, transforms means = torch.Tensor([0.3, 0.7]) dist, transforms = _get_transformed_normal(means=means, stds=torch.Tensor( [1.0, 1.0])) mode = dist_utils.get_mode(dist) transformed_mode = means for transform in transforms: transformed_mode = transform(transformed_mode) self.assertTrue((transformed_mode == mode).all()) epsilon = 0.0 action_obtained = dist_utils.epsilon_greedy_sample(dist, epsilon) self.assertTrue((transformed_mode == action_obtained).all())
def test_action_sampling_normal(self): m = torch.distributions.normal.Normal(torch.Tensor([0.3, 0.7]), torch.Tensor([1.0, 1.0])) M = m.expand([10, 2]) epsilon = 0.0 action_expected = torch.Tensor([0.3, 0.7]).repeat(10, 1) action_obtained = dist_utils.epsilon_greedy_sample(M, epsilon) self.assertTrue((action_expected == action_obtained).all())
def test_action_sampling_categorical(self): m = torch.distributions.categorical.Categorical( torch.Tensor([0.25, 0.75])) M = m.expand([10]) epsilon = 0.0 action_expected = torch.Tensor([1]).repeat(10) action_obtained = dist_utils.epsilon_greedy_sample(M, epsilon) self.assertTrue((action_expected == action_obtained).all())
def predict_step(self, time_step: TimeStep, state: ActorCriticState, epsilon_greedy): """Predict for one step.""" action_dist, actor_state = self._actor_network(time_step.observation, state=state.actor) action = dist_utils.epsilon_greedy_sample(action_dist, epsilon_greedy) return AlgStep(output=action, state=ActorCriticState(actor=actor_state), info=ActorCriticInfo(action_distribution=action_dist))
def test_action_sampling_transformed_categorical(self): def _get_transformed_categorical(probs): categorical_dist = td.Independent(td.Categorical(probs=probs), 1) return categorical_dist probs = torch.Tensor([[0.3, 0.5, 0.2], [0.6, 0.4, 0.0]]) dist = _get_transformed_categorical(probs=probs) mode = dist_utils.get_mode(dist) expected_mode = torch.argmax(probs, dim=1) self.assertTensorEqual(expected_mode, mode) epsilon = 0.0 action_obtained = dist_utils.epsilon_greedy_sample(dist, epsilon) self.assertTensorEqual(expected_mode, action_obtained)
def _predict_action(self, observation, state: SacActionState, epsilon_greedy=None, eps_greedy_sampling=False): """The reason why we want to do action sampling inside this function instead of outside is that for the mixed case, once a continuous action is sampled here, we should pair it with the discrete action sampled from the Q value. If we just return two distributions and sample outside, then the actions will not match. """ new_state = SacActionState() if self._act_type != ActionType.Discrete: continuous_action_dist, actor_network_state = self._actor_network( observation, state=state.actor_network) new_state = new_state._replace(actor_network=actor_network_state) if eps_greedy_sampling: continuous_action = dist_utils.epsilon_greedy_sample( continuous_action_dist, epsilon_greedy) else: continuous_action = dist_utils.rsample_action_distribution( continuous_action_dist) critic_network_inputs = observation if self._act_type == ActionType.Mixed: critic_network_inputs = (observation, continuous_action) q_values = None if self._act_type != ActionType.Continuous: q_values, critic_state = self._critic_networks( critic_network_inputs, state=state.critic) new_state = new_state._replace(critic=critic_state) if self._act_type == ActionType.Discrete: alpha = torch.exp(self._log_alpha).detach() else: alpha = torch.exp(self._log_alpha[0]).detach() # p(a|s) = exp(Q(s,a)/alpha) / Z; q_values = q_values.min(dim=1)[0] logits = q_values / alpha discrete_action_dist = td.Categorical(logits=logits) if eps_greedy_sampling: discrete_action = dist_utils.epsilon_greedy_sample( discrete_action_dist, epsilon_greedy) else: discrete_action = dist_utils.sample_action_distribution( discrete_action_dist) if self._act_type == ActionType.Mixed: # Note that in this case ``action_dist`` is not the valid joint # action distribution because ``discrete_action_dist`` is conditioned # on a particular continuous action sampled above. So DO NOT use this # ``action_dist`` to directly sample an action pair with an arbitrary # continuous action anywhere else! # However, for computing the log probability of *this* sampled # ``action``, it's still valid. It can also be used for summary # purpose because of the expectation taken over the continuous action # when summarizing. action_dist = type(self._action_spec)( (discrete_action_dist, continuous_action_dist)) action = type(self._action_spec)( (discrete_action, continuous_action)) elif self._act_type == ActionType.Discrete: action_dist = discrete_action_dist action = discrete_action else: action_dist = continuous_action_dist action = continuous_action return action_dist, action, q_values, new_state
def predict_step(self, time_step: TimeStep, state, epsilon_greedy): action_distribution, state = self._get_action(time_step.observation, state) action = dist_utils.epsilon_greedy_sample(action_distribution, epsilon_greedy) return AlgStep(output=action, state=state, info=())