Exemplo n.º 1
0
    def collect_trajectories(self,
                             train=True,
                             n_trajectories=1,
                             n_observations=None,
                             temperature=1.0,
                             abort_fn=None,
                             raw_trajectory=False):
        key = self._get_rng()

        env = self.train_env
        max_timestep = self._max_timestep
        should_reset = self._should_reset_train_env
        if not train:  # eval
            env = self.eval_env
            max_timestep = self._max_timestep_eval
            should_reset = True

        # If async, read the required trajectories for the epoch.
        if self._async_mode:
            trajs, n_done, timing_info, self._model_state = self.collect_trajectories_async(
                env,
                train=train,
                n_trajectories=n_trajectories,
                n_observations=n_observations,
                temperature=temperature)
        else:
            trajs, n_done, timing_info, self._model_state = ppo.collect_trajectories(
                env,
                policy_fn=self._policy_fun,
                n_trajectories=n_trajectories,
                n_observations=n_observations,
                max_timestep=max_timestep,
                state=self._model_state,
                rng=key,
                len_history_for_policy=self._len_history_for_policy,
                boundary=self._boundary,
                reset=should_reset,
                temperature=temperature,
                abort_fn=abort_fn,
                raw_trajectory=raw_trajectory,
            )

        if train:
            self._n_trajectories_done_since_last_save += n_done

        return trajs, n_done, timing_info, self._model_state
Exemplo n.º 2
0
    def collect_trajectories(self,
                             train=True,
                             temperature=1.0,
                             abort_fn=None,
                             raw_trajectory=False):
        self._rng, key = jax_random.split(self._rng)

        env = self.train_env
        max_timestep = self._max_timestep
        should_reset = self._should_reset
        if not train:  # eval
            env = self.eval_env
            max_timestep = self._max_timestep_eval
            should_reset = True

        n_trajectories = env.batch_size

        # If async, read the required trajectories for the epoch.
        if self._async_mode:
            trajs, n_done, timing_info, self._model_state = self.collect_trajectories_async(
                env,
                train=train,
                n_trajectories=n_trajectories,
                temperature=temperature)
        else:
            trajs, n_done, timing_info, self._model_state = ppo.collect_trajectories(
                env,
                policy_fn=self._policy_fun,
                n_trajectories=n_trajectories,
                max_timestep=max_timestep,
                state=self._model_state,
                rng=key,
                len_history_for_policy=self._len_history_for_policy,
                boundary=self._boundary,
                reset=should_reset,
                temperature=temperature,
                abort_fn=abort_fn,
                raw_trajectory=raw_trajectory,
            )

        if train:
            self._n_trajectories_done += n_done

        return trajs, n_done, timing_info, self._model_state