def test_that_export_fn_is_called(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) self._set_up_mock_est_to_train_and_evaluate_once( mock_est, mock_train_spec) def export_fn(estimator, *args, **kwargs): del args, kwargs estimator.export_fn_was_called = True export_strategy = export_strategy_lib.ExportStrategy( name='see_whether_export_fn_is_called', export_fn=export_fn) eval_spec = training.EvalSpec(input_fn=lambda: 1, steps=2, delay_secs=0, throttle_secs=0, export_strategies=export_strategy) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) executor.run_evaluator() # Verify that export_fn was called on the right estimator. self.assertTrue(mock_est.export_fn_was_called)
def test_that_export_fn_is_called_with_run_local(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_train_spec.max_steps = 200 mock_est.evaluate.return_value = { _GLOBAL_STEP_KEY: mock_train_spec.max_steps } # _validate_hooks would have made sure that train_spec.hooks is [], when # None were passed. mock_train_spec.hooks = [] def export_fn(estimator, *args, **kwargs): del args, kwargs estimator.export_fn_was_called = True export_strategy = export_strategy_lib.ExportStrategy( name='see_whether_export_fn_is_called', export_fn=export_fn) eval_spec = training.EvalSpec(input_fn=lambda: 1, steps=2, delay_secs=0, throttle_secs=213, export_strategies=export_strategy) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) executor.run_local() self.assertTrue(mock_est.export_fn_was_called)
def testCallsExportFnThatAcceptsEvalResultButNotCheckpoint(self): expected_estimator = {} def export_fn(estimator, export_path, eval_result): del estimator, export_path, eval_result raise RuntimeError("Should raise ValueError before this.") export_strategy = export_strategy_lib.ExportStrategy( name="test", export_fn=export_fn) expected_error_message = ( "An export_fn accepting eval_result must also accept checkpoint_path" ) with self.assertRaisesRegexp(ValueError, expected_error_message): export_strategy.export(estimator=expected_estimator, export_path="expected_path") with self.assertRaisesRegexp(ValueError, expected_error_message): export_strategy.export( estimator=expected_estimator, export_path="expected_path", checkpoint_path="unexpected_checkpoint_path") with self.assertRaisesRegexp(ValueError, expected_error_message): export_strategy.export(estimator=expected_estimator, export_path="expected_path", eval_result=()) with self.assertRaisesRegexp(ValueError, expected_error_message): export_strategy.export( estimator=expected_estimator, export_path="expected_path", checkpoint_path="unexpected_checkpoint_path", eval_result=())
def testCallsExportFnThatDoesntKnowExtraArguments(self): expected_estimator = {} def export_fn(estimator, export_path): self.assertEqual(expected_estimator, estimator) self.assertEqual("expected_path", export_path) export_strategy = export_strategy_lib.ExportStrategy( name="test", export_fn=export_fn) export_strategy.export(estimator=expected_estimator, export_path="expected_path") # Also works with additional arguments that `export_fn` doesn't support. # The lack of support is detected and the arguments aren't passed. export_strategy.export(estimator=expected_estimator, export_path="expected_path", checkpoint_path="unexpected_checkpoint_path") export_strategy.export(estimator=expected_estimator, export_path="expected_path", eval_result=()) export_strategy.export(estimator=expected_estimator, export_path="expected_path", checkpoint_path="unexpected_checkpoint_path", eval_result=())
def test_runs_in_a_loop_until_max_steps(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn mock_est.times_export_fn_was_called = 0 def export_fn(estimator, *args, **kwargs): del args, kwargs estimator.times_export_fn_was_called += 1 export_strategy = export_strategy_lib.ExportStrategy( name='see_whether_export_fn_is_called', export_fn=export_fn) train_spec = training.TrainSpec( input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100, export_strategies=export_strategy) # should be called 3 times. mock_est.evaluate.side_effect = [{ _GLOBAL_STEP_KEY: train_spec.max_steps - 100 }, { _GLOBAL_STEP_KEY: train_spec.max_steps - 50 }, { _GLOBAL_STEP_KEY: train_spec.max_steps }] executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() self.assertEqual(3, mock_est.train.call_count) self.assertEqual(3, mock_est.evaluate.call_count) self.assertEqual(3, mock_est.times_export_fn_was_called)
def test_evaluate_multiple_times(self): training_max_step = 200 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) mock_est.evaluate.side_effect = [ {_GLOBAL_STEP_KEY: training_max_step // 2}, {_GLOBAL_STEP_KEY: training_max_step} ] mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_train_spec.max_steps = training_max_step mock_est.times_export_fn_was_called = 0 def export_fn(estimator, *args, **kwargs): del args, kwargs estimator.times_export_fn_was_called += 1 export_strategy = export_strategy_lib.ExportStrategy( name='see_whether_export_fn_is_called', export_fn=export_fn) eval_spec = training.EvalSpec( input_fn=lambda: 1, delay_secs=0, throttle_secs=0, export_strategies=export_strategy) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) executor.run_evaluator() self.assertEqual(2, mock_est.evaluate.call_count) self.assertEqual(2, mock_est.times_export_fn_was_called)
def testAcceptsNameAndFn(self): def export_fn(estimator, export_path): del estimator, export_path export_strategy = export_strategy_lib.ExportStrategy( name="test", export_fn=export_fn) self.assertEqual("test", export_strategy.name) self.assertEqual(export_fn, export_strategy.export_fn)
def testCallsExportFnThatKnowsAboutEvalResultButItsNotGiven(self): expected_estimator = {} def export_fn(estimator, export_path, checkpoint_path, eval_result): self.assertEqual(expected_estimator, estimator) self.assertEqual("expected_path", export_path) self.assertEqual(None, checkpoint_path) self.assertEqual(None, eval_result) export_strategy = export_strategy_lib.ExportStrategy( name="test", export_fn=export_fn) export_strategy.export(estimator=expected_estimator, export_path="expected_path")
def testCallsExportFnWithEvalResultAndCheckpointPath(self): expected_estimator = {} expected_eval_result = {} def export_fn(estimator, export_path, checkpoint_path, eval_result): self.assertEqual(expected_estimator, estimator) self.assertEqual("expected_path", export_path) self.assertEqual("expected_checkpoint_path", checkpoint_path) self.assertEqual(expected_eval_result, eval_result) export_strategy = export_strategy_lib.ExportStrategy( name="test", export_fn=export_fn) export_strategy.export(estimator=expected_estimator, export_path="expected_path", checkpoint_path="expected_checkpoint_path", eval_result=expected_eval_result)
def _create_fake_export_strategy(name): def export_fn(estimator, export_path): del estimator, export_path return export_strategy_lib.ExportStrategy(name=name, export_fn=export_fn)