def test_indicator_column(self, sparse_input_args_a, sparse_input_args_b, expected_input_layer, expected_sequence_length): sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a) sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b) vocabulary_size_a = 3 vocabulary_size_b = 2 categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size_a) indicator_column_a = fc.indicator_column(categorical_column_a) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size_b) indicator_column_b = fc.indicator_column(categorical_column_b) # Test that columns are reordered alphabetically. sequence_input_layer = ksfc.SequenceFeatures( [indicator_column_b, indicator_column_a]) input_layer, sequence_length = sequence_input_layer({ 'aaa': sparse_input_a, 'bbb': sparse_input_b }) self.assertAllEqual(expected_input_layer, self.evaluate(input_layer)) self.assertAllEqual(expected_sequence_length, self.evaluate(sequence_length))
def test_embedding_column( self, sparse_input_args_a, sparse_input_args_b, expected_input_layer, expected_sequence_length): sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a) sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b) vocabulary_size = 3 embedding_dimension_a = 2 embedding_values_a = ( (1., 2.), # id 0 (3., 4.), # id 1 (5., 6.) # id 2 ) embedding_dimension_b = 3 embedding_values_b = ( (11., 12., 13.), # id 0 (14., 15., 16.), # id 1 (17., 18., 19.) # id 2 ) def _get_initializer(embedding_dimension, embedding_values): def _initializer(shape, dtype, partition_info=None): self.assertAllEqual((vocabulary_size, embedding_dimension), shape) self.assertEqual(dtypes.float32, dtype) self.assertIsNone(partition_info) return embedding_values return _initializer categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) embedding_column_a = fc.embedding_column( categorical_column_a, dimension=embedding_dimension_a, initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) embedding_column_b = fc.embedding_column( categorical_column_b, dimension=embedding_dimension_b, initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) # Test that columns are reordered alphabetically. sequence_input_layer = ksfc.SequenceFeatures( [embedding_column_b, embedding_column_a]) input_layer, sequence_length = sequence_input_layer({ 'aaa': sparse_input_a, 'bbb': sparse_input_b,}) self.evaluate(variables_lib.global_variables_initializer()) weights = sequence_input_layer.weights self.assertCountEqual( ('sequence_features/aaa_embedding/embedding_weights:0', 'sequence_features/bbb_embedding/embedding_weights:0'), tuple([v.name for v in weights])) self.assertAllEqual(embedding_values_a, self.evaluate(weights[0])) self.assertAllEqual(embedding_values_b, self.evaluate(weights[1])) self.assertAllEqual(expected_input_layer, self.evaluate(input_layer)) self.assertAllEqual( expected_sequence_length, self.evaluate(sequence_length))
def test_get_sequence_dense_tensor(self, inputs_args, expected): inputs = sparse_tensor.SparseTensorValue(**inputs_args) vocabulary_size = 3 embedding_dimension = 2 embedding_values = ( (1., 2.), # id 0 (3., 5.), # id 1 (7., 11.) # id 2 ) def _initializer(shape, dtype, partition_info=None): self.assertAllEqual((vocabulary_size, embedding_dimension), shape) self.assertEqual(dtypes.float32, dtype) self.assertIsNone(partition_info) return embedding_values categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) embedding_column = fc.embedding_column(categorical_column, dimension=embedding_dimension, initializer=_initializer) embedding_lookup, _, state_manager = _get_sequence_dense_tensor_state( embedding_column, {'aaa': inputs}) variables = state_manager._layer.weights self.evaluate(variables_lib.global_variables_initializer()) self.assertCountEqual(('embedding_weights:0', ), tuple([v.name for v in variables])) self.assertAllEqual(embedding_values, self.evaluate(variables[0])) self.assertAllEqual(expected, self.evaluate(embedding_lookup))
def test_shared_sequence_non_sequence_into_input_layer(self): non_seq = fc.categorical_column_with_identity('non_seq', num_buckets=10) seq = sfc.sequence_categorical_column_with_identity('seq', num_buckets=10) shared_non_seq, shared_seq = fc.shared_embedding_columns_v2( [non_seq, seq], dimension=4, combiner='sum', initializer=init_ops_v2.Ones(), shared_embedding_collection_name='shared') seq = sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]], values=[0, 1, 2], dense_shape=[2, 2]) non_seq = sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]], values=[0, 1, 2], dense_shape=[2, 2]) features = {'seq': seq, 'non_seq': non_seq} # Tile the context features across the sequence features seq_input, seq_length = ksfc.SequenceFeatures([shared_seq])(features) non_seq_input = dense_features.DenseFeatures([shared_non_seq ])(features) with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) output_seq, output_seq_length, output_non_seq = sess.run( [seq_input, seq_length, non_seq_input]) self.assertAllEqual( output_seq, [[[1, 1, 1, 1], [1, 1, 1, 1]], [[1, 1, 1, 1], [0, 0, 0, 0]]]) self.assertAllEqual(output_seq_length, [2, 1]) self.assertAllEqual(output_non_seq, [[2, 2, 2, 2], [1, 1, 1, 1]])
def test_sequence_length_with_empty_rows(self): """Tests _sequence_length when some examples do not have ids.""" vocabulary_size = 3 sparse_input = sparse_tensor.SparseTensorValue( # example 0, ids [] # example 1, ids [2] # example 2, ids [0, 1] # example 3, ids [] # example 4, ids [1] # example 5, ids [] indices=((1, 0), (2, 0), (2, 1), (4, 0)), values=(2, 0, 1, 1), dense_shape=(6, 2)) expected_sequence_length = [0, 1, 2, 0, 1, 0] categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) embedding_column = fc.embedding_column( categorical_column, dimension=2) _, sequence_length, _ = _get_sequence_dense_tensor_state( embedding_column, {'aaa': sparse_input}) self.assertAllEqual( expected_sequence_length, self.evaluate(sequence_length))
def test_serialization(self): """Tests that column can be serialized.""" parent = sfc.sequence_categorical_column_with_identity('animal', num_buckets=4) animal = fc.indicator_column(parent) config = animal.get_config() self.assertEqual( { 'categorical_column': { 'class_name': 'SequenceCategoricalColumn', 'config': { 'categorical_column': { 'class_name': 'IdentityCategoricalColumn', 'config': { 'default_value': None, 'key': 'animal', 'number_buckets': 4 } } } } }, config) new_animal = fc.IndicatorColumn.from_config(config) self.assertEqual(animal, new_animal) self.assertIsNot(parent, new_animal.categorical_column) new_animal = fc.IndicatorColumn.from_config( config, columns_by_name={ serialization._column_name_with_class_name(parent): parent }) self.assertEqual(animal, new_animal) self.assertIs(parent, new_animal.categorical_column)
def test_sequence_length_with_empty_rows(self): """Tests _sequence_length when some examples do not have ids.""" with ops.Graph().as_default(): vocabulary_size = 3 sparse_input_a = sparse_tensor.SparseTensorValue( # example 0, ids [] # example 1, ids [2] # example 2, ids [0, 1] # example 3, ids [] # example 4, ids [1] # example 5, ids [] indices=((1, 0), (2, 0), (2, 1), (4, 0)), values=(2, 0, 1, 1), dense_shape=(6, 2)) expected_sequence_length_a = [0, 1, 2, 0, 1, 0] categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) sparse_input_b = sparse_tensor.SparseTensorValue( # example 0, ids [2] # example 1, ids [] # example 2, ids [] # example 3, ids [] # example 4, ids [1] # example 5, ids [0, 1] indices=((0, 0), (4, 0), (5, 0), (5, 1)), values=(2, 1, 0, 1), dense_shape=(6, 2)) expected_sequence_length_b = [1, 0, 0, 0, 1, 2] categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=2) sequence_length_a = _get_sequence_dense_tensor( shared_embedding_columns[0], {'aaa': sparse_input_a})[1] sequence_length_b = _get_sequence_dense_tensor( shared_embedding_columns[1], {'bbb': sparse_input_b})[1] with _initialized_session() as sess: self.assertAllEqual(expected_sequence_length_a, sequence_length_a.eval(session=sess)) self.assertAllEqual(expected_sequence_length_b, sequence_length_b.eval(session=sess))
def test_get_sparse_tensors(self, inputs_args, expected_args): inputs = sparse_tensor.SparseTensorValue(**inputs_args) expected = sparse_tensor.SparseTensorValue(**expected_args) column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9) id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs}) self.assertIsNone(id_weight_pair.weight_tensor) _assert_sparse_tensor_value( self, expected, self.evaluate(id_weight_pair.id_tensor))
def test_static_shape_from_tensors_indicator(self, sparse_input_args, expected_shape): """Tests that we return a known static shape when we have one.""" sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args) categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=3) indicator_column = fc.indicator_column(categorical_column) sequence_input_layer = ksfc.SequenceFeatures([indicator_column]) input_layer, _ = sequence_input_layer({'aaa': sparse_input}) shape = input_layer.get_shape() self.assertEqual(shape, expected_shape)
def test_get_sequence_dense_tensor(self, inputs_args, expected): inputs = sparse_tensor.SparseTensorValue(**inputs_args) vocabulary_size = 3 categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) indicator_column = fc.indicator_column(categorical_column) indicator_tensor, _ = _get_sequence_dense_tensor( indicator_column, {'aaa': inputs}) self.assertAllEqual(expected, self.evaluate(indicator_tensor))
def test_sequence_length(self): with ops.Graph().as_default(): vocabulary_size = 3 sparse_input_a = sparse_tensor.SparseTensorValue( # example 0, ids [2] # example 1, ids [0, 1] indices=((0, 0), (1, 0), (1, 1)), values=(2, 0, 1), dense_shape=(2, 2)) expected_sequence_length_a = [1, 2] categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) sparse_input_b = sparse_tensor.SparseTensorValue( # example 0, ids [0, 2] # example 1, ids [1] indices=((0, 0), (0, 1), (1, 0)), values=(0, 2, 1), dense_shape=(2, 2)) expected_sequence_length_b = [2, 1] categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=2) sequence_length_a = _get_sequence_dense_tensor( shared_embedding_columns[0], {'aaa': sparse_input_a})[1] sequence_length_b = _get_sequence_dense_tensor( shared_embedding_columns[1], {'bbb': sparse_input_b})[1] with _initialized_session() as sess: sequence_length_a = sess.run(sequence_length_a) self.assertAllEqual(expected_sequence_length_a, sequence_length_a) self.assertEqual(np.int64, sequence_length_a.dtype) sequence_length_b = sess.run(sequence_length_b) self.assertAllEqual(expected_sequence_length_b, sequence_length_b) self.assertEqual(np.int64, sequence_length_b.dtype)
def test_sequence_length(self, inputs_args, expected_sequence_length): inputs = sparse_tensor.SparseTensorValue(**inputs_args) vocabulary_size = 3 categorical_column = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) indicator_column = fc.indicator_column(categorical_column) _, sequence_length = _get_sequence_dense_tensor( indicator_column, {'aaa': inputs}) sequence_length = self.evaluate(sequence_length) self.assertAllEqual(expected_sequence_length, sequence_length) self.assertEqual(np.int64, sequence_length.dtype)
def _build_feature_columns(self): col = fc.categorical_column_with_identity('int_ctx', num_buckets=100) ctx_cols = [ fc.embedding_column(col, dimension=10), fc.numeric_column('float_ctx') ] identity_col = sfc.sequence_categorical_column_with_identity( 'int_list', num_buckets=10) bucket_col = sfc.sequence_categorical_column_with_hash_bucket( 'bytes_list', hash_bucket_size=100) seq_cols = [ fc.embedding_column(identity_col, dimension=10), fc.embedding_column(bucket_col, dimension=20) ] return ctx_cols, seq_cols
def _build_feature_columns(self): col = fc.categorical_column_with_identity('int_ctx', num_buckets=100) ctx_cols = [ fc.embedding_column(col, dimension=10), fc.numeric_column('float_ctx') ] identity_col = sfc.sequence_categorical_column_with_identity( 'int_list', num_buckets=10) bucket_col = sfc.sequence_categorical_column_with_hash_bucket( 'bytes_list', hash_bucket_size=100) seq_cols = [ fc.embedding_column(identity_col, dimension=10), fc.embedding_column(bucket_col, dimension=20) ] return ctx_cols, seq_cols
def test_indicator_column(self): """Tests that error is raised for sequence indicator column.""" vocabulary_size = 3 sparse_input = sparse_tensor.SparseTensorValue( # example 0, ids [2] # example 1, ids [0, 1] indices=((0, 0), (1, 0), (1, 1)), values=(2, 0, 1), dense_shape=(2, 2)) categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) indicator_column_a = fc.indicator_column(categorical_column_a) input_layer = df.DenseFeatures([indicator_column_a]) with self.assertRaisesRegex( ValueError, r'In indicator_column: aaa_indicator\. categorical_column must not be ' r'of type SequenceCategoricalColumn\.'): _ = input_layer({'aaa': sparse_input})
def _get_sequence_categorical_column(params: dict) -> fc.SequenceCategoricalColumn: key = params['key'] if 'vocabulary' in params.keys(): feature = sfc.sequence_categorical_column_with_vocabulary_list(key, vocabulary_list=_parse_vocabulary( params['vocabulary']), default_value=0) elif 'bucket_size' in params.keys(): feature = sfc.sequence_categorical_column_with_hash_bucket( key, hash_bucket_size=params['bucket_size']) elif 'file' in params.keys(): feature = sfc.sequence_categorical_column_with_vocabulary_file(key, vocabulary_file=params['file'], default_value=0) elif 'num_buckets' in params.keys(): feature = sfc.sequence_categorical_column_with_identity(key, num_buckets=params['num_buckets']) else: raise Exception("params error") return feature
def make_feature_config(num_players): return FeatureConfig( context_features=[ fc.numeric_column( "public_context__starting_stack_sizes", shape=num_players, dtype=tf.int64, ), fc.embedding_column( tf.feature_column.categorical_column_with_vocabulary_list( "private_context__hand_encoded", range(1326)), dimension=4, ), ], sequence_features=[ fc.indicator_column( sfc.sequence_categorical_column_with_identity( "last_action__action_encoded", 22)), fc.indicator_column( sfc.sequence_categorical_column_with_identity( "last_action__move", 5)), sfc.sequence_numeric_column( "last_action__amount_added", dtype=tf.int64, default_value=-1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "last_action__amount_added_percent_of_remaining", dtype=tf.float32, default_value=-1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "last_action__amount_raised", dtype=tf.int64, default_value=-1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "last_action__amount_raised_percent_of_pot", dtype=tf.float32, default_value=-1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "public_state__all_in_player_mask", dtype=tf.int64, default_value=-1, shape=num_players, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "public_state__stack_sizes", dtype=tf.int64, default_value=-1, shape=num_players, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "public_state__amount_to_call", dtype=tf.int64, default_value=-1, shape=num_players, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "public_state__current_player_mask", dtype=tf.int64, default_value=-1, shape=num_players, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "public_state__min_raise_amount", dtype=tf.int64, default_value=-1, shape=1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "public_state__pot_size", dtype=tf.int64, default_value=-1, shape=1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "public_state__street", dtype=tf.int64, default_value=-1, shape=1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "player_state__is_current_player", dtype=tf.int64, default_value=-1, shape=1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "player_state__current_player_offset", dtype=tf.int64, default_value=-1, shape=1, normalizer_fn=make_float, ), fc.indicator_column( sfc.sequence_categorical_column_with_identity( "player_state__current_hand_type", 9)), sfc.sequence_numeric_column( "player_state__win_odds", dtype=tf.float32, default_value=-1, shape=1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "player_state__win_odds_vs_better", dtype=tf.float32, default_value=-1, shape=1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "player_state__win_odds_vs_tied", dtype=tf.float32, default_value=-1, shape=1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "player_state__win_odds_vs_worse", dtype=tf.float32, default_value=-1, shape=1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "player_state__frac_better_hands", dtype=tf.float32, default_value=-1, shape=1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "player_state__frac_tied_hands", dtype=tf.float32, default_value=-1, shape=1, normalizer_fn=make_float, ), sfc.sequence_numeric_column( "player_state__frac_worse_hands", dtype=tf.float32, default_value=-1, shape=1, normalizer_fn=make_float, ), ], context_targets=[ fc.numeric_column("public_context__num_players", shape=1, dtype=tf.int64), ], sequence_targets=[ sfc.sequence_numeric_column("next_action__action_encoded", dtype=tf.int64, default_value=-1), sfc.sequence_numeric_column("reward__cumulative_reward", dtype=tf.int64, default_value=-1), sfc.sequence_numeric_column("public_state__pot_size", dtype=tf.int64, default_value=-1), sfc.sequence_numeric_column("player_state__is_current_player", dtype=tf.int64, default_value=-1), sfc.sequence_numeric_column("public_state__num_players_remaining", dtype=tf.int64, default_value=-1), ], )
def test_shared_embedding_column(self): with ops.Graph().as_default(): vocabulary_size = 3 sparse_input_a = sparse_tensor.SparseTensorValue( # example 0, ids [2] # example 1, ids [0, 1] indices=((0, 0), (1, 0), (1, 1)), values=(2, 0, 1), dense_shape=(2, 2)) sparse_input_b = sparse_tensor.SparseTensorValue( # example 0, ids [1] # example 1, ids [2, 0] indices=((0, 0), (1, 0), (1, 1)), values=(1, 2, 0), dense_shape=(2, 2)) embedding_dimension = 2 embedding_values = ( (1., 2.), # id 0 (3., 4.), # id 1 (5., 6.) # id 2 ) def _get_initializer(embedding_dimension, embedding_values): def _initializer(shape, dtype, partition_info=None): self.assertAllEqual((vocabulary_size, embedding_dimension), shape) self.assertEqual(dtypes.float32, dtype) self.assertIsNone(partition_info) return embedding_values return _initializer expected_input_layer = [ # example 0, ids_a [2], ids_b [1] [[5., 6., 3., 4.], [0., 0., 0., 0.]], # example 1, ids_a [0, 1], ids_b [2, 0] [[1., 2., 5., 6.], [3., 4., 1., 2.]], ] expected_sequence_length = [1, 2] categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) # Test that columns are reordered alphabetically. shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_b, categorical_column_a], dimension=embedding_dimension, initializer=_get_initializer(embedding_dimension, embedding_values)) sequence_input_layer = ksfc.SequenceFeatures( shared_embedding_columns) input_layer, sequence_length = sequence_input_layer({ 'aaa': sparse_input_a, 'bbb': sparse_input_b }) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertCountEqual(('aaa_bbb_shared_embedding:0', ), tuple([v.name for v in global_vars])) with _initialized_session() as sess: self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) self.assertAllEqual(expected_sequence_length, sequence_length.eval(session=sess))
def test_get_sequence_dense_tensor(self): vocabulary_size = 3 embedding_dimension = 2 embedding_values = ( (1., 2.), # id 0 (3., 5.), # id 1 (7., 11.) # id 2 ) def _initializer(shape, dtype, partition_info=None): self.assertAllEqual((vocabulary_size, embedding_dimension), shape) self.assertEqual(dtypes.float32, dtype) self.assertIsNone(partition_info) return embedding_values sparse_input_a = sparse_tensor.SparseTensorValue( # example 0, ids [2] # example 1, ids [0, 1] # example 2, ids [] # example 3, ids [1] indices=((0, 0), (1, 0), (1, 1), (3, 0)), values=(2, 0, 1, 1), dense_shape=(4, 2)) sparse_input_b = sparse_tensor.SparseTensorValue( # example 0, ids [1] # example 1, ids [0, 2] # example 2, ids [0] # example 3, ids [] indices=((0, 0), (1, 0), (1, 1), (2, 0)), values=(1, 0, 2, 0), dense_shape=(4, 2)) expected_lookups_a = [ # example 0, ids [2] [[7., 11.], [0., 0.]], # example 1, ids [0, 1] [[1., 2.], [3., 5.]], # example 2, ids [] [[0., 0.], [0., 0.]], # example 3, ids [1] [[3., 5.], [0., 0.]], ] expected_lookups_b = [ # example 0, ids [1] [[3., 5.], [0., 0.]], # example 1, ids [0, 2] [[1., 2.], [7., 11.]], # example 2, ids [0] [[1., 2.], [0., 0.]], # example 3, ids [] [[0., 0.], [0., 0.]], ] categorical_column_a = sfc.sequence_categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) categorical_column_b = sfc.sequence_categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) shared_embedding_columns = fc.shared_embedding_columns_v2( [categorical_column_a, categorical_column_b], dimension=embedding_dimension, initializer=_initializer) embedding_lookup_a = _get_sequence_dense_tensor( shared_embedding_columns[0], {'aaa': sparse_input_a})[0] embedding_lookup_b = _get_sequence_dense_tensor( shared_embedding_columns[1], {'bbb': sparse_input_b})[0] self.evaluate(variables_lib.global_variables_initializer()) global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertItemsEqual(('aaa_bbb_shared_embedding:0', ), tuple([v.name for v in global_vars])) self.assertAllEqual(embedding_values, self.evaluate(global_vars[0])) self.assertAllEqual(expected_lookups_a, self.evaluate(embedding_lookup_a)) self.assertAllEqual(expected_lookups_b, self.evaluate(embedding_lookup_b))