예제 #1
0
    def test_run_all(self):

        pkg_dir = Path(__file__).parent
        ExperimentRunner.run_all(experiment=pkg_dir / 'test_experiment.json',
                                 experiment_config=pkg_dir /
                                 'test_experiment.cfg',
                                 report_dir=self.test_dir + '/reports',
                                 trainer_config_name='the_trainer',
                                 reporter_config_name='the_reporter',
                                 ENV_PARAM='my_env_param')

        self.assertEqual(2, ExperimentRunnerTest._reporter_calls)
        self.assertEqual(2, ExperimentRunnerTest._trainer_calls)

        self.assertEqual(2, len(ExperimentRunnerTest._configs))

        for name, bparam, iparam, fparam, sparam in [
            ('config1', True, 1, 1.5, 'hello'),
            ('config2', False, 2, 2.5, 'world')
        ]:

            # assert params where substituted into the experiment properly
            cfg = ExperimentRunnerTest._configs[name]['the_trainer']
            self.assertEqual(bparam, cfg.bool_param)
            self.assertEqual(iparam, cfg.int_param)
            self.assertEqual(fparam, cfg.float_param)
            self.assertEqual(sparam, cfg.str_param)
            self.assertEqual('my_env_param', cfg.env_param)

            # assert params were recorded in the reports directory
            cp = configparser.ConfigParser()
            cp.read(f'{self.test_dir}/reports/{name}/experiment.cfg')
            self.assertEqual(1, len(cp.sections()))
            self.assertEqual(bparam, cp.getboolean(name, 'bparam'))
            self.assertEqual(iparam, cp.getint(name, 'iparam'))
            self.assertEqual(fparam, cp.getfloat(name, 'fparam'))
            self.assertEqual(sparam, cp.get(name, 'sparam'))
            self.assertEqual('my_env_param', cp.get(name, 'ENV_PARAM'))

            self.assertEqual(
                ExperimentConfig.load_experiment_json(pkg_dir /
                                                      'test_experiment.json'),
                ExperimentConfig.load_experiment_json(
                    f'{self.test_dir}/reports/{name}/experiment.json'))
예제 #2
0
    def run_all(experiment: Union[str, Path, Dict],
                experiment_config: Union[str, Path],
                report_dir: Union[str, Path],
                trainer_config_name: str = 'trainer',
                reporter_config_name: str = 'reporter',
                **env_vars) -> None:
        """
        :param experiment: the experiment config
        :param experiment_config: the experiment config file. The cfg file should be defined in `ConfigParser
               <https://docs.python.org/3/library/configparser.html#module-configparser>`_ format such that
               each section is an experiment configuration.
        :param report_dir: the directory in which to produce the reports. It's recommended to include a timestamp your report directory so you
               can preserve previous reports across code changes. E.g. $HOME/reports/run_2019_02_22.
        :param trainer_config_name: the name of the trainer configuration object. The referenced object should implement `TrainerABC`.
        :param reporter_config_name: the name of the reporter configuration object. The referenced object should implement `ReporterABC`.
        :param env_vars: any additional environment variables, like file system paths
        :return: None
        """

        envs: Dict[str, ConfigEnv] = load_config(Path(experiment_config))

        report_path = Path(report_dir)
        report_path.mkdir(parents=True)

        for exp_name, env in envs.items():
            exp_report_path = report_path / exp_name
            exp_report_path.mkdir()
            log_handler = ExperimentRunner._capture_logs(exp_report_path)
            try:
                logging.info('running %s', exp_name)
                all_vars = dict(env_vars)
                all_vars.update(env)
                experiment_config = ExperimentConfig(experiment, **all_vars)
                trainer: TrainerABC = experiment_config[trainer_config_name]
                reporter: ReporterABC = experiment_config[reporter_config_name]
                trainer.train()
                exp_json = ExperimentConfig.load_experiment_json(experiment)
                ExperimentRunner._write_config(exp_name, exp_json, all_vars,
                                               exp_report_path)
                reporter.report(exp_name, experiment_config, exp_report_path)
            finally:
                ExperimentRunner._stop_log_capture(log_handler)