def collect_trajectories(self, train=True, n_trajectories=1, n_observations=None, temperature=1.0, abort_fn=None, raw_trajectory=False): key = self._get_rng() env = self.train_env max_timestep = self._max_timestep should_reset = self._should_reset_train_env if not train: # eval env = self.eval_env max_timestep = self._max_timestep_eval should_reset = True # If async, read the required trajectories for the epoch. if self._async_mode: trajs, n_done, timing_info, self._model_state = self.collect_trajectories_async( env, train=train, n_trajectories=n_trajectories, n_observations=n_observations, temperature=temperature) else: trajs, n_done, timing_info, self._model_state = ppo.collect_trajectories( env, policy_fn=self._policy_fun, n_trajectories=n_trajectories, n_observations=n_observations, max_timestep=max_timestep, state=self._model_state, rng=key, len_history_for_policy=self._len_history_for_policy, boundary=self._boundary, reset=should_reset, temperature=temperature, abort_fn=abort_fn, raw_trajectory=raw_trajectory, ) if train: self._n_trajectories_done_since_last_save += n_done return trajs, n_done, timing_info, self._model_state
def collect_trajectories(self, train=True, temperature=1.0, abort_fn=None, raw_trajectory=False): self._rng, key = jax_random.split(self._rng) env = self.train_env max_timestep = self._max_timestep should_reset = self._should_reset if not train: # eval env = self.eval_env max_timestep = self._max_timestep_eval should_reset = True n_trajectories = env.batch_size # If async, read the required trajectories for the epoch. if self._async_mode: trajs, n_done, timing_info, self._model_state = self.collect_trajectories_async( env, train=train, n_trajectories=n_trajectories, temperature=temperature) else: trajs, n_done, timing_info, self._model_state = ppo.collect_trajectories( env, policy_fn=self._policy_fun, n_trajectories=n_trajectories, max_timestep=max_timestep, state=self._model_state, rng=key, len_history_for_policy=self._len_history_for_policy, boundary=self._boundary, reset=should_reset, temperature=temperature, abort_fn=abort_fn, raw_trajectory=raw_trajectory, ) if train: self._n_trajectories_done += n_done return trajs, n_done, timing_info, self._model_state