Beispiel #1
0
    def test_that_export_fn_is_called(self):
        mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
        mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
        self._set_up_mock_est_to_train_and_evaluate_once(
            mock_est, mock_train_spec)

        def export_fn(estimator, *args, **kwargs):
            del args, kwargs
            estimator.export_fn_was_called = True

        export_strategy = export_strategy_lib.ExportStrategy(
            name='see_whether_export_fn_is_called', export_fn=export_fn)

        eval_spec = training.EvalSpec(input_fn=lambda: 1,
                                      steps=2,
                                      delay_secs=0,
                                      throttle_secs=0,
                                      export_strategies=export_strategy)

        executor = training._TrainingExecutor(mock_est, mock_train_spec,
                                              eval_spec)
        executor.run_evaluator()

        # Verify that export_fn was called on the right estimator.
        self.assertTrue(mock_est.export_fn_was_called)
Beispiel #2
0
    def test_that_export_fn_is_called_with_run_local(self):
        mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
        mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
        mock_train_spec.max_steps = 200
        mock_est.evaluate.return_value = {
            _GLOBAL_STEP_KEY: mock_train_spec.max_steps
        }
        # _validate_hooks would have made sure that train_spec.hooks is [], when
        # None were passed.
        mock_train_spec.hooks = []

        def export_fn(estimator, *args, **kwargs):
            del args, kwargs
            estimator.export_fn_was_called = True

        export_strategy = export_strategy_lib.ExportStrategy(
            name='see_whether_export_fn_is_called', export_fn=export_fn)

        eval_spec = training.EvalSpec(input_fn=lambda: 1,
                                      steps=2,
                                      delay_secs=0,
                                      throttle_secs=213,
                                      export_strategies=export_strategy)

        executor = training._TrainingExecutor(mock_est, mock_train_spec,
                                              eval_spec)
        executor.run_local()

        self.assertTrue(mock_est.export_fn_was_called)
    def testCallsExportFnThatAcceptsEvalResultButNotCheckpoint(self):
        expected_estimator = {}

        def export_fn(estimator, export_path, eval_result):
            del estimator, export_path, eval_result
            raise RuntimeError("Should raise ValueError before this.")

        export_strategy = export_strategy_lib.ExportStrategy(
            name="test", export_fn=export_fn)

        expected_error_message = (
            "An export_fn accepting eval_result must also accept checkpoint_path"
        )

        with self.assertRaisesRegexp(ValueError, expected_error_message):
            export_strategy.export(estimator=expected_estimator,
                                   export_path="expected_path")

        with self.assertRaisesRegexp(ValueError, expected_error_message):
            export_strategy.export(
                estimator=expected_estimator,
                export_path="expected_path",
                checkpoint_path="unexpected_checkpoint_path")

        with self.assertRaisesRegexp(ValueError, expected_error_message):
            export_strategy.export(estimator=expected_estimator,
                                   export_path="expected_path",
                                   eval_result=())

        with self.assertRaisesRegexp(ValueError, expected_error_message):
            export_strategy.export(
                estimator=expected_estimator,
                export_path="expected_path",
                checkpoint_path="unexpected_checkpoint_path",
                eval_result=())
    def testCallsExportFnThatDoesntKnowExtraArguments(self):
        expected_estimator = {}

        def export_fn(estimator, export_path):
            self.assertEqual(expected_estimator, estimator)
            self.assertEqual("expected_path", export_path)

        export_strategy = export_strategy_lib.ExportStrategy(
            name="test", export_fn=export_fn)

        export_strategy.export(estimator=expected_estimator,
                               export_path="expected_path")

        # Also works with additional arguments that `export_fn` doesn't support.
        # The lack of support is detected and the arguments aren't passed.
        export_strategy.export(estimator=expected_estimator,
                               export_path="expected_path",
                               checkpoint_path="unexpected_checkpoint_path")
        export_strategy.export(estimator=expected_estimator,
                               export_path="expected_path",
                               eval_result=())
        export_strategy.export(estimator=expected_estimator,
                               export_path="expected_path",
                               checkpoint_path="unexpected_checkpoint_path",
                               eval_result=())
