Beispiel #1
0
  def test_gym_registration_with_kwargs(self):
    reg_id, env = gym_utils.register_gym_env(
        "tensor2tensor.rl.gym_utils_test:EnvWithOptions",
        kwargs={"done_action": 2}
    )

    self.assertEqual("T2TEnv-EnvWithOptions-v0", reg_id)

    # Obligatory reset.
    env.reset()

    # Make sure that on action = 0, 1 we are not done, but on 2 we are.
    _, _, done, _ = env.step(0)
    self.assertFalse(done)

    _, _, done, _ = env.step(1)
    self.assertFalse(done)

    _, _, done, _ = env.step(2)
    self.assertTrue(done)

    # Now lets try to change the env -- note we have to change the version.
    reg_id, env = gym_utils.register_gym_env(
        "tensor2tensor.rl.gym_utils_test:EnvWithOptions",
        version="v1",
        kwargs={"done_action": 1}
    )

    self.assertEqual("T2TEnv-EnvWithOptions-v1", reg_id)

    # Obligatory reset.
    env.reset()

    # Make sure that on action = 0, 2 we are not done, but on 1 we are.
    _, _, done, _ = env.step(0)
    self.assertFalse(done)

    _, _, done, _ = env.step(2)
    self.assertFalse(done)

    _, _, done, _ = env.step(1)
    self.assertTrue(done)
Beispiel #2
0
    def test_gym_registration(self):
        env = gym_utils.register_gym_env(
            "tensor2tensor.rl.gym_utils_test:SimpleEnv")

        # Most basic check.
        self.assertTrue(isinstance(env, gym.Env))

        # Just make sure we got the same environment.
        self.assertTrue(
            np.allclose(env.reset(), np.zeros(shape=(3, 3), dtype=np.uint8)))

        _, _, done, _ = env.step(1)
        self.assertTrue(done)
Beispiel #3
0
    def test_gym_registration_continuous(self):
        reg_id, env = gym_utils.register_gym_env(
            "tensor2tensor.rl.gym_utils_test:SimpleContinuousActionsEnv",
            kwargs={"dimensions": 2})

        self.assertEqual("T2TEnv-SimpleContinuousActionsEnv-v0", reg_id)

        # Most basic check.
        self.assertIsInstance(env, gym.Env)

        # Just make sure we got the same environment.
        self.assertTrue(
            np.allclose(env.reset(), np.zeros(shape=(3, 3), dtype=np.uint8)))

        _, _, done, _ = env.step(1)
        self.assertTrue(done)
Beispiel #4
0
def register():
    # Register this with gym.
    unused_tictactoe_id, unused_tictactoe_env = gym_utils.register_gym_env(
        "tensor2tensor.envs.tic_tac_toe_env:TicTacToeEnv", version="v0")