def test_play_env_problem_randomly(self): batch_size = 5 num_steps = 100 ep = tic_tac_toe_env_problem.TicTacToeEnvProblem() ep.initialize(batch_size=batch_size) env_problem_utils.play_env_problem_randomly(ep, num_steps) # We've played num_steps * batch_size steps + everytime we get 'done' we # create another step + batch_size number of pending steps. self.assertEqual( num_steps * batch_size + len(ep.trajectories.completed_trajectories) + batch_size, ep.trajectories.num_time_steps)
def generate_data_for_env_problem(problem_name): """Generate data for `EnvProblem`s.""" assert FLAGS.env_problem_max_env_steps > 0, ("--env_problem_max_env_steps " "should be greater than zero") assert FLAGS.env_problem_batch_size > 0, ("--env_problem_batch_size should be" " greather than zero") problem = registry.env_problem(problem_name) task_id = None if FLAGS.task_id < 0 else FLAGS.task_id data_dir = os.path.expanduser(FLAGS.data_dir) tmp_dir = os.path.expanduser(FLAGS.tmp_dir) # TODO(msaffar): Handle large values for env_problem_batch_size where we # cannot create that many environments within the same process. problem.initialize(batch_size=FLAGS.env_problem_batch_size) env_problem_utils.play_env_problem_randomly( problem, num_steps=FLAGS.env_problem_max_env_steps) problem.generate_data(data_dir=data_dir, tmp_dir=tmp_dir, task_id=task_id)
def test_generate_timesteps(self): env = ReacherEnvProblem() env.initialize(batch_size=2) env_problem_utils.play_env_problem_randomly(env, num_steps=5) env.trajectories.complete_all_trajectories() frame_number = 0 for time_step in env._generate_time_steps( env.trajectories.completed_trajectories): # original observation should not be in time_step self.assertNotIn(env_problem.OBSERVATION_FIELD, time_step) # validate frame self.assertIn(rendered_env_problem._IMAGE_ENCODED_FIELD, time_step) self.assertIn(rendered_env_problem._IMAGE_HEIGHT_FIELD, time_step) self.assertIn(rendered_env_problem._IMAGE_WIDTH_FIELD, time_step) self.assertIn(rendered_env_problem._IMAGE_FORMAT_FIELD, time_step) self.assertIn(rendered_env_problem._FRAME_NUMBER_FIELD, time_step) decoded_frame = tf.image.decode_png( time_step[rendered_env_problem._IMAGE_ENCODED_FIELD][0]) decoded_frame = self.evaluate(decoded_frame) self.assertListEqual( [env.frame_height, env.frame_width, env.num_channels], list(decoded_frame.shape)) self.assertListEqual( [rendered_env_problem._FORMAT], time_step[rendered_env_problem._IMAGE_FORMAT_FIELD]) self.assertListEqual( [frame_number], time_step[rendered_env_problem._FRAME_NUMBER_FIELD]) self.assertListEqual( [env.frame_width], time_step[rendered_env_problem._IMAGE_WIDTH_FIELD]) self.assertListEqual( [env.frame_height], time_step[rendered_env_problem._IMAGE_HEIGHT_FIELD]) frame_number += 1 frame_number %= 6