Esempio n. 1
0
    def collect_trajectories_async(self,
                                   env,
                                   train=True,
                                   n_trajectories=1,
                                   n_observations=None,
                                   temperature=1.0):
        """Collects trajectories in an async manner."""

        assert self._async_mode

        # TODO(afrozm): Make this work, should be easy.
        # NOTE: We still collect whole trajectories, however the async trajectory
        # collectors now will poll not on the amount of trajectories collected but
        # on the amount of observations in the completed trajectories and bail out.
        assert n_observations is None

        # trajectories/train and trajectories/eval are the two subdirectories.
        trajectory_dir = os.path.join(self._output_dir, 'trajectories',
                                      'train' if train else 'eval')
        epoch = self.epoch

        logging.info(
            'Loading [%s] trajectories from dir [%s] for epoch [%s] and temperature'
            ' [%s]', n_trajectories, trajectory_dir, epoch, temperature)

        bt = trajectory.BatchTrajectory.load_from_directory(
            trajectory_dir,
            epoch=epoch,
            temperature=temperature,
            wait_forever=True,
            n_trajectories=n_trajectories)

        if bt is None:
            logging.error(
                'Couldn\'t load [%s] trajectories from dir [%s] for epoch [%s] and '
                'temperature [%s]', n_trajectories, trajectory_dir, epoch,
                temperature)
            assert bt

        # Doing this is important, since we want to modify `env` so that it looks
        # like `env` was actually played and the trajectories came from it.
        env.trajectories = bt

        trajs = env_problem_utils.get_completed_trajectories_from_env(
            env, n_trajectories)
        n_done = len(trajs)
        timing_info = {}
        return trajs, n_done, timing_info, self._model_state
    def collect_trajectories_async(self,
                                   env,
                                   train=True,
                                   n_trajectories=1,
                                   temperature=1.0):
        """Collects trajectories in an async manner."""

        assert self._async_mode

        # trajectories/train and trajectories/eval are the two subdirectories.
        trajectory_dir = os.path.join(self._output_dir, "trajectories",
                                      "train" if train else "eval")
        epoch = self.epoch

        logging.info(
            "Loading [%s] trajectories from dir [%s] for epoch [%s] and temperature"
            " [%s]", n_trajectories, trajectory_dir, epoch, temperature)

        bt = trajectory.BatchTrajectory.load_from_directory(
            trajectory_dir,
            epoch=epoch,
            temperature=temperature,
            wait_forever=True,
            n_trajectories=n_trajectories)

        if bt is None:
            logging.error(
                "Couldn't load [%s] trajectories from dir [%s] for epoch [%s] and "
                "temperature [%s]", n_trajectories, trajectory_dir, epoch,
                temperature)
            assert bt

        # Doing this is important, since we want to modify `env` so that it looks
        # like `env` was actually played and the trajectories came from it.
        env.trajectories = bt

        trajs = env_problem_utils.get_completed_trajectories_from_env(
            env, n_trajectories)
        n_done = len(trajs)
        timing_info = {}
        return trajs, n_done, timing_info, self._model_state
