def test_indicator_column(self, sparse_input_args_a, sparse_input_args_b,
                              expected_input_layer, expected_sequence_length):
        sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a)
        sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b)

        vocabulary_size_a = 3
        vocabulary_size_b = 2

        categorical_column_a = sfc.sequence_categorical_column_with_identity(
            key='aaa', num_buckets=vocabulary_size_a)
        indicator_column_a = fc.indicator_column(categorical_column_a)
        categorical_column_b = sfc.sequence_categorical_column_with_identity(
            key='bbb', num_buckets=vocabulary_size_b)
        indicator_column_b = fc.indicator_column(categorical_column_b)
        # Test that columns are reordered alphabetically.
        sequence_input_layer = ksfc.SequenceFeatures(
            [indicator_column_b, indicator_column_a])
        input_layer, sequence_length = sequence_input_layer({
            'aaa':
            sparse_input_a,
            'bbb':
            sparse_input_b
        })

        self.assertAllEqual(expected_input_layer, self.evaluate(input_layer))
        self.assertAllEqual(expected_sequence_length,
                            self.evaluate(sequence_length))
Пример #2
0
  def test_embedding_column(
      self, sparse_input_args_a, sparse_input_args_b, expected_input_layer,
      expected_sequence_length):

    sparse_input_a = sparse_tensor.SparseTensorValue(**sparse_input_args_a)
    sparse_input_b = sparse_tensor.SparseTensorValue(**sparse_input_args_b)
    vocabulary_size = 3
    embedding_dimension_a = 2
    embedding_values_a = (
        (1., 2.),  # id 0
        (3., 4.),  # id 1
        (5., 6.)  # id 2
    )
    embedding_dimension_b = 3
    embedding_values_b = (
        (11., 12., 13.),  # id 0
        (14., 15., 16.),  # id 1
        (17., 18., 19.)  # id 2
    )
    def _get_initializer(embedding_dimension, embedding_values):

      def _initializer(shape, dtype, partition_info=None):
        self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
        self.assertEqual(dtypes.float32, dtype)
        self.assertIsNone(partition_info)
        return embedding_values
      return _initializer

    categorical_column_a = sfc.sequence_categorical_column_with_identity(
        key='aaa', num_buckets=vocabulary_size)
    embedding_column_a = fc.embedding_column(
        categorical_column_a,
        dimension=embedding_dimension_a,
        initializer=_get_initializer(embedding_dimension_a, embedding_values_a))
    categorical_column_b = sfc.sequence_categorical_column_with_identity(
        key='bbb', num_buckets=vocabulary_size)
    embedding_column_b = fc.embedding_column(
        categorical_column_b,
        dimension=embedding_dimension_b,
        initializer=_get_initializer(embedding_dimension_b, embedding_values_b))

    # Test that columns are reordered alphabetically.
    sequence_input_layer = ksfc.SequenceFeatures(
        [embedding_column_b, embedding_column_a])
    input_layer, sequence_length = sequence_input_layer({
        'aaa': sparse_input_a, 'bbb': sparse_input_b,})

    self.evaluate(variables_lib.global_variables_initializer())
    weights = sequence_input_layer.weights
    self.assertCountEqual(
        ('sequence_features/aaa_embedding/embedding_weights:0',
         'sequence_features/bbb_embedding/embedding_weights:0'),
        tuple([v.name for v in weights]))
    self.assertAllEqual(embedding_values_a, self.evaluate(weights[0]))
    self.assertAllEqual(embedding_values_b, self.evaluate(weights[1]))
    self.assertAllEqual(expected_input_layer, self.evaluate(input_layer))
    self.assertAllEqual(
        expected_sequence_length, self.evaluate(sequence_length))
