Exemplo n.º 1
0
  def testBatchedEnvironment(self, max_steps, max_episodes, expected_length):

    expected_trajectories = [
        trajectory.Trajectory(
            step_type=np.array([0, 0]),
            observation=np.array([0, 0]),
            action=np.array([2, 1]),
            policy_info=np.array([4, 2]),
            next_step_type=np.array([1, 1]),
            reward=np.array([1., 1.]),
            discount=np.array([1., 1.])),
        trajectory.Trajectory(
            step_type=np.array([1, 1]),
            observation=np.array([2, 1]),
            action=np.array([1, 2]),
            policy_info=np.array([2, 4]),
            next_step_type=np.array([2, 1]),
            reward=np.array([1., 1.]),
            discount=np.array([0., 1.])),
        trajectory.Trajectory(
            step_type=np.array([2, 1]),
            observation=np.array([3, 3]),
            action=np.array([2, 1]),
            policy_info=np.array([4, 2]),
            next_step_type=np.array([0, 2]),
            reward=np.array([0., 1.]),
            discount=np.array([1., 0.]))
    ]

    env1 = driver_test_utils.PyEnvironmentMock(final_state=3)
    env2 = driver_test_utils.PyEnvironmentMock(final_state=4)
    env = batched_py_environment.BatchedPyEnvironment([env1, env2])
    tf_env = tf_py_environment.TFPyEnvironment(env)

    policy = driver_test_utils.TFPolicyMock(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        batch_size=2,
        initial_policy_state=tf.constant([1, 2], dtype=tf.int32))

    replay_buffer_observer = MockReplayBufferObserver()

    driver = tf_driver.TFDriver(
        tf_env,
        policy,
        observers=[replay_buffer_observer],
        max_steps=max_steps,
        max_episodes=max_episodes,
    )
    initial_time_step = tf_env.reset()
    initial_policy_state = tf.constant([1, 2], dtype=tf.int32)
    self.evaluate(driver.run(initial_time_step, initial_policy_state))
    trajectories = replay_buffer_observer.gather_all()

    self.assertEqual(
        len(trajectories), len(expected_trajectories[:expected_length]))

    for t1, t2 in zip(trajectories, expected_trajectories[:expected_length]):
      for t1_field, t2_field in zip(t1, t2):
        self.assertAllEqual(t1_field, t2_field)
Exemplo n.º 2
0
    def testOneStepUpdatesObservers(self):
        if tf.executing_eagerly():
            self.skipTest('b/123880556')

        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        policy_state_ph = tensor_spec.to_nest_placeholder(
            policy.policy_state_spec,
            default=0,
            name_scope='policy_state_ph',
            outer_dims=(1, ))
        num_episodes_observer = driver_test_utils.NumEpisodesObserver()

        driver = dynamic_step_driver.DynamicStepDriver(
            env, policy, observers=[num_episodes_observer])
        run_driver = driver.run(policy_state=policy_state_ph)

        with self.session() as session:
            session.run(tf.compat.v1.global_variables_initializer())
            _, policy_state = session.run(run_driver)
            for _ in range(4):
                _, policy_state = session.run(
                    run_driver, feed_dict={policy_state_ph: policy_state})
            self.assertEqual(self.evaluate(num_episodes_observer.num_episodes),
                             2)
    def testMultipleRunMaxEpisodes(self):
        num_episodes = 2
        num_expected_steps = 6

        env = driver_test_utils.PyEnvironmentMock()
        tf_env = tf_py_environment.TFPyEnvironment(env)
        policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(),
                                                tf_env.action_spec())

        replay_buffer_observer = MockReplayBufferObserver()
        driver = tf_driver.TFDriver(
            tf_env,
            policy,
            observers=[replay_buffer_observer],
            max_steps=None,
            max_episodes=1,
        )

        time_step = tf_env.reset()
        policy_state = policy.get_initial_state(batch_size=1)
        for _ in range(num_episodes):
            time_step, policy_state = self.evaluate(
                driver.run(time_step, policy_state))
        trajectories = replay_buffer_observer.gather_all()
        self.assertEqual(trajectories, self._trajectories[:num_expected_steps])
