def test_encode_pointwise_features(self): with tf.Graph().as_default(): # Batch size = 2, tf.Example input format. features = { "query_length": tf.convert_to_tensor( value=[[1], [1]]), # Repeated context feature. "utility": tf.convert_to_tensor(value=[[1.0], [0.0]]), "unigrams": tf.SparseTensor(indices=[[0, 0], [1, 0]], values=["ranking", "regression"], dense_shape=[2, 1]) } context_feature_columns = { "query_length": tf.feature_column.numeric_column("query_length", shape=(1, ), default_value=0, dtype=tf.int64) } example_feature_columns = { "utility": tf.feature_column.numeric_column("utility", shape=(1, ), default_value=0.0, dtype=tf.float32), "unigrams": tf.feature_column.embedding_column( feature_column.categorical_column_with_vocabulary_list( "unigrams", vocabulary_list=[ "ranking", "regression", "classification", "ordinal" ]), dimension=10) } (context_features, example_features) = feature_lib.encode_pointwise_features( features, context_feature_columns=context_feature_columns, example_feature_columns=example_feature_columns) self.assertAllEqual(["query_length"], sorted(context_features)) self.assertAllEqual(["unigrams", "utility"], sorted(example_features)) # Unigrams dense tensor has shape: [batch_size=2, list_size=1, dim=10]. self.assertAllEqual( [2, 1, 10], example_features["unigrams"].get_shape().as_list()) with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.tables_initializer()) context_features, example_features = sess.run( [context_features, example_features]) self.assertAllEqual([[1], [1]], context_features["query_length"]) # Utility tensor has shape: [batch_size=2, list_size=1, 1]. self.assertAllEqual([[[1.0]], [[0.0]]], example_features["utility"])
def _transform_fn(self, features, mode): """Defines the transform fn.""" if self._transform_function is not None: return self._transform_function(features=features, mode=mode) if mode == tf.estimator.ModeKeys.PREDICT: return feature.encode_pointwise_features( features=features, context_feature_columns=self._context_feature_columns, example_feature_columns=self._example_feature_columns, mode=mode, scope="transform_layer") else: return feature.encode_listwise_features( features=features, context_feature_columns=self._context_feature_columns, example_feature_columns=self._example_feature_columns, mode=mode, scope="transform_layer")