def step_epochs(self): """Step through each epoch. This function returns a magic generator. When iterated through, this generator automatically performs services such as snapshotting and log management. It is used inside train() in each algorithm. The generator initializes two variables: `self.step_itr` and `self.step_path`. To use the generator, these two have to be updated manually in each epoch, as the example shows below. Yields: int: The next training epoch. Examples: for epoch in runner.step_epochs(): runner.step_path = runner.obtain_samples(...) self.train_once(...) runner.step_itr += 1 """ try: self._start_worker() self._start_time = time.time() self.step_itr = (self.train_args.start_epoch * self.train_args.n_epoch_cycles) self.step_path = None for epoch in range(self.train_args.start_epoch, self.train_args.n_epochs): self._itr_start_time = time.time() with logger.prefix('epoch #%d | ' % epoch): yield epoch save_path = (self.step_path if self.train_args.store_paths else None) self.save(epoch, save_path) self.log_diagnostics(self.train_args.pause_for_plot) logger.dump_all(self.step_itr) tabular.clear() finally: self._shutdown_worker()
def _train(self, n_epochs, n_epoch_cycles, batch_size, plot, store_paths, pause_for_plot, start_epoch=0): """Start actual training. Args: n_epochs(int): Number of epochs. n_epoch_cycles(int): Number of batches of samples in each epoch. This is only useful for off-policy algorithm. For on-policy algorithm this value should always be 1. batch_size(int): Number of steps in batch. plot(bool): Visualize policy by doing rollout after each epoch. store_paths(bool): Save paths in snapshot. pause_for_plot(bool): Pause for plot. start_epoch: (internal) The starting epoch. Use for experiment resuming. Returns: The average return in last epoch cycle. """ assert self.has_setup, ('Use Runner.setup() to setup runner before ' 'training.') # Save arguments for restore self.train_args = SimpleNamespace(n_epochs=n_epochs, n_epoch_cycles=n_epoch_cycles, batch_size=batch_size, plot=plot, store_paths=store_paths, pause_for_plot=pause_for_plot, start_epoch=start_epoch) self.start_worker() self.start_time = time.time() itr = start_epoch * n_epoch_cycles last_return = None for epoch in range(start_epoch, n_epochs): self.itr_start_time = time.time() paths = None with logger.prefix('epoch #%d | ' % epoch): for cycle in range(n_epoch_cycles): paths = self.obtain_samples(itr, batch_size) paths = self.sampler.process_samples(itr, paths) last_return = self.algo.train_once(itr, paths) itr += 1 self.save(epoch, paths if store_paths else None) self.log_diagnostics(pause_for_plot) logger.dump_all(itr) tabular.clear() self.shutdown_worker() return last_return
def train(self, n_epochs, batch_size=None, plot=False, store_episodes=False, pause_for_plot=False): """Start training. Args: n_epochs (int): Number of epochs. batch_size (int or None): Number of environment steps in one batch. plot (bool): Visualize an episode from the policy after each epoch. store_episodes (bool): Save episodes in snapshot. pause_for_plot (bool): Pause for plot. Raises: NotSetupError: If train() is called before setup(). Returns: float: The average return in last epoch cycle. """ self.batch_size = batch_size self.store_episodes = store_episodes self.pause_for_plot = pause_for_plot if not self._has_setup: raise NotSetupError( 'Use setup() to setup trainer before training.') self._plot = plot returns = [] for itr in range(self.start_itr, self.n_itr): with logger.prefix(f'itr #{itr} | '): # train policy self._algo.train(self) # compute irl and update reward function logger.log('Obtaining paths...') paths = self.obtain_samples(itr) logger.log('Processing paths...') paths = self._train_irl(paths, itr=itr) samples_data = self.process_samples(itr, paths) logger.log('Logging diagnostics...') logger.log('Time %.2f s' % (time.time() - self._start_time)) logger.log('EpochTime %.2f s' % (time.time() - self._itr_start_time)) tabular.record('TotalEnvSteps', self._stats.total_env_steps) self.log_diagnostics(paths) logger.log('Optimizing policy...') logger.log('Saving snapshot...') self.save(itr, paths=paths) logger.log('Saved') tabular.record('Time', time.time() - self._start_time) tabular.record('ItrTime', time.time() - self._itr_start_time) logger.dump_all(self.step_itr) tabular.clear() self._shutdown_worker() return