Exemplo n.º 4
0
    def testOneStepUpdatesObservers(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        num_episodes_observer = driver_test_utils.NumEpisodesObserver()
        num_steps_observer = driver_test_utils.NumStepsObserver()
        num_steps_transition_observer = (
            driver_test_utils.NumStepsTransitionObserver())

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env,
            policy,
            observers=[num_episodes_observer, num_steps_observer],
            transition_observers=[num_steps_transition_observer])

        self.evaluate(tf.compat.v1.global_variables_initializer())
        for _ in range(5):
            self.evaluate(driver.run())

        self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 5)
        # Two steps per episode.
        self.assertEqual(self.evaluate(num_steps_observer.num_steps), 10)
        self.assertEqual(
            self.evaluate(num_steps_transition_observer.num_steps), 10)
Exemplo n.º 5
0
    def testTwoStepObservers(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        num_episodes_observer0 = driver_test_utils.NumEpisodesObserver(
            variable_scope='observer0')
        num_episodes_observer1 = driver_test_utils.NumEpisodesObserver(
            variable_scope='observer1')
        num_steps_transition_observer = (
            driver_test_utils.NumStepsTransitionObserver())

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env,
            policy,
            num_episodes=5,
            observers=[num_episodes_observer0, num_episodes_observer1],
            transition_observers=[num_steps_transition_observer])
        run_driver = driver.run()

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.evaluate(run_driver)
        self.assertEqual(self.evaluate(num_episodes_observer0.num_episodes), 5)
        self.assertEqual(self.evaluate(num_episodes_observer1.num_episodes), 5)
        self.assertEqual(
            self.evaluate(num_steps_transition_observer.num_steps), 10)
    def test_with_dynamic_step_driver(self):
        env = driver_test_utils.PyEnvironmentMock()
        tf_env = tf_py_environment.TFPyEnvironment(env)
        policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(),
                                                tf_env.action_spec())

        trajectory_spec = trajectory.from_transition(tf_env.time_step_spec(),
                                                     policy.policy_step_spec,
                                                     tf_env.time_step_spec())

        tfrecord_observer = example_encoding_dataset.TFRecordObserver(
            self.dataset_path, trajectory_spec)
        driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            policy,
            observers=[common.function(tfrecord_observer)],
            num_steps=10)
        self.evaluate(tf.compat.v1.global_variables_initializer())

        time_step = self.evaluate(tf_env.reset())
        initial_policy_state = policy.get_initial_state(batch_size=1)
        self.evaluate(
            common.function(driver.run)(time_step, initial_policy_state))
        tfrecord_observer.flush()
        tfrecord_observer.close()

        dataset = example_encoding_dataset.load_tfrecord_dataset(
            [self.dataset_path], buffer_size=2, as_trajectories=True)
        iterator = eager_utils.dataset_iterator(dataset)
        sample = self.evaluate(eager_utils.get_next(iterator))
        self.assertIsInstance(sample, trajectory.Trajectory)
    def test_with_py_driver(self):
        env = driver_test_utils.PyEnvironmentMock()
        policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        trajectory_spec = trajectory.from_transition(env.time_step_spec(),
                                                     policy.policy_step_spec,
                                                     env.time_step_spec())
        trajectory_spec = tensor_spec.from_spec(trajectory_spec)

        tfrecord_observer = example_encoding_dataset.TFRecordObserver(
            self.dataset_path, trajectory_spec, py_mode=True)

        driver = py_driver.PyDriver(env,
                                    policy, [tfrecord_observer],
                                    max_steps=10)
        time_step = env.reset()
        driver.run(time_step)
        tfrecord_observer.flush()
        tfrecord_observer.close()

        dataset = example_encoding_dataset.load_tfrecord_dataset(
            [self.dataset_path], buffer_size=2, as_trajectories=True)

        iterator = eager_utils.dataset_iterator(dataset)
        sample = self.evaluate(eager_utils.get_next(iterator))
        self.assertIsInstance(sample, trajectory.Trajectory)
