Пример #1
0
def _generate_time_step(batched,
                        observation,
                        step_type,
                        discount,
                        prev_action=None,
                        action_spec=None,
                        reward=None,
                        reward_spec=ts.TensorSpec(()),
                        env_id=None,
                        env_info={}):

    flat_observation = nest.flatten(observation)

    if all(map(_is_numpy_array, flat_observation)):
        md = np
        if reward is not None:
            reward = np.float32(reward)
        discount = np.float32(discount)
    else:
        assert all(
            map(torch.is_tensor,
                flat_observation)), ("Elements in observation must be Tensor")
        md = torch
        if reward is not None:
            reward = to_tensor(reward, dtype=torch.float32)
        discount = to_tensor(discount, dtype=torch.float32)

    if batched:
        batch_size = flat_observation[0].shape[0]
        outer_dims = (batch_size, )
        if env_id is None:
            env_id = md.arange(batch_size, dtype=md.int32)
        if reward is not None:
            assert reward.shape[:1] == outer_dims
        if prev_action is not None:
            flat_action = nest.flatten(prev_action)
            assert flat_action[0].shape[:1] == outer_dims
    else:
        outer_dims = ()
        if env_id is None:
            env_id = md.zeros((), dtype=md.int32)

    step_type = md.full(outer_dims, step_type, dtype=md.int32)
    if reward is None:
        reward = md.zeros(outer_dims + reward_spec.shape, dtype=md.float32)
    discount = md.ones(outer_dims, dtype=md.float32) * discount
    if prev_action is None:
        prev_action = nest.map_structure(
            lambda spec: md.zeros(outer_dims + spec.shape,
                                  dtype=getattr(
                                      md, ts.torch_dtype_to_str(spec.dtype))),
            action_spec)

    return TimeStep(step_type,
                    reward,
                    discount,
                    observation,
                    prev_action,
                    env_id,
                    env_info=env_info)
Пример #2
0
 def testIntegerSamplesExcludeMaxOfDtype(self, dtype):
     # Exclude non integer types and uint8 (has special sampling logic).
     if dtype.is_floating_point or dtype == torch.uint8:
         return
     info = np.iinfo(torch_dtype_to_str(dtype))
     spec = BoundedTensorSpec(self._shape, dtype, info.max - 1,
                              info.max - 1)
     sample = spec.sample(outer_dims=(1, ))
     self.assertEqual(sample.shape, (1, ) + self._shape)
     self.assertTrue(torch.all(sample == info.max - 1))
Пример #3
0
 def test_obs_dtype(self):
     cartpole_env = gym.spec('CartPole-v1').make()
     env = alf_gym_wrapper.AlfGymWrapper(cartpole_env)
     time_step = env.reset()
     self.assertEqual(torch_dtype_to_str(env.observation_spec().dtype),
                      str(time_step.observation.dtype))
Пример #4
0
 def _as_spec_dtype(arr, spec):
     dtype = torch_dtype_to_str(spec.dtype)
     if str(arr.dtype) == dtype:
         return arr
     else:
         return arr.astype(dtype)