Ejemplo n.º 1
0
    def play_env(self,
                 env=None,
                 nsteps=100,
                 base_env_name=None,
                 batch_size=5,
                 reward_range=None):
        """Creates `EnvProblem` with the given arguments and plays it randomly.

    Args:
      env: optional env.
      nsteps: plays the env randomly for nsteps.
      base_env_name: passed to EnvProblem's init.
      batch_size: passed to EnvProblem's init.
      reward_range: passed to EnvProblem's init.

    Returns:
      tuple of env_problem, number of trajectories done, number of trajectories
      done in the last step.
    """

        if env is None:
            env = env_problem.EnvProblem(base_env_name=base_env_name,
                                         batch_size=batch_size,
                                         reward_range=reward_range)
            # Usually done by a registered subclass, we do this manually in the test.
            env.name = base_env_name

        # Reset all environments.
        env.reset()

        # Play for some steps to generate data.
        num_dones = 0
        num_dones_in_last_step = 0
        for _ in range(nsteps):
            # Sample actions.
            actions = np.stack(
                [env.action_space.sample() for _ in range(batch_size)])
            # Step through it.
            _, _, dones, _ = env.step(actions)
            # Get the indices where we are done ...
            done_indices = env_problem_utils.done_indices(dones)
            # ... and reset those.
            env.reset(indices=done_indices)
            # count the number of dones we got, in this step and overall.
            num_dones_in_last_step = sum(dones)
            num_dones += num_dones_in_last_step

        return env, num_dones, num_dones_in_last_step
    def test_registration_and_interaction_with_env_problem(self):
        batch_size = 5
        # This ensures that registration has occurred.
        ep = registry.env_problem("tic_tac_toe_env_problem",
                                  batch_size=batch_size)
        ep.reset()
        num_done, num_lost, num_won, num_draw = 0, 0, 0, 0
        nsteps = 100
        for _ in range(nsteps):
            actions = np.stack(
                [ep.action_space.sample() for _ in range(batch_size)])
            obs, rewards, dones, infos = ep.step(actions)

            # Assert that things are happening batchwise.
            self.assertEqual(batch_size, len(obs))
            self.assertEqual(batch_size, len(rewards))
            self.assertEqual(batch_size, len(dones))
            self.assertEqual(batch_size, len(infos))

            done_indices = env_problem_utils.done_indices(dones)
            ep.reset(done_indices)
            num_done += sum(dones)
            for r, d in zip(rewards, dones):
                if not d:
                    continue
                if r == -1:
                    num_lost += 1
                elif r == 0:
                    num_draw += 1
                elif r == 1:
                    num_won += 1
                else:
                    raise ValueError(
                        "reward should be -1, 0, 1 but is {}".format(r))

        # Assert that something got done atleast, without that the next assert is
        # meaningless.
        self.assertGreater(num_done, 0)

        # Assert that things are consistent.
        self.assertEqual(num_done, num_won + num_lost + num_draw)
Ejemplo n.º 3
0
def play_env_problem(env, policy_fn):
    """Plays an EnvProblem using a given policy function."""
    trajectories = [trajectory.Trajectory() for _ in range(env.batch_size)]
    observations = env.reset()
    for (traj, observation) in zip(trajectories, observations):
        traj.add_time_step(observation=observation)

    done_so_far = np.array([False] * env.batch_size)
    while not np.all(done_so_far):
        padded_observations, _ = env.trajectories.observations_np(
            len_history_for_policy=None)
        actions = policy_fn(padded_observations)
        (observations, rewards, dones, _) = env.step(actions)
        for (traj, observation, action, reward,
             done) in zip(trajectories, observations, actions, rewards, dones):
            if not traj.done:
                traj.change_last_time_step(action=action)
                traj.add_time_step(observation=observation,
                                   raw_reward=reward,
                                   done=done)
            env.reset(indices=env_problem_utils.done_indices(dones))
        done_so_far = np.logical_or(done_so_far, dones)
    return trajectories