Esempio n. 3
0
def collect_trajectories(
    env,
    policy_fn,
    n_trajectories=1,
    n_observations=None,
    max_timestep=None,
    reset=True,
    len_history_for_policy=32,
    boundary=32,
    state=None,
    temperature=1.0,
    rng=None,
    abort_fn=None,
    raw_trajectory=False,
):
    """Collect trajectories with the given policy net and behaviour.

  Args:
    env: A gym env interface, for now this is not-batched.
    policy_fn: Callable
      (observations(B,T+1), actions(B, T+1, C)) -> log-probabs(B, T+1, C, A).
    n_trajectories: int, number of trajectories.
    n_observations: int, number of non-terminal observations. NOTE: Exactly one
      of `n_trajectories` and `n_observations` should be None.
    max_timestep: int or None, the index of the maximum time-step at which we
      return the trajectory, None for ending a trajectory only when env returns
      done.
    reset: bool, true if we want to reset the envs. The envs are also reset if
      max_max_timestep is None or < 0
    len_history_for_policy: int or None, the maximum history to keep for
      applying the policy on. If None, use the full history.
    boundary: int, pad the sequences to the multiples of this number.
    state: state for `policy_fn`.
    temperature: (float) temperature to sample action from policy_fn.
    rng: jax rng, splittable.
    abort_fn: callable, If not None, then at every env step call and abort the
      trajectory collection if it returns True, if so reset the env and return
      None.
    raw_trajectory: bool, if True a list of trajectory.Trajectory objects is
      returned, otherwise a list of numpy representations of
      `trajectory.Trajectory` is returned.

  Returns:
    A tuple (trajectory, number of trajectories that are done)
    trajectory: list of (observation, action, reward) tuples, where each element
    `i` is a tuple of numpy arrays with shapes as follows:
    observation[i] = (B, T_i + 1)
    action[i] = (B, T_i)
    reward[i] = (B, T_i)
  """

    assert isinstance(env, env_problem.EnvProblem)

    # We need to reset all environments, if we're coming here the first time.
    if reset or max_timestep is None or max_timestep <= 0:
        env.reset()
    else:
        # Clear completed trajectories held internally.
        env.trajectories.clear_completed_trajectories()

    num_done_trajectories = 0

    # The stopping criterion, returns True if we should stop.
    def should_stop():
        if n_trajectories is not None:
            assert n_observations is None
            return env.trajectories.num_completed_trajectories >= n_trajectories
        assert n_observations is not None
        # The number of non-terminal observations is what we want.
        return (env.trajectories.num_completed_time_steps -
                env.trajectories.num_completed_trajectories) >= n_observations

    policy_application_total_time = 0
    env_actions_total_time = 0
    bare_env_run_time = 0
    while not should_stop():
        # Check if we should abort and return nothing.
        if abort_fn and abort_fn():
            # We should also reset the environment, since it will have some
            # trajectories (complete and incomplete) that we want to discard.
            env.reset()
            return None, 0, {}, state

        # Get all the observations for all the active trajectories.
        # Shape is (B, T+1) + OBS
        # Bucket on whatever length is needed.
        padded_observations, lengths = env.trajectories.observations_np(
            boundary=boundary, len_history_for_policy=len_history_for_policy)

        B = padded_observations.shape[0]  # pylint: disable=invalid-name

        assert B == env.batch_size
        assert (B, ) == lengths.shape

        t1 = time.time()
        log_probs, value_preds, state, rng = policy_fn(padded_observations,
                                                       lengths,
                                                       state=state,
                                                       rng=rng)
        policy_application_total_time += (time.time() - t1)

        assert B == log_probs.shape[0]

        actions = tl.gumbel_sample(log_probs, temperature)
        if (isinstance(env.action_space, gym.spaces.Discrete)
                and (actions.shape[1] == 1)):
            actions = onp.squeeze(actions, axis=1)

        # Step through the env.
        t1 = time.time()
        _, _, dones, env_infos = env.step(actions,
                                          infos={
                                              'log_prob_actions': log_probs,
                                              'value_predictions': value_preds,
                                          })
        env_actions_total_time += (time.time() - t1)
        bare_env_run_time += sum(info['__bare_env_run_time__']
                                 for info in env_infos)

        # Count the number of done trajectories, the others could just have been
        # truncated.
        num_done_trajectories += onp.sum(dones)

        # Get the indices where we are done ...
        done_idxs = env_problem_utils.done_indices(dones)

        # ... and reset those.
        t1 = time.time()
        if done_idxs.size:
            env.reset(indices=done_idxs)
        env_actions_total_time += (time.time() - t1)

        if max_timestep is None or max_timestep < 1:
            continue

        # Are there any trajectories that have exceeded the time-limit we want.
        lengths = env.trajectories.trajectory_lengths
        exceeded_time_limit_idxs = env_problem_utils.done_indices(
            lengths > max_timestep)

        # If so, reset these as well.
        t1 = time.time()
        if exceeded_time_limit_idxs.size:
            # This just cuts the trajectory, doesn't reset the env, so it continues
            # from where it left off.
            env.truncate(indices=exceeded_time_limit_idxs, num_to_keep=1)
        env_actions_total_time += (time.time() - t1)

    # We have the trajectories we need, return a list of triples:
    # (observations, actions, rewards)
    completed_trajectories = (
        env_problem_utils.get_completed_trajectories_from_env(
            env,
            env.trajectories.num_completed_trajectories,
            raw_trajectory=raw_trajectory))

    timing_info = {
        'trajectory_collection/policy_application':
        policy_application_total_time,
        'trajectory_collection/env_actions': env_actions_total_time,
        'trajectory_collection/env_actions/bare_env': bare_env_run_time,
    }
    timing_info = {k: round(1000 * v, 2) for k, v in timing_info.items()}

    return completed_trajectories, num_done_trajectories, timing_info, state