def testDistributionRaisesError(self):
    # Set up the smoothing policy.
    policy = StateIncrementPolicy(self._time_step_spec, self._action_spec)
    smoothed_policy = temporal_action_smoothing.TemporalActionSmoothing(
        policy, smoothing_coefficient=0.5)

    # Create actions sampled in time order.
    policy_state = smoothed_policy.get_initial_state(batch_size=1)
    with self.assertRaises(NotImplementedError):
      smoothed_policy.distribution(self._time_step, policy_state)
  def testSmoothedActions(self, smoothing_coefficient, expected_actions):
    # Set up the smoothing policy.
    policy = StateIncrementPolicy(self._time_step_spec, self._action_spec)
    smoothed_policy = temporal_action_smoothing.TemporalActionSmoothing(
        policy, smoothing_coefficient)

    # Create actions sampled in time order.
    policy_state = smoothed_policy.get_initial_state(batch_size=1)
    smoothed_actions = []
    for _ in range(5):
      action, policy_state, unused_policy_info = smoothed_policy.action(
          self._time_step, policy_state=policy_state)
      smoothed_actions.append(action)

    # Make sure smoothed actions are as expected.
    smoothed_actions_ = self.evaluate(smoothed_actions)
    self.assertAllClose(np.squeeze(smoothed_actions_), expected_actions)