Ejemplo n.º 1
0
    def _convert_train_freq(self) -> None:
        """
        Convert `train_freq` parameter (int or tuple)
        to a TrainFreq object.
        """
        if not isinstance(self.train_freq, TrainFreq):
            train_freq = self.train_freq

            # The value of the train frequency will be checked later
            if not isinstance(train_freq, tuple):
                train_freq = (train_freq, "step")

            try:
                train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1]))
            except ValueError:
                raise ValueError(
                    f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!"
                )

            if not isinstance(train_freq[0], int):
                raise ValueError(
                    f"The frequency of `train_freq` must be an integer and not {train_freq[0]}"
                )

            self.train_freq = TrainFreq(*train_freq)
 def setup_buffer(self, num_samples):
     assert self.n_envs == 1, "I don't think multiple envs works for offline policies, but you can check and make suitable updates"
     self._old_buffer = self.replay_buffer
     callback = DoNothingCallback(self)
     self.replay_buffer = ReplayBuffer(
         num_samples,
         self.observation_space,
         self.action_space,
         self.device,
     )
     self.env.reset()
     train_freq = TrainFreq(num_samples, TrainFrequencyUnit("step"))
     self.collect_rollouts(
         self.env,
         train_freq=train_freq,
         action_noise=self.action_noise,
         callback=callback,
         learning_starts=0,
         replay_buffer=self.replay_buffer,
         log_interval=10,
     )