def test_invalid_cases(self, shared): # Inputs. input_sparse_tensor = sparse_tensor.SparseTensorValue( indices=((0, 0), (1, 0), (1, 1), (1, 4)), values=(2, 0, 1, 3), dense_shape=(2, 5)) input_features = {'inp': input_sparse_tensor} # Build columns. categorical_column_input = fc_lib.categorical_column_with_identity( key='inp', num_buckets=3) # Training on TPU with cpu embedding lookups is not supported. if shared: embedding_column = tpu_fc.shared_embedding_columns_v2( [categorical_column_input], dimension=2, embedding_lookup_device='cpu', tensor_core_shape=[None, 3]) else: embedding_column = tpu_fc.embedding_column_v2( categorical_column_input, dimension=2, embedding_lookup_device='cpu', tensor_core_shape=[None, 3]) dense_features = fc_lib.DenseFeatures(embedding_column) with self.assertRaisesRegexp( ValueError, r'.*embedding_lookup_device=\"cpu\" during training is not'): dense_features(input_features) # Inference on with TPU Embedding Hardware is not supported. if shared: embedding_column = tpu_fc.shared_embedding_columns_v2( [categorical_column_input], dimension=2, embedding_lookup_device='tpu_embedding_core', tensor_core_shape=[None, 3]) else: embedding_column = tpu_fc.embedding_column_v2( categorical_column_input, dimension=2, embedding_lookup_device='tpu_embedding_core', tensor_core_shape=[None, 3]) context = tpu._TPUInferenceContext('tpu_inference') context.Enter() dense_features = fc_lib.DenseFeatures(embedding_column) with self.assertRaisesRegexp( ValueError, r'Using embedding_lookup_device=tpu_embedding_core during inference is ' ): dense_features(input_features) context.Exit()
def test_dense_embedding_lookup(self, shared, combiner): # Inputs. vocabulary_size = 3 input_sparse_tensor = sparse_tensor.SparseTensorValue( # example 0, ids [2] # example 1, ids [0, 1, 3] indices=((0, 0), (1, 0), (1, 1), (1, 4)), values=(2, 0, 1, 3), dense_shape=(2, 5)) input_features = {'inp': input_sparse_tensor} # Embedding variable. embedding_dimension = 2 embedding_values = ( (1., 2.), # id 0 (3., 5.), # id 1 (7., 11.), # id 2 (13., 17.) # id 3 ) def _initializer(shape, dtype, partition_info=None): self.assertAllEqual((vocabulary_size, embedding_dimension), shape) self.assertEqual(dtypes.float32, dtype) self.assertIsNone(partition_info) return embedding_values # Build columns. categorical_column_input = fc_lib.categorical_column_with_identity( key='inp', num_buckets=vocabulary_size) # Set tensor_core_shape to be [None, 20] to ensure some padding and # dynamic batch size. if shared: embedding_column = tpu_fc.shared_embedding_columns_v2( [categorical_column_input], dimension=embedding_dimension, initializer=_initializer, combiner=combiner, embedding_lookup_device='tpu_tensor_core', tensor_core_shape=[None, 3]) else: embedding_column = tpu_fc.embedding_column_v2( categorical_column_input, dimension=embedding_dimension, initializer=_initializer, combiner=combiner, embedding_lookup_device='tpu_tensor_core', tensor_core_shape=[None, 3]) # Run in TPUInferenceContext so that we hit the intended densification case. context = tpu._TPUInferenceContext('tpu_inference') context.Enter() dense_features = fc_lib.DenseFeatures(embedding_column) # Sqrtn combiner not supported for now. if combiner == 'sqrtn': with self.assertRaisesRegexp( ValueError, 'Dense TPU Embedding does not support combiner'): embedding_lookup = dense_features(input_features) return if combiner == 'mean': expected_lookups = ( # example 0: (7., 11.), # ids [2], embedding = [7, 11] # example 1: (2., 3.5 ), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] ) elif combiner == 'sum': expected_lookups = ( # example 0: (7., 11.), # ids [2], embedding = [7, 11] # example 1: (4., 7 ), # ids [0, 1], embedding = sum([1, 2] + [3, 5]) = [4, 7] ) embedding_lookup = dense_features(input_features) # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) if shared: self.assertCountEqual(('inp_shared_embedding:0', ), tuple([v.name for v in global_vars])) else: self.assertCountEqual( ('dense_features/inp_embedding/embedding_weights:0', ), tuple([v.name for v in global_vars])) embedding_var = global_vars[0] with _initialized_session(): self.assertAllEqual(embedding_values, embedding_var.eval()) eval_res = embedding_lookup.eval() self.assertAllEqual(expected_lookups, eval_res) context.Exit()
def test_empty_row(self): # Inputs. vocabulary_size = 3 input_sparse_tensor = sparse_tensor.SparseTensorValue( # example 0, ids [] # example 1, ids [0, 1, 3] indices=((1, 0), (1, 1), (1, 4)), values=(0, 1, 3), dense_shape=(2, 5)) input_features = {'inp': input_sparse_tensor} # Embedding variable. embedding_dimension = 2 embedding_values = ( (1., 2.), # id 0 (3., 5.), # id 1 (7., 11.), # id 2 (13., 17.) # id 3 ) def _initializer(shape, dtype, partition_info=None): self.assertAllEqual((vocabulary_size, embedding_dimension), shape) self.assertEqual(dtypes.float32, dtype) self.assertIsNone(partition_info) return embedding_values # Build columns. categorical_column_input = fc_lib.categorical_column_with_identity( key='inp', num_buckets=vocabulary_size) # Set tensor_core_shape to be [None, 20] to ensure some padding and # dynamic batch size. embedding_column = tpu_fc.embedding_column_v2( categorical_column_input, dimension=embedding_dimension, initializer=_initializer, combiner='mean', embedding_lookup_device='tpu_tensor_core', tensor_core_shape=[None, 3]) # Run in TPUContexts so that we hit the intended densification case. context = tpu._TPUInferenceContext('tpu_inference') context.Enter() with tpu_function.tpu_shard_context(1): dense_features = fc_lib.DenseFeatures(embedding_column) expected_lookups = ( # example 0: (0., 0.), # ids [], embedding = [0, 0] # example 1: (2., 3.5 ), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] ) embedding_lookup = dense_features(input_features) # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertCountEqual( ('dense_features/inp_embedding/embedding_weights:0', ), tuple([v.name for v in global_vars])) embedding_var = global_vars[0] with _initialized_session(): self.assertAllEqual(embedding_values, embedding_var) eval_res = embedding_lookup.eval() self.assertAllEqual(expected_lookups, eval_res) context.Exit()