Exemplo n.º 8
0
    def testToTransitionHandlesTrajectoryFromDriverCorrectly(self):
        env = tf_py_environment.TFPyEnvironment(
            drivers_test_utils.PyEnvironmentMock())
        policy = drivers_test_utils.TFPolicyMock(env.time_step_spec(),
                                                 env.action_spec())
        replay_buffer = drivers_test_utils.make_replay_buffer(policy)

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env, policy, num_episodes=3, observers=[replay_buffer.add_batch])

        run_driver = driver.run()
        rb_gather_all = replay_buffer.gather_all()

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.evaluate(run_driver)
        trajectories = self.evaluate(rb_gather_all)

        time_steps, policy_step, next_time_steps = trajectory.to_transition(
            trajectories)

        self.assertAllEqual(time_steps.observation,
                            trajectories.observation[:, :-1])
        self.assertAllEqual(time_steps.step_type,
                            trajectories.step_type[:, :-1])
        self.assertAllEqual(next_time_steps.observation,
                            trajectories.observation[:, 1:])
        self.assertAllEqual(next_time_steps.step_type,
                            trajectories.step_type[:, 1:])
        self.assertAllEqual(next_time_steps.reward,
                            trajectories.reward[:, :-1])
        self.assertAllEqual(next_time_steps.discount,
                            trajectories.discount[:, :-1])

        self.assertAllEqual(policy_step.action, trajectories.action[:, :-1])
        self.assertAllEqual(policy_step.info, trajectories.policy_info[:, :-1])
Exemplo n.º 9
0
    def testMultiStepReplayBufferObservers(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        replay_buffer = driver_test_utils.make_replay_buffer(policy)

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env, policy, num_episodes=3, observers=[replay_buffer.add_batch])

        run_driver = driver.run()
        rb_gather_all = replay_buffer.gather_all()

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.evaluate(run_driver)
        trajectories = self.evaluate(rb_gather_all)

        self.assertAllEqual(trajectories.step_type,
                            [[0, 1, 2, 0, 1, 2, 0, 1, 2]])
        self.assertAllEqual(trajectories.action, [[1, 2, 1, 1, 2, 1, 1, 2, 1]])
        self.assertAllEqual(trajectories.observation,
                            [[0, 1, 3, 0, 1, 3, 0, 1, 3]])
        self.assertAllEqual(trajectories.policy_info,
                            [[2, 4, 2, 2, 4, 2, 2, 4, 2]])
        self.assertAllEqual(trajectories.next_step_type,
                            [[1, 2, 0, 1, 2, 0, 1, 2, 0]])
        self.assertAllEqual(trajectories.reward,
                            [[1., 1., 0., 1., 1., 0., 1., 1., 0.]])
        self.assertAllEqual(trajectories.discount,
                            [[1., 0., 1., 1., 0., 1., 1., 0., 1.]])
Exemplo n.º 10
0
    def testRunOnce(self, max_steps, max_episodes, expected_steps):
        env = driver_test_utils.PyEnvironmentMock()
        policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        replay_buffer_observer = MockReplayBufferObserver()
        transition_replay_buffer_observer = MockReplayBufferObserver()
        driver = py_driver.PyDriver(
            env,
            policy,
            observers=[replay_buffer_observer],
            transition_observers=[transition_replay_buffer_observer],
            max_steps=max_steps,
            max_episodes=max_episodes,
        )

        initial_time_step = env.reset()
        initial_policy_state = policy.get_initial_state()
        driver.run(initial_time_step, initial_policy_state)
        trajectories = replay_buffer_observer.gather_all()
        self.assertEqual(trajectories, self._trajectories[:expected_steps])

        transitions = transition_replay_buffer_observer.gather_all()
        self.assertLen(transitions, expected_steps)
        # TimeStep, Action, NextTimeStep
        self.assertLen(transitions[0], 3)
Exemplo n.º 11
0
    def testToTransitionSpec(self):
        env = tf_py_environment.TFPyEnvironment(
            drivers_test_utils.PyEnvironmentMock())
        policy = drivers_test_utils.TFPolicyMock(env.time_step_spec(),
                                                 env.action_spec())
        trajectory_spec = policy.trajectory_spec
        ts_spec, ps_spec, nts_spec = trajectory.to_transition_spec(
            trajectory_spec)

        self.assertAllEqual(ts_spec, env.time_step_spec())
        self.assertAllEqual(ps_spec.action, env.action_spec())
        self.assertAllEqual(nts_spec, env.time_step_spec())
Exemplo n.º 12
0
 def testValueErrorOnInvalidArgs(self, max_steps, max_episodes):
     env = driver_test_utils.PyEnvironmentMock()
     policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                             env.action_spec())
     replay_buffer_observer = MockReplayBufferObserver()
     with self.assertRaises(ValueError):
         py_driver.PyDriver(
             env,
             policy,
             observers=[replay_buffer_observer],
             max_steps=max_steps,
             max_episodes=max_episodes,
         )
