def test_continuous_train_and_eval_with_invalid_predicate_fn(self):
     for est in self._estimators_for_tests():
         ex = Experiment(
             est,
             train_input_fn='train_input',
             eval_input_fn='eval_input')
         with self.assertRaisesRegexp(ValueError,
                                      '`continuous_eval_predicate_fn` must be a callable'):
             ex.continuous_train_and_evaluate(continuous_eval_predicate_fn='fn')
Exemple #2
0
 def test_continuous_train_and_eval_with_invalid_predicate_fn(self):
     for est in self._estimators_for_tests():
         ex = Experiment(est,
                         train_input_fn='train_input',
                         eval_input_fn='eval_input')
         with self.assertRaisesRegexp(
                 ValueError,
                 '`continuous_eval_predicate_fn` must be a callable'):
             ex.continuous_train_and_evaluate(
                 continuous_eval_predicate_fn='fn')
 def test_continuous_train_and_eval(self):
     for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
         noop_hook = _NoopHook()
         export_strategy = make_export_strategy(est, None, exports_to_keep=None)
         ex = Experiment(est, train_input_fn='train_input', eval_input_fn='eval_input',
                         eval_hooks=[noop_hook], train_steps=100, eval_steps=100,
                         export_strategies=export_strategy)
         ex.continuous_train_and_evaluate()
         self.assertEqual(1, est.fit_count)
         self.assertEqual(1, est.eval_count)
         self.assertEqual(1, est.export_count)
         self.assertEqual([noop_hook], est.eval_hooks)
Exemple #4
0
 def test_continuous_train_and_eval(self):
     for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
         noop_hook = _NoopHook()
         export_strategy = make_export_strategy(est,
                                                None,
                                                exports_to_keep=None)
         ex = Experiment(est,
                         train_input_fn='train_input',
                         eval_input_fn='eval_input',
                         eval_hooks=[noop_hook],
                         train_steps=100,
                         eval_steps=100,
                         export_strategies=export_strategy)
         ex.continuous_train_and_evaluate()
         self.assertEqual(1, est.fit_count)
         self.assertEqual(1, est.eval_count)
         self.assertEqual(1, est.export_count)
         self.assertEqual([noop_hook], est.eval_hooks)
    def test_continuous_train_and_eval_with_adapted_steps_per_iteration(self):
        mock_estimator = test.mock.Mock(Estimator)
        type(mock_estimator).model_dir = test.mock.PropertyMock(return_value='test_dir')

        total_steps = 100000000000000
        ex = Experiment(mock_estimator, train_input_fn='train_input', eval_input_fn='eval_input',
                        train_steps=total_steps, train_steps_per_iteration=None)

        def predicate_fn(eval_result):
            # Allows the first invoke only.
            return eval_result is None

        ex.continuous_train_and_evaluate(continuous_eval_predicate_fn=predicate_fn)
        mock_estimator.train.assert_called_once_with(
            input_fn='train_input',
            steps=int(total_steps / 10),
            max_steps=None,
            hooks=[])
    def test_continuous_train_and_eval_with_predicate_fn(self):
        for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
            export_strategy = make_export_strategy(est, None, exports_to_keep=None)
            ex = Experiment(
                est,
                train_input_fn='train_input',
                eval_input_fn='eval_input',
                train_steps=100000000000,  # a value will make `ex` never stops.
                eval_steps=100,
                export_strategies=export_strategy)

            def predicate_fn(eval_result):
                del eval_result  # unused. for fn signature.
                return False

            ex.continuous_train_and_evaluate(continuous_eval_predicate_fn=predicate_fn)
            self.assertEqual(0, est.fit_count)
            self.assertEqual(0, est.eval_count)
            self.assertEqual(1, est.export_count)
Exemple #7
0
    def test_continuous_train_and_eval_with_default_steps_per_iteration(self):
        mock_estimator = test.mock.Mock(Estimator)
        type(mock_estimator).model_dir = test.mock.PropertyMock(
            return_value='test_dir')

        ex = Experiment(mock_estimator,
                        train_input_fn='train_input',
                        eval_input_fn='eval_input',
                        train_steps_per_iteration=None,
                        train_steps=None)

        def predicate_fn(eval_result):
            # Allows the first invoke only.
            return eval_result is None

        ex.continuous_train_and_evaluate(
            continuous_eval_predicate_fn=predicate_fn)
        mock_estimator.train.assert_called_once_with(input_fn='train_input',
                                                     steps=1000,
                                                     max_steps=test.mock.ANY,
                                                     hooks=test.mock.ANY)
Exemple #8
0
    def test_continuous_train_and_eval_with_predicate_fn(self):
        for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
            export_strategy = make_export_strategy(est,
                                                   None,
                                                   exports_to_keep=None)
            ex = Experiment(
                est,
                train_input_fn='train_input',
                eval_input_fn='eval_input',
                train_steps=100000000000,  # a value will make `ex` never stops.
                eval_steps=100,
                export_strategies=export_strategy)

            def predicate_fn(eval_result):
                del eval_result  # unused. for fn signature.
                return False

            ex.continuous_train_and_evaluate(
                continuous_eval_predicate_fn=predicate_fn)
            self.assertEqual(0, est.fit_count)
            self.assertEqual(0, est.eval_count)
            self.assertEqual(1, est.export_count)