예제 #1
0
    def test_accumulate_client_votes_works_as_expected(self):
        possible_prefix_extensions = tf.constant(
            ['a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR],
            dtype=tf.string)
        discovered_prefixes = tf.constant(['a', 'b', 'c', 'd'],
                                          dtype=tf.string)
        round_num = tf.constant(1)
        num_sub_rounds = tf.constant(1)
        example1 = tf.constant('ab', dtype=tf.string)

        discovered_prefixes_table = tf.lookup.StaticHashTable(
            tf.lookup.KeyValueTensorInitializer(
                discovered_prefixes,
                tf.range(tf.shape(discovered_prefixes)[0])),
            triehh_tf.DEFAULT_VALUE)

        possible_prefix_extensions_table = tf.lookup.StaticHashTable(
            tf.lookup.KeyValueTensorInitializer(
                possible_prefix_extensions,
                tf.range(tf.shape(possible_prefix_extensions)[0])),
            triehh_tf.DEFAULT_VALUE)

        accumulate_client_votes = triehh_tf.make_accumulate_client_votes_fn(
            round_num, num_sub_rounds, discovered_prefixes_table,
            possible_prefix_extensions_table,
            tf.constant(triehh_tf.DEFAULT_TERMINATOR, dtype=tf.string))

        initial_votes = tf.constant(
            [[1, 2, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)

        accumulated_votes = accumulate_client_votes(initial_votes, example1)

        expected_accumulated_votes = tf.constant(
            [[1, 3, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)

        self.assertAllEqual(accumulated_votes, expected_accumulated_votes)

        # An example that the prefix is not in the discovered prefixes.
        # The expected result is that the vote is not counted.
        example2 = tf.constant('ea', dtype=tf.string)
        accumulated_votes = accumulate_client_votes(initial_votes, example2)
        self.assertAllEqual(accumulated_votes, initial_votes)
예제 #2
0
    def test_accumulate_client_votes_works_as_expected(self):
        possible_prefix_extensions = tf.constant(['a', 'b', 'c', 'd', '$'],
                                                 dtype=tf.string)
        discovered_prefixes = tf.constant(['a', 'b', 'c', 'd'],
                                          dtype=tf.string)
        round_num = tf.constant(1)
        num_sub_rounds = tf.constant(1)
        example = tf.constant('ab', dtype=tf.string)

        discovered_prefixes_table = tf.lookup.StaticHashTable(
            tf.lookup.KeyValueTensorInitializer(
                discovered_prefixes,
                tf.range(tf.shape(discovered_prefixes)[0])),
            triehh_tf.DEFAULT_VALUE)

        possible_prefix_extensions_table = tf.lookup.StaticHashTable(
            tf.lookup.KeyValueTensorInitializer(
                possible_prefix_extensions,
                tf.range(tf.shape(possible_prefix_extensions)[0])),
            triehh_tf.DEFAULT_VALUE)

        accumulate_client_votes = triehh_tf.make_accumulate_client_votes_fn(
            round_num, num_sub_rounds, discovered_prefixes_table,
            possible_prefix_extensions_table)

        initial_votes = tf.constant(
            [[1, 2, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)

        accumulated_votes = accumulate_client_votes(initial_votes, example)
        expected_accumulated_votes = tf.constant(
            [[1, 3, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)
        self.assertAllEqual(accumulated_votes, expected_accumulated_votes)