Пример #3
0
    def test_get_sequence_dense_tensor(self, inputs_args, expected):
        inputs = sparse_tensor.SparseTensorValue(**inputs_args)
        vocabulary_size = 3
        embedding_dimension = 2
        embedding_values = (
            (1., 2.),  # id 0
            (3., 5.),  # id 1
            (7., 11.)  # id 2
        )

        def _initializer(shape, dtype, partition_info=None):
            self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
            self.assertEqual(dtypes.float32, dtype)
            self.assertIsNone(partition_info)
            return embedding_values

        categorical_column = sfc.sequence_categorical_column_with_identity(
            key='aaa', num_buckets=vocabulary_size)
        embedding_column = fc.embedding_column(categorical_column,
                                               dimension=embedding_dimension,
                                               initializer=_initializer)

        embedding_lookup, _, state_manager = _get_sequence_dense_tensor_state(
            embedding_column, {'aaa': inputs})

        variables = state_manager._layer.weights
        self.evaluate(variables_lib.global_variables_initializer())
        self.assertCountEqual(('embedding_weights:0', ),
                              tuple([v.name for v in variables]))
        self.assertAllEqual(embedding_values, self.evaluate(variables[0]))
        self.assertAllEqual(expected, self.evaluate(embedding_lookup))
    def test_shared_sequence_non_sequence_into_input_layer(self):
        non_seq = fc.categorical_column_with_identity('non_seq',
                                                      num_buckets=10)
        seq = sfc.sequence_categorical_column_with_identity('seq',
                                                            num_buckets=10)
        shared_non_seq, shared_seq = fc.shared_embedding_columns_v2(
            [non_seq, seq],
            dimension=4,
            combiner='sum',
            initializer=init_ops_v2.Ones(),
            shared_embedding_collection_name='shared')

        seq = sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
                                         values=[0, 1, 2],
                                         dense_shape=[2, 2])
        non_seq = sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
                                             values=[0, 1, 2],
                                             dense_shape=[2, 2])
        features = {'seq': seq, 'non_seq': non_seq}

        # Tile the context features across the sequence features
        seq_input, seq_length = ksfc.SequenceFeatures([shared_seq])(features)
        non_seq_input = dense_features.DenseFeatures([shared_non_seq
                                                      ])(features)

        with self.cached_session() as sess:
            sess.run(variables.global_variables_initializer())
            output_seq, output_seq_length, output_non_seq = sess.run(
                [seq_input, seq_length, non_seq_input])
            self.assertAllEqual(
                output_seq,
                [[[1, 1, 1, 1], [1, 1, 1, 1]], [[1, 1, 1, 1], [0, 0, 0, 0]]])
            self.assertAllEqual(output_seq_length, [2, 1])
            self.assertAllEqual(output_non_seq, [[2, 2, 2, 2], [1, 1, 1, 1]])
Пример #5
0
  def test_sequence_length_with_empty_rows(self):
    """Tests _sequence_length when some examples do not have ids."""
    vocabulary_size = 3
    sparse_input = sparse_tensor.SparseTensorValue(
        # example 0, ids []
        # example 1, ids [2]
        # example 2, ids [0, 1]
        # example 3, ids []
        # example 4, ids [1]
        # example 5, ids []
        indices=((1, 0), (2, 0), (2, 1), (4, 0)),
        values=(2, 0, 1, 1),
        dense_shape=(6, 2))
    expected_sequence_length = [0, 1, 2, 0, 1, 0]

    categorical_column = sfc.sequence_categorical_column_with_identity(
        key='aaa', num_buckets=vocabulary_size)
    embedding_column = fc.embedding_column(
        categorical_column, dimension=2)

    _, sequence_length, _ = _get_sequence_dense_tensor_state(
        embedding_column, {'aaa': sparse_input})

    self.assertAllEqual(
        expected_sequence_length, self.evaluate(sequence_length))
