def test_run_all(self):
        pkg_dir = Path(__file__).parent

        experiment_cache = ExperimentRunner.run_all(
            experiment=pkg_dir / 'test_experiment.yml',
            experiment_config=pkg_dir / 'test_experiment.toml',
            report_dir=self.test_dir + '/reports',
            trainer_config_name='the_trainer',
            reporter_config_name='the_reporter',
            ENV_PARAM='my_env_param',
            experiment_cache=pkg_dir / 'test_read_only.json')
        log_contents = log_capture_string.getvalue()
        log_capture_string.close()

        exp1_logs = log_contents.split("\n")[0]
        exp2_logs = log_contents.split("\n")[1]
        global_reporting_logs = log_contents.split("\n")[2]

        exp1_logs = exp1_logs.replace("<",
                                      "\"").replace(">",
                                                    "\"").replace("\'", "\"")
        exp1_logs = json.loads(exp1_logs)
        exp2_logs = exp2_logs.replace("<",
                                      "\"").replace(">",
                                                    "\"").replace("\'", "\"")
        exp2_logs = json.loads(exp2_logs)

        # Check that the reference values we've put in the config file have been
        # replaced in the experiment file

        # exp1_logs has only 2 objects in lparams
        self.assertEqual(len(exp1_logs['lobjects']), 2)
        self.assertIn('MockTrainer', exp1_logs['lobjects'][0])
        self.assertIn('MockReporter', exp1_logs['lobjects'][1])

        # exp2_logs has only one object
        self.assertEqual(len(exp2_logs['lobjects']), 1)
        self.assertIn('MockTrainer', exp2_logs['lobjects'][0])

        # Check the global reporting
        self.assertEqual(global_reporting_logs, "global reporting message")

        self.assertIsInstance(experiment_cache['another_trainer'], MockTrainer)
        self.assertEqual(experiment_cache['another_trainer'].int_param, 1)
        self.assertEqual(experiment_cache['another_trainer'].bool_param, True)

        self.assertEqual(2, ExperimentRunnerTest._reporter_calls)
        self.assertEqual(2, ExperimentRunnerTest._trainer_calls)

        self.assertEqual(2, len(ExperimentRunnerTest._configs))

        for name, bparam, iparam, fparam, sparam in [
            ('config1', True, 1, 1.5, 'hello'),
            ('config2', False, 2, 2.5, 'world')
        ]:
            # assert params where substituted into the experiment properly
            cfg = ExperimentRunnerTest._configs[name]['the_trainer']
            self.assertEqual(bparam, cfg.bool_param)
            self.assertEqual(iparam, cfg.int_param)
            self.assertEqual(fparam, cfg.float_param)
            self.assertEqual(sparam, cfg.str_param)
            self.assertEqual('my_env_param', cfg.env_param)

            # assert params were recorded in the reports directory
            config = toml.load(
                f'{self.test_dir}/reports/{name}/experiment_config.toml')
            self.assertEqual(1, len(config))
            self.assertEqual(bparam, config[name]['bparam'])
            self.assertEqual(iparam, config[name]['iparam'])
            self.assertEqual(fparam, config[name]['fparam'])
            self.assertEqual(sparam, config[name]['sparam'])
            self.assertEqual('my_env_param', config[name]['ENV_PARAM'])

            self.assertEqual(
                ExperimentConfig.load_experiment_config(pkg_dir /
                                                        'test_experiment.yml'),
                ExperimentConfig.load_experiment_config(
                    f'{self.test_dir}/reports/global-reporting/test_experiment.yml'
                ))
Beispiel #2
0
    def run_all(experiment: Union[str, Path],
                experiment_config: Union[str, Path],
                report_dir: Union[str, Path],
                trainer_config_name: str = 'trainer',
                reporter_config_name: str = 'reporter',
                experiment_cache: Union[str, Path, Dict] = None,
                **env_vars) -> ExperimentConfig:
        """
        :param experiment: the experiment config
        :param experiment_config: the experiment config file. The cfg file should be defined in `ConfigParser
               <https://docs.python.org/3/library/configparser.html#module-configparser>`_ format such that
               each section is an experiment configuration.
        :param report_dir: the directory in which to produce the reports. It's recommended to include a timestamp your report directory so you
               can preserve previous reports across code changes. E.g. $HOME/reports/run_2019_02_22.
        :param trainer_config_name: the name of the trainer configuration object. The referenced object should implement `TrainerABC`.
        :param reporter_config_name: the name of the reporter configuration object. The referenced object should implement `ReporterABC`.
        :param experiment_cache: the experiment config with cached objects
        :param env_vars: any additional environment variables, like file system paths
        :return: the experiment cache
        """

        envs: Dict[str, ConfigEnv] = load_config(Path(experiment_config))

        report_path = Path(report_dir)
        report_path.mkdir(parents=True)

        # Before starting, save the 3 global files: experiment, configs and cache
        global_report_dir = report_path / 'global-reporting'
        global_report_dir.mkdir(parents=True)
        shutil.copy(src=str(experiment),
                    dst=str(global_report_dir / str(Path(experiment).name)))
        shutil.copy(src=str(experiment),
                    dst=str(global_report_dir /
                            str(Path(experiment_cache).name)))
        shutil.copy(src=str(experiment_config),
                    dst=str(global_report_dir /
                            str(Path(experiment_config).name)))

        experiment_config_cache = {}
        if experiment_cache:
            logging.info(
                "#" * 5 +
                f"Building a set of read-only objects and cache them for use in different experiment settings"
                + "#" * 5)
            experiment_config_cache = ExperimentConfig(experiment_cache,
                                                       **env_vars)
            logging.info(
                "#" * 5 +
                f"Read-only objects are built and cached for use in different experiment settings"
                + "#" * 5)

        aggregate_reports = {}
        for exp_name, env in envs.items():
            exp_report_path = report_path / exp_name
            exp_report_path.mkdir()
            log_handler = ExperimentRunner._capture_logs(exp_report_path)
            try:
                logging.info('running %s', exp_name)
                all_vars = dict(env_vars)
                all_vars.update(env)

                exp = deepcopy(experiment)
                if experiment_cache:
                    exp = ExperimentConfig.load_experiment_config(exp)
                    exp.update(experiment_config_cache)

                experiment_config = ExperimentConfig(exp, **all_vars)
                trainer: TrainerABC = experiment_config[trainer_config_name]
                reporter: ReporterABC = experiment_config[reporter_config_name]
                trainer.train()

                # Save the config for this particular experiment
                exp_config = {exp_name: all_vars}
                with (exp_report_path /
                      'experiment_config.toml').open('w') as expfile:
                    toml.dump(exp_config, expfile)

                # Get this particular config reporting and store it in the
                # aggregated reportings
                report = reporter.report(exp_name, experiment_config,
                                         exp_report_path)
                aggregate_reports[exp_name] = report
            finally:
                ExperimentRunner._stop_log_capture(log_handler)

        reporter_class = experiment_config[reporter_config_name].__class__
        if issubclass(reporter_class, ReporterABC):
            reporter_class.report_globally(aggregate_reports=aggregate_reports,
                                           report_dir=global_report_dir)

        return experiment_config_cache