示例#1
0
    def testResetSavesCurrentTimeStep(self):
        obs_spec = BoundedTensorSpec((1, ), torch.int32)
        action_spec = BoundedTensorSpec((1, ), torch.int64)

        random_env = RandomAlfEnvironment(observation_spec=obs_spec,
                                          action_spec=action_spec)

        time_step = random_env.reset()
        current_time_step = random_env.current_time_step()
        nest.map_structure(self.assertEqual, time_step, current_time_step)
示例#2
0
 def testBatchSize(self):
     batch_size = 3
     obs_spec = BoundedTensorSpec((2, 3), torch.int32, -10, 10)
     action_spec = BoundedTensorSpec((1, ), torch.int64)
     env = RandomAlfEnvironment(obs_spec,
                                action_spec,
                                batch_size=batch_size)
     time_step = env.step(torch.tensor(0, dtype=torch.int64))
     self.assertEqual(time_step.observation.shape, (3, 2, 3))
     self.assertEqual(time_step.reward.shape[0], batch_size)
     self.assertEqual(time_step.discount.shape[0], batch_size)
示例#3
0
    def testRendersImage(self):
        action_spec = BoundedTensorSpec((1, ), torch.int64, -10, 10)
        observation_spec = BoundedTensorSpec((1, ), torch.int32, -10, 10)
        env = RandomAlfEnvironment(observation_spec,
                                   action_spec,
                                   render_size=(4, 4, 3))

        env.reset()
        img = env.render()

        self.assertTrue(np.all(img < 256))
        self.assertTrue(np.all(img >= 0))
        self.assertEqual((4, 4, 3), img.shape)
        self.assertEqual(np.uint8, img.dtype)
示例#4
0
    def testRewardFnCalled(self):
        def reward_fn(unused_step_type, action, unused_observation):
            return action

        action_spec = BoundedTensorSpec((1, ), torch.int64, -10, 10)
        observation_spec = BoundedTensorSpec((1, ), torch.int32, -10, 10)
        env = RandomAlfEnvironment(observation_spec,
                                   action_spec,
                                   reward_fn=reward_fn)

        action = np.array(1, dtype=np.int64)
        time_step = env.step(action)  # No reward in first time_step
        self.assertEqual(np.zeros((), dtype=np.float32), time_step.reward)
        time_step = env.step(action)
        self.assertEqual(np.ones((), dtype=np.float32), time_step.reward)
示例#5
0
    def testEnvMaxDuration(self, max_duration):
        obs_spec = BoundedTensorSpec((2, 3), torch.int32, -10, 10)
        action_spec = BoundedTensorSpec([], torch.int32)
        env = RandomAlfEnvironment(obs_spec,
                                   action_spec,
                                   episode_end_probability=0.1,
                                   max_duration=max_duration)
        num_episodes = 100

        action = torch.tensor(0, dtype=torch.int64)
        for _ in range(num_episodes):
            time_step = env.step(action)
            self.assertTrue(time_step.is_first())
            num_steps = 0
            while not time_step.is_last():
                time_step = env.step(action)
                num_steps += 1
            self.assertLessEqual(num_steps, max_duration)
示例#6
0
    def testEnvResetAutomatically(self):
        obs_spec = BoundedTensorSpec((2, 3), torch.int32, -10, 10)
        action_spec = BoundedTensorSpec([], torch.int32)
        env = RandomAlfEnvironment(obs_spec, action_spec)

        action = torch.tensor(0, dtype=torch.int64)
        time_step = env.step(action)
        self.assertTrue(np.all(time_step.observation >= -10))
        self.assertTrue(np.all(time_step.observation <= 10))
        self.assertTrue(time_step.is_first())

        while not time_step.is_last():
            time_step = env.step(action)
            self.assertTrue(np.all(time_step.observation >= -10))
            self.assertTrue(np.all(time_step.observation <= 10))

        time_step = env.step(action)
        self.assertTrue(np.all(time_step.observation >= -10))
        self.assertTrue(np.all(time_step.observation <= 10))
        self.assertTrue(time_step.is_first())
示例#7
0
 def test_batch_properties(self, batch_size):
     obs_spec = ts.BoundedTensorSpec((2, 3), torch.int32, -10, 10)
     action_spec = ts.BoundedTensorSpec((1, ), torch.int64, -10, 10)
     env = RandomAlfEnvironment(
         obs_spec,
         action_spec,
         reward_fn=lambda *_: torch.tensor([1.0], dtype=torch.float32),
         batch_size=batch_size)
     wrap_env = alf_wrappers.AlfEnvironmentBaseWrapper(env)
     self.assertEqual(wrap_env.batched, env.batched)
     self.assertEqual(wrap_env.batch_size, env.batch_size)
示例#8
0
 def testRewardCheckerSizeMismatch(self):
     # Ensure custom scalar reward with batch_size greater than 1 raises
     # ValueError
     obs_spec = BoundedTensorSpec((2, 3), torch.int32, -10, 10)
     action_spec = BoundedTensorSpec((1, ), torch.int64)
     env = RandomAlfEnvironment(obs_spec,
                                action_spec,
                                reward_fn=lambda *_: np.array([1.0]),
                                batch_size=5)
     env.reset()
     env._done = False
     action = torch.tensor(0, dtype=torch.int64)
     with self.assertRaises(ValueError):
         env.step(action)
示例#9
0
 def testRewardCheckerBatchSizeOne(self):
     # Ensure batch size 1 with scalar reward works
     obs_spec = BoundedTensorSpec((2, 3), torch.int32, -10, 10)
     action_spec = BoundedTensorSpec((1, ), torch.int64)
     env = RandomAlfEnvironment(obs_spec,
                                action_spec,
                                reward_fn=lambda *_: np.array([1.0]),
                                batch_size=1)
     env._done = False
     env.reset()
     action = torch.tensor([0], dtype=torch.int64)
     time_step = env.step(action)
     self.assertEqual(time_step.reward, 1.0)
示例#10
0
 def testCustomRewardFn(self):
     obs_spec = BoundedTensorSpec((2, 3), torch.int32, -10, 10)
     action_spec = BoundedTensorSpec((1, ), torch.int64)
     batch_size = 3
     env = RandomAlfEnvironment(obs_spec,
                                action_spec,
                                reward_fn=lambda *_: np.ones(batch_size),
                                batch_size=batch_size)
     env._done = False
     env.reset()
     action = torch.ones(batch_size)
     time_step = env.step(action)
     self.assertSequenceAlmostEqual([1.0] * 3, time_step.reward)