Пример #6
0
    def test_serialization(self):
        """Tests that column can be serialized."""
        parent = sfc.sequence_categorical_column_with_identity('animal',
                                                               num_buckets=4)
        animal = fc.indicator_column(parent)

        config = animal.get_config()
        self.assertEqual(
            {
                'categorical_column': {
                    'class_name': 'SequenceCategoricalColumn',
                    'config': {
                        'categorical_column': {
                            'class_name': 'IdentityCategoricalColumn',
                            'config': {
                                'default_value': None,
                                'key': 'animal',
                                'number_buckets': 4
                            }
                        }
                    }
                }
            }, config)

        new_animal = fc.IndicatorColumn.from_config(config)
        self.assertEqual(animal, new_animal)
        self.assertIsNot(parent, new_animal.categorical_column)

        new_animal = fc.IndicatorColumn.from_config(
            config,
            columns_by_name={
                serialization._column_name_with_class_name(parent): parent
            })
        self.assertEqual(animal, new_animal)
        self.assertIs(parent, new_animal.categorical_column)
Пример #7
0
    def test_sequence_length_with_empty_rows(self):
        """Tests _sequence_length when some examples do not have ids."""
        with ops.Graph().as_default():
            vocabulary_size = 3
            sparse_input_a = sparse_tensor.SparseTensorValue(
                # example 0, ids []
                # example 1, ids [2]
                # example 2, ids [0, 1]
                # example 3, ids []
                # example 4, ids [1]
                # example 5, ids []
                indices=((1, 0), (2, 0), (2, 1), (4, 0)),
                values=(2, 0, 1, 1),
                dense_shape=(6, 2))
            expected_sequence_length_a = [0, 1, 2, 0, 1, 0]
            categorical_column_a = sfc.sequence_categorical_column_with_identity(
                key='aaa', num_buckets=vocabulary_size)

            sparse_input_b = sparse_tensor.SparseTensorValue(
                # example 0, ids [2]
                # example 1, ids []
                # example 2, ids []
                # example 3, ids []
                # example 4, ids [1]
                # example 5, ids [0, 1]
                indices=((0, 0), (4, 0), (5, 0), (5, 1)),
                values=(2, 1, 0, 1),
                dense_shape=(6, 2))
            expected_sequence_length_b = [1, 0, 0, 0, 1, 2]
            categorical_column_b = sfc.sequence_categorical_column_with_identity(
                key='bbb', num_buckets=vocabulary_size)

            shared_embedding_columns = fc.shared_embedding_columns_v2(
                [categorical_column_a, categorical_column_b], dimension=2)

            sequence_length_a = _get_sequence_dense_tensor(
                shared_embedding_columns[0], {'aaa': sparse_input_a})[1]
            sequence_length_b = _get_sequence_dense_tensor(
                shared_embedding_columns[1], {'bbb': sparse_input_b})[1]

            with _initialized_session() as sess:
                self.assertAllEqual(expected_sequence_length_a,
                                    sequence_length_a.eval(session=sess))
                self.assertAllEqual(expected_sequence_length_b,
                                    sequence_length_b.eval(session=sess))
Пример #8
0
  def test_get_sparse_tensors(self, inputs_args, expected_args):
    inputs = sparse_tensor.SparseTensorValue(**inputs_args)
    expected = sparse_tensor.SparseTensorValue(**expected_args)
    column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9)

    id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs})

    self.assertIsNone(id_weight_pair.weight_tensor)
    _assert_sparse_tensor_value(
        self, expected, self.evaluate(id_weight_pair.id_tensor))
    def test_static_shape_from_tensors_indicator(self, sparse_input_args,
                                                 expected_shape):
        """Tests that we return a known static shape when we have one."""
        sparse_input = sparse_tensor.SparseTensorValue(**sparse_input_args)
        categorical_column = sfc.sequence_categorical_column_with_identity(
            key='aaa', num_buckets=3)
        indicator_column = fc.indicator_column(categorical_column)

        sequence_input_layer = ksfc.SequenceFeatures([indicator_column])
        input_layer, _ = sequence_input_layer({'aaa': sparse_input})
        shape = input_layer.get_shape()
        self.assertEqual(shape, expected_shape)
