def test_batch_properties(self, batch_size): obs_spec = ts.BoundedTensorSpec((2, 3), torch.int32, -10, 10) action_spec = ts.BoundedTensorSpec((1, ), torch.int64, -10, 10) env = RandomAlfEnvironment( obs_spec, action_spec, reward_fn=lambda *_: torch.tensor([1.0], dtype=torch.float32), batch_size=batch_size) wrap_env = alf_wrappers.AlfEnvironmentBaseWrapper(env) self.assertEqual(wrap_env.batched, env.batched) self.assertEqual(wrap_env.batch_size, env.batch_size)
def _set_default_specs(self): self.observation_spec = ts.TensorSpec((3, 3), torch.float32) self.action_spec = ts.BoundedTensorSpec([7], dtype=torch.float32, minimum=-1.0, maximum=1.0) self.time_step_spec = ds.time_step_spec(self.observation_spec, self.action_spec, ts.TensorSpec(()))
def test_close_no_hang_after_init(self): constructor = functools.partial( RandomAlfEnvironment, ts.TensorSpec((3, 3), torch.float32), ts.BoundedTensorSpec([1], torch.float32, minimum=-1.0, maximum=1.0), episode_end_probability=0, min_duration=2, max_duration=2) env = ProcessEnvironment(constructor) env.start() env.close()
def __init__(self, crash_at_step, env_id=None): super(MockEnvironmentCrashInStep, self).__init__( observation_spec=ts.TensorSpec((3, 3), torch.float32), action_spec=ts.BoundedTensorSpec([1], torch.float32, minimum=-1.0, maximum=1.0), env_id=env_id, episode_end_probability=0, min_duration=crash_at_step + 1, max_duration=crash_at_step + 1) self._crash_at_step = crash_at_step self._steps = 0
def time_step_spec(observation_spec, action_spec, reward_spec): """Returns a ``TimeStep`` spec given the ``observation_spec`` and the ``action_spec``. """ def is_valid_tensor_spec(spec): return isinstance(spec, ts.TensorSpec) assert all(map(is_valid_tensor_spec, nest.flatten(observation_spec))) assert all(map(is_valid_tensor_spec, nest.flatten(action_spec))) return TimeStep(step_type=ts.TensorSpec([], torch.int32), reward=reward_spec, discount=ts.BoundedTensorSpec([], torch.float32, minimum=0.0, maximum=1.0), observation=observation_spec, prev_action=action_spec, env_id=ts.TensorSpec([], torch.int32))
def _sample_spec(self, spec, outer_dims): """Sample the given TensorSpec.""" shape = spec.shape if not isinstance(spec, ts.BoundedTensorSpec): spec = ts.BoundedTensorSpec(shape, spec.dtype) return spec.sample(outer_dims=outer_dims)