コード例 #1
0
    def test_dnn_classifier(self):
        embedding = feature_column_lib.embedding_column(
            feature_column_lib.categorical_column_with_vocabulary_list(
                'wire_cast', ['kima', 'omar', 'stringer']), 8)
        dnn = estimator_lib.DNNClassifier(feature_columns=[embedding],
                                          hidden_units=[3, 1])

        def train_input_fn():
            return dataset_ops.Dataset.from_tensors(({
                'wire_cast': [['omar'], ['kima']]
            }, [[0], [1]])).repeat(3)

        def eval_input_fn():
            return dataset_ops.Dataset.from_tensors(({
                'wire_cast': [['stringer'], ['kima']]
            }, [[0], [1]])).repeat(2)

        evaluator = hooks_lib.InMemoryEvaluatorHook(dnn,
                                                    eval_input_fn,
                                                    name='in-memory')
        dnn.train(train_input_fn, hooks=[evaluator])
        self.assertTrue(os.path.isdir(dnn.eval_dir('in-memory')))
        step_keyword_to_value = summary_step_keyword_to_value_mapping(
            dnn.eval_dir('in-memory'))

        final_metrics = dnn.evaluate(eval_input_fn)
        step = final_metrics[ops.GraphKeys.GLOBAL_STEP]
        for summary_tag in final_metrics:
            if summary_tag == ops.GraphKeys.GLOBAL_STEP:
                continue
            self.assertEqual(final_metrics[summary_tag],
                             step_keyword_to_value[step][summary_tag])
コード例 #2
0
 def test_creates_regular_stop_at_step_hook_for_chief(self):
     # by default an estimator is in chief mode
     dnn = estimator_lib.DNNClassifier(
         feature_columns=[feature_column_lib.numeric_column('x')],
         hidden_units=[3, 1])
     hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)
     self.assertIsInstance(hook, training.StopAtStepHook)
     self.assertEqual(300, hook._last_step)
コード例 #3
0
    def test_creates_checkpoint_hook_for_workers(self):
        class FakeWorkerConfig(estimator_lib.RunConfig):
            @property
            def is_chief(self):
                return False

        dnn = estimator_lib.DNNClassifier(
            feature_columns=[feature_column_lib.numeric_column('x')],
            hidden_units=[3, 1],
            config=FakeWorkerConfig())
        hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)
        self.assertIsInstance(hook, hooks_lib._StopAtCheckpointStepHook)
        self.assertEqual(300, hook._last_step)
        self.assertEqual(dnn.model_dir, hook._model_dir)
コード例 #4
0
    def test_raise_error_with_ps(self):
        tf_config = {
            'cluster': {
                run_config_lib.TaskType.CHIEF: ['host0:0'],
                run_config_lib.TaskType.PS: ['host1:1'],
            },
            'task': {
                'type': run_config_lib.TaskType.CHIEF,
                'index': 0
            }
        }
        with test.mock.patch.dict('os.environ',
                                  {'TF_CONFIG': json.dumps(tf_config)}):
            dnn = estimator_lib.DNNClassifier(
                feature_columns=[feature_column_lib.numeric_column('x')],
                hidden_units=[3, 1])

        def eval_input_fn():
            pass

        with self.assertRaisesRegexp(ValueError,
                                     'supports only single machine'):
            hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn)