def test_run_all(self): pkg_dir = Path(__file__).parent experiment_cache = ExperimentRunner.run_all( experiment=pkg_dir / 'test_experiment.yml', experiment_config=pkg_dir / 'test_experiment.toml', report_dir=self.test_dir + '/reports', trainer_config_name='the_trainer', reporter_config_name='the_reporter', ENV_PARAM='my_env_param', experiment_cache=pkg_dir / 'test_read_only.json') log_contents = log_capture_string.getvalue() log_capture_string.close() exp1_logs = log_contents.split("\n")[0] exp2_logs = log_contents.split("\n")[1] global_reporting_logs = log_contents.split("\n")[2] exp1_logs = exp1_logs.replace("<", "\"").replace(">", "\"").replace("\'", "\"") exp1_logs = json.loads(exp1_logs) exp2_logs = exp2_logs.replace("<", "\"").replace(">", "\"").replace("\'", "\"") exp2_logs = json.loads(exp2_logs) # Check that the reference values we've put in the config file have been # replaced in the experiment file # exp1_logs has only 2 objects in lparams self.assertEqual(len(exp1_logs['lobjects']), 2) self.assertIn('MockTrainer', exp1_logs['lobjects'][0]) self.assertIn('MockReporter', exp1_logs['lobjects'][1]) # exp2_logs has only one object self.assertEqual(len(exp2_logs['lobjects']), 1) self.assertIn('MockTrainer', exp2_logs['lobjects'][0]) # Check the global reporting self.assertEqual(global_reporting_logs, "global reporting message") self.assertIsInstance(experiment_cache['another_trainer'], MockTrainer) self.assertEqual(experiment_cache['another_trainer'].int_param, 1) self.assertEqual(experiment_cache['another_trainer'].bool_param, True) self.assertEqual(2, ExperimentRunnerTest._reporter_calls) self.assertEqual(2, ExperimentRunnerTest._trainer_calls) self.assertEqual(2, len(ExperimentRunnerTest._configs)) for name, bparam, iparam, fparam, sparam in [ ('config1', True, 1, 1.5, 'hello'), ('config2', False, 2, 2.5, 'world') ]: # assert params where substituted into the experiment properly cfg = ExperimentRunnerTest._configs[name]['the_trainer'] self.assertEqual(bparam, cfg.bool_param) self.assertEqual(iparam, cfg.int_param) self.assertEqual(fparam, cfg.float_param) self.assertEqual(sparam, cfg.str_param) self.assertEqual('my_env_param', cfg.env_param) # assert params were recorded in the reports directory config = toml.load( f'{self.test_dir}/reports/{name}/experiment_config.toml') self.assertEqual(1, len(config)) self.assertEqual(bparam, config[name]['bparam']) self.assertEqual(iparam, config[name]['iparam']) self.assertEqual(fparam, config[name]['fparam']) self.assertEqual(sparam, config[name]['sparam']) self.assertEqual('my_env_param', config[name]['ENV_PARAM']) self.assertEqual( ExperimentConfig.load_experiment_config(pkg_dir / 'test_experiment.yml'), ExperimentConfig.load_experiment_config( f'{self.test_dir}/reports/global-reporting/test_experiment.yml' ))
def run_all(experiment: Union[str, Path], experiment_config: Union[str, Path], report_dir: Union[str, Path], trainer_config_name: str = 'trainer', reporter_config_name: str = 'reporter', experiment_cache: Union[str, Path, Dict] = None, **env_vars) -> ExperimentConfig: """ :param experiment: the experiment config :param experiment_config: the experiment config file. The cfg file should be defined in `ConfigParser <https://docs.python.org/3/library/configparser.html#module-configparser>`_ format such that each section is an experiment configuration. :param report_dir: the directory in which to produce the reports. It's recommended to include a timestamp your report directory so you can preserve previous reports across code changes. E.g. $HOME/reports/run_2019_02_22. :param trainer_config_name: the name of the trainer configuration object. The referenced object should implement `TrainerABC`. :param reporter_config_name: the name of the reporter configuration object. The referenced object should implement `ReporterABC`. :param experiment_cache: the experiment config with cached objects :param env_vars: any additional environment variables, like file system paths :return: the experiment cache """ envs: Dict[str, ConfigEnv] = load_config(Path(experiment_config)) report_path = Path(report_dir) report_path.mkdir(parents=True) # Before starting, save the 3 global files: experiment, configs and cache global_report_dir = report_path / 'global-reporting' global_report_dir.mkdir(parents=True) shutil.copy(src=str(experiment), dst=str(global_report_dir / str(Path(experiment).name))) shutil.copy(src=str(experiment), dst=str(global_report_dir / str(Path(experiment_cache).name))) shutil.copy(src=str(experiment_config), dst=str(global_report_dir / str(Path(experiment_config).name))) experiment_config_cache = {} if experiment_cache: logging.info( "#" * 5 + f"Building a set of read-only objects and cache them for use in different experiment settings" + "#" * 5) experiment_config_cache = ExperimentConfig(experiment_cache, **env_vars) logging.info( "#" * 5 + f"Read-only objects are built and cached for use in different experiment settings" + "#" * 5) aggregate_reports = {} for exp_name, env in envs.items(): exp_report_path = report_path / exp_name exp_report_path.mkdir() log_handler = ExperimentRunner._capture_logs(exp_report_path) try: logging.info('running %s', exp_name) all_vars = dict(env_vars) all_vars.update(env) exp = deepcopy(experiment) if experiment_cache: exp = ExperimentConfig.load_experiment_config(exp) exp.update(experiment_config_cache) experiment_config = ExperimentConfig(exp, **all_vars) trainer: TrainerABC = experiment_config[trainer_config_name] reporter: ReporterABC = experiment_config[reporter_config_name] trainer.train() # Save the config for this particular experiment exp_config = {exp_name: all_vars} with (exp_report_path / 'experiment_config.toml').open('w') as expfile: toml.dump(exp_config, expfile) # Get this particular config reporting and store it in the # aggregated reportings report = reporter.report(exp_name, experiment_config, exp_report_path) aggregate_reports[exp_name] = report finally: ExperimentRunner._stop_log_capture(log_handler) reporter_class = experiment_config[reporter_config_name].__class__ if issubclass(reporter_class, ReporterABC): reporter_class.report_globally(aggregate_reports=aggregate_reports, report_dir=global_report_dir) return experiment_config_cache