def _convert_train_freq(self) -> None: """ Convert `train_freq` parameter (int or tuple) to a TrainFreq object. """ if not isinstance(self.train_freq, TrainFreq): train_freq = self.train_freq # The value of the train frequency will be checked later if not isinstance(train_freq, tuple): train_freq = (train_freq, "step") try: train_freq = (train_freq[0], TrainFrequencyUnit(train_freq[1])) except ValueError: raise ValueError( f"The unit of the `train_freq` must be either 'step' or 'episode' not '{train_freq[1]}'!" ) if not isinstance(train_freq[0], int): raise ValueError( f"The frequency of `train_freq` must be an integer and not {train_freq[0]}" ) self.train_freq = TrainFreq(*train_freq)
def setup_buffer(self, num_samples): assert self.n_envs == 1, "I don't think multiple envs works for offline policies, but you can check and make suitable updates" self._old_buffer = self.replay_buffer callback = DoNothingCallback(self) self.replay_buffer = ReplayBuffer( num_samples, self.observation_space, self.action_space, self.device, ) self.env.reset() train_freq = TrainFreq(num_samples, TrainFrequencyUnit("step")) self.collect_rollouts( self.env, train_freq=train_freq, action_noise=self.action_noise, callback=callback, learning_starts=0, replay_buffer=self.replay_buffer, log_interval=10, )