def collect_trajectories_async(self, env, train=True, n_trajectories=1, n_observations=None, temperature=1.0): """Collects trajectories in an async manner.""" assert self._async_mode # TODO(afrozm): Make this work, should be easy. # NOTE: We still collect whole trajectories, however the async trajectory # collectors now will poll not on the amount of trajectories collected but # on the amount of observations in the completed trajectories and bail out. assert n_observations is None # trajectories/train and trajectories/eval are the two subdirectories. trajectory_dir = os.path.join(self._output_dir, 'trajectories', 'train' if train else 'eval') epoch = self.epoch logging.info( 'Loading [%s] trajectories from dir [%s] for epoch [%s] and temperature' ' [%s]', n_trajectories, trajectory_dir, epoch, temperature) bt = trajectory.BatchTrajectory.load_from_directory( trajectory_dir, epoch=epoch, temperature=temperature, wait_forever=True, n_trajectories=n_trajectories) if bt is None: logging.error( 'Couldn\'t load [%s] trajectories from dir [%s] for epoch [%s] and ' 'temperature [%s]', n_trajectories, trajectory_dir, epoch, temperature) assert bt # Doing this is important, since we want to modify `env` so that it looks # like `env` was actually played and the trajectories came from it. env.trajectories = bt trajs = env_problem_utils.get_completed_trajectories_from_env( env, n_trajectories) n_done = len(trajs) timing_info = {} return trajs, n_done, timing_info, self._model_state
def collect_trajectories_async(self, env, train=True, n_trajectories=1, temperature=1.0): """Collects trajectories in an async manner.""" assert self._async_mode # trajectories/train and trajectories/eval are the two subdirectories. trajectory_dir = os.path.join(self._output_dir, "trajectories", "train" if train else "eval") epoch = self.epoch logging.info( "Loading [%s] trajectories from dir [%s] for epoch [%s] and temperature" " [%s]", n_trajectories, trajectory_dir, epoch, temperature) bt = trajectory.BatchTrajectory.load_from_directory( trajectory_dir, epoch=epoch, temperature=temperature, wait_forever=True, n_trajectories=n_trajectories) if bt is None: logging.error( "Couldn't load [%s] trajectories from dir [%s] for epoch [%s] and " "temperature [%s]", n_trajectories, trajectory_dir, epoch, temperature) assert bt # Doing this is important, since we want to modify `env` so that it looks # like `env` was actually played and the trajectories came from it. env.trajectories = bt trajs = env_problem_utils.get_completed_trajectories_from_env( env, n_trajectories) n_done = len(trajs) timing_info = {} return trajs, n_done, timing_info, self._model_state
def collect_trajectories( env, policy_fn, n_trajectories=1, n_observations=None, max_timestep=None, reset=True, len_history_for_policy=32, boundary=32, state=None, temperature=1.0, rng=None, abort_fn=None, raw_trajectory=False, ): """Collect trajectories with the given policy net and behaviour. Args: env: A gym env interface, for now this is not-batched. policy_fn: Callable (observations(B,T+1), actions(B, T+1, C)) -> log-probabs(B, T+1, C, A). n_trajectories: int, number of trajectories. n_observations: int, number of non-terminal observations. NOTE: Exactly one of `n_trajectories` and `n_observations` should be None. max_timestep: int or None, the index of the maximum time-step at which we return the trajectory, None for ending a trajectory only when env returns done. reset: bool, true if we want to reset the envs. The envs are also reset if max_max_timestep is None or < 0 len_history_for_policy: int or None, the maximum history to keep for applying the policy on. If None, use the full history. boundary: int, pad the sequences to the multiples of this number. state: state for `policy_fn`. temperature: (float) temperature to sample action from policy_fn. rng: jax rng, splittable. abort_fn: callable, If not None, then at every env step call and abort the trajectory collection if it returns True, if so reset the env and return None. raw_trajectory: bool, if True a list of trajectory.Trajectory objects is returned, otherwise a list of numpy representations of `trajectory.Trajectory` is returned. Returns: A tuple (trajectory, number of trajectories that are done) trajectory: list of (observation, action, reward) tuples, where each element `i` is a tuple of numpy arrays with shapes as follows: observation[i] = (B, T_i + 1) action[i] = (B, T_i) reward[i] = (B, T_i) """ assert isinstance(env, env_problem.EnvProblem) # We need to reset all environments, if we're coming here the first time. if reset or max_timestep is None or max_timestep <= 0: env.reset() else: # Clear completed trajectories held internally. env.trajectories.clear_completed_trajectories() num_done_trajectories = 0 # The stopping criterion, returns True if we should stop. def should_stop(): if n_trajectories is not None: assert n_observations is None return env.trajectories.num_completed_trajectories >= n_trajectories assert n_observations is not None # The number of non-terminal observations is what we want. return (env.trajectories.num_completed_time_steps - env.trajectories.num_completed_trajectories) >= n_observations policy_application_total_time = 0 env_actions_total_time = 0 bare_env_run_time = 0 while not should_stop(): # Check if we should abort and return nothing. if abort_fn and abort_fn(): # We should also reset the environment, since it will have some # trajectories (complete and incomplete) that we want to discard. env.reset() return None, 0, {}, state # Get all the observations for all the active trajectories. # Shape is (B, T+1) + OBS # Bucket on whatever length is needed. padded_observations, lengths = env.trajectories.observations_np( boundary=boundary, len_history_for_policy=len_history_for_policy) B = padded_observations.shape[0] # pylint: disable=invalid-name assert B == env.batch_size assert (B, ) == lengths.shape t1 = time.time() log_probs, value_preds, state, rng = policy_fn(padded_observations, lengths, state=state, rng=rng) policy_application_total_time += (time.time() - t1) assert B == log_probs.shape[0] actions = tl.gumbel_sample(log_probs, temperature) if (isinstance(env.action_space, gym.spaces.Discrete) and (actions.shape[1] == 1)): actions = onp.squeeze(actions, axis=1) # Step through the env. t1 = time.time() _, _, dones, env_infos = env.step(actions, infos={ 'log_prob_actions': log_probs, 'value_predictions': value_preds, }) env_actions_total_time += (time.time() - t1) bare_env_run_time += sum(info['__bare_env_run_time__'] for info in env_infos) # Count the number of done trajectories, the others could just have been # truncated. num_done_trajectories += onp.sum(dones) # Get the indices where we are done ... done_idxs = env_problem_utils.done_indices(dones) # ... and reset those. t1 = time.time() if done_idxs.size: env.reset(indices=done_idxs) env_actions_total_time += (time.time() - t1) if max_timestep is None or max_timestep < 1: continue # Are there any trajectories that have exceeded the time-limit we want. lengths = env.trajectories.trajectory_lengths exceeded_time_limit_idxs = env_problem_utils.done_indices( lengths > max_timestep) # If so, reset these as well. t1 = time.time() if exceeded_time_limit_idxs.size: # This just cuts the trajectory, doesn't reset the env, so it continues # from where it left off. env.truncate(indices=exceeded_time_limit_idxs, num_to_keep=1) env_actions_total_time += (time.time() - t1) # We have the trajectories we need, return a list of triples: # (observations, actions, rewards) completed_trajectories = ( env_problem_utils.get_completed_trajectories_from_env( env, env.trajectories.num_completed_trajectories, raw_trajectory=raw_trajectory)) timing_info = { 'trajectory_collection/policy_application': policy_application_total_time, 'trajectory_collection/env_actions': env_actions_total_time, 'trajectory_collection/env_actions/bare_env': bare_env_run_time, } timing_info = {k: round(1000 * v, 2) for k, v in timing_info.items()} return completed_trajectories, num_done_trajectories, timing_info, state