Пример #1
0
def run_trial_test_dist(spec_file, spec_name=False):
    spec = spec_util.get(spec_file, spec_name)
    spec = spec_util.override_spec(spec, 'test')
    spec_util.tick(spec, 'trial')
    spec['meta']['distributed'] = 'synced'
    spec['meta']['max_session'] = 2

    trial = Trial(spec)
    # manually run the logic to obtain global nets for testing to ensure global net gets updated
    global_nets = trial.init_global_nets()
    # only test first network
    if ps.is_list(global_nets):  # multiagent only test first
        net = list(global_nets[0].values())[0]
    else:
        net = list(global_nets.values())[0]
    session_metrics_list = trial.parallelize_sessions(global_nets)
    trial_metrics = analysis.analyze_trial(spec, session_metrics_list)
    trial.close()
    assert isinstance(trial_metrics, dict)
Пример #2
0
def run_trial_test_dist(spec_file, spec_name=False):
    spec = spec_util.get(spec_file, spec_name)
    spec = spec_util.override_test_spec(spec)
    info_space = InfoSpace()
    info_space.tick('trial')
    spec['meta']['distributed'] = True
    spec['meta']['max_session'] = 2

    trial = Trial(spec, info_space)
    # manually run the logic to obtain global nets for testing to ensure global net gets updated
    global_nets = trial.init_global_nets()
    # only test first network
    if ps.is_list(global_nets):  # multiagent only test first
        net = list(global_nets[0].values())[0]
    else:
        net = list(global_nets.values())[0]
    session_datas = trial.parallelize_sessions(global_nets)
    trial.session_data_dict = {data.index[0]: data for data in session_datas}
    trial_data = analysis.analyze_trial(trial)
    trial.close()
    assert isinstance(trial_data, pd.DataFrame)