Exemplo n.º 1
0
  def test_parse_input_fn_result_invalid(self):
    def _input_fn():
      features = np.expand_dims(np.arange(100), 0)
      labels = np.expand_dims(np.arange(100, 200), 0)
      return dataset_ops.Dataset.from_tensor_slices((features, labels, labels))

    with self.assertRaisesRegexp(ValueError, 'input_fn should return'):
      util.parse_input_fn_result(_input_fn())
Exemplo n.º 2
0
    def test_parse_input_fn_result_invalid(self):
        def _input_fn():
            features = np.expand_dims(np.arange(100), 0)
            labels = np.expand_dims(np.arange(100, 200), 0)
            return dataset_ops.Dataset.from_tensor_slices(
                (features, labels, labels))

        with self.assertRaisesRegexp(ValueError, 'input_fn should return'):
            util.parse_input_fn_result(_input_fn())
Exemplo n.º 3
0
def build_supervised_input_receiver_fn_from_input_fn(input_fn, **input_fn_args):
  """Get a function that returns a SupervisedInputReceiver matching an input_fn.

  Note that this function calls the input_fn in a local graph in order to
  extract features and labels. Placeholders are then created from those
  features and labels in the default graph.

  Args:
    input_fn: An Estimator input_fn, which is a function that returns one of:

      * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
          tuple (features, labels) with same constraints as below.
      * A tuple (features, labels): Where `features` is a `Tensor` or a
        dictionary of string feature name to `Tensor` and `labels` is a
        `Tensor` or a dictionary of string label name to `Tensor`. Both
        `features` and `labels` are consumed by `model_fn`. They should
        satisfy the expectation of `model_fn` from inputs.

    **input_fn_args: set of kwargs to be passed to the input_fn. Note that
      these will not be checked or validated here, and any errors raised by
      the input_fn will be thrown to the top.

  Returns:
    A function taking no arguments that, when called, returns a
    SupervisedInputReceiver. This function can be passed in as part of the
    input_receiver_map when exporting SavedModels from Estimator with multiple
    modes.
  """
  # Wrap the input_fn call in a graph to prevent sullying the default namespace
  with ops.Graph().as_default():
    result = input_fn(**input_fn_args)
    features, labels, _ = util.parse_input_fn_result(result)
  # Placeholders are created back in the default graph.
  return build_raw_supervised_input_receiver_fn(features, labels)
Exemplo n.º 4
0
def build_supervised_input_receiver_fn_from_input_fn(input_fn,
                                                     **input_fn_args):
    """Get a function that returns a SupervisedInputReceiver matching an input_fn.

  Note that this function calls the input_fn in a local graph in order to
  extract features and labels. Placeholders are then created from those
  features and labels in the default graph.

  Args:
    input_fn: An Estimator input_fn, which is a function that returns one of:

      * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
          tuple (features, labels) with same constraints as below.
      * A tuple (features, labels): Where `features` is a `Tensor` or a
        dictionary of string feature name to `Tensor` and `labels` is a
        `Tensor` or a dictionary of string label name to `Tensor`. Both
        `features` and `labels` are consumed by `model_fn`. They should
        satisfy the expectation of `model_fn` from inputs.

    **input_fn_args: set of kwargs to be passed to the input_fn. Note that
      these will not be checked or validated here, and any errors raised by
      the input_fn will be thrown to the top.

  Returns:
    A function taking no arguments that, when called, returns a
    SupervisedInputReceiver. This function can be passed in as part of the
    input_receiver_map when exporting SavedModels from Estimator with multiple
    modes.
  """
    # Wrap the input_fn call in a graph to prevent sullying the default namespace
    with ops.Graph().as_default():
        result = input_fn(**input_fn_args)
        features, labels, _ = util.parse_input_fn_result(result)
    # Placeholders are created back in the default graph.
    return build_raw_supervised_input_receiver_fn(features, labels)
Exemplo n.º 5
0
  def test_parse_input_fn_result_features_only(self):
    def _input_fn():
      return constant_op.constant(np.arange(100))

    features, labels, hooks = util.parse_input_fn_result(_input_fn())

    with self.cached_session() as sess:
      vals = sess.run([features])

    self.assertAllEqual(vals[0], np.arange(100))
    self.assertEqual(labels, None)
    self.assertEqual(hooks, [])
Exemplo n.º 6
0
    def test_parse_input_fn_result_features_only(self):
        def _input_fn():
            return constant_op.constant(np.arange(100))

        features, labels, hooks = util.parse_input_fn_result(_input_fn())

        with self.test_session() as sess:
            vals = sess.run([features])

        self.assertAllEqual(vals[0], np.arange(100))
        self.assertEqual(labels, None)
        self.assertEqual(hooks, [])
Exemplo n.º 7
0
  def test_parse_input_fn_result_features_only_dataset(self):
    def _input_fn():
      features = np.expand_dims(np.arange(100), 0)
      return dataset_ops.Dataset.from_tensor_slices(features)

    features, labels, hooks = util.parse_input_fn_result(_input_fn())

    with training.MonitoredSession(hooks=hooks) as sess:
      vals = sess.run([features])

    self.assertAllEqual(vals[0], np.arange(100))
    self.assertEqual(labels, None)
    self.assertIsInstance(hooks[0], util._DatasetInitializerHook)
Exemplo n.º 8
0
    def test_parse_input_fn_result_features_only_dataset(self):
        def _input_fn():
            features = np.expand_dims(np.arange(100), 0)
            return dataset_ops.Dataset.from_tensor_slices(features)

        features, labels, hooks = util.parse_input_fn_result(_input_fn())

        with training.MonitoredSession(hooks=hooks) as sess:
            vals = sess.run([features])

        self.assertAllEqual(vals[0], np.arange(100))
        self.assertEqual(labels, None)
        self.assertIsInstance(hooks[0], util._DatasetInitializerHook)
Exemplo n.º 9
0
  def test_parse_input_fn_result_tuple(self):
    def _input_fn():
      features = constant_op.constant(np.arange(100))
      labels = constant_op.constant(np.arange(100, 200))
      return features, labels

    features, labels, hooks = util.parse_input_fn_result(_input_fn())

    with self.cached_session() as sess:
      vals = sess.run([features, labels])

    self.assertAllEqual(vals[0], np.arange(100))
    self.assertAllEqual(vals[1], np.arange(100, 200))
    self.assertEqual(hooks, [])
Exemplo n.º 10
0
    def test_parse_input_fn_result_tuple(self):
        def _input_fn():
            features = constant_op.constant(np.arange(100))
            labels = constant_op.constant(np.arange(100, 200))
            return features, labels

        features, labels, hooks = util.parse_input_fn_result(_input_fn())

        with self.test_session() as sess:
            vals = sess.run([features, labels])

        self.assertAllEqual(vals[0], np.arange(100))
        self.assertAllEqual(vals[1], np.arange(100, 200))
        self.assertEqual(hooks, [])