def test_make_parsing_fn_exception(self): with tf.Graph().as_default(): with self.assertRaises(ValueError): data_lib.make_parsing_fn( "non_existing_format", context_feature_spec=CONTEXT_FEATURE_SPEC, example_feature_spec=EXAMPLE_FEATURE_SPEC)
def test_make_parsing_fn_eie(self): with tf.Graph().as_default(): parsing_fn = data_lib.make_parsing_fn( data_lib.EIE, context_feature_spec=CONTEXT_FEATURE_SPEC, example_feature_spec=EXAMPLE_FEATURE_SPEC) serialized_example_in_example = [ _example_in_example(CONTEXT_1, EXAMPLES_1).SerializeToString(), _example_in_example(CONTEXT_2, EXAMPLES_2).SerializeToString(), ] features = parsing_fn(serialized_example_in_example) with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.local_variables_initializer()) features = sess.run(features) # Test dense_shape, indices and values for a SparseTensor. self.assertAllEqual(features["unigrams"].dense_shape, [2, 2, 3]) self.assertAllEqual( features["unigrams"].indices, [[0, 0, 0], [0, 1, 0], [0, 1, 1], [0, 1, 2], [1, 0, 0]]) self.assertAllEqual( features["unigrams"].values, [b"tensorflow", b"learning", b"to", b"rank", b"gbdt"]) # For Tensors with dense values, values can be directly checked. self.assertAllEqual(features["query_length"], [[3], [2]]) self.assertAllEqual(features["utility"], [[[0.], [1.0]], [[0.], [-1.]]])
def test_make_parsing_fn_seq(self): with tf.Graph().as_default(): parsing_fn = data_lib.make_parsing_fn( data_lib.SEQ, context_feature_spec=CONTEXT_FEATURE_SPEC, example_feature_spec=EXAMPLE_FEATURE_SPEC) sequence_examples = [ SEQ_EXAMPLE_PROTO_1.SerializeToString(), SEQ_EXAMPLE_PROTO_2.SerializeToString(), ] features = parsing_fn(sequence_examples) with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.local_variables_initializer()) feature_map = sess.run(features) self.assertCountEqual(feature_map, ["query_length", "unigrams", "utility"]) self.assertAllEqual(feature_map["unigrams"].dense_shape, [2, 2, 3]) self.assertAllEqual( feature_map["unigrams"].indices, [[0, 0, 0], [0, 1, 0], [0, 1, 1], [0, 1, 2], [1, 0, 0]]) self.assertAllEqual( feature_map["unigrams"].values, [b"tensorflow", b"learning", b"to", b"rank", b"gbdt"]) self.assertAllEqual(feature_map["query_length"], [[3], [2]]) self.assertAllEqual(feature_map["utility"], [[[0.], [1.]], [[0.], [-1.]]])
def _decode(self, record: tf.Tensor) -> Dict[str, tf.Tensor]: """Decodes a serialized ELWC.""" parsing_example_feature_spec = self._example_feature_spec if self._label_spec: parsing_example_feature_spec.update(dict([self._label_spec])) parsing_fn = tfr_data.make_parsing_fn( self._params.data_format, self._params.list_size, self._context_feature_spec, parsing_example_feature_spec, mask_feature_name=self._params.mask_feature_name, shuffle_examples=self._params.shuffle_examples, seed=self._params.seed) # The TF-Ranking parsing functions only takes batched ELWCs as input and # output a dictionary from feature names to Tensors with the shape of # (batch_size, list_size, feature_length). features = parsing_fn(tf.reshape(record, [1])) # Remove the first batch_size dimension and leave batching to DataLoader # class in construction of distributed data set. output_features = { name: tf.squeeze(tensor, 0) for name, tensor in features.items() } # ELWC only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in output_features: t = output_features[name] if t.dtype == tf.int64: t = tf.cast(t, tf.int32) output_features[name] = t return output_features