Exemplo n.º 1
0
 def test_batch_properties(self, batch_size):
     obs_spec = ts.BoundedTensorSpec((2, 3), torch.int32, -10, 10)
     action_spec = ts.BoundedTensorSpec((1, ), torch.int64, -10, 10)
     env = RandomAlfEnvironment(
         obs_spec,
         action_spec,
         reward_fn=lambda *_: torch.tensor([1.0], dtype=torch.float32),
         batch_size=batch_size)
     wrap_env = alf_wrappers.AlfEnvironmentBaseWrapper(env)
     self.assertEqual(wrap_env.batched, env.batched)
     self.assertEqual(wrap_env.batch_size, env.batch_size)
Exemplo n.º 2
0
 def _set_default_specs(self):
     self.observation_spec = ts.TensorSpec((3, 3), torch.float32)
     self.action_spec = ts.BoundedTensorSpec([7],
                                             dtype=torch.float32,
                                             minimum=-1.0,
                                             maximum=1.0)
     self.time_step_spec = ds.time_step_spec(self.observation_spec,
                                             self.action_spec,
                                             ts.TensorSpec(()))
Exemplo n.º 3
0
 def test_close_no_hang_after_init(self):
     constructor = functools.partial(
         RandomAlfEnvironment,
         ts.TensorSpec((3, 3), torch.float32),
         ts.BoundedTensorSpec([1], torch.float32, minimum=-1.0,
                              maximum=1.0),
         episode_end_probability=0,
         min_duration=2,
         max_duration=2)
     env = ProcessEnvironment(constructor)
     env.start()
     env.close()
Exemplo n.º 4
0
 def __init__(self, crash_at_step, env_id=None):
     super(MockEnvironmentCrashInStep, self).__init__(
         observation_spec=ts.TensorSpec((3, 3), torch.float32),
         action_spec=ts.BoundedTensorSpec([1],
                                          torch.float32,
                                          minimum=-1.0,
                                          maximum=1.0),
         env_id=env_id,
         episode_end_probability=0,
         min_duration=crash_at_step + 1,
         max_duration=crash_at_step + 1)
     self._crash_at_step = crash_at_step
     self._steps = 0
Exemplo n.º 5
0
def time_step_spec(observation_spec, action_spec, reward_spec):
    """Returns a ``TimeStep`` spec given the ``observation_spec`` and the
    ``action_spec``.
    """
    def is_valid_tensor_spec(spec):
        return isinstance(spec, ts.TensorSpec)

    assert all(map(is_valid_tensor_spec, nest.flatten(observation_spec)))
    assert all(map(is_valid_tensor_spec, nest.flatten(action_spec)))
    return TimeStep(step_type=ts.TensorSpec([], torch.int32),
                    reward=reward_spec,
                    discount=ts.BoundedTensorSpec([],
                                                  torch.float32,
                                                  minimum=0.0,
                                                  maximum=1.0),
                    observation=observation_spec,
                    prev_action=action_spec,
                    env_id=ts.TensorSpec([], torch.int32))
Exemplo n.º 6
0
 def _sample_spec(self, spec, outer_dims):
     """Sample the given TensorSpec."""
     shape = spec.shape
     if not isinstance(spec, ts.BoundedTensorSpec):
         spec = ts.BoundedTensorSpec(shape, spec.dtype)
     return spec.sample(outer_dims=outer_dims)