Beispiel #5
0
  def test_runs_in_a_loop_until_max_steps(self):
    mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
    mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn

    mock_est.times_export_fn_was_called = 0
    def export_fn(estimator, *args, **kwargs):
      del args, kwargs
      estimator.times_export_fn_was_called += 1

    export_strategy = export_strategy_lib.ExportStrategy(
        name='see_whether_export_fn_is_called', export_fn=export_fn)

    train_spec = training.TrainSpec(
        input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
    eval_spec = training.EvalSpec(
        input_fn=lambda: 1,
        hooks=[_FakeHook()],
        throttle_secs=100,
        export_strategies=export_strategy)
    # should be called 3 times.
    mock_est.evaluate.side_effect = [{
        _GLOBAL_STEP_KEY: train_spec.max_steps - 100
    }, {
        _GLOBAL_STEP_KEY: train_spec.max_steps - 50
    }, {
        _GLOBAL_STEP_KEY: train_spec.max_steps
    }]

    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    executor.run_local()

    self.assertEqual(3, mock_est.train.call_count)
    self.assertEqual(3, mock_est.evaluate.call_count)
    self.assertEqual(3, mock_est.times_export_fn_was_called)
Beispiel #6
0
  def test_evaluate_multiple_times(self):
    training_max_step = 200

    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.model_dir = compat.as_bytes(test.get_temp_dir())
    mock_est.evaluate.side_effect = [
        {_GLOBAL_STEP_KEY: training_max_step // 2},
        {_GLOBAL_STEP_KEY: training_max_step}
    ]
    mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']

    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    mock_train_spec.max_steps = training_max_step

    mock_est.times_export_fn_was_called = 0
    def export_fn(estimator, *args, **kwargs):
      del args, kwargs
      estimator.times_export_fn_was_called += 1

    export_strategy = export_strategy_lib.ExportStrategy(
        name='see_whether_export_fn_is_called', export_fn=export_fn)

    eval_spec = training.EvalSpec(
        input_fn=lambda: 1,
        delay_secs=0,
        throttle_secs=0,
        export_strategies=export_strategy)

    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
    executor.run_evaluator()

    self.assertEqual(2, mock_est.evaluate.call_count)
    self.assertEqual(2, mock_est.times_export_fn_was_called)
    def testAcceptsNameAndFn(self):
        def export_fn(estimator, export_path):
            del estimator, export_path

        export_strategy = export_strategy_lib.ExportStrategy(
            name="test", export_fn=export_fn)

        self.assertEqual("test", export_strategy.name)
        self.assertEqual(export_fn, export_strategy.export_fn)
    def testCallsExportFnThatKnowsAboutEvalResultButItsNotGiven(self):
        expected_estimator = {}

        def export_fn(estimator, export_path, checkpoint_path, eval_result):
            self.assertEqual(expected_estimator, estimator)
            self.assertEqual("expected_path", export_path)
            self.assertEqual(None, checkpoint_path)
            self.assertEqual(None, eval_result)

        export_strategy = export_strategy_lib.ExportStrategy(
            name="test", export_fn=export_fn)

        export_strategy.export(estimator=expected_estimator,
                               export_path="expected_path")
    def testCallsExportFnWithEvalResultAndCheckpointPath(self):
        expected_estimator = {}
        expected_eval_result = {}

        def export_fn(estimator, export_path, checkpoint_path, eval_result):
            self.assertEqual(expected_estimator, estimator)
            self.assertEqual("expected_path", export_path)
            self.assertEqual("expected_checkpoint_path", checkpoint_path)
            self.assertEqual(expected_eval_result, eval_result)

        export_strategy = export_strategy_lib.ExportStrategy(
            name="test", export_fn=export_fn)

        export_strategy.export(estimator=expected_estimator,
                               export_path="expected_path",
                               checkpoint_path="expected_checkpoint_path",
                               eval_result=expected_eval_result)
Beispiel #10
0
def _create_fake_export_strategy(name):
    def export_fn(estimator, export_path):
        del estimator, export_path

    return export_strategy_lib.ExportStrategy(name=name, export_fn=export_fn)