def _get_diagnostics(self): algo_log = OrderedDict() append_log(algo_log, self.replay_buffer.get_diagnostics(), prefix='replay_buffer/') append_log(algo_log, self.trainer.get_diagnostics(), prefix='trainer/') # Exploration append_log(algo_log, self.expl_data_collector.get_diagnostics(), prefix='exploration/') expl_paths = self.expl_data_collector.get_epoch_paths() if hasattr(self.expl_env, 'get_diagnostics'): append_log(algo_log, self.expl_env.get_diagnostics(expl_paths), prefix='exploration/') append_log(algo_log, eval_util.get_generic_path_information(expl_paths), prefix="exploration/") # Eval append_log(algo_log, self.eval_data_collector.get_diagnostics(), prefix='evaluation/') eval_paths = self.eval_data_collector.get_epoch_paths() if hasattr(self.eval_env, 'get_diagnostics'): append_log(algo_log, self.eval_env.get_diagnostics(eval_paths), prefix='evaluation/') append_log(algo_log, eval_util.get_generic_path_information(eval_paths), prefix="evaluation/") timer.stamp('logging') append_log(algo_log, _get_epoch_timings()) algo_log['epoch'] = self.epoch return algo_log
def _train(self): done = (self.epoch == self.num_epochs) if done: return OrderedDict(), done self.training_mode(False) if self.min_num_steps_before_training > 0 and self.epoch == 0: self.expl_data_collector.collect_new_steps( self.max_path_length, self.min_num_steps_before_training, discard_incomplete_paths=False, ) init_expl_paths = self.expl_data_collector.get_epoch_paths() self.replay_buffer.add_paths(init_expl_paths) self.expl_data_collector.end_epoch(-1) num_trains_per_expl_step = self.num_trains_per_train_loop // self.num_expl_steps_per_train_loop self.eval_data_collector.collect_new_paths( self.max_path_length, self.num_eval_steps_per_epoch, discard_incomplete_paths=True, ) timer.stamp('evaluation sampling') for _ in range(self.num_train_loops_per_epoch): for _ in range(self.num_expl_steps_per_train_loop): self.expl_data_collector.collect_new_steps( self.max_path_length, 1, # num steps discard_incomplete_paths=False, ) timer.stamp('exploration sampling', unique=False) self.training_mode(True) for _ in range(num_trains_per_expl_step): train_data = self.replay_buffer.random_batch( self.batch_size) self.trainer.train(train_data) timer.stamp('training', unique=False) self.training_mode(False) new_expl_paths = self.expl_data_collector.get_epoch_paths() self.replay_buffer.add_paths(new_expl_paths) timer.stamp('data storing', unique=False) log_stats = self._get_diagnostics() return log_stats, False
def _train(self): done = (self.epoch == self.num_epochs) if done: return OrderedDict(), done if self.epoch == 0 and self.min_num_steps_before_training > 0 and \ not self.batch_rl and not self.dataset: init_expl_paths = self.expl_data_collector.collect_new_paths( self.max_path_length, self.min_num_steps_before_training, discard_incomplete_paths=False, ) self.replay_buffer.add_paths(init_expl_paths) self.expl_data_collector.end_epoch(-1) if self.dataset: pass elif self.q_learning_alg: self.eval_data_collector.collect_new_paths( self.policy_fn, self.max_path_length, self.num_eval_steps_per_epoch, discard_incomplete_paths=True) else: self.eval_data_collector.collect_new_paths( self.max_path_length, self.num_eval_steps_per_epoch, discard_incomplete_paths=True, ) if self.vae_eval_data_collector is not None: self.vae_eval_data_collector.collect_new_paths( self.policy_fn_vae, self.max_path_length, self.num_eval_steps_per_epoch, discard_incomplete_paths=True, ) timer.stamp('evaluation sampling') for _ in range(self.num_train_loops_per_epoch): if not self.batch_rl and not self.dataset: new_expl_paths = self.expl_data_collector.collect_new_paths( self.max_path_length, self.num_expl_steps_per_train_loop, discard_incomplete_paths=False, ) timer.stamp('exploration sampling', unique=False) self.replay_buffer.add_paths(new_expl_paths) timer.stamp('data storing', unique=False) elif self.eval_both: # Now evaluate the policy here: policy_fn = self.policy_fn if self.trainer.discrete: policy_fn = self.policy_fn_discrete self.expl_data_collector.collect_new_paths( policy_fn, self.max_path_length, self.num_eval_steps_per_epoch, discard_incomplete_paths=True, ) timer.stamp('policy fn evaluation') for _ in range(self.num_trains_per_train_loop): train_data = self.replay_buffer.random_batch(self.batch_size) self.trainer.train(train_data) timer.stamp('training', unique=False) log_stats = self._get_diagnostics() return log_stats, False
def _end_epoch(self): self._train_vae(self.epoch) timer.stamp('vae training') super()._end_epoch()