def _generate_time_step(batched, observation, step_type, discount, prev_action=None, action_spec=None, reward=None, reward_spec=ts.TensorSpec(()), env_id=None, env_info={}): flat_observation = nest.flatten(observation) if all(map(_is_numpy_array, flat_observation)): md = np if reward is not None: reward = np.float32(reward) discount = np.float32(discount) else: assert all( map(torch.is_tensor, flat_observation)), ("Elements in observation must be Tensor") md = torch if reward is not None: reward = to_tensor(reward, dtype=torch.float32) discount = to_tensor(discount, dtype=torch.float32) if batched: batch_size = flat_observation[0].shape[0] outer_dims = (batch_size, ) if env_id is None: env_id = md.arange(batch_size, dtype=md.int32) if reward is not None: assert reward.shape[:1] == outer_dims if prev_action is not None: flat_action = nest.flatten(prev_action) assert flat_action[0].shape[:1] == outer_dims else: outer_dims = () if env_id is None: env_id = md.zeros((), dtype=md.int32) step_type = md.full(outer_dims, step_type, dtype=md.int32) if reward is None: reward = md.zeros(outer_dims + reward_spec.shape, dtype=md.float32) discount = md.ones(outer_dims, dtype=md.float32) * discount if prev_action is None: prev_action = nest.map_structure( lambda spec: md.zeros(outer_dims + spec.shape, dtype=getattr( md, ts.torch_dtype_to_str(spec.dtype))), action_spec) return TimeStep(step_type, reward, discount, observation, prev_action, env_id, env_info=env_info)
def testIntegerSamplesExcludeMaxOfDtype(self, dtype): # Exclude non integer types and uint8 (has special sampling logic). if dtype.is_floating_point or dtype == torch.uint8: return info = np.iinfo(torch_dtype_to_str(dtype)) spec = BoundedTensorSpec(self._shape, dtype, info.max - 1, info.max - 1) sample = spec.sample(outer_dims=(1, )) self.assertEqual(sample.shape, (1, ) + self._shape) self.assertTrue(torch.all(sample == info.max - 1))
def test_obs_dtype(self): cartpole_env = gym.spec('CartPole-v1').make() env = alf_gym_wrapper.AlfGymWrapper(cartpole_env) time_step = env.reset() self.assertEqual(torch_dtype_to_str(env.observation_spec().dtype), str(time_step.observation.dtype))
def _as_spec_dtype(arr, spec): dtype = torch_dtype_to_str(spec.dtype) if str(arr.dtype) == dtype: return arr else: return arr.astype(dtype)