Пример #1
0
 def test_make_parsing_fn_exception(self):
     with tf.Graph().as_default():
         with self.assertRaises(ValueError):
             data_lib.make_parsing_fn(
                 "non_existing_format",
                 context_feature_spec=CONTEXT_FEATURE_SPEC,
                 example_feature_spec=EXAMPLE_FEATURE_SPEC)
Пример #2
0
    def test_make_parsing_fn_eie(self):
        with tf.Graph().as_default():
            parsing_fn = data_lib.make_parsing_fn(
                data_lib.EIE,
                context_feature_spec=CONTEXT_FEATURE_SPEC,
                example_feature_spec=EXAMPLE_FEATURE_SPEC)
            serialized_example_in_example = [
                _example_in_example(CONTEXT_1, EXAMPLES_1).SerializeToString(),
                _example_in_example(CONTEXT_2, EXAMPLES_2).SerializeToString(),
            ]
            features = parsing_fn(serialized_example_in_example)

            with tf.compat.v1.Session() as sess:
                sess.run(tf.compat.v1.local_variables_initializer())
                features = sess.run(features)
                # Test dense_shape, indices and values for a SparseTensor.
                self.assertAllEqual(features["unigrams"].dense_shape,
                                    [2, 2, 3])
                self.assertAllEqual(
                    features["unigrams"].indices,
                    [[0, 0, 0], [0, 1, 0], [0, 1, 1], [0, 1, 2], [1, 0, 0]])
                self.assertAllEqual(
                    features["unigrams"].values,
                    [b"tensorflow", b"learning", b"to", b"rank", b"gbdt"])
                # For Tensors with dense values, values can be directly checked.
                self.assertAllEqual(features["query_length"], [[3], [2]])
                self.assertAllEqual(features["utility"],
                                    [[[0.], [1.0]], [[0.], [-1.]]])
Пример #3
0
    def test_make_parsing_fn_seq(self):
        with tf.Graph().as_default():
            parsing_fn = data_lib.make_parsing_fn(
                data_lib.SEQ,
                context_feature_spec=CONTEXT_FEATURE_SPEC,
                example_feature_spec=EXAMPLE_FEATURE_SPEC)
            sequence_examples = [
                SEQ_EXAMPLE_PROTO_1.SerializeToString(),
                SEQ_EXAMPLE_PROTO_2.SerializeToString(),
            ]
            features = parsing_fn(sequence_examples)

            with tf.compat.v1.Session() as sess:
                sess.run(tf.compat.v1.local_variables_initializer())
                feature_map = sess.run(features)
                self.assertCountEqual(feature_map,
                                      ["query_length", "unigrams", "utility"])
                self.assertAllEqual(feature_map["unigrams"].dense_shape,
                                    [2, 2, 3])
                self.assertAllEqual(
                    feature_map["unigrams"].indices,
                    [[0, 0, 0], [0, 1, 0], [0, 1, 1], [0, 1, 2], [1, 0, 0]])
                self.assertAllEqual(
                    feature_map["unigrams"].values,
                    [b"tensorflow", b"learning", b"to", b"rank", b"gbdt"])
                self.assertAllEqual(feature_map["query_length"], [[3], [2]])
                self.assertAllEqual(feature_map["utility"],
                                    [[[0.], [1.]], [[0.], [-1.]]])
Пример #4
0
    def _decode(self, record: tf.Tensor) -> Dict[str, tf.Tensor]:
        """Decodes a serialized ELWC."""
        parsing_example_feature_spec = self._example_feature_spec
        if self._label_spec:
            parsing_example_feature_spec.update(dict([self._label_spec]))

        parsing_fn = tfr_data.make_parsing_fn(
            self._params.data_format,
            self._params.list_size,
            self._context_feature_spec,
            parsing_example_feature_spec,
            mask_feature_name=self._params.mask_feature_name,
            shuffle_examples=self._params.shuffle_examples,
            seed=self._params.seed)

        # The TF-Ranking parsing functions only takes batched ELWCs as input and
        # output a dictionary from feature names to Tensors with the shape of
        # (batch_size, list_size, feature_length).
        features = parsing_fn(tf.reshape(record, [1]))

        # Remove the first batch_size dimension and leave batching to DataLoader
        # class in construction of distributed data set.
        output_features = {
            name: tf.squeeze(tensor, 0)
            for name, tensor in features.items()
        }

        # ELWC only supports tf.int64, but the TPU only supports tf.int32.
        # So cast all int64 to int32.
        for name in output_features:
            t = output_features[name]
            if t.dtype == tf.int64:
                t = tf.cast(t, tf.int32)
            output_features[name] = t

        return output_features