def test_unstack_actions(self): num_envs = 5 action_spec = self.action_spec rng = np.random.RandomState() batched_action = np.array([ array_spec.sample_bounded_spec(action_spec, rng) for _ in range(num_envs) ]) # Test that actions are correctly unstacked when just batched in np.array. unstacked_actions = batched_py_environment.unstack_actions(batched_action) for action in unstacked_actions: self.assertAllEqual(action_spec.shape, action.shape)
def step_adversary(self, actions): if self._num_envs == 1: actions = nest_utils.unbatch_nested_array(actions) time_steps = self._envs[0].step_adversary(actions) return nest_utils.batch_nested_array(time_steps) else: unstacked_actions = batched_py_environment.unstack_actions(actions) if len(unstacked_actions) != self.batch_size: raise ValueError( 'Primary dimension of action items does not match ' 'batch size: %d vs. %d' % (len(unstacked_actions), self.batch_size)) time_steps = self._execute( lambda env_action: env_action[0].step_adversary(env_action[1]), zip(self._envs, unstacked_actions)) return nest_utils.stack_nested_arrays(time_steps)
def test_unstack_nested_actions(self): num_envs = 5 action_spec = self.action_spec rng = np.random.RandomState() batched_action = np.array([ array_spec.sample_bounded_spec(action_spec, rng) for _ in range(num_envs) ]) # Test that actions are correctly unstacked when nested in namedtuple. class NestedAction( collections.namedtuple('NestedAction', ['action', 'other_var'])): pass nested_action = NestedAction( action=batched_action, other_var=np.array([13.0] * num_envs)) unstacked_actions = batched_py_environment.unstack_actions(nested_action) for nested_action in unstacked_actions: self.assertAllEqual(action_spec.shape, nested_action.action.shape) self.assertEqual(13.0, nested_action.other_var)