Exemplo n.º 1
0
    def test_build_raw_supervised_input_receiver_fn_raw_tensors(self):
        features = {
            "feature_1": constant_op.constant(["hello"]),
            "feature_2": constant_op.constant([42])
        }
        labels = {
            "foo": constant_op.constant([5]),
            "bar": constant_op.constant([6])
        }
        input_receiver_fn1 = export.build_raw_supervised_input_receiver_fn(
            features["feature_1"], labels)
        input_receiver_fn2 = export.build_raw_supervised_input_receiver_fn(
            features["feature_1"], labels["foo"])
        with ops.Graph().as_default():
            input_receiver = input_receiver_fn1()
            self.assertIsInstance(input_receiver.features, ops.Tensor)
            self.assertEqual(set(["foo", "bar"]),
                             set(input_receiver.labels.keys()))
            self.assertEqual(set(["input", "foo", "bar"]),
                             set(input_receiver.receiver_tensors.keys()))

            input_receiver = input_receiver_fn2()
            self.assertIsInstance(input_receiver.features, ops.Tensor)
            self.assertIsInstance(input_receiver.labels, ops.Tensor)
            self.assertEqual(set(["input", "label"]),
                             set(input_receiver.receiver_tensors.keys()))
Exemplo n.º 2
0
 def test_build_raw_supervised_input_receiver_fn_overlapping_keys(self):
   features = {"feature_1": constant_op.constant(["hello"]),
               "feature_2": constant_op.constant([42])}
   labels = {"feature_1": constant_op.constant([5]),
             "bar": constant_op.constant([6])}
   with self.assertRaises(ValueError):
     export.build_raw_supervised_input_receiver_fn(features, labels)
Exemplo n.º 3
0
 def test_build_raw_supervised_input_receiver_fn(self):
     features = {
         "feature_1": constant_op.constant(["hello"]),
         "feature_2": constant_op.constant([42])
     }
     labels = {
         "foo": constant_op.constant([5]),
         "bar": constant_op.constant([6])
     }
     input_receiver_fn = export.build_raw_supervised_input_receiver_fn(
         features, labels)
     with ops.Graph().as_default():
         input_receiver = input_receiver_fn()
         self.assertEqual(set(["feature_1", "feature_2"]),
                          set(input_receiver.features.keys()))
         self.assertEqual(set(["foo", "bar"]),
                          set(input_receiver.labels.keys()))
         self.assertEqual(set(["feature_1", "feature_2", "foo", "bar"]),
                          set(input_receiver.receiver_tensors.keys()))
         self.assertEqual(
             dtypes.string,
             input_receiver.receiver_tensors["feature_1"].dtype)
         self.assertEqual(
             dtypes.int32,
             input_receiver.receiver_tensors["feature_2"].dtype)
Exemplo n.º 4
0
 def test_build_raw_supervised_input_receiver_fn_batch_size(self):
   features = {"feature_1": constant_op.constant(["hello"]),
               "feature_2": constant_op.constant([42])}
   labels = {"foo": constant_op.constant([5]),
             "bar": constant_op.constant([6])}
   input_receiver_fn = export.build_raw_supervised_input_receiver_fn(
       features, labels, default_batch_size=10)
   with ops.Graph().as_default():
     input_receiver = input_receiver_fn()
     self.assertEqual([10], input_receiver.receiver_tensors["feature_1"].shape)
     self.assertEqual([10], input_receiver.features["feature_1"].shape)
Exemplo n.º 5
0
def dummy_supervised_receiver_fn():
    feature_spec = {
        'x':
        array_ops.placeholder(dtype=dtypes.int64,
                              shape=(2, 1),
                              name='feature_x'),
    }
    label_spec = array_ops.placeholder(dtype=dtypes.float32,
                                       shape=[2, 1],
                                       name='truth')
    return export.build_raw_supervised_input_receiver_fn(
        feature_spec, label_spec)