예제 #1
0
  def test_build_supervised_input_receiver_fn_from_input_fn(self):
    def dummy_input_fn():
      return ({"x": constant_op.constant([[1], [1]]),
               "y": constant_op.constant(["hello", "goodbye"])},
              constant_op.constant([[1], [1]]))

    input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(
        dummy_input_fn)

    with ops.Graph().as_default():
      input_receiver = input_receiver_fn()
      self.assertEqual(set(["x", "y"]),
                       set(input_receiver.features.keys()))
      self.assertIsInstance(input_receiver.labels, ops.Tensor)
      self.assertEqual(set(["x", "y", "label"]),
                       set(input_receiver.receiver_tensors.keys()))
예제 #2
0
  def test_build_supervised_input_receiver_fn_from_input_fn(self):
    def dummy_input_fn():
      return ({"x": constant_op.constant([[1], [1]]),
               "y": constant_op.constant(["hello", "goodbye"])},
              constant_op.constant([[1], [1]]))

    input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(
        dummy_input_fn)

    with ops.Graph().as_default():
      input_receiver = input_receiver_fn()
      self.assertEqual(set(["x", "y"]),
                       set(input_receiver.features.keys()))
      self.assertIsInstance(input_receiver.labels, ops.Tensor)
      self.assertEqual(set(["x", "y", "label"]),
                       set(input_receiver.receiver_tensors.keys()))
예제 #3
0
  def test_build_supervised_input_receiver_fn_from_input_fn_args(self):
    def dummy_input_fn(feature_key="x"):
      return ({feature_key: constant_op.constant([[1], [1]]),
               "y": constant_op.constant(["hello", "goodbye"])},
              {"my_label": constant_op.constant([[1], [1]])})

    input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(
        dummy_input_fn, feature_key="z")

    with ops.Graph().as_default():
      input_receiver = input_receiver_fn()
      self.assertEqual(set(["z", "y"]),
                       set(input_receiver.features.keys()))
      self.assertEqual(set(["my_label"]),
                       set(input_receiver.labels.keys()))
      self.assertEqual(set(["z", "y", "my_label"]),
                       set(input_receiver.receiver_tensors.keys()))
예제 #4
0
  def test_build_supervised_input_receiver_fn_from_input_fn_args(self):
    def dummy_input_fn(feature_key="x"):
      return ({feature_key: constant_op.constant([[1], [1]]),
               "y": constant_op.constant(["hello", "goodbye"])},
              {"my_label": constant_op.constant([[1], [1]])})

    input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(
        dummy_input_fn, feature_key="z")

    with ops.Graph().as_default():
      input_receiver = input_receiver_fn()
      self.assertEqual(set(["z", "y"]),
                       set(input_receiver.features.keys()))
      self.assertEqual(set(["my_label"]),
                       set(input_receiver.labels.keys()))
      self.assertEqual(set(["z", "y", "my_label"]),
                       set(input_receiver.receiver_tensors.keys()))