Exemplo n.º 13
0
    def testMultiStepUpdatesObservers(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        num_episodes_observer = driver_test_utils.NumEpisodesObserver()

        driver = dynamic_step_driver.DynamicStepDriver(
            env, policy, num_steps=5, observers=[num_episodes_observer])

        run_driver = driver.run()

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.evaluate(run_driver)
        self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 2)
Exemplo n.º 14
0
    def testMultiStepUpdatesObservers(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        num_episodes_observer = NumEpisodesObserver()
        num_steps_observer = NumStepsObserver()

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env, policy, observers=[num_episodes_observer, num_steps_observer])

        run_driver = driver.run(num_episodes=5)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(run_driver)
        self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 5)
        # Two steps per episode.
        self.assertEqual(self.evaluate(num_steps_observer.num_steps), 10)
Exemplo n.º 15
0
    def testPolicyState(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())

        num_episodes_observer = driver_test_utils.NumEpisodesObserver()
        num_steps_observer = driver_test_utils.NumStepsObserver()

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env, policy, observers=[num_episodes_observer, num_steps_observer])
        run_driver = driver.run()

        self.evaluate(tf.compat.v1.global_variables_initializer())

        time_step, policy_state = self.evaluate(run_driver)

        self.assertEqual(time_step.step_type, 0)
        self.assertEqual(policy_state, [3])
Exemplo n.º 16
0
    def testOneStepReplayBufferObservers(self):
        if tf.executing_eagerly():
            self.skipTest('b/123880556')

        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        policy_state_ph = tensor_spec.to_nest_placeholder(
            policy.policy_state_spec,
            default=0,
            name_scope='policy_state_ph',
            outer_dims=(1, ))
        replay_buffer = driver_test_utils.make_replay_buffer(policy)

        driver = dynamic_step_driver.DynamicStepDriver(
            env, policy, num_steps=1, observers=[replay_buffer.add_batch])

        run_driver = driver.run(policy_state=policy_state_ph)
        rb_gather_all = replay_buffer.gather_all()

        with self.session() as session:
            session.run(tf.compat.v1.global_variables_initializer())
            _, policy_state = session.run(run_driver)
            for _ in range(5):
                _, policy_state = session.run(
                    run_driver, feed_dict={policy_state_ph: policy_state})

            trajectories = self.evaluate(rb_gather_all)

        self.assertAllEqual(trajectories.step_type, [[0, 1, 2, 0, 1, 2, 0, 1]])
        self.assertAllEqual(trajectories.observation,
                            [[0, 1, 3, 0, 1, 3, 0, 1]])
        self.assertAllEqual(trajectories.action, [[1, 2, 1, 1, 2, 1, 1, 2]])
        self.assertAllEqual(trajectories.policy_info,
                            [[2, 4, 2, 2, 4, 2, 2, 4]])
        self.assertAllEqual(trajectories.next_step_type,
                            [[1, 2, 0, 1, 2, 0, 1, 2]])
        self.assertAllEqual(trajectories.reward,
                            [[1., 1., 0., 1., 1., 0., 1., 1.]])
        self.assertAllEqual(trajectories.discount,
                            [[1., 0., 1, 1, 0, 1., 1., 0.]])
Exemplo n.º 17
0
    def testOneStepUpdatesObservers(self):
        if tf.executing_eagerly():
            self.skipTest('b/123880410')
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        num_episodes_observer = driver_test_utils.NumEpisodesObserver()
        num_steps_observer = driver_test_utils.NumStepsObserver()

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env, policy, observers=[num_episodes_observer, num_steps_observer])
        run_driver = driver.run()

        self.evaluate(tf.compat.v1.global_variables_initializer())
        for _ in range(5):
            self.evaluate(run_driver)

        self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 5)
        # Two steps per episode.
        self.assertEqual(self.evaluate(num_steps_observer.num_steps), 10)
Exemplo n.º 18
0
  def testPolicyStateReset(self):
    num_episodes = 2
    num_expected_steps = 6

    env = driver_test_utils.PyEnvironmentMock()
    policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                            env.action_spec())
    replay_buffer_observer = MockReplayBufferObserver()
    driver = py_driver.PyDriver(
        env,
        policy,
        observers=[replay_buffer_observer],
        max_steps=None,
        max_episodes=num_episodes,
    )

    time_step = env.reset()
    policy_state = policy.get_initial_state()
    time_step, policy_state = driver.run(time_step, policy_state)
    trajectories = replay_buffer_observer.gather_all()
    self.assertEqual(trajectories, self._trajectories[:num_expected_steps])
    self.assertEqual(num_episodes, policy.get_initial_state_call_count)
