def generate_data_for_env_problem(problem_name):
    """Generate data for `EnvProblem`s."""
    assert FLAGS.env_problem_max_env_steps > 0, ("--env_problem_max_env_steps "
                                                 "should be greater than zero")
    assert FLAGS.env_problem_batch_size > 0, (
        "--env_problem_batch_size should be"
        " greather than zero")
    problem = registry.env_problem(problem_name)
    task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
    data_dir = os.path.expanduser(FLAGS.data_dir)
    tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
    # TODO(msaffar): Handle large values for env_problem_batch_size where we
    #  cannot create that many environments within the same process.
    problem.initialize(batch_size=FLAGS.env_problem_batch_size)
    env_problem_utils.play_env_problem_randomly(
        problem, num_steps=FLAGS.env_problem_max_env_steps)
    problem.generate_data(data_dir=data_dir, tmp_dir=tmp_dir, task_id=task_id)
    def test_registration_and_interaction_with_env_problem(self):
        batch_size = 5
        # This ensures that registration has occurred.
        ep = registry.env_problem("tic_tac_toe_env_problem", batch_size)
        ep.reset()
        num_done, num_lost, num_won, num_draw = 0, 0, 0, 0
        nsteps = 100
        for _ in range(nsteps):
            actions = np.stack(
                [ep.action_space.sample() for _ in range(batch_size)])
            obs, rewards, dones, infos = ep.step(actions)

            # Assert that things are happening batchwise.
            self.assertEqual(batch_size, len(obs))
            self.assertEqual(batch_size, len(rewards))
            self.assertEqual(batch_size, len(dones))
            self.assertEqual(batch_size, len(infos))

            done_indices = ep.done_indices(dones)
            ep.reset(done_indices)
            num_done += sum(dones)
            for r, d in zip(rewards, dones):
                if not d:
                    continue
                # NOTE: r is 0, 1, 2 because the default EnvProblem.process_rewards
                # shifts the rewards so that min is 0.
                if r == 0:
                    num_lost += 1
                elif r == 1:
                    num_draw += 1
                elif r == 2:
                    num_won += 1
                else:
                    raise ValueError(
                        "reward should be 0, 1, 2 but is {}".format(r))

        # Assert that something got done atleast, without that the next assert is
        # meaningless.
        self.assertGreater(num_done, 0)

        # Assert that things are consistent.
        self.assertEqual(num_done, num_won + num_lost + num_draw)
Esempio n. 3
0
  def testEnvProblem(self):
    # Register this class and expect to get it back.

    @registry.register_env_problem
    class EnvProb(object):

      batch_size = None

      def initialize(self, batch_size):
        self.batch_size = batch_size

    # Get it with given batch_size.
    batch_size = 100
    ep = registry.env_problem("env_prob", batch_size=batch_size)

    # name property is set.
    self.assertEqual("env_prob", ep.name)

    # initialize was called and therefore batch_size was set.
    self.assertEqual(batch_size, ep.batch_size)

    # assert on the type.
    self.assertIsInstance(ep, EnvProb)
Esempio n. 4
0
    def test_registration_and_interaction_with_env_problem(self):
        batch_size = 5
        # This ensures that registration has occurred.
        ep = registry.env_problem("reacher_env_problem", batch_size=batch_size)
        ep.reset()
        num_done = 0
        nsteps = 100
        for _ in range(nsteps):
            actions = np.stack(
                [ep.action_space.sample() for _ in range(batch_size)])
            obs, rewards, dones, infos = ep.step(actions)

            # Assert that things are happening batchwise.
            self.assertEqual(batch_size, len(obs))
            self.assertEqual(batch_size, len(rewards))
            self.assertEqual(batch_size, len(dones))
            self.assertEqual(batch_size, len(infos))

            done_indices = env_problem_utils.done_indices(dones)
            ep.reset(done_indices)
            num_done += sum(dones)

        # Assert that something got done atleast,
        self.assertGreater(num_done, 0)