Пример #10
0
    def test_get_sequence_dense_tensor(self, inputs_args, expected):
        inputs = sparse_tensor.SparseTensorValue(**inputs_args)
        vocabulary_size = 3

        categorical_column = sfc.sequence_categorical_column_with_identity(
            key='aaa', num_buckets=vocabulary_size)
        indicator_column = fc.indicator_column(categorical_column)

        indicator_tensor, _ = _get_sequence_dense_tensor(
            indicator_column, {'aaa': inputs})

        self.assertAllEqual(expected, self.evaluate(indicator_tensor))
Пример #11
0
    def test_sequence_length(self):
        with ops.Graph().as_default():
            vocabulary_size = 3

            sparse_input_a = sparse_tensor.SparseTensorValue(
                # example 0, ids [2]
                # example 1, ids [0, 1]
                indices=((0, 0), (1, 0), (1, 1)),
                values=(2, 0, 1),
                dense_shape=(2, 2))
            expected_sequence_length_a = [1, 2]
            categorical_column_a = sfc.sequence_categorical_column_with_identity(
                key='aaa', num_buckets=vocabulary_size)

            sparse_input_b = sparse_tensor.SparseTensorValue(
                # example 0, ids [0, 2]
                # example 1, ids [1]
                indices=((0, 0), (0, 1), (1, 0)),
                values=(0, 2, 1),
                dense_shape=(2, 2))
            expected_sequence_length_b = [2, 1]
            categorical_column_b = sfc.sequence_categorical_column_with_identity(
                key='bbb', num_buckets=vocabulary_size)
            shared_embedding_columns = fc.shared_embedding_columns_v2(
                [categorical_column_a, categorical_column_b], dimension=2)

            sequence_length_a = _get_sequence_dense_tensor(
                shared_embedding_columns[0], {'aaa': sparse_input_a})[1]
            sequence_length_b = _get_sequence_dense_tensor(
                shared_embedding_columns[1], {'bbb': sparse_input_b})[1]

            with _initialized_session() as sess:
                sequence_length_a = sess.run(sequence_length_a)
                self.assertAllEqual(expected_sequence_length_a,
                                    sequence_length_a)
                self.assertEqual(np.int64, sequence_length_a.dtype)
                sequence_length_b = sess.run(sequence_length_b)
                self.assertAllEqual(expected_sequence_length_b,
                                    sequence_length_b)
                self.assertEqual(np.int64, sequence_length_b.dtype)
Пример #12
0
    def test_sequence_length(self, inputs_args, expected_sequence_length):
        inputs = sparse_tensor.SparseTensorValue(**inputs_args)
        vocabulary_size = 3

        categorical_column = sfc.sequence_categorical_column_with_identity(
            key='aaa', num_buckets=vocabulary_size)
        indicator_column = fc.indicator_column(categorical_column)

        _, sequence_length = _get_sequence_dense_tensor(
            indicator_column, {'aaa': inputs})

        sequence_length = self.evaluate(sequence_length)
        self.assertAllEqual(expected_sequence_length, sequence_length)
        self.assertEqual(np.int64, sequence_length.dtype)
  def _build_feature_columns(self):
    col = fc.categorical_column_with_identity('int_ctx', num_buckets=100)
    ctx_cols = [
        fc.embedding_column(col, dimension=10),
        fc.numeric_column('float_ctx')
    ]

    identity_col = sfc.sequence_categorical_column_with_identity(
        'int_list', num_buckets=10)
    bucket_col = sfc.sequence_categorical_column_with_hash_bucket(
        'bytes_list', hash_bucket_size=100)
    seq_cols = [
        fc.embedding_column(identity_col, dimension=10),
        fc.embedding_column(bucket_col, dimension=20)
    ]

    return ctx_cols, seq_cols
    def _build_feature_columns(self):
        col = fc.categorical_column_with_identity('int_ctx', num_buckets=100)
        ctx_cols = [
            fc.embedding_column(col, dimension=10),
            fc.numeric_column('float_ctx')
        ]

        identity_col = sfc.sequence_categorical_column_with_identity(
            'int_list', num_buckets=10)
        bucket_col = sfc.sequence_categorical_column_with_hash_bucket(
            'bytes_list', hash_bucket_size=100)
        seq_cols = [
            fc.embedding_column(identity_col, dimension=10),
            fc.embedding_column(bucket_col, dimension=20)
        ]

        return ctx_cols, seq_cols
