Exemple #1
0
    def test_parse_from_example_list_truncate(self):
        with tf.Graph().as_default():
            serialized_example_lists = [
                EXAMPLE_LIST_PROTO_1.SerializeToString(),
                EXAMPLE_LIST_PROTO_2.SerializeToString()
            ]
            # Trunate number of examples from 2 to 1.
            features = data_lib.parse_from_example_list(
                serialized_example_lists,
                list_size=1,
                context_feature_spec=CONTEXT_FEATURE_SPEC,
                example_feature_spec=EXAMPLE_FEATURE_SPEC)

            with tf.compat.v1.Session() as sess:
                sess.run(tf.compat.v1.local_variables_initializer())
                features = sess.run(features)
                # Test dense_shape, indices and values for a SparseTensor.
                self.assertAllEqual(features["unigrams"].dense_shape,
                                    [2, 1, 1])
                self.assertAllEqual(features["unigrams"].indices,
                                    [[0, 0, 0], [1, 0, 0]])
                self.assertAllEqual(features["unigrams"].values,
                                    [b"tensorflow", b"gbdt"])
                # For Tensors with dense values, values can be directly checked.
                self.assertAllEqual(features["query_length"], [[3], [2]])
                self.assertAllEqual(features["utility"], [[[0.]], [[0.]]])
Exemple #2
0
    def test_parse_from_example_list_shuffle(self):
        with tf.Graph().as_default():
            serialized_example_lists = [
                EXAMPLE_LIST_PROTO_1.SerializeToString(),
                EXAMPLE_LIST_PROTO_2.SerializeToString()
            ]
            # Trunate number of examples from 2 to 1.
            features = data_lib.parse_from_example_list(
                serialized_example_lists,
                list_size=1,
                context_feature_spec=CONTEXT_FEATURE_SPEC,
                example_feature_spec=EXAMPLE_FEATURE_SPEC,
                shuffle_examples=True,
                seed=1)

            with tf.compat.v1.Session() as sess:
                sess.run(tf.compat.v1.local_variables_initializer())
                features = sess.run(features)
                # With `shuffle_examples` and seed=1, the example `tensorflow` and the
                # example `learning to rank` in EXAMPLE_LIST_PROTO_1 switch order. After
                # truncation at list_size=1, only `learning to rank` in
                # EXAMPLE_LIST_PROTO_1 and `gbdt` in EXAMPLE_LIST_PROTO_2 are left in
                # serialized features.
                # Test dense_shape, indices and values for a SparseTensor.
                self.assertAllEqual(features["unigrams"].dense_shape,
                                    [2, 1, 3])
                self.assertAllEqual(
                    features["unigrams"].indices,
                    [[0, 0, 0], [0, 0, 1], [0, 0, 2], [1, 0, 0]])
                self.assertAllEqual(features["unigrams"].values,
                                    [b"learning", b"to", b"rank", b"gbdt"])
                # For Tensors with dense values, values can be directly checked.
                self.assertAllEqual(features["query_length"], [[3], [2]])
                self.assertAllEqual(features["utility"], [[[1.]], [[0.]]])
Exemple #3
0
 def _predict(serialized):
   features = data.parse_from_example_list(
       serialized,
       context_feature_spec=self._context_feature_spec,
       example_feature_spec=self._example_feature_spec,
       mask_feature_name="mask")
   scores = dnn_model(inputs=features, training=False)
   return {"predictions": scores}
Exemple #4
0
 def _predict(serialized):
   features = data.parse_from_example_list(
       serialized,
       context_feature_spec=context_feature_spec,
       example_feature_spec=example_feature_spec,
       size_feature_name="example_list_size")
   scores = ranker(inputs=features, training=False)
   return {"predictions": scores}
Exemple #5
0
 def predict(serialized_elwcs: tf.Tensor) -> Dict[str, tf.Tensor]:
     """Defines predict signature."""
     features = data.parse_from_example_list(
         serialized_elwcs,
         context_feature_spec=self._context_feature_spec,
         example_feature_spec=self._example_feature_spec,
         mask_feature_name=self._mask_feature_name)
     outputs = self._model(inputs=features, training=False)
     return _normalize_outputs(tf.saved_model.PREDICT_OUTPUTS, outputs)
Exemple #6
0
    def test_parse_example_list_with_sizes(self):
        with tf.Graph().as_default():
            serialized_example_lists = [
                EXAMPLE_LIST_PROTO_1.SerializeToString(),
                EXAMPLE_LIST_PROTO_2.SerializeToString()
            ]
            # Padding since list_size 3 is larger than 2.
            features = data_lib.parse_from_example_list(
                serialized_example_lists,
                list_size=3,
                context_feature_spec=CONTEXT_FEATURE_SPEC,
                example_feature_spec=EXAMPLE_FEATURE_SPEC,
                size_feature_name=_SIZE)

            with tf.compat.v1.Session() as sess:
                sess.run(tf.compat.v1.local_variables_initializer())
                features = sess.run(features)
                self.assertAllEqual(features[_SIZE], [2, 1])
Exemple #7
0
 def test_parse_from_example_list_static_shape(self):
     with tf.Graph().as_default():
         serialized_example_lists = [
             EXAMPLE_LIST_PROTO_1.SerializeToString(),
             EXAMPLE_LIST_PROTO_2.SerializeToString()
         ]
         feature_map_list = []
         for list_size in [None, 100, 1]:
             feature_map_list.append(
                 data_lib.parse_from_example_list(
                     serialized_example_lists,
                     list_size=list_size,
                     context_feature_spec=CONTEXT_FEATURE_SPEC,
                     example_feature_spec=EXAMPLE_FEATURE_SPEC))
         for features in feature_map_list:
             self.assertAllEqual(
                 [2, 1], features["query_length"].get_shape().as_list())
         for features, static_shape in zip(feature_map_list, [
             [2, 2, 1],
             [2, 100, 1],
             [2, 1, 1],
         ]):
             self.assertAllEqual(static_shape,
                                 features["utility"].get_shape().as_list())