Ejemplo n.º 4
0
    def test_registration_and_interaction_with_env_problem(self):
        batch_size = 5
        # This ensures that registration has occurred.
        ep = registry.env_problem("reacher_env_problem", batch_size=batch_size)
        ep.reset()
        num_done = 0
        nsteps = 100
        for _ in range(nsteps):
            actions = np.stack(
                [ep.action_space.sample() for _ in range(batch_size)])
            obs, rewards, dones, infos = ep.step(actions)

            # Assert that things are happening batchwise.
            self.assertEqual(batch_size, len(obs))
            self.assertEqual(batch_size, len(rewards))
            self.assertEqual(batch_size, len(dones))
            self.assertEqual(batch_size, len(infos))

            done_indices = env_problem_utils.done_indices(dones)
            ep.reset(done_indices)
            num_done += sum(dones)

        # Assert that something got done atleast,
        self.assertGreater(num_done, 0)
Ejemplo n.º 5
0
    def test_interaction_with_env(self):
        batch_size = 5
        reward_range = (-1, 1)
        ep = env_problem.EnvProblem(base_env_name="KellyCoinflip-v0",
                                    batch_size=batch_size,
                                    reward_range=reward_range)

        # Resets all environments.
        ep.reset()

        # Let's play a few steps.
        nsteps = 100
        num_trajectories_completed = 0
        num_timesteps_completed = 0
        # If batch_done_at_step[i] = j then it means that i^th env last got done at
        # step = j.
        batch_done_at_step = np.full(batch_size, -1)
        for i in range(nsteps):
            # Sample batch_size actions from the action space and stack them (since
            # that is the expected type).
            actions = np.stack(
                [ep.action_space.sample() for _ in range(batch_size)])

            _, _, dones, _ = ep.step(actions)

            # Do the book-keeping on number of trajectories completed and expect that
            # it matches ep's completed number.

            num_done = sum(dones)
            num_trajectories_completed += num_done

            self.assertEqual(num_trajectories_completed,
                             len(ep.trajectories.completed_trajectories))

            # Get the indices where we are done ...
            done_indices = env_problem_utils.done_indices(dones)

            # ... and reset those.
            ep.reset(indices=done_indices)

            # If nothing got done, go on to the next step.
            if done_indices.size == 0:
                # i.e. this is an empty array.
                continue

            # See when these indices were last done and calculate how many time-steps
            # each one took to get done.
            num_timesteps_completed += sum(i + 1 -
                                           batch_done_at_step[done_indices])
            batch_done_at_step[done_indices] = i

            # This should also match the number of time-steps completed given by ep.
            num_timesteps_completed_ep = sum(
                ct.num_time_steps
                for ct in ep.trajectories.completed_trajectories)
            self.assertEqual(num_timesteps_completed,
                             num_timesteps_completed_ep)

        # Reset the trajectories.
        ep.trajectories.reset_batch_trajectories()
        self.assertEqual(0, len(ep.trajectories.completed_trajectories))