Exemplo n.º 19
0
    def testTwoObservers(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        policy_state = policy.get_initial_state(1)
        num_episodes_observer0 = NumEpisodesObserver(
            variable_scope='observer0')
        num_episodes_observer1 = NumEpisodesObserver(
            variable_scope='observer1')

        driver = dynamic_step_driver.DynamicStepDriver(
            env,
            policy,
            num_steps=5,
            observers=[num_episodes_observer0, num_episodes_observer1])
        run_driver = driver.run(policy_state=policy_state)

        self.evaluate(tf.global_variables_initializer())
        self.evaluate(run_driver)
        self.assertEqual(self.evaluate(num_episodes_observer0.num_episodes), 2)
        self.assertEqual(self.evaluate(num_episodes_observer1.num_episodes), 2)
Exemplo n.º 20
0
  def testMultipleRunMaxSteps(self):
    num_steps = 3
    num_expected_steps = 4

    env = driver_test_utils.PyEnvironmentMock()
    policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                            env.action_spec())
    replay_buffer_observer = MockReplayBufferObserver()
    driver = py_driver.PyDriver(
        env,
        policy,
        observers=[replay_buffer_observer],
        max_steps=1,
        max_episodes=None,
    )

    time_step = env.reset()
    policy_state = policy.get_initial_state()
    for _ in range(num_steps):
      time_step, policy_state = driver.run(time_step, policy_state)
    trajectories = replay_buffer_observer.gather_all()
    self.assertEqual(trajectories, self._trajectories[:num_expected_steps])
    def testMultiStepEpisodicReplayBuffer(self):
        num_episodes = 5
        num_driver_episodes = 5

        # Create mock environment.
        py_env = batched_py_environment.BatchedPyEnvironment([
            driver_test_utils.PyEnvironmentMock(final_state=i + 1)
            for i in range(num_episodes)
        ])
        env = tf_py_environment.TFPyEnvironment(py_env)

        # Creat mock policy.
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec(),
                                                batch_size=num_episodes)

        # Create replay buffer and driver.
        replay_buffer = self._make_replay_buffer(env)
        stateful_buffer = episodic_replay_buffer.StatefulEpisodicReplayBuffer(
            replay_buffer, num_episodes)
        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env,
            policy,
            num_episodes=num_driver_episodes,
            observers=[stateful_buffer.add_batch])

        run_driver = driver.run()

        end_episodes = replay_buffer._maybe_end_batch_episodes(
            stateful_buffer.episode_ids, end_episode=True)

        completed_episodes = replay_buffer._completed_episodes()

        self.evaluate([
            tf.compat.v1.local_variables_initializer(),
            tf.compat.v1.global_variables_initializer()
        ])

        self.evaluate(run_driver)

        self.evaluate(end_episodes)
        completed_episodes = self.evaluate(completed_episodes)
        eps = [replay_buffer._get_episode(ep) for ep in completed_episodes]
        eps = self.evaluate(eps)

        episodes_length = [tf.nest.flatten(ep)[0].shape[0] for ep in eps]

        # Compare with expected output.
        self.assertAllEqual(completed_episodes, [3, 4, 5, 6, 7])
        self.assertAllEqual(episodes_length, [4, 4, 2, 1, 1])

        first = ts.StepType.FIRST
        mid = ts.StepType.MID
        last = ts.StepType.LAST

        step_types = [ep.step_type for ep in eps]
        observations = [ep.observation for ep in eps]
        rewards = [ep.reward for ep in eps]
        actions = [ep.action for ep in eps]

        self.assertAllClose([[first, mid, mid, last], [first, mid, mid, mid],
                             [first, last], [first], [first]], step_types)

        self.assertAllClose([
            [0, 1, 3, 4],
            [0, 1, 3, 4],
            [0, 1],
            [0],
            [0],
        ], observations)

        self.assertAllClose([
            [1, 2, 1, 2],
            [1, 2, 1, 2],
            [1, 2],
            [1],
            [1],
        ], actions)

        self.assertAllClose([
            [1, 1, 1, 0],
            [1, 1, 1, 1],
            [1, 0],
            [1],
            [1],
        ], rewards)