def test_unstack_actions(self):
    num_envs = 5
    action_spec = self.action_spec
    rng = np.random.RandomState()
    batched_action = np.array([
        array_spec.sample_bounded_spec(action_spec, rng)
        for _ in range(num_envs)
    ])

    # Test that actions are correctly unstacked when just batched in np.array.
    unstacked_actions = batched_py_environment.unstack_actions(batched_action)
    for action in unstacked_actions:
      self.assertAllEqual(action_spec.shape, action.shape)
 def step_adversary(self, actions):
   if self._num_envs == 1:
     actions = nest_utils.unbatch_nested_array(actions)
     time_steps = self._envs[0].step_adversary(actions)
     return nest_utils.batch_nested_array(time_steps)
   else:
     unstacked_actions = batched_py_environment.unstack_actions(actions)
     if len(unstacked_actions) != self.batch_size:
       raise ValueError(
           'Primary dimension of action items does not match '
           'batch size: %d vs. %d' % (len(unstacked_actions), self.batch_size))
     time_steps = self._execute(
         lambda env_action: env_action[0].step_adversary(env_action[1]),
         zip(self._envs, unstacked_actions))
     return nest_utils.stack_nested_arrays(time_steps)
  def test_unstack_nested_actions(self):
    num_envs = 5
    action_spec = self.action_spec
    rng = np.random.RandomState()
    batched_action = np.array([
        array_spec.sample_bounded_spec(action_spec, rng)
        for _ in range(num_envs)
    ])

    # Test that actions are correctly unstacked when nested in namedtuple.
    class NestedAction(
        collections.namedtuple('NestedAction', ['action', 'other_var'])):
      pass

    nested_action = NestedAction(
        action=batched_action, other_var=np.array([13.0] * num_envs))
    unstacked_actions = batched_py_environment.unstack_actions(nested_action)
    for nested_action in unstacked_actions:
      self.assertAllEqual(action_spec.shape, nested_action.action.shape)
      self.assertEqual(13.0, nested_action.other_var)