def testDistributionRaisesError(self): # Set up the smoothing policy. policy = StateIncrementPolicy(self._time_step_spec, self._action_spec) smoothed_policy = temporal_action_smoothing.TemporalActionSmoothing( policy, smoothing_coefficient=0.5) # Create actions sampled in time order. policy_state = smoothed_policy.get_initial_state(batch_size=1) with self.assertRaises(NotImplementedError): smoothed_policy.distribution(self._time_step, policy_state)
def testSmoothedActions(self, smoothing_coefficient, expected_actions): # Set up the smoothing policy. policy = StateIncrementPolicy(self._time_step_spec, self._action_spec) smoothed_policy = temporal_action_smoothing.TemporalActionSmoothing( policy, smoothing_coefficient) # Create actions sampled in time order. policy_state = smoothed_policy.get_initial_state(batch_size=1) smoothed_actions = [] for _ in range(5): action, policy_state, unused_policy_info = smoothed_policy.action( self._time_step, policy_state=policy_state) smoothed_actions.append(action) # Make sure smoothed actions are as expected. smoothed_actions_ = self.evaluate(smoothed_actions) self.assertAllClose(np.squeeze(smoothed_actions_), expected_actions)