def test_analyze_gather_tb(tmpdir: str): config_updates = dict(local_dir=tmpdir, run_name="test") config_updates.update(PARALLEL_CONFIG_LOW_RESOURCE) parallel_run = parallel_ex.run(named_configs=["generate_test_data"], config_updates=config_updates) assert parallel_run.status == "COMPLETED" run = analysis_ex.run(command_name="gather_tb_directories", config_updates=dict(source_dir=tmpdir, )) assert run.status == "COMPLETED" assert isinstance(run.result, dict) assert run.result["n_tb_dirs"] == 4
def test_analyze_imitation(tmpdir: str, run_name: Optional[str], expected_entries: int): run = analysis_ex.run(command_name="analyze_imitation", config_updates=dict( source_dir="tests/data/imit_benchmark", run_name=run_name, csv_output_path=osp.join(tmpdir, "analysis.csv"), verbose=True, )) assert run.status == 'COMPLETED' df = pd.DataFrame(run.result) assert df.shape[0] == expected_entries
def check(run_name: Optional[str], count: int) -> None: run = analysis_ex.run(command_name="analyze_imitation", config_updates=dict( source_dir=sacred_logs_dir, run_name=run_name, csv_output_path=osp.join( tmpdir, "analysis.csv"), verbose=True, )) assert run.status == 'COMPLETED' df = pd.DataFrame(run.result) assert df.shape[0] == count
def test_analyze_gather_tb(tmpdir: str): parallel_run = parallel_ex.run(named_configs=["generate_test_data"], config_updates=dict( local_dir=tmpdir, run_name="test", )) assert parallel_run.status == 'COMPLETED' run = analysis_ex.run(command_name="gather_tb_directories", config_updates=dict(source_dir=tmpdir, )) assert run.status == 'COMPLETED' assert isinstance(run.result, dict) assert run.result["n_tb_dirs"] == 4