def test_invalid_eval_spec(self): estimator = estimator_lib.Estimator(model_fn=lambda features: features) train_spec = training.TrainSpec(input_fn=lambda: 1) invalid_eval_spec = object() with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG): training._TrainingExecutor(estimator, train_spec, invalid_eval_spec)
def test_invalid_estimator(self): invalid_estimator = object() train_spec = training.TrainSpec(input_fn=lambda: 1) eval_spec = training.EvalSpec(input_fn=lambda: 1) with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG): training._TrainingExecutor(invalid_estimator, train_spec, eval_spec)
def test_fail_with_none_task_id(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) mock_est.config.cluster_spec = {'gs': 'dummy'} mock_est.config.master = 'grpc://...' mock_est.config.task_type = 'gs' mock_est.config.task_id = None with self.assertRaisesRegexp(RuntimeError, _INVALID_CONFIG_FOR_STD_SERVER_MSG): training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec).run_ps()
def test_fail_with_empty_task_type(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) mock_est.config.cluster_spec = {'gs': 'dummy'} mock_est.config.master = 'grpc://...' mock_est.config.task_type = '' mock_est.config.task_id = 2 with self.assertRaisesRegexp(RuntimeError, _INVALID_CONFIG_FOR_STD_SERVER_MSG): training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec).run_ps()
def test_evaluate_with_evaluate_spec(self, mock_latest_ckpt): latest_ckpt_path = mock_latest_ckpt.return_value mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) self._set_up_mock_est_to_train_and_evaluate_once( mock_est, mock_train_spec) eval_spec = training.EvalSpec(input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval', delay_secs=0, throttle_secs=0) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) executor.run_evaluator() mock_est.evaluate.assert_called_with(name='cont_eval', input_fn=eval_spec.input_fn, steps=eval_spec.steps, checkpoint_path=latest_ckpt_path, hooks=eval_spec.hooks) self.assertFalse(mock_est.train.called)
def test_that_export_fn_is_called_with_run_local(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_train_spec.max_steps = 200 mock_est.evaluate.return_value = { _GLOBAL_STEP_KEY: mock_train_spec.max_steps } # _validate_hooks would have made sure that train_spec.hooks is [], when # None were passed. mock_train_spec.hooks = [] def export_fn(estimator, *args, **kwargs): del args, kwargs estimator.export_fn_was_called = True export_strategy = export_strategy_lib.ExportStrategy( name='see_whether_export_fn_is_called', export_fn=export_fn) eval_spec = training.EvalSpec(input_fn=lambda: 1, steps=2, delay_secs=0, throttle_secs=213, export_strategies=export_strategy) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) executor.run_local() self.assertTrue(mock_est.export_fn_was_called)
def test_runs_in_a_loop_until_max_steps(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn mock_est.times_export_fn_was_called = 0 def export_fn(estimator, *args, **kwargs): del args, kwargs estimator.times_export_fn_was_called += 1 export_strategy = export_strategy_lib.ExportStrategy( name='see_whether_export_fn_is_called', export_fn=export_fn) train_spec = training.TrainSpec( input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100, export_strategies=export_strategy) # should be called 3 times. mock_est.evaluate.side_effect = [{ _GLOBAL_STEP_KEY: train_spec.max_steps - 100 }, { _GLOBAL_STEP_KEY: train_spec.max_steps - 50 }, { _GLOBAL_STEP_KEY: train_spec.max_steps }] executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() self.assertEqual(3, mock_est.train.call_count) self.assertEqual(3, mock_est.evaluate.call_count) self.assertEqual(3, mock_est.times_export_fn_was_called)
def test_skip_evaluation_due_to_ckpt(self, mock_latest_ckpt): training_max_step = 200 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.evaluate.side_effect = [{ _GLOBAL_STEP_KEY: training_max_step // 2 }, { _GLOBAL_STEP_KEY: training_max_step }] mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_train_spec.max_steps = training_max_step self._set_up_mock_est_to_train_and_evaluate_once( mock_est, mock_train_spec) # First two items are invalid, next two items are same. mock_latest_ckpt.side_effect = [None, '', 'same', 'same', 'path_2'] eval_spec = training.EvalSpec(input_fn=lambda: 1, delay_secs=0, throttle_secs=0) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) with test.mock.patch.object(logging, 'warning') as mock_log: executor.run_evaluator() # Three checkpoint paths are invalid. self.assertEqual(5, mock_latest_ckpt.call_count) self.assertEqual(2, mock_est.evaluate.call_count) # Two warning logs are expected (last warning time is reset after a # successuful evaluation) self.assertEqual(2, mock_log.call_count)
def test_that_export_fn_is_called(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) self._set_up_mock_est_to_train_and_evaluate_once( mock_est, mock_train_spec) def export_fn(estimator, *args, **kwargs): del args, kwargs estimator.export_fn_was_called = True export_strategy = export_strategy_lib.ExportStrategy( name='see_whether_export_fn_is_called', export_fn=export_fn) eval_spec = training.EvalSpec(input_fn=lambda: 1, steps=2, delay_secs=0, throttle_secs=0, export_strategies=export_strategy) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) executor.run_evaluator() # Verify that export_fn was called on the right estimator. self.assertTrue(mock_est.export_fn_was_called)
def test_that_export_is_called_with_run_local(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_train_spec.max_steps = 200 mock_est.evaluate.return_value = { _GLOBAL_STEP_KEY: mock_train_spec.max_steps } # _validate_hooks would have made sure that train_spec.hooks is [], when # None were passed. mock_train_spec.hooks = [] def export(estimator, *args, **kwargs): del args, kwargs estimator.export_was_called = True exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) exporter.name = 'see_whether_export_is_called' exporter.export = export eval_spec = training.EvalSpec( input_fn=lambda: 1, steps=2, start_delay_secs=0, throttle_secs=213, exporters=exporter) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) executor.run_local() self.assertTrue(mock_est.export_was_called)
def test_evaluate_multiple_times(self): training_max_step = 200 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) mock_est.evaluate.side_effect = [ {_GLOBAL_STEP_KEY: training_max_step // 2}, {_GLOBAL_STEP_KEY: training_max_step} ] mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_train_spec.max_steps = training_max_step exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) exporter.name = 'see_how_many_times_export_is_called' eval_spec = training.EvalSpec( input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0, exporters=exporter) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) executor.run_evaluator() self.assertEqual(2, mock_est.evaluate.call_count) self.assertEqual(2, exporter.export.call_count)
def test_evaluate_multiple_times(self): training_max_step = 200 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) mock_est.evaluate.side_effect = [ {_GLOBAL_STEP_KEY: training_max_step // 2}, {_GLOBAL_STEP_KEY: training_max_step} ] mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_train_spec.max_steps = training_max_step mock_est.times_export_fn_was_called = 0 def export_fn(estimator, *args, **kwargs): del args, kwargs estimator.times_export_fn_was_called += 1 export_strategy = export_strategy_lib.ExportStrategy( name='see_whether_export_fn_is_called', export_fn=export_fn) eval_spec = training.EvalSpec( input_fn=lambda: 1, delay_secs=0, throttle_secs=0, export_strategies=export_strategy) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) executor.run_evaluator() self.assertEqual(2, mock_est.evaluate.call_count) self.assertEqual(2, mock_est.times_export_fn_was_called)
def test_train_and_evaluate_args(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') mock_est.latest_checkpoint.return_value = 'checkpoint_path/' train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec(input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval') mock_est.evaluate.return_value = { _GLOBAL_STEP_KEY: train_spec.max_steps } executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() mock_est.evaluate.assert_called_with( name=eval_spec.name, input_fn=eval_spec.input_fn, steps=eval_spec.steps, checkpoint_path='checkpoint_path/', hooks=eval_spec.hooks) train_args = mock_est.train.call_args[1] self.assertEqual(list(train_spec.hooks), list(train_args['hooks'][:-1])) self.assertEqual(train_spec.input_fn, train_args['input_fn']) self.assertEqual(train_spec.max_steps, train_args['max_steps'])
def testRequiredArgumentsSet(self): estimator = estimator_lib.Estimator(model_fn=lambda features: features) train_spec = training.TrainSpec(input_fn=lambda: 1) eval_spec = training.EvalSpec(input_fn=lambda: 1) executor = training._TrainingExecutor(estimator, train_spec, eval_spec) self.assertEqual(estimator, executor.estimator)
def test_skip_evaluation_due_to_ckpt(self): training_max_step = 200 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.evaluate.side_effect = [ {_GLOBAL_STEP_KEY: training_max_step // 2}, {_GLOBAL_STEP_KEY: training_max_step} ] mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_train_spec.max_steps = training_max_step self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) # First two items are invalid, next two items are same. mock_est.latest_checkpoint.side_effect = [ None, '', 'same', 'same', 'path_2' ] eval_spec = training.EvalSpec( input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) with test.mock.patch.object(logging, 'warning') as mock_log: executor.run_evaluator() # Three checkpoint paths are invalid. self.assertEqual(5, mock_est.latest_checkpoint.call_count) self.assertEqual(2, mock_est.evaluate.call_count) # Two warning logs are expected (last warning time is reset after a # successuful evaluation) self.assertEqual(2, mock_log.call_count)
def test_that_export_is_called(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) def export(estimator, *args, **kwargs): del args, kwargs estimator.export_was_called = True exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) exporter.name = 'see_whether_export_is_called' exporter.export = export eval_spec = training.EvalSpec( input_fn=lambda: 1, steps=2, start_delay_secs=0, throttle_secs=0, exporters=exporter) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) executor.run_evaluator() # Verify that export was called on the right estimator. self.assertTrue(mock_est.export_was_called)
def test_train_with_train_spec(self, mock_server, unused_mock_sleep): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.config = self._run_config train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) mock_server_instance = mock_server.return_value executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec) self._run_task(executor) mock_server.assert_called_with(mock_est.config.cluster_spec, job_name=mock_est.config.task_type, task_index=mock_est.config.task_id, config=test.mock.ANY, start=False) self.assertTrue(mock_server_instance.start.called) mock_est.train.assert_called_with(input_fn=train_spec.input_fn, max_steps=train_spec.max_steps, hooks=train_spec.hooks) mock_est.evaluate.assert_not_called() mock_est.export_savedmodel.assert_not_called()
def test_runs_in_a_loop_until_max_steps(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn mock_est.times_export_was_called = 0 def export(estimator, *args, **kwargs): del args, kwargs estimator.times_export_was_called += 1 exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) exporter.name = 'see_how_many_times_export_is_called' exporter.export = export train_spec = training.TrainSpec( input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100, exporters=exporter) # should be called 3 times. mock_est.evaluate.side_effect = [{ _GLOBAL_STEP_KEY: train_spec.max_steps - 100 }, { _GLOBAL_STEP_KEY: train_spec.max_steps - 50 }, { _GLOBAL_STEP_KEY: train_spec.max_steps }] executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() self.assertEqual(3, mock_est.train.call_count) self.assertEqual(3, mock_est.evaluate.call_count) self.assertEqual(3, mock_est.times_export_was_called)
def test_train_with_train_spec(self, mock_server, unused_mock_sleep): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.config = self._run_config train_spec = training.TrainSpec( input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) mock_server_instance = mock_server.return_value executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec) self._run_task(executor) mock_server.assert_called_with( mock_est.config.cluster_spec, job_name=mock_est.config.task_type, task_index=mock_est.config.task_id, config=test.mock.ANY, start=False) self.assertTrue(mock_server_instance.start.called) mock_est.train.assert_called_with(input_fn=train_spec.input_fn, max_steps=train_spec.max_steps, hooks=train_spec.hooks, saving_listeners=test.mock.ANY) mock_est.evaluate.assert_not_called() mock_est.export_savedmodel.assert_not_called()
def test_errors_out_if_throttle_secs_is_zero(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) train_spec = training.TrainSpec(input_fn=lambda: 1) eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=0) executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) with self.assertRaisesRegexp(ValueError, 'throttle_secs'): executor.run_local()
def test_errors_out_if_evaluate_returns_non_dict(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) train_spec = training.TrainSpec(input_fn=lambda: 1) eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) mock_est.evaluate.return_value = 123 executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR): executor.run_local()
def test_errors_out_if_evaluate_returns_empty_dict(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) train_spec = training.TrainSpec(input_fn=lambda: 1) eval_spec = training.EvalSpec(input_fn=(lambda: 1), start_delay_secs=0, throttle_secs=0) mock_est.evaluate.return_value = {} executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) with self.assertRaisesRegexp(RuntimeError, _INVALID_EMPTY_EVAL_RESULT_ERR): executor.run_evaluator()
def test_errors_out_if_evaluate_returns_dict_without_global_step(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) train_spec = training.TrainSpec(input_fn=lambda: 1) eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123) mock_est.evaluate.return_value = {'loss': 123} executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) with self.assertRaisesRegexp(RuntimeError, _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR): executor.run_local()
def test_no_delay_for_master(self, _): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.config = self._run_config mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) executor = training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec) with test.mock.patch.object(time, 'sleep') as mock_sleep: self._run_task(executor) mock_sleep.assert_not_called()
def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.config = self._run_config mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) executor = training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec) tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)} with test.mock.patch.dict('os.environ', tf_config): self._run_task(executor) mock_server.assert_not_called()
def test_send_stop_at_secs_to_train(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) train_spec = training.TrainSpec( input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() stop_hook = mock_est.train.call_args[1]['hooks'][-1] self.assertIsInstance(stop_hook, training._StopAtSecsHook) self.assertEqual(eval_spec.throttle_secs, stop_hook._stop_after_secs)
def test_delay_for_worker(self, _): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.config = self._run_config mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) executor = training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec) expected_secs = (self._run_config.task_id + 1) * _DELAY_SECS_PER_WORKER with test.mock.patch.object(time, 'sleep') as mock_sleep: mock_sleep.side_effect = lambda s: self.assertEqual(expected_secs, s) self._run_task(executor) self.assertTrue(mock_sleep.called)
def test_send_stop_at_secs_to_train(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn train_spec = training.TrainSpec( input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() stop_hook = mock_est.train.call_args[1]['hooks'][-1] self.assertIsInstance(stop_hook, training._StopAtSecsHook) self.assertEqual(eval_spec.throttle_secs, stop_hook._stop_after_secs)
def test_fail_with_empty_master(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) mock_est.config = test.mock.PropertyMock(spec=run_config_lib.RunConfig) mock_est.config.cluster_spec = {'worker': 'dummy'} mock_est.config.master = '' mock_est.config.task_type = 'worker' mock_est.config.task_id = 2 with self.assertRaisesRegexp(RuntimeError, _INVALID_CONFIG_FOR_STD_SERVER_MSG): self._run_task(training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))
def test_sleep_delay_secs(self): training_max_step = 200 delay_secs = 123 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: training_max_step} mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_train_spec.max_steps = training_max_step eval_spec = training.EvalSpec( input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval', delay_secs=delay_secs, throttle_secs=0) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) with test.mock.patch.object(time, 'sleep') as mock_sleep: executor.run_evaluator() mock_sleep.assert_called_with(delay_secs) self.assertTrue(mock_est.evaluate.called)
def test_evaluate_multiple_times(self): training_max_step = 200 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.evaluate.side_effect = [ {_GLOBAL_STEP_KEY: training_max_step // 2}, {_GLOBAL_STEP_KEY: training_max_step} ] mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_train_spec.max_steps = training_max_step eval_spec = training.EvalSpec( input_fn=lambda: 1, delay_secs=0, throttle_secs=0) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) executor.run_evaluator() self.assertEqual(2, mock_est.evaluate.call_count)
def test_throttle_secs(self, mock_sleep, mock_time): throttle_secs = 123 operation_secs = 12 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) eval_spec = training.EvalSpec( input_fn=lambda: 1, delay_secs=0, throttle_secs=throttle_secs) mock_time.side_effect = [921, 921 + operation_secs] executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) # Disable logging as it calls time.time also. with test.mock.patch.object(logging, 'info'): executor.run_evaluator() mock_sleep.assert_called_with(throttle_secs - operation_secs) self.assertTrue(mock_est.evaluate.called)
def test_throttle_secs(self, mock_sleep, mock_time): throttle_secs = 123 operation_secs = 12 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) eval_spec = training.EvalSpec( input_fn=lambda: 1, start_delay_secs=0, throttle_secs=throttle_secs) mock_time.side_effect = [921, 921 + operation_secs] executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) # Disable logging as it calls time.time also. with test.mock.patch.object(logging, 'info'): executor.run_evaluator() mock_sleep.assert_called_with(throttle_secs - operation_secs) self.assertTrue(mock_est.evaluate.called)
def test_sleep_delay_secs(self): training_max_step = 200 delay_secs = 123 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: training_max_step} mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_train_spec.max_steps = training_max_step eval_spec = training.EvalSpec( input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval', delay_secs=delay_secs, throttle_secs=0) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) with test.mock.patch.object(time, 'sleep') as mock_sleep: executor.run_evaluator() mock_sleep.assert_called_with(delay_secs) self.assertTrue(mock_est.evaluate.called)
def test_handles_no_new_checkpoint_found(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') mock_est.latest_checkpoint.return_value = ( 'no_new_checkpoints_after_the_first_train_step') train_spec = training.TrainSpec( input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) # It was going to be called 3 times. mock_est.evaluate.side_effect = [{ _GLOBAL_STEP_KEY: train_spec.max_steps - 100 }, { _GLOBAL_STEP_KEY: train_spec.max_steps - 50 }, { _GLOBAL_STEP_KEY: train_spec.max_steps }] executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG): executor.run_local()
def test_runs_in_a_loop_until_max_steps(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) train_spec = training.TrainSpec( input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) # should be called 3 times. mock_est.evaluate.side_effect = [{ _GLOBAL_STEP_KEY: train_spec.max_steps - 100 }, { _GLOBAL_STEP_KEY: train_spec.max_steps - 50 }, { _GLOBAL_STEP_KEY: train_spec.max_steps }] executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() self.assertEqual(3, mock_est.train.call_count) self.assertEqual(3, mock_est.evaluate.call_count)
def test_evaluate_with_evaluate_spec(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.latest_checkpoint.return_value = 'latest_it_is' mock_train_spec = test.mock.Mock(spec=training.TrainSpec) self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec) eval_spec = training.EvalSpec( input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval', start_delay_secs=0, throttle_secs=0) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) executor.run_evaluator() mock_est.evaluate.assert_called_with( name='cont_eval', input_fn=eval_spec.input_fn, steps=eval_spec.steps, checkpoint_path='latest_it_is', hooks=eval_spec.hooks) self.assertFalse(mock_est.train.called)
def test_train_and_evaluate_args(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) train_spec = training.TrainSpec( input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec( input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval') mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() mock_est.evaluate.assert_called_with( name=eval_spec.name, input_fn=eval_spec.input_fn, steps=eval_spec.steps, hooks=eval_spec.hooks) train_args = mock_est.train.call_args[1] self.assertEqual(list(train_spec.hooks), list(train_args['hooks'][:-1])) self.assertEqual(train_spec.input_fn, train_args['input_fn']) self.assertEqual(train_spec.max_steps, train_args['max_steps'])
def test_std_server(self, mock_server): mock_server_instance = test.mock.Mock() mock_server.return_value = mock_server_instance mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.config = _create_run_config_with_cluster_spec( _TF_CONFIG_FOR_PS) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) executor = training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec) executor.run_ps() mock_server.assert_called_with(mock_est.config.cluster_spec, job_name=mock_est.config.task_type, task_index=mock_est.config.task_id, config=test.mock.ANY, start=False) self.assertTrue(mock_server_instance.start.called) self.assertTrue(mock_server_instance.join.called)
def test_std_server(self, mock_server): mock_server_instance = test.mock.Mock() mock_server.return_value = mock_server_instance mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.config = _create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS) mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_eval_spec = test.mock.Mock(spec=training.EvalSpec) executor = training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec) executor.run_ps() mock_server.assert_called_with( mock_est.config.cluster_spec, job_name=mock_est.config.task_type, task_index=mock_est.config.task_id, config=test.mock.ANY, start=False) self.assertTrue(mock_server_instance.start.called) self.assertTrue(mock_server_instance.join.called)
def test_evaluate_multiple_times(self, mock_latest_ckpt): training_max_step = 200 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) mock_est.evaluate.side_effect = [{ _GLOBAL_STEP_KEY: training_max_step // 2 }, { _GLOBAL_STEP_KEY: training_max_step }] mock_latest_ckpt.side_effect = ['path_1', 'path_2'] mock_train_spec = test.mock.Mock(spec=training.TrainSpec) mock_train_spec.max_steps = training_max_step eval_spec = training.EvalSpec(input_fn=lambda: 1, delay_secs=0, throttle_secs=0) executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) executor.run_evaluator() self.assertEqual(2, mock_est.evaluate.call_count)
def test_run_master_triggers_evaluate(self, _): def estimator_train(saving_listeners, *args, **kwargs): # There shalt be a saving_listener. Estimator is going to call # `after_save`. del args, kwargs saving_listeners[0].after_save(session=None, global_step_value=None) mock_est = test.mock.Mock( spec=estimator_lib.Estimator, model_dir='path/', train=estimator_train) mock_est.latest_checkpoint.return_value = 'checkpoint_path/' mock_est.config = self._run_config def export(estimator, *args, **kwargs): del args, kwargs estimator.export_was_called = True exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) exporter.name = 'see_whether_export_is_called' exporter.export = export train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) eval_spec = training.EvalSpec( input_fn=lambda: 1, steps=2, exporters=exporter) mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps} executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_master() mock_est.evaluate.assert_called_with( name=eval_spec.name, input_fn=eval_spec.input_fn, steps=eval_spec.steps, checkpoint_path='checkpoint_path/', hooks=eval_spec.hooks) self.assertTrue(mock_est.export_was_called)
def test_runs_in_a_loop_until_max_steps(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator) train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) eval_spec = training.EvalSpec(input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100) # should be called 3 times. mock_est.evaluate.side_effect = [{ _GLOBAL_STEP_KEY: train_spec.max_steps - 100 }, { _GLOBAL_STEP_KEY: train_spec.max_steps - 50 }, { _GLOBAL_STEP_KEY: train_spec.max_steps }] executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) executor.run_local() self.assertEqual(3, mock_est.train.call_count) self.assertEqual(3, mock_est.evaluate.call_count)