コード例 #1
0
ファイル: util_test.py プロジェクト: tirkarthi/estimator
    def test_parse_input_fn_result_invalid(self, dataset_class):
        def _input_fn():
            features = np.expand_dims(np.arange(100), 0)
            labels = np.expand_dims(np.arange(100, 200), 0)
            return dataset_class.from_tensor_slices((features, labels, labels))

        with self.assertRaisesRegexp(ValueError, 'input_fn should return'):
            util.parse_input_fn_result(_input_fn())
コード例 #2
0
def build_supervised_input_receiver_fn_from_input_fn(input_fn, **input_fn_args):
  """Get a function that returns a SupervisedInputReceiver matching an input_fn.

  Note that this function calls the input_fn in a local graph in order to
  extract features and labels. Placeholders are then created from those
  features and labels in the default graph.

  Args:
    input_fn: An Estimator input_fn, which is a function that returns one of:

      * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
          tuple (features, labels) with same constraints as below.
      * A tuple (features, labels): Where `features` is a `Tensor` or a
        dictionary of string feature name to `Tensor` and `labels` is a
        `Tensor` or a dictionary of string label name to `Tensor`. Both
        `features` and `labels` are consumed by `model_fn`. They should
        satisfy the expectation of `model_fn` from inputs.

    **input_fn_args: set of kwargs to be passed to the input_fn. Note that
      these will not be checked or validated here, and any errors raised by
      the input_fn will be thrown to the top.

  Returns:
    A function taking no arguments that, when called, returns a
    SupervisedInputReceiver. This function can be passed in as part of the
    input_receiver_map when exporting SavedModels from Estimator with multiple
    modes.
  """
  # Wrap the input_fn call in a graph to prevent sullying the default namespace
  with ops.Graph().as_default():
    result = input_fn(**input_fn_args)
    features, labels, _ = util.parse_input_fn_result(result)
  # Placeholders are created back in the default graph.
  return build_raw_supervised_input_receiver_fn(features, labels)
コード例 #3
0
ファイル: util_test.py プロジェクト: tirkarthi/estimator
    def test_parse_input_fn_result_features_only(self):
        def _input_fn():
            return tf.constant(np.arange(100))

        features, labels, hooks = util.parse_input_fn_result(_input_fn())

        with self.cached_session() as sess:
            vals = sess.run([features])

        self.assertAllEqual(vals[0], np.arange(100))
        self.assertEqual(labels, None)
        self.assertEqual(hooks, [])
コード例 #4
0
ファイル: util_test.py プロジェクト: tirkarthi/estimator
    def test_parse_input_fn_result_features_only_dataset(self, dataset_class):
        def _input_fn():
            features = np.expand_dims(np.arange(100), 0)
            return dataset_class.from_tensor_slices(features)

        features, labels, hooks = util.parse_input_fn_result(_input_fn())

        with tf.compat.v1.train.MonitoredSession(hooks=hooks) as sess:
            vals = sess.run([features])

        self.assertAllEqual(vals[0], np.arange(100))
        self.assertEqual(labels, None)
        self.assertIsInstance(hooks[0], util._DatasetInitializerHook)
コード例 #5
0
ファイル: util_test.py プロジェクト: tirkarthi/estimator
    def test_parse_input_fn_result_tuple(self):
        def _input_fn():
            features = tf.constant(np.arange(100))
            labels = tf.constant(np.arange(100, 200))
            return features, labels

        features, labels, hooks = util.parse_input_fn_result(_input_fn())

        with self.cached_session() as sess:
            vals = sess.run([features, labels])

        self.assertAllEqual(vals[0], np.arange(100))
        self.assertAllEqual(vals[1], np.arange(100, 200))
        self.assertEqual(hooks, [])
コード例 #6
0
  def test_parse_input_fn_result_dataset(self):
    def _input_fn():
      features = np.expand_dims(np.arange(100), 0)
      labels = np.expand_dims(np.arange(100, 200), 0)
      return dataset_ops.Dataset.from_tensor_slices((features, labels))

    features, labels, hooks = util.parse_input_fn_result(_input_fn())

    with training.MonitoredSession(hooks=hooks) as sess:
      vals = sess.run([features, labels])

    self.assertAllEqual(vals[0], np.arange(100))
    self.assertAllEqual(vals[1], np.arange(100, 200))
    self.assertIsInstance(hooks[0], util._DatasetInitializerHook)
コード例 #7
0
    def test_parse_input_fn_result_mimic_dataset(self):
        class MimicIterator(object):
            @property
            def initializer(self):
                return {}

            def get_next(self):
                features = np.arange(100)
                labels = np.arange(100, 200)
                return ops.convert_to_tensor(features), ops.convert_to_tensor(
                    labels)

        class MimicDataset(object):
            def make_initializable_iterator(self):
                return MimicIterator()

        features, labels, hooks = util.parse_input_fn_result(MimicDataset())

        with training.MonitoredSession(hooks=hooks) as sess:
            vals = sess.run([features, labels])

        self.assertAllEqual(vals[0], np.arange(100))
        self.assertAllEqual(vals[1], np.arange(100, 200))
        self.assertIsInstance(hooks[0], util._DatasetInitializerHook)