Ejemplo n.º 6
0
def collect_trajectories(
    env,
    policy_fn,
    n_trajectories=1,
    n_observations=None,
    max_timestep=None,
    reset=True,
    len_history_for_policy=32,
    boundary=32,
    state=None,
    temperature=1.0,
    rng=None,
    abort_fn=None,
    raw_trajectory=False,
):
    """Collect trajectories with the given policy net and behaviour.

  Args:
    env: A gym env interface, for now this is not-batched.
    policy_fn: Callable
      (observations(B,T+1), actions(B, T+1, C)) -> log-probabs(B, T+1, C, A).
    n_trajectories: int, number of trajectories.
    n_observations: int, number of non-terminal observations. NOTE: Exactly one
      of `n_trajectories` and `n_observations` should be None.
    max_timestep: int or None, the index of the maximum time-step at which we
      return the trajectory, None for ending a trajectory only when env returns
      done.
    reset: bool, true if we want to reset the envs. The envs are also reset if
      max_max_timestep is None or < 0
    len_history_for_policy: int or None, the maximum history to keep for
      applying the policy on. If None, use the full history.
    boundary: int, pad the sequences to the multiples of this number.
    state: state for `policy_fn`.
    temperature: (float) temperature to sample action from policy_fn.
    rng: jax rng, splittable.
    abort_fn: callable, If not None, then at every env step call and abort the
      trajectory collection if it returns True, if so reset the env and return
      None.
    raw_trajectory: bool, if True a list of trajectory.Trajectory objects is
      returned, otherwise a list of numpy representations of
      `trajectory.Trajectory` is returned.

  Returns:
    A tuple (trajectory, number of trajectories that are done)
    trajectory: list of (observation, action, reward) tuples, where each element
    `i` is a tuple of numpy arrays with shapes as follows:
    observation[i] = (B, T_i + 1)
    action[i] = (B, T_i)
    reward[i] = (B, T_i)
  """

    assert isinstance(env, env_problem.EnvProblem)

    # We need to reset all environments, if we're coming here the first time.
    if reset or max_timestep is None or max_timestep <= 0:
        env.reset()
    else:
        # Clear completed trajectories held internally.
        env.trajectories.clear_completed_trajectories()

    num_done_trajectories = 0

    # The stopping criterion, returns True if we should stop.
    def should_stop():
        if n_trajectories is not None:
            assert n_observations is None
            return env.trajectories.num_completed_trajectories >= n_trajectories
        assert n_observations is not None
        # The number of non-terminal observations is what we want.
        return (env.trajectories.num_completed_time_steps -
                env.trajectories.num_completed_trajectories) >= n_observations

    policy_application_total_time = 0
    env_actions_total_time = 0
    bare_env_run_time = 0
    while not should_stop():
        # Check if we should abort and return nothing.
        if abort_fn and abort_fn():
            # We should also reset the environment, since it will have some
            # trajectories (complete and incomplete) that we want to discard.
            env.reset()
            return None, 0, {}, state

        # Get all the observations for all the active trajectories.
        # Shape is (B, T+1) + OBS
        # Bucket on whatever length is needed.
        padded_observations, lengths = env.trajectories.observations_np(
            boundary=boundary, len_history_for_policy=len_history_for_policy)

        B = padded_observations.shape[0]  # pylint: disable=invalid-name

        assert B == env.batch_size
        assert (B, ) == lengths.shape

        t1 = time.time()
        log_probs, value_preds, state, rng = policy_fn(padded_observations,
                                                       lengths,
                                                       state=state,
                                                       rng=rng)
        policy_application_total_time += (time.time() - t1)

        assert B == log_probs.shape[0]

        actions = tl.gumbel_sample(log_probs, temperature)
        if (isinstance(env.action_space, gym.spaces.Discrete)
                and (actions.shape[1] == 1)):
            actions = onp.squeeze(actions, axis=1)

        # Step through the env.
        t1 = time.time()
        _, _, dones, env_infos = env.step(actions,
                                          infos={
                                              'log_prob_actions': log_probs,
                                              'value_predictions': value_preds,
                                          })
        env_actions_total_time += (time.time() - t1)
        bare_env_run_time += sum(info['__bare_env_run_time__']
                                 for info in env_infos)

        # Count the number of done trajectories, the others could just have been
        # truncated.
        num_done_trajectories += onp.sum(dones)

        # Get the indices where we are done ...
        done_idxs = env_problem_utils.done_indices(dones)

        # ... and reset those.
        t1 = time.time()
        if done_idxs.size:
            env.reset(indices=done_idxs)
        env_actions_total_time += (time.time() - t1)

        if max_timestep is None or max_timestep < 1:
            continue

        # Are there any trajectories that have exceeded the time-limit we want.
        lengths = env.trajectories.trajectory_lengths
        exceeded_time_limit_idxs = env_problem_utils.done_indices(
            lengths > max_timestep)

        # If so, reset these as well.
        t1 = time.time()
        if exceeded_time_limit_idxs.size:
            # This just cuts the trajectory, doesn't reset the env, so it continues
            # from where it left off.
            env.truncate(indices=exceeded_time_limit_idxs, num_to_keep=1)
        env_actions_total_time += (time.time() - t1)

    # We have the trajectories we need, return a list of triples:
    # (observations, actions, rewards)
    completed_trajectories = (
        env_problem_utils.get_completed_trajectories_from_env(
            env,
            env.trajectories.num_completed_trajectories,
            raw_trajectory=raw_trajectory))

    timing_info = {
        'trajectory_collection/policy_application':
        policy_application_total_time,
        'trajectory_collection/env_actions': env_actions_total_time,
        'trajectory_collection/env_actions/bare_env': bare_env_run_time,
    }
    timing_info = {k: round(1000 * v, 2) for k, v in timing_info.items()}

    return completed_trajectories, num_done_trajectories, timing_info, state