Esempio n. 1
0
 def test_fail_invalid_hparams_type(self):
     run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
     with self.assertRaisesRegexp(ValueError, _INVALID_HPARAMS_ERR_MSG):
         learn_runner.run(build_experiment_for_run_config,
                          run_config=run_config,
                          schedule="local_run",
                          hparams=["hparams"])
Esempio n. 2
0
 def test_fail_output_dir_and_run_config_are_both_set(self):
     with self.assertRaisesRegexp(
             ValueError, _CANNOT_SET_BOTH_OUTPUT_DIR_AND_CONFIG_MSG):
         learn_runner.run(build_experiment,
                          output_dir=_MODIR_DIR,
                          schedule="simple_task",
                          run_config=run_config_lib.RunConfig())
Esempio n. 3
0
 def test_fail_hparams_are_set(self):
     hparams = _HPARAMS
     with self.assertRaisesRegexp(
             ValueError, _HPARAMS_CANNOT_BE_SET_FOR_OUTPUT_DIR_MSG):
         learn_runner.run(build_experiment,
                          _MODIR_DIR,
                          schedule="simple_task",
                          hparams=hparams)
Esempio n. 4
0
    def test_fail_not_experiment(self):
        def _experiment_fn(run_config, hparams):
            del run_config, hparams  # unused.
            return "not experiment"

        run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
        with self.assertRaisesRegexp(TypeError, _NOT_EXP_TYPE_MSG):
            learn_runner.run(_experiment_fn,
                             run_config=run_config,
                             schedule="simple_task")
Esempio n. 5
0
    def test_basic_run_config_uid_check(self):
        expected_run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)

        def _experiment_fn(run_config, hparams):
            del run_config, hparams  # unused.
            # Explicitly use a new run_config.
            new_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR +
                                                  "/123")

            return TestExperiment(config=new_config)

        with self.assertRaisesRegexp(RuntimeError,
                                     _RUN_CONFIG_UID_CHECK_ERR_MSG):
            learn_runner.run(experiment_fn=_experiment_fn,
                             run_config=expected_run_config)
Esempio n. 6
0
    def test_fail_invalid_experiment_config_type(self):
        expected_run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)

        def _experiment_fn(run_config, hparams):
            del run_config, hparams  # unused.
            # Explicitly use a new run_config without `uid` method.
            new_config = core_run_config_lib.RunConfig(model_dir=_MODIR_DIR +
                                                       "/123")

            return TestExperiment(config=new_config)

        with self.assertRaisesRegexp(RuntimeError,
                                     _MISSING_RUN_CONFIG_UID_ERR_MSG):
            learn_runner.run(experiment_fn=_experiment_fn,
                             run_config=expected_run_config)
Esempio n. 7
0
 def test_run_with_custom_schedule(self):
     run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
     self.assertEqual(
         "simple_task, default=None.",
         learn_runner.run(build_experiment_for_run_config,
                          run_config=run_config,
                          schedule="simple_task"))
Esempio n. 8
0
 def test_run_with_explicit_local_run(self):
     run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
     self.assertEqual(
         "local_run-" + _MODIR_DIR,
         learn_runner.run(build_experiment_for_run_config,
                          run_config=run_config,
                          schedule="local_run"))
Esempio n. 9
0
 def test_no_schedule_and_non_distributed_runs_train_and_evaluate(self):
     tf_config = {"cluster": build_non_distributed_cluster_spec()}
     with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
         config = run_config_lib.RunConfig()
         self.assertEqual(
             "train_and_evaluate-" + _MODIR_DIR,
             learn_runner.run(build_experiment_fn_for_output_dir(config),
                              output_dir=_MODIR_DIR))
Esempio n. 10
0
 def test_schedule_from_tf_config_runs_serve_on_ps(self):
     tf_config = {
         "cluster": build_distributed_cluster_spec(),
         "task": {
             "type": run_config_lib.TaskType.PS
         }
     }
     with patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
         config = run_config_lib.RunConfig()
         self.assertEqual(
             "run_std_server-" + _MODIR_DIR,
             learn_runner.run(build_experiment_fn_for_output_dir(config),
                              output_dir=_MODIR_DIR))
