Ejemplo n.º 1
0
    def test_accumulate_server_votes_and_decode_threhold_works_as_expected(
            self):
        max_num_prefixes = tf.constant(4)
        threshold = tf.constant(5)
        possible_prefix_extensions = [
            'a', 'n', 's', 't', 'u', triehh_tf.DEFAULT_TERMINATOR
        ]
        possible_prefix_extensions = tf.constant(possible_prefix_extensions,
                                                 dtype=tf.string)
        discovered_prefixes = ['su', 'st']
        discovered_heavy_hitters = []
        heavy_hitter_frequencies = []
        initial_votes = tf.constant([[1, 2, 1, 0, 0, 0], [1, 2, 1, 0, 0, 0],
                                     [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]],
                                    dtype=tf.int32)
        initial_weights = tf.constant(10, dtype=tf.int32)
        server_state = triehh_tf.ServerState(
            discovered_heavy_hitters=tf.constant(discovered_heavy_hitters,
                                                 dtype=tf.string),
            heavy_hitter_frequencies=tf.constant(heavy_hitter_frequencies,
                                                 dtype=tf.float64),
            discovered_prefixes=tf.constant(discovered_prefixes,
                                            dtype=tf.string),
            round_num=tf.constant(3, dtype=tf.int32),
            accumulated_votes=initial_votes,
            accumulated_weights=initial_weights)

        sub_round_votes = tf.constant([[3, 3, 1, 0, 0, 0], [5, 1, 1, 0, 0, 0],
                                       [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]],
                                      dtype=tf.int32)
        sub_round_weight = tf.constant(20, dtype=tf.int32)

        server_state = triehh_tf.accumulate_server_votes_and_decode(
            server_state, possible_prefix_extensions, sub_round_votes,
            sub_round_weight, max_num_prefixes, threshold)

        expected_accumulated_votes = tf.constant(
            [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0]],
            dtype=tf.int32)
        expected_accumulated_weights = tf.constant(0, dtype=tf.int32)
        expected_discovered_prefixes = tf.constant(['sta', 'sun'],
                                                   dtype=tf.string)
        expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string)
        expected_heavy_hitter_frequencies = tf.constant([], dtype=tf.float64)

        self.assertAllEqual(server_state.accumulated_votes,
                            expected_accumulated_votes)
        self.assertEqual(server_state.accumulated_weights,
                         expected_accumulated_weights)
        self.assertSetAllEqual(server_state.discovered_prefixes,
                               expected_discovered_prefixes)
        self.assertSetAllEqual(server_state.discovered_heavy_hitters,
                               expected_discovered_heavy_hitters)
        self.assertHistogramsEqual(server_state.discovered_heavy_hitters,
                                   server_state.heavy_hitter_frequencies,
                                   expected_discovered_heavy_hitters,
                                   expected_heavy_hitter_frequencies)
Ejemplo n.º 2
0
    def test_accumulate_server_votes_and_decode_works_as_expected(self):
        max_num_heavy_hitters = tf.constant(4)
        default_terminator = tf.constant('$', tf.string)
        possible_prefix_extensions = ['a', 'n', 's', 't', 'u']
        discovered_prefixes = ['su', 'st']
        discovered_heavy_hitters = []
        initial_votes = tf.constant([[1, 2, 1, 0, 0], [1, 2, 1, 0, 0],
                                     [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
                                    dtype=tf.int32)

        server_state = triehh_tf.ServerState(
            discovered_heavy_hitters=tf.constant(discovered_heavy_hitters,
                                                 dtype=tf.string),
            discovered_prefixes=tf.constant(discovered_prefixes,
                                            dtype=tf.string),
            possible_prefix_extensions=tf.constant(possible_prefix_extensions,
                                                   dtype=tf.string),
            round_num=tf.constant(3, dtype=tf.int32),
            accumulated_votes=initial_votes)

        sub_round_votes = tf.constant([[3, 3, 1, 0, 0], [5, 1, 1, 0, 0],
                                       [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
                                      dtype=tf.int32)

        server_state = triehh_tf.accumulate_server_votes_and_decode(
            server_state, sub_round_votes, max_num_heavy_hitters,
            default_terminator)

        expected_accumulated_votes = tf.constant(
            [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0]],
            dtype=tf.int32)
        expected_discovered_prefixes = tf.constant(
            ['sta', 'sun', 'sua', 'stn'], dtype=tf.string)
        expected_discovered_heavy_hitters = tf.constant([], dtype=tf.string)

        self.assertAllEqual(server_state.accumulated_votes,
                            expected_accumulated_votes)
        self.assertAllEqual(server_state.discovered_prefixes,
                            expected_discovered_prefixes)
        self.assertAllEqual(server_state.discovered_heavy_hitters,
                            expected_discovered_heavy_hitters)