def test_parse_from_example_list_truncate(self): with tf.Graph().as_default(): serialized_example_lists = [ EXAMPLE_LIST_PROTO_1.SerializeToString(), EXAMPLE_LIST_PROTO_2.SerializeToString() ] # Trunate number of examples from 2 to 1. features = data_lib.parse_from_example_list( serialized_example_lists, list_size=1, context_feature_spec=CONTEXT_FEATURE_SPEC, example_feature_spec=EXAMPLE_FEATURE_SPEC) with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.local_variables_initializer()) features = sess.run(features) # Test dense_shape, indices and values for a SparseTensor. self.assertAllEqual(features["unigrams"].dense_shape, [2, 1, 1]) self.assertAllEqual(features["unigrams"].indices, [[0, 0, 0], [1, 0, 0]]) self.assertAllEqual(features["unigrams"].values, [b"tensorflow", b"gbdt"]) # For Tensors with dense values, values can be directly checked. self.assertAllEqual(features["query_length"], [[3], [2]]) self.assertAllEqual(features["utility"], [[[0.]], [[0.]]])
def test_parse_from_example_list_shuffle(self): with tf.Graph().as_default(): serialized_example_lists = [ EXAMPLE_LIST_PROTO_1.SerializeToString(), EXAMPLE_LIST_PROTO_2.SerializeToString() ] # Trunate number of examples from 2 to 1. features = data_lib.parse_from_example_list( serialized_example_lists, list_size=1, context_feature_spec=CONTEXT_FEATURE_SPEC, example_feature_spec=EXAMPLE_FEATURE_SPEC, shuffle_examples=True, seed=1) with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.local_variables_initializer()) features = sess.run(features) # With `shuffle_examples` and seed=1, the example `tensorflow` and the # example `learning to rank` in EXAMPLE_LIST_PROTO_1 switch order. After # truncation at list_size=1, only `learning to rank` in # EXAMPLE_LIST_PROTO_1 and `gbdt` in EXAMPLE_LIST_PROTO_2 are left in # serialized features. # Test dense_shape, indices and values for a SparseTensor. self.assertAllEqual(features["unigrams"].dense_shape, [2, 1, 3]) self.assertAllEqual( features["unigrams"].indices, [[0, 0, 0], [0, 0, 1], [0, 0, 2], [1, 0, 0]]) self.assertAllEqual(features["unigrams"].values, [b"learning", b"to", b"rank", b"gbdt"]) # For Tensors with dense values, values can be directly checked. self.assertAllEqual(features["query_length"], [[3], [2]]) self.assertAllEqual(features["utility"], [[[1.]], [[0.]]])
def _predict(serialized): features = data.parse_from_example_list( serialized, context_feature_spec=self._context_feature_spec, example_feature_spec=self._example_feature_spec, mask_feature_name="mask") scores = dnn_model(inputs=features, training=False) return {"predictions": scores}
def _predict(serialized): features = data.parse_from_example_list( serialized, context_feature_spec=context_feature_spec, example_feature_spec=example_feature_spec, size_feature_name="example_list_size") scores = ranker(inputs=features, training=False) return {"predictions": scores}
def predict(serialized_elwcs: tf.Tensor) -> Dict[str, tf.Tensor]: """Defines predict signature.""" features = data.parse_from_example_list( serialized_elwcs, context_feature_spec=self._context_feature_spec, example_feature_spec=self._example_feature_spec, mask_feature_name=self._mask_feature_name) outputs = self._model(inputs=features, training=False) return _normalize_outputs(tf.saved_model.PREDICT_OUTPUTS, outputs)
def test_parse_example_list_with_sizes(self): with tf.Graph().as_default(): serialized_example_lists = [ EXAMPLE_LIST_PROTO_1.SerializeToString(), EXAMPLE_LIST_PROTO_2.SerializeToString() ] # Padding since list_size 3 is larger than 2. features = data_lib.parse_from_example_list( serialized_example_lists, list_size=3, context_feature_spec=CONTEXT_FEATURE_SPEC, example_feature_spec=EXAMPLE_FEATURE_SPEC, size_feature_name=_SIZE) with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.local_variables_initializer()) features = sess.run(features) self.assertAllEqual(features[_SIZE], [2, 1])
def test_parse_from_example_list_static_shape(self): with tf.Graph().as_default(): serialized_example_lists = [ EXAMPLE_LIST_PROTO_1.SerializeToString(), EXAMPLE_LIST_PROTO_2.SerializeToString() ] feature_map_list = [] for list_size in [None, 100, 1]: feature_map_list.append( data_lib.parse_from_example_list( serialized_example_lists, list_size=list_size, context_feature_spec=CONTEXT_FEATURE_SPEC, example_feature_spec=EXAMPLE_FEATURE_SPEC)) for features in feature_map_list: self.assertAllEqual( [2, 1], features["query_length"].get_shape().as_list()) for features, static_shape in zip(feature_map_list, [ [2, 2, 1], [2, 100, 1], [2, 1, 1], ]): self.assertAllEqual(static_shape, features["utility"].get_shape().as_list())