Пример #15
0
    def test_indicator_column(self):
        """Tests that error is raised for sequence indicator column."""
        vocabulary_size = 3
        sparse_input = sparse_tensor.SparseTensorValue(
            # example 0, ids [2]
            # example 1, ids [0, 1]
            indices=((0, 0), (1, 0), (1, 1)),
            values=(2, 0, 1),
            dense_shape=(2, 2))

        categorical_column_a = sfc.sequence_categorical_column_with_identity(
            key='aaa', num_buckets=vocabulary_size)
        indicator_column_a = fc.indicator_column(categorical_column_a)

        input_layer = df.DenseFeatures([indicator_column_a])
        with self.assertRaisesRegex(
                ValueError,
                r'In indicator_column: aaa_indicator\. categorical_column must not be '
                r'of type SequenceCategoricalColumn\.'):
            _ = input_layer({'aaa': sparse_input})
Пример #16
0
def _get_sequence_categorical_column(params: dict) -> fc.SequenceCategoricalColumn:
    key = params['key']
    if 'vocabulary' in params.keys():
        feature = sfc.sequence_categorical_column_with_vocabulary_list(key,
                                                                       vocabulary_list=_parse_vocabulary(
                                                                           params['vocabulary']),
                                                                       default_value=0)
    elif 'bucket_size' in params.keys():
        feature = sfc.sequence_categorical_column_with_hash_bucket(
            key, hash_bucket_size=params['bucket_size'])
    elif 'file' in params.keys():
        feature = sfc.sequence_categorical_column_with_vocabulary_file(key,
                                                                       vocabulary_file=params['file'],
                                                                       default_value=0)
    elif 'num_buckets' in params.keys():
        feature = sfc.sequence_categorical_column_with_identity(key,
                                                                num_buckets=params['num_buckets'])
    else:
        raise Exception("params error")

    return feature
