def try_ckpt(self, agent, env): '''Check then run checkpoint log/eval''' body = agent.body if self.to_ckpt(env, 'log') and self.env.clock.get('epi') > self.warmup_epi: body.train_ckpt() body.log_summary('train') if self.to_ckpt(env, 'eval'): avg_return, avg_len, avg_success, avg_p, avg_r, avg_f1, avg_book_rate = analysis.gen_avg_result( agent, self.eval_env, self.num_eval) body.eval_ckpt(self.eval_env, avg_return, avg_len, avg_success) body.log_summary('eval') if body.eval_reward_ma >= body.best_reward_ma: body.best_reward_ma = body.eval_reward_ma agent.save(ckpt='best') if self.env.clock.get('epi') > self.warmup_epi: if len(body.train_df ) > 1: # need > 1 row to calculate stability metrics = analysis.analyze_session(self.spec, body.train_df, 'train') if len(body.eval_df ) > 1: # need > 1 row to calculate stability metrics = analysis.analyze_session(self.spec, body.eval_df, 'eval')
def _retro_analyze_session(session_spec_path): '''Method to retro analyze a single session given only a path to its spec''' session_spec = util.read(session_spec_path) info_prepath = session_spec['meta']['info_prepath'] for df_mode in ('eval', 'train'): session_df = util.read(f'{info_prepath}_session_df_{df_mode}.csv') analysis.analyze_session(session_spec, session_df, df_mode)
def run(self): if util.in_eval_lab_modes(): self.run_eval() metrics = None else: self.run_rl() metrics = analysis.analyze_session(self.spec, self.agent.body.eval_df, 'eval') self.close() return metrics