コード例 #1
0
    def try_ckpt(self, agent, env):
        '''Check then run checkpoint log/eval'''
        body = agent.body
        if self.to_ckpt(env,
                        'log') and self.env.clock.get('epi') > self.warmup_epi:
            body.train_ckpt()
            body.log_summary('train')

        if self.to_ckpt(env, 'eval'):
            avg_return, avg_len, avg_success, avg_p, avg_r, avg_f1, avg_book_rate = analysis.gen_avg_result(
                agent, self.eval_env, self.num_eval)
            body.eval_ckpt(self.eval_env, avg_return, avg_len, avg_success)
            body.log_summary('eval')
            if body.eval_reward_ma >= body.best_reward_ma:
                body.best_reward_ma = body.eval_reward_ma
                agent.save(ckpt='best')
            if self.env.clock.get('epi') > self.warmup_epi:
                if len(body.train_df
                       ) > 1:  # need > 1 row to calculate stability
                    metrics = analysis.analyze_session(self.spec,
                                                       body.train_df, 'train')
                if len(body.eval_df
                       ) > 1:  # need > 1 row to calculate stability
                    metrics = analysis.analyze_session(self.spec, body.eval_df,
                                                       'eval')
コード例 #2
0
def _retro_analyze_session(session_spec_path):
    '''Method to retro analyze a single session given only a path to its spec'''
    session_spec = util.read(session_spec_path)
    info_prepath = session_spec['meta']['info_prepath']
    for df_mode in ('eval', 'train'):
        session_df = util.read(f'{info_prepath}_session_df_{df_mode}.csv')
        analysis.analyze_session(session_spec, session_df, df_mode)
コード例 #3
0
 def run(self):
     if util.in_eval_lab_modes():
         self.run_eval()
         metrics = None
     else:
         self.run_rl()
         metrics = analysis.analyze_session(self.spec, self.agent.body.eval_df, 'eval')
     self.close()
     return metrics