Пример #17
0
def make_feature_config(num_players):
    return FeatureConfig(
        context_features=[
            fc.numeric_column(
                "public_context__starting_stack_sizes",
                shape=num_players,
                dtype=tf.int64,
            ),
            fc.embedding_column(
                tf.feature_column.categorical_column_with_vocabulary_list(
                    "private_context__hand_encoded", range(1326)),
                dimension=4,
            ),
        ],
        sequence_features=[
            fc.indicator_column(
                sfc.sequence_categorical_column_with_identity(
                    "last_action__action_encoded", 22)),
            fc.indicator_column(
                sfc.sequence_categorical_column_with_identity(
                    "last_action__move", 5)),
            sfc.sequence_numeric_column(
                "last_action__amount_added",
                dtype=tf.int64,
                default_value=-1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "last_action__amount_added_percent_of_remaining",
                dtype=tf.float32,
                default_value=-1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "last_action__amount_raised",
                dtype=tf.int64,
                default_value=-1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "last_action__amount_raised_percent_of_pot",
                dtype=tf.float32,
                default_value=-1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "public_state__all_in_player_mask",
                dtype=tf.int64,
                default_value=-1,
                shape=num_players,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "public_state__stack_sizes",
                dtype=tf.int64,
                default_value=-1,
                shape=num_players,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "public_state__amount_to_call",
                dtype=tf.int64,
                default_value=-1,
                shape=num_players,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "public_state__current_player_mask",
                dtype=tf.int64,
                default_value=-1,
                shape=num_players,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "public_state__min_raise_amount",
                dtype=tf.int64,
                default_value=-1,
                shape=1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "public_state__pot_size",
                dtype=tf.int64,
                default_value=-1,
                shape=1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "public_state__street",
                dtype=tf.int64,
                default_value=-1,
                shape=1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "player_state__is_current_player",
                dtype=tf.int64,
                default_value=-1,
                shape=1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "player_state__current_player_offset",
                dtype=tf.int64,
                default_value=-1,
                shape=1,
                normalizer_fn=make_float,
            ),
            fc.indicator_column(
                sfc.sequence_categorical_column_with_identity(
                    "player_state__current_hand_type", 9)),
            sfc.sequence_numeric_column(
                "player_state__win_odds",
                dtype=tf.float32,
                default_value=-1,
                shape=1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "player_state__win_odds_vs_better",
                dtype=tf.float32,
                default_value=-1,
                shape=1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "player_state__win_odds_vs_tied",
                dtype=tf.float32,
                default_value=-1,
                shape=1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "player_state__win_odds_vs_worse",
                dtype=tf.float32,
                default_value=-1,
                shape=1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "player_state__frac_better_hands",
                dtype=tf.float32,
                default_value=-1,
                shape=1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "player_state__frac_tied_hands",
                dtype=tf.float32,
                default_value=-1,
                shape=1,
                normalizer_fn=make_float,
            ),
            sfc.sequence_numeric_column(
                "player_state__frac_worse_hands",
                dtype=tf.float32,
                default_value=-1,
                shape=1,
                normalizer_fn=make_float,
            ),
        ],
        context_targets=[
            fc.numeric_column("public_context__num_players",
                              shape=1,
                              dtype=tf.int64),
        ],
        sequence_targets=[
            sfc.sequence_numeric_column("next_action__action_encoded",
                                        dtype=tf.int64,
                                        default_value=-1),
            sfc.sequence_numeric_column("reward__cumulative_reward",
                                        dtype=tf.int64,
                                        default_value=-1),
            sfc.sequence_numeric_column("public_state__pot_size",
                                        dtype=tf.int64,
                                        default_value=-1),
            sfc.sequence_numeric_column("player_state__is_current_player",
                                        dtype=tf.int64,
                                        default_value=-1),
            sfc.sequence_numeric_column("public_state__num_players_remaining",
                                        dtype=tf.int64,
                                        default_value=-1),
        ],
    )
    def test_shared_embedding_column(self):
        with ops.Graph().as_default():
            vocabulary_size = 3
            sparse_input_a = sparse_tensor.SparseTensorValue(
                # example 0, ids [2]
                # example 1, ids [0, 1]
                indices=((0, 0), (1, 0), (1, 1)),
                values=(2, 0, 1),
                dense_shape=(2, 2))
            sparse_input_b = sparse_tensor.SparseTensorValue(
                # example 0, ids [1]
                # example 1, ids [2, 0]
                indices=((0, 0), (1, 0), (1, 1)),
                values=(1, 2, 0),
                dense_shape=(2, 2))

            embedding_dimension = 2
            embedding_values = (
                (1., 2.),  # id 0
                (3., 4.),  # id 1
                (5., 6.)  # id 2
            )

            def _get_initializer(embedding_dimension, embedding_values):
                def _initializer(shape, dtype, partition_info=None):
                    self.assertAllEqual((vocabulary_size, embedding_dimension),
                                        shape)
                    self.assertEqual(dtypes.float32, dtype)
                    self.assertIsNone(partition_info)
                    return embedding_values

                return _initializer

            expected_input_layer = [
                # example 0, ids_a [2], ids_b [1]
                [[5., 6., 3., 4.], [0., 0., 0., 0.]],
                # example 1, ids_a [0, 1], ids_b [2, 0]
                [[1., 2., 5., 6.], [3., 4., 1., 2.]],
            ]
            expected_sequence_length = [1, 2]

            categorical_column_a = sfc.sequence_categorical_column_with_identity(
                key='aaa', num_buckets=vocabulary_size)
            categorical_column_b = sfc.sequence_categorical_column_with_identity(
                key='bbb', num_buckets=vocabulary_size)
            # Test that columns are reordered alphabetically.
            shared_embedding_columns = fc.shared_embedding_columns_v2(
                [categorical_column_b, categorical_column_a],
                dimension=embedding_dimension,
                initializer=_get_initializer(embedding_dimension,
                                             embedding_values))

            sequence_input_layer = ksfc.SequenceFeatures(
                shared_embedding_columns)
            input_layer, sequence_length = sequence_input_layer({
                'aaa':
                sparse_input_a,
                'bbb':
                sparse_input_b
            })

            global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
            self.assertCountEqual(('aaa_bbb_shared_embedding:0', ),
                                  tuple([v.name for v in global_vars]))
            with _initialized_session() as sess:
                self.assertAllEqual(embedding_values,
                                    global_vars[0].eval(session=sess))
                self.assertAllEqual(expected_input_layer,
                                    input_layer.eval(session=sess))
                self.assertAllEqual(expected_sequence_length,
                                    sequence_length.eval(session=sess))
