Exemplo n.º 1
0
    def test_dnn_classifier(self):
        embedding = feature_column_lib.embedding_column(
            feature_column_lib.categorical_column_with_vocabulary_list(
                'wire_cast', ['kima', 'omar', 'stringer']), 8)
        dnn = estimator_lib.DNNClassifier(feature_columns=[embedding],
                                          hidden_units=[3, 1])

        def train_input_fn():
            return dataset_ops.Dataset.from_tensors(({
                'wire_cast': [['omar'], ['kima']]
            }, [[0], [1]])).repeat(3)

        def eval_input_fn():
            return dataset_ops.Dataset.from_tensors(({
                'wire_cast': [['stringer'], ['kima']]
            }, [[0], [1]])).repeat(2)

        evaluator = hooks_lib.InMemoryEvaluatorHook(dnn,
                                                    eval_input_fn,
                                                    name='in-memory')
        dnn.train(train_input_fn, hooks=[evaluator])
        self.assertTrue(os.path.isdir(dnn.eval_dir('in-memory')))
        step_keyword_to_value = summary_step_keyword_to_value_mapping(
            dnn.eval_dir('in-memory'))

        final_metrics = dnn.evaluate(eval_input_fn)
        step = final_metrics[ops.GraphKeys.GLOBAL_STEP]
        for summary_tag in final_metrics:
            if summary_tag == ops.GraphKeys.GLOBAL_STEP:
                continue
            self.assertEqual(final_metrics[summary_tag],
                             step_keyword_to_value[step][summary_tag])
Exemplo n.º 2
0
 def test_creates_regular_stop_at_step_hook_for_chief(self):
     # by default an estimator is in chief mode
     dnn = estimator_lib.DNNClassifier(
         feature_columns=[feature_column_lib.numeric_column('x')],
         hidden_units=[3, 1])
     hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)
     self.assertIsInstance(hook, training.StopAtStepHook)
     self.assertEqual(300, hook._last_step)
Exemplo n.º 3
0
    def test_creates_checkpoint_hook_for_workers(self):
        class FakeWorkerConfig(estimator_lib.RunConfig):
            @property
            def is_chief(self):
                return False

        dnn = estimator_lib.DNNClassifier(
            feature_columns=[feature_column_lib.numeric_column('x')],
            hidden_units=[3, 1],
            config=FakeWorkerConfig())
        hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)
        self.assertIsInstance(hook, hooks_lib._StopAtCheckpointStepHook)
        self.assertEqual(300, hook._last_step)
        self.assertEqual(dnn.model_dir, hook._model_dir)
Exemplo n.º 4
0
  def test_raise_error_with_ps(self):
    tf_config = {
        'cluster': {
            run_config_lib.TaskType.CHIEF: ['host0:0'],
            run_config_lib.TaskType.PS: ['host1:1'],
        },
        'task': {
            'type': run_config_lib.TaskType.CHIEF,
            'index': 0
        }
    }
    with test.mock.patch.dict('os.environ',
                              {'TF_CONFIG': json.dumps(tf_config)}):
      dnn = estimator_lib.DNNClassifier(
          feature_columns=[feature_column_lib.numeric_column('x')],
          hidden_units=[3, 1])

    def eval_input_fn():
      pass

    with self.assertRaisesRegexp(ValueError, 'supports only single machine'):
      hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn)