def test_a2ctrainer_cartpole(self): """Test-runs a2c on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1, max_steps=2) policy_model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax()) value_model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(1)) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-4, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.AdvantageActorCriticTrainer( task, n_shared_layers=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=2, value_train_steps_per_epoch=2, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=2, policy_train_steps_per_epoch=2, collect_per_epoch=2) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_trajectory_batch_stream_shape(self): task = rl_task.RLTask(DummyEnv(), initial_trajectories=1, max_steps=10) batch_stream = task.trajectory_batch_stream(batch_size=3, min_slice_length=4, max_slice_length=4) batch = next(batch_stream) self.assertEqual(batch.observations.shape, (3, 4, 2))
def test_policytrainer_cartpole(self): """Trains a policy on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1, max_steps=200) model = functools.partial( models.Policy, body=lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(64), tl.Relu() ), ) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') max_avg_returns = -math.inf for _ in range(5): trainer = training.PolicyGradient( task, policy_model=model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=128, policy_train_steps_per_epoch=1, n_trajectories_per_epoch=2) # Assert that we get to 200 at some point and then exit so the test is as # fast as possible. for ep in range(200): trainer.run(1) self.assertEqual(trainer.current_epoch, ep + 1) if trainer.avg_returns[-1] == 200.0: return max_avg_returns = max(max_avg_returns, trainer.avg_returns[-1]) self.fail( 'The expected score of 200 has not been reached. ' 'Maximum at end was {}.'.format(max_avg_returns) )
def test_trajectory_stream_sampling_uniform(self): """Test if the trajectory stream samples uniformly.""" # Long trajectory of 0s. tr1 = rl_task.Trajectory(0) for _ in range(100): tr1.extend( action=0, dist_inputs=0, reward=0, done=False, new_observation=0 ) tr1.extend( action=0, dist_inputs=0, reward=0, done=True, new_observation=200 ) # Short trajectory of 101. tr2 = rl_task.Trajectory(101) tr2.extend( action=0, dist_inputs=0, reward=0, done=True, new_observation=200 ) task = rl_task.RLTask( DummyEnv(), initial_trajectories=[tr1, tr2], max_steps=9) # Stream of both. Check that we're sampling by slice, not by trajectory. stream = task.trajectory_stream(max_slice_length=1) slices = [] for _ in range(10): next_slice = next(stream) assert len(next_slice) == 1 slices.append(next_slice.last_observation) mean_obs = sum(slices) / float(len(slices)) # Average should be around 1 sampling from 0x100, 101 uniformly. self.assertLess(mean_obs, 31) # Sampling 101 even 3 times is unlikely. self.assertLen(slices, 10)
def test_trajectory_stream_margin(self): """Test trajectory stream with an added margin.""" tr1 = rl_task.Trajectory(0) tr1.extend( action=0, dist_inputs=0, reward=0, done=False, new_observation=1 ) tr1.extend( action=1, dist_inputs=2, reward=3, done=True, new_observation=1 ) task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) # Stream of slices without the final state. stream1 = task.trajectory_stream( max_slice_length=3, margin=2, include_final_state=False) got_done = False for _ in range(10): next_slice = next(stream1) self.assertLen(next_slice, 3) if next_slice.timesteps[0].done: for i in range(1, 3): self.assertTrue(next_slice.timesteps[i].done) self.assertFalse(next_slice.timesteps[i].mask) got_done = True # Assert that we got a done somewhere, otherwise the test is not triggered. # Not getting done has low probability (1/2^10) but is possible, flaky test. self.assertTrue(got_done)
def test_trajectory_stream_final_state(self): """Test trajectory stream with and without the final state.""" tr1 = rl_task.Trajectory(0) tr1.extend( action=0, dist_inputs=0, reward=0, done=True, new_observation=1 ) task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) # Stream of slices without the final state. stream1 = task.trajectory_stream( max_slice_length=1, include_final_state=False) for _ in range(10): next_slice = next(stream1) self.assertLen(next_slice, 1) self.assertEqual(next_slice.last_observation, 0) # Stream of slices with the final state. stream2 = task.trajectory_stream( max_slice_length=1, include_final_state=True) all_sum = 0 for _ in range(100): next_slice = next(stream2) self.assertLen(next_slice, 1) all_sum += next_slice.last_observation self.assertEqual(min(all_sum, 1), 1) # We've seen the end at least once.
def test_sanity_a2ctrainer_cartpole(self): """Test-runs a2c on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, max_steps=2) body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-4, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.A2C(task, n_shared_layers=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=2, value_train_steps_per_epoch=2, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=2, policy_train_steps_per_epoch=2, n_trajectories_per_epoch=2) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_policytrainer_cartpole(self): """Trains a policy on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1, max_steps=200) # TODO(pkozakowski): Use Distribution.n_inputs to initialize the action # head. model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax()) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = training.PolicyGradientTrainer( task, policy_model=model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=128, policy_train_steps_per_epoch=1, collect_per_epoch=2) # Assert that we get to 200 at some point and then exit so the test is as # fast as possible. for ep in range(200): trainer.run(1) self.assertEqual(trainer.current_epoch, ep + 1) if trainer.avg_returns[-1] == 200.0: return self.fail( 'The expected score of 200 has not been reached. ' 'Maximum was {}.'.format(max(trainer.avg_returns)) )
def test_sampling_awrtrainer_mountain_acr(self): """Test-runs Sampling AWR on MountainCarContinuous.""" task = rl_task.RLTask('MountainCarContinuous-v0', initial_trajectories=0, max_steps=2) body = lambda mode: tl.Serial(tl.Dense(2), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.SamplingAWR( task, n_shared_layers=0, added_policy_slice_length=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=2, value_train_steps_per_epoch=2, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=2, policy_train_steps_per_epoch=2, n_trajectories_per_epoch=2, advantage_estimator=advantages.monte_carlo, advantage_normalization=False, q_value_n_samples=3, ) trainer.run(1) self.assertEqual(1, trainer.current_epoch)
def test_policy_gradient_cartpole(self): """Trains a policy on cartpole.""" task = rl_task.RLTask('CartPole-v0', max_steps=200) lr = lambda: lr_schedules.multifactor(constant=1e-2, factors='constant') max_avg_returns = -math.inf for _ in range(2): agent = training.PolicyGradient( task, model_fn=self._model_fn, optimizer=opt.Adam, lr_schedule=lr, batch_size=128, n_trajectories_per_epoch=2, ) # Assert that we get to 200 at some point and then exit so the test is as # fast as possible. for ep in range(200): agent.run(1) self.assertEqual(agent.current_epoch, ep + 1) if agent.avg_returns[-1] == 200.0: return max_avg_returns = max(max_avg_returns, agent.avg_returns[-1]) self.fail('The expected score of 200 has not been reached. ' 'Maximum at end was {}.'.format(max_avg_returns))
def test_task_random_initial_trajectories_and_max_steps(self): """Test generating initial random trajectories, stop at max steps.""" task = rl_task.RLTask(DummyEnv(), initial_trajectories=1, max_steps=9) stream = task.trajectory_stream(max_slice_length=1) next_slice = next(stream) self.assertLen(next_slice, 1) self.assertEqual(next_slice.last_observation.shape, (2,))
def test_awrjoint_save_restore(self): """Check save and restore of joint AWR trainer.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=100, max_steps=200) joint_model = functools.partial( models.PolicyAndValue, body=lambda mode: tl.Serial(tl.Dense(64), tl.Relu()), ) tmp_dir = self.create_tempdir().full_path trainer1 = actor_critic_joint.AWRJointTrainer( task, joint_model=joint_model, optimizer=opt.Adam, batch_size=4, train_steps_per_epoch=1, n_trajectories_per_epoch=2, output_dir=tmp_dir) trainer1.run(2) self.assertEqual(trainer1.current_epoch, 2) self.assertEqual(trainer1._trainer.step, 2) # Trainer 2 starts where trainer 1 stopped. trainer2 = actor_critic_joint.AWRJointTrainer( task, joint_model=joint_model, optimizer=opt.Adam, batch_size=4, train_steps_per_epoch=1, n_trajectories_per_epoch=2, output_dir=tmp_dir) trainer2.run(1) self.assertEqual(trainer2.current_epoch, 3) self.assertEqual(trainer2._trainer.step, 3) trainer1.close() trainer2.close()
def test_sanity_awrtrainer_transformer_cartpole(self): """Test-runs AWR on cartpole with Transformer.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=2, max_steps=2) body = lambda mode: models.TransformerDecoder( # pylint: disable=g-long-lambda d_model=2, d_ff=2, n_layers=1, n_heads=1, mode=mode) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.AWR(task, n_shared_layers=0, max_slice_length=2, added_policy_slice_length=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=2, value_train_steps_per_epoch=2, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=2, policy_train_steps_per_epoch=2, n_trajectories_per_epoch=1, n_eval_episodes=1) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_sanity_ppo_cartpole(self): """Run PPO and check whether it correctly runs for 2 epochs.s.""" task = rl_task.RLTask('CartPole-v1', initial_trajectories=0, max_steps=200) lr = lambda: lr_schedules.multifactor( # pylint: disable=g-long-lambda constant=1e-3, warmup_steps=100, factors='constant * linear_warmup') body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) trainer = actor_critic.PPO(task, n_shared_layers=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=128, value_train_steps_per_epoch=10, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=128, policy_train_steps_per_epoch=10, n_trajectories_per_epoch=10) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_trajectory_slice_stream_margin(self): """Test trajectory stream with an added margin.""" tr1 = rl_task.Trajectory(0) self._extend(tr1, new_observation=1) self._extend(tr1, new_observation=1) self._extend( tr1, new_observation=1, action=1, dist_inputs=2, reward=3, done=True ) task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) # Stream of slices without the final state. stream1 = task.trajectory_slice_stream(max_slice_length=4, margin=3) got_done = False for _ in range(20): next_slice = next(stream1) self.assertEqual(next_slice.observation.shape, (4,)) if next_slice.done[0]: # In the slice, first we have the last timestep in the actual # trajectory, so observation = 1. # Then comes the first timestep in the margin, which has the final # observation from the trajectory: observation = 1. # The remaining timesteps have 0 observations. np.testing.assert_array_equal(next_slice.observation, [1, 1, 0, 0]) # In the margin, done = True and mask = 0. for i in range(1, next_slice.observation.shape[0]): self.assertTrue(next_slice.done[i]) self.assertFalse(next_slice.mask[i]) got_done = True # Assert that we got a done somewhere, otherwise the test is not triggered. # Not getting done has low probability (1/2^20) but is possible, flaky test. self.assertTrue(got_done)
def test_jointppotrainer_cartpole(self): """Test-runs joint PPO on CartPole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=0, max_steps=2) joint_model = functools.partial( models.PolicyAndValue, body=lambda mode: tl.Serial(tl.Dense(2), tl.Relu()), ) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic_joint.PPOJointTrainer( task, joint_model=joint_model, optimizer=opt.Adam, lr_schedule=lr, batch_size=4, train_steps_per_epoch=2, n_trajectories_per_epoch=5) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_awrtrainer_cartpole(self): """Test-runs AWR on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1000, max_steps=200) policy_model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax()) value_model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(1)) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.AWRTrainer(task, n_shared_layers=0, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=32, value_train_steps_per_epoch=1000, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=32, policy_train_steps_per_epoch=1000, collect_per_epoch=10) trainer.run(1) self.assertEqual(1, trainer.current_epoch) self.assertGreater(trainer.avg_returns[-1], 180.0)
def test_awrtrainer_cartpole(self): """Test-runs AWR on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1000, max_steps=200) body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.AWRTrainer( task, n_shared_layers=0, added_policy_slice_length=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=32, value_train_steps_per_epoch=200, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=32, policy_train_steps_per_epoch=200, n_trajectories_per_epoch=10, advantage_estimator=advantages.monte_carlo, advantage_normalization=False, ) trainer.run(1) self.assertEqual(1, trainer.current_epoch) self.assertGreater(trainer.avg_returns[-1], 35.0)
def test_jointawrtrainer_cartpole(self): """Test-runs joint AWR on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1000, max_steps=200) shared_model = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_top = lambda mode: tl.Serial(tl.Dense(2), tl.LogSoftmax()) value_top = lambda mode: tl.Dense(1) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic_joint.AWRJointTrainer( task, shared_model=shared_model, policy_top=policy_top, value_top=value_top, optimizer=opt.Adam, lr_schedule=lr, batch_size=32, train_steps_per_epoch=1000, collect_per_epoch=10) trainer.run(1) self.assertEqual(1, trainer.current_epoch)
def test_trajectory_stream_sampling_by_trajectory(self): """Test if the trajectory stream samples by trajectory.""" # Long trajectory of 0s. tr1 = rl_task.Trajectory(0) for _ in range(100): tr1.extend(0, 0, 0, False, 0) tr1.extend(0, 0, 0, True, 200) # Short trajectory of 101. tr2 = rl_task.Trajectory(101) tr2.extend(0, 0, 0, True, 200) task = rl_task.RLTask( DummyEnv(), initial_trajectories=[tr1, tr2], max_steps=9) # Stream of both. Check that we're sampling by trajectory. stream = task.trajectory_stream( max_slice_length=1, sample_trajectories_uniformly=True) slices = [] for _ in range(10): next_slice = next(stream) assert len(next_slice) == 1 slices.append(next_slice.last_observation) mean_obs = sum(slices) / float(len(slices)) # Average should be around 50, sampling from {0, 101} uniformly. # Sampling 101 < 2 times has low probability (but it possible, flaky test). self.assertGreater(mean_obs, 20) self.assertLen(slices, 10)
def test_collects_specified_number_of_trajectories(self): """Test that the specified number of interactions are collected.""" task = rl_task.RLTask( DummyEnv(), initial_trajectories=0, max_steps=3, time_limit=20 ) task.collect_trajectories(policy=(lambda _: (0, 0)), n_trajectories=3) trajectories = task.trajectories[1] # Get trajectories from epoch 1. self.assertLen(trajectories, 3)
def test_trajectory_batch_stream_propagates_env_info(self): task = rl_task.RLTask(DummyEnv(), initial_trajectories=1, max_steps=4) stream = task.trajectory_batch_stream(batch_size=1, max_slice_length=4) tr_slice = next(stream) # control_mask = step % 2 == 0, discount_mask = step % 3 == 0. np.testing.assert_array_equal(tr_slice.env_info.control_mask, [[1, 0, 1, 0]]) np.testing.assert_array_equal(tr_slice.env_info.discount_mask, [[1, 0, 0, 1]])
def test_trajectory_slice_stream_shape(self): """Test the shape yielded by trajectory stream.""" obs = np.zeros((12, 13)) tr1 = rl_task.Trajectory(obs) self._extend(tr1, new_observation=obs, done=True) task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) stream = task.trajectory_slice_stream(max_slice_length=1) next_slice = next(stream) self.assertEqual(next_slice.observation.shape, (1, 12, 13))
def test_task_save_init(self): """Test saving and re-initialization.""" task1 = rl_task.RLTask(DummyEnv(), initial_trajectories=13, max_steps=9, gamma=0.9) self.assertLen(task1.trajectories[0], 13) self.assertEqual(task1.max_steps, 9) self.assertEqual(task1.gamma, 0.9) temp_file = os.path.join(self.create_tempdir().full_path, 'task.pkl') task1.save_to_file(temp_file) task2 = rl_task.RLTask(DummyEnv(), initial_trajectories=3, max_steps=19, gamma=1.0) self.assertLen(task2.trajectories[0], 3) self.assertEqual(task2.max_steps, 19) self.assertEqual(task2.gamma, 1.0) task2.init_from_file(temp_file) self.assertLen(task2.trajectories[0], 13) self.assertEqual(task2.max_steps, 9) self.assertEqual(task2.gamma, 0.9)
def setUp(self): super().setUp() self._model_fn = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(1)) self._task = rl_task.RLTask('CartPole-v0', gamma=0.5, max_steps=10, initial_trajectories=100) self._trajectory_batch_stream = self._task.trajectory_batch_stream( batch_size=256, epochs=[-1], max_slice_length=2)
def test_trajectory_stream_shape(self): """Test the shape yielded by trajectory stream.""" elem = np.zeros((12, 13)) tr1 = rl_task.Trajectory(elem) tr1.extend(0, 0, 0, True, elem) task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) stream = task.trajectory_stream(max_slice_length=1) next_slice = next(stream) self.assertLen(next_slice, 1) self.assertEqual(next_slice.last_observation.shape, (12, 13))
def test_time_limit_terminates_epsiodes(self): """Test that episodes are terminated upon reaching `time_limit` steps.""" task = rl_task.RLTask( DummyEnv(), initial_trajectories=3, max_steps=10, time_limit=10 ) trajectories = task.trajectories[0] # Get trajectories from epoch 0. self.assertLen(trajectories, 3) for trajectory in trajectories: self.assertTrue(trajectory.done) # max_steps + 1 (the initial observation doesn't count). self.assertLen(trajectory, 11)
def test_policytrainer_cartpole(self): """Trains a policy on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=100, max_steps=200) model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(32), tl.Relu(), tl.Dense(3), tl.LogSoftmax()) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-3, warmup_steps=100, factors='constant * linear_warmup') trainer = training.ExamplePolicyTrainer(task, model, opt.Adam, lr) trainer.run(1) self.assertEqual(1, trainer.current_epoch)
def test_trajectory_stream_long_slice(self): """Test trajectory stream with slices of longer length.""" elem = np.zeros((12, 13)) tr1 = rl_task.Trajectory(elem) tr1.extend(0, 0, 0, False, elem) tr1.extend(0, 0, 0, True, elem) task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) stream = task.trajectory_stream(max_slice_length=2) next_slice = next(stream) self.assertLen(next_slice, 2) self.assertEqual(next_slice.last_observation.shape, (12, 13))
def test_trajectory_stream_shape(self): """Test the shape yielded by trajectory stream.""" obs = np.zeros((12, 13)) tr1 = rl_task.Trajectory(obs) tr1.extend( action=0, dist_inputs=0, reward=0, done=True, new_observation=obs ) task = rl_task.RLTask(DummyEnv(), initial_trajectories=[tr1], max_steps=9) stream = task.trajectory_stream(max_slice_length=1) next_slice = next(stream) self.assertLen(next_slice, 1) self.assertEqual(next_slice.last_observation.shape, (12, 13))