예제 #1
0
def _retro_analyze_trial(trial_spec_path):
    '''Method to retro analyze a single trial given only a path to its spec'''
    trial_spec = util.read(trial_spec_path)
    meta_spec = trial_spec['meta']
    info_prepath = meta_spec['info_prepath']
    session_metrics_list = [util.read(f'{info_prepath}_s{s}_session_metrics_eval.pkl') for s in range(meta_spec['max_session'])]
    analysis.analyze_trial(trial_spec, session_metrics_list)
예제 #2
0
def analyze_eval_trial(spec, info_space, predir):
    '''Create a trial and run analysis to get the trial graph and other trial data'''
    from slm_lab.experiment.control import Trial
    trial = Trial(spec, info_space)
    trial.session_data_dict = session_data_dict_from_file(
        predir, trial.index, ps.get(info_space, 'ckpt'))
    # don't zip for eval analysis, slow otherwise
    analysis.analyze_trial(trial, zip=False)
예제 #3
0
def retro_analyze_trials(predir):
    '''Retro-analyze all trial level datas.'''
    logger.info('Retro-analyzing trials from file')
    from slm_lab.experiment.control import Trial
    filenames = ps.filter_(os.listdir(predir),
                           lambda filename: filename.endswith('_trial_df.csv'))
    for idx, filename in enumerate(filenames):
        filepath = f'{predir}/{filename}'
        prepath = filepath.replace('_trial_df.csv', '')
        spec, info_space = util.prepath_to_spec_info_space(prepath)
        trial_index, _ = util.prepath_to_idxs(prepath)
        trial = Trial(spec, info_space)
        trial.session_data_dict = session_data_dict_from_file(
            predir, trial_index, ps.get(info_space, 'ckpt'))
        # zip only at the last
        zip = (idx == len(filenames) - 1)
        trial_fitness_df = analysis.analyze_trial(trial, zip)

        # write trial_data that was written from ray search
        trial_data_filepath = filepath.replace('_trial_df.csv',
                                               '_trial_data.json')
        if os.path.exists(trial_data_filepath):
            fitness_vec = trial_fitness_df.iloc[0].to_dict()
            fitness = analysis.calc_fitness(trial_fitness_df)
            trial_data = util.read(trial_data_filepath)
            trial_data.update({
                **fitness_vec,
                'fitness': fitness,
                'trial_index': trial_index,
            })
            util.write(trial_data, trial_data_filepath)
예제 #4
0
 def run(self):
     if self.spec['meta'].get('distributed') == False:
         session_metrics_list = self.run_sessions()
     else:
         session_metrics_list = self.run_distributed_sessions()
     metrics = analysis.analyze_trial(self.spec, session_metrics_list)
     self.close()
     return metrics['scalar']
예제 #5
0
 def run(self):
     for s in range(_.get(self.spec, 'meta.max_session')):
         logger.debug(f'session {s}')
         (self.session_df_dict[s],
          self.session_fitness_df_dict[s]) = self.init_session().run()
     self.df, self.fitness_df = analysis.analyze_trial(self)
     self.close()
     return self.df, self.fitness_df
예제 #6
0
 def run(self):
     if self.spec['meta'].get('distributed'):
         session_datas = self.run_distributed_sessions()
     else:
         session_datas = self.run_sessions()
     self.session_data_dict = {
         data.index[0]: data
         for data in session_datas
     }
     self.data = analysis.analyze_trial(self)
     self.close()
     return self.data
예제 #7
0
 def run(self):
     num_cpus = ps.get(self.spec['meta'], 'resources.num_cpus', util.NUM_CPUS)
     info_spaces = []
     for _s in range(self.spec['meta']['max_session']):
         self.info_space.tick('session')
         info_spaces.append(deepcopy(self.info_space))
     if util.get_lab_mode() == 'train' and len(info_spaces) > 1:
         session_datas = util.parallelize_fn(self.init_session_and_run, info_spaces, num_cpus)
     else:  # dont parallelize when debugging to allow render
         session_datas = [self.init_session_and_run(info_space) for info_space in info_spaces]
     self.session_data_dict = {data.index[0]: data for data in session_datas}
     self.data = analysis.analyze_trial(self)
     self.close()
     return self.data
예제 #8
0
def run_trial_test_dist(spec_file, spec_name=False):
    spec = spec_util.get(spec_file, spec_name)
    spec = spec_util.override_spec(spec, 'test')
    spec_util.tick(spec, 'trial')
    spec['meta']['distributed'] = 'synced'
    spec['meta']['max_session'] = 2

    trial = Trial(spec)
    # manually run the logic to obtain global nets for testing to ensure global net gets updated
    global_nets = trial.init_global_nets()
    # only test first network
    if ps.is_list(global_nets):  # multiagent only test first
        net = list(global_nets[0].values())[0]
    else:
        net = list(global_nets.values())[0]
    session_metrics_list = trial.parallelize_sessions(global_nets)
    trial_metrics = analysis.analyze_trial(spec, session_metrics_list)
    trial.close()
    assert isinstance(trial_metrics, dict)
예제 #9
0
 def run(self):
     info_spaces = []
     for _s in range(self.spec['meta']['max_session']):
         self.info_space.tick('session')
         info_spaces.append(deepcopy(self.info_space))
     if self.spec['meta']['train_mode']:
         session_datas = util.parallelize_fn(self.init_session_and_run,
                                             info_spaces)
     else:  # dont parallelize when debugging to allow render
         session_datas = [
             self.init_session_and_run(info_space)
             for info_space in info_spaces
         ]
     self.session_data_dict = {
         data.index[0]: data
         for data in session_datas
     }
     self.data = analysis.analyze_trial(self)
     self.close()
     return self.data
예제 #10
0
def run_trial_test_dist(spec_file, spec_name=False):
    spec = spec_util.get(spec_file, spec_name)
    spec = spec_util.override_test_spec(spec)
    info_space = InfoSpace()
    info_space.tick('trial')
    spec['meta']['distributed'] = True
    spec['meta']['max_session'] = 2

    trial = Trial(spec, info_space)
    # manually run the logic to obtain global nets for testing to ensure global net gets updated
    global_nets = trial.init_global_nets()
    # only test first network
    if ps.is_list(global_nets):  # multiagent only test first
        net = list(global_nets[0].values())[0]
    else:
        net = list(global_nets.values())[0]
    session_datas = trial.parallelize_sessions(global_nets)
    trial.session_data_dict = {data.index[0]: data for data in session_datas}
    trial_data = analysis.analyze_trial(trial)
    trial.close()
    assert isinstance(trial_data, pd.DataFrame)