Ejemplo n.º 1
0
  def test_play_env_problem_randomly(self):
    batch_size = 5
    num_steps = 100

    ep = tic_tac_toe_env_problem.TicTacToeEnvProblem()
    ep.initialize(batch_size=batch_size)

    env_problem_utils.play_env_problem_randomly(ep, num_steps)

    # We've played num_steps * batch_size steps + everytime we get 'done' we
    # create another step + batch_size number of pending steps.
    self.assertEqual(
        num_steps * batch_size + len(ep.trajectories.completed_trajectories) +
        batch_size, ep.trajectories.num_time_steps)
Ejemplo n.º 2
0
def generate_data_for_env_problem(problem_name):
  """Generate data for `EnvProblem`s."""
  assert FLAGS.env_problem_max_env_steps > 0, ("--env_problem_max_env_steps "
                                               "should be greater than zero")
  assert FLAGS.env_problem_batch_size > 0, ("--env_problem_batch_size should be"
                                            " greather than zero")
  problem = registry.env_problem(problem_name)
  task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
  data_dir = os.path.expanduser(FLAGS.data_dir)
  tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
  # TODO(msaffar): Handle large values for env_problem_batch_size where we
  #  cannot create that many environments within the same process.
  problem.initialize(batch_size=FLAGS.env_problem_batch_size)
  env_problem_utils.play_env_problem_randomly(
      problem, num_steps=FLAGS.env_problem_max_env_steps)
  problem.generate_data(data_dir=data_dir, tmp_dir=tmp_dir, task_id=task_id)
    def test_generate_timesteps(self):
        env = ReacherEnvProblem()
        env.initialize(batch_size=2)
        env_problem_utils.play_env_problem_randomly(env, num_steps=5)
        env.trajectories.complete_all_trajectories()

        frame_number = 0
        for time_step in env._generate_time_steps(
                env.trajectories.completed_trajectories):
            # original observation should not be in time_step
            self.assertNotIn(env_problem.OBSERVATION_FIELD, time_step)
            # validate frame
            self.assertIn(rendered_env_problem._IMAGE_ENCODED_FIELD, time_step)
            self.assertIn(rendered_env_problem._IMAGE_HEIGHT_FIELD, time_step)
            self.assertIn(rendered_env_problem._IMAGE_WIDTH_FIELD, time_step)
            self.assertIn(rendered_env_problem._IMAGE_FORMAT_FIELD, time_step)
            self.assertIn(rendered_env_problem._FRAME_NUMBER_FIELD, time_step)

            decoded_frame = tf.image.decode_png(
                time_step[rendered_env_problem._IMAGE_ENCODED_FIELD][0])

            decoded_frame = self.evaluate(decoded_frame)

            self.assertListEqual(
                [env.frame_height, env.frame_width, env.num_channels],
                list(decoded_frame.shape))
            self.assertListEqual(
                [rendered_env_problem._FORMAT],
                time_step[rendered_env_problem._IMAGE_FORMAT_FIELD])
            self.assertListEqual(
                [frame_number],
                time_step[rendered_env_problem._FRAME_NUMBER_FIELD])
            self.assertListEqual(
                [env.frame_width],
                time_step[rendered_env_problem._IMAGE_WIDTH_FIELD])
            self.assertListEqual(
                [env.frame_height],
                time_step[rendered_env_problem._IMAGE_HEIGHT_FIELD])
            frame_number += 1
            frame_number %= 6