Esempio n. 11
0
 def test_schedule_from_tf_config_runs_train_on_worker(self):
     os.environ["TF_CONFIG"] = json.dumps({
         "cluster":
         build_distributed_cluster_spec(),
         "task": {
             "type": run_config_lib.TaskType.WORKER
         }
     })
     # RunConfig constructor will set job_name from TF_CONFIG.
     config = run_config_lib.RunConfig()
     self.assertEqual(
         "train-" + _MODIR_DIR,
         learn_runner.run(build_experiment_fn_for_output_dir(config),
                          output_dir=_MODIR_DIR))
Esempio n. 12
0
 def test_fail_empty_output_dir(self):
     with self.assertRaisesRegexp(ValueError, _MUST_SPECIFY_OUTPUT_DIR_MSG):
         learn_runner.run(build_experiment,
                          output_dir="",
                          schedule="simple_task")
Esempio n. 13
0
 def test_fail_no_output_dir(self):
     run_config = run_config_lib.RunConfig()
     with self.assertRaisesRegexp(ValueError, _MISSING_MODEL_DIR_ERR_MSG):
         learn_runner.run(build_experiment_for_run_config,
                          run_config=run_config,
                          schedule="local_run")
Esempio n. 14
0
 def test_fail_non_callable_task(self):
     run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
     with self.assertRaisesRegexp(TypeError, _NON_CALLABLE_MSG):
         learn_runner.run(build_experiment_for_run_config,
                          run_config=run_config,
                          schedule="default")
Esempio n. 15
0
 def test_fail_no_output_dir(self):
     with self.assertRaisesRegexp(ValueError,
                                  _MUST_SPECIFY_OUTPUT_DIR_OR_CONFIG_MSG):
         learn_runner.run(build_experiment, None, "simple_task")
Esempio n. 16
0
 def test_fail_non_callable(self):
     with self.assertRaisesRegexp(TypeError, _EXP_NOT_CALLABLE_MSG):
         learn_runner.run("not callable", _MODIR_DIR, "simple_test")
Esempio n. 17
0
 def test_fail_non_existent_task(self):
     run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
     with self.assertRaisesRegexp(ValueError, _NON_EXIST_TASK_MSG):
         learn_runner.run(build_experiment_for_run_config,
                          run_config=run_config,
                          schedule="mirage")
Esempio n. 18
0
 def test_fail_non_callable(self):
     run_config = run_config_lib.RunConfig(model_dir=_MODIR_DIR)
     with self.assertRaisesRegexp(TypeError, _EXP_NOT_CALLABLE_MSG):
         learn_runner.run("not callable",
                          run_config=run_config,
                          schedule="simple_task")
Esempio n. 19
0
 def test_fail_not_experiment(self):
     with self.assertRaisesRegexp(TypeError, _NOT_EXP_TYPE_MSG):
         learn_runner.run(build_non_experiment, _MODIR_DIR, "simple_test")
Esempio n. 20
0
 def test_run_with_explicit_local_run(self):
     self.assertEqual(
         "local_run-" + _MODIR_DIR,
         learn_runner.run(build_experiment,
                          output_dir=_MODIR_DIR,
                          schedule="local_run"))
Esempio n. 21
0
 def test_fail_non_existent_task(self):
     with self.assertRaisesRegexp(ValueError, _NON_EXIST_TASK_MSG):
         learn_runner.run(build_experiment, _MODIR_DIR, "mirage")
Esempio n. 22
0
 def test_fail_non_callable_task(self):
     with self.assertRaisesRegexp(TypeError, _NON_CALLABLE_MSG):
         learn_runner.run(build_experiment, _MODIR_DIR, "default")
Esempio n. 23
0
 def test_no_schedule_and_no_config_runs_train_and_evaluate(self):
     self.assertEqual(
         "train_and_evaluate-" + _MODIR_DIR,
         learn_runner.run(build_experiment, output_dir=_MODIR_DIR))
Esempio n. 24
0
 def test_fail_invalid_run_config_type(self):
     run_config = "invalid_run_config"
     with self.assertRaisesRegexp(ValueError, _INVALID_RUN_CONFIG_TYPE_MSG):
         learn_runner.run(build_experiment_for_run_config,
                          run_config=run_config,
                          schedule="local_run")
Esempio n. 25
0
def main(unused_argv):
    learn_runner.run(experiment_fn=_make_experiment_fn,
                     output_dir=FLAGS.output_dir,
                     schedule="train_and_evaluate")
Esempio n. 26
0
 def test_run_with_custom_schedule(self):
     self.assertEqual(
         "simple_task, default=None.",
         learn_runner.run(build_experiment,
                          output_dir=_MODIR_DIR,
                          schedule="simple_task"))