Пример #19
0
    def test_get_sequence_dense_tensor(self):
        vocabulary_size = 3
        embedding_dimension = 2
        embedding_values = (
            (1., 2.),  # id 0
            (3., 5.),  # id 1
            (7., 11.)  # id 2
        )

        def _initializer(shape, dtype, partition_info=None):
            self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
            self.assertEqual(dtypes.float32, dtype)
            self.assertIsNone(partition_info)
            return embedding_values

        sparse_input_a = sparse_tensor.SparseTensorValue(
            # example 0, ids [2]
            # example 1, ids [0, 1]
            # example 2, ids []
            # example 3, ids [1]
            indices=((0, 0), (1, 0), (1, 1), (3, 0)),
            values=(2, 0, 1, 1),
            dense_shape=(4, 2))
        sparse_input_b = sparse_tensor.SparseTensorValue(
            # example 0, ids [1]
            # example 1, ids [0, 2]
            # example 2, ids [0]
            # example 3, ids []
            indices=((0, 0), (1, 0), (1, 1), (2, 0)),
            values=(1, 0, 2, 0),
            dense_shape=(4, 2))

        expected_lookups_a = [
            # example 0, ids [2]
            [[7., 11.], [0., 0.]],
            # example 1, ids [0, 1]
            [[1., 2.], [3., 5.]],
            # example 2, ids []
            [[0., 0.], [0., 0.]],
            # example 3, ids [1]
            [[3., 5.], [0., 0.]],
        ]

        expected_lookups_b = [
            # example 0, ids [1]
            [[3., 5.], [0., 0.]],
            # example 1, ids [0, 2]
            [[1., 2.], [7., 11.]],
            # example 2, ids [0]
            [[1., 2.], [0., 0.]],
            # example 3, ids []
            [[0., 0.], [0., 0.]],
        ]

        categorical_column_a = sfc.sequence_categorical_column_with_identity(
            key='aaa', num_buckets=vocabulary_size)
        categorical_column_b = sfc.sequence_categorical_column_with_identity(
            key='bbb', num_buckets=vocabulary_size)
        shared_embedding_columns = fc.shared_embedding_columns_v2(
            [categorical_column_a, categorical_column_b],
            dimension=embedding_dimension,
            initializer=_initializer)

        embedding_lookup_a = _get_sequence_dense_tensor(
            shared_embedding_columns[0], {'aaa': sparse_input_a})[0]
        embedding_lookup_b = _get_sequence_dense_tensor(
            shared_embedding_columns[1], {'bbb': sparse_input_b})[0]

        self.evaluate(variables_lib.global_variables_initializer())
        global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
        self.assertItemsEqual(('aaa_bbb_shared_embedding:0', ),
                              tuple([v.name for v in global_vars]))
        self.assertAllEqual(embedding_values, self.evaluate(global_vars[0]))
        self.assertAllEqual(expected_lookups_a,
                            self.evaluate(embedding_lookup_a))
        self.assertAllEqual(expected_lookups_b,
                            self.evaluate(embedding_lookup_b))