Example #1
0
    def test_encode_pointwise_features(self):
        with tf.Graph().as_default():
            # Batch size = 2, tf.Example input format.
            features = {
                "query_length":
                tf.convert_to_tensor(
                    value=[[1], [1]]),  # Repeated context feature.
                "utility":
                tf.convert_to_tensor(value=[[1.0], [0.0]]),
                "unigrams":
                tf.SparseTensor(indices=[[0, 0], [1, 0]],
                                values=["ranking", "regression"],
                                dense_shape=[2, 1])
            }
            context_feature_columns = {
                "query_length":
                tf.feature_column.numeric_column("query_length",
                                                 shape=(1, ),
                                                 default_value=0,
                                                 dtype=tf.int64)
            }
            example_feature_columns = {
                "utility":
                tf.feature_column.numeric_column("utility",
                                                 shape=(1, ),
                                                 default_value=0.0,
                                                 dtype=tf.float32),
                "unigrams":
                tf.feature_column.embedding_column(
                    feature_column.categorical_column_with_vocabulary_list(
                        "unigrams",
                        vocabulary_list=[
                            "ranking", "regression", "classification",
                            "ordinal"
                        ]),
                    dimension=10)
            }

            (context_features,
             example_features) = feature_lib.encode_pointwise_features(
                 features,
                 context_feature_columns=context_feature_columns,
                 example_feature_columns=example_feature_columns)
            self.assertAllEqual(["query_length"], sorted(context_features))
            self.assertAllEqual(["unigrams", "utility"],
                                sorted(example_features))
            # Unigrams dense tensor has shape: [batch_size=2, list_size=1, dim=10].
            self.assertAllEqual(
                [2, 1, 10], example_features["unigrams"].get_shape().as_list())
            with tf.compat.v1.Session() as sess:
                sess.run(tf.compat.v1.global_variables_initializer())
                sess.run(tf.compat.v1.tables_initializer())
                context_features, example_features = sess.run(
                    [context_features, example_features])
                self.assertAllEqual([[1], [1]],
                                    context_features["query_length"])
                # Utility tensor has shape: [batch_size=2, list_size=1, 1].
                self.assertAllEqual([[[1.0]], [[0.0]]],
                                    example_features["utility"])
Example #2
0
    def _transform_fn(self, features, mode):
        """Defines the transform fn."""
        if self._transform_function is not None:
            return self._transform_function(features=features, mode=mode)

        if mode == tf.estimator.ModeKeys.PREDICT:
            return feature.encode_pointwise_features(
                features=features,
                context_feature_columns=self._context_feature_columns,
                example_feature_columns=self._example_feature_columns,
                mode=mode,
                scope="transform_layer")
        else:
            return feature.encode_listwise_features(
                features=features,
                context_feature_columns=self._context_feature_columns,
                example_feature_columns=self._example_feature_columns,
                mode=mode,
                scope="transform_layer")