Esempio n. 1
0
  def test_accumulate_server_votes_works_as_expected(self):
    possible_prefix_extensions = ['a', 'b', 'c', 'd', 'e']
    discovered_prefixes = ['a', 'b']
    discovered_heavy_hitters = []
    initial_votes = tf.constant(
        [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
        dtype=tf.int32)

    server_state = triehh_tf.ServerState(
        discovered_heavy_hitters=tf.constant(
            discovered_heavy_hitters, dtype=tf.string),
        discovered_prefixes=tf.constant(discovered_prefixes, dtype=tf.string),
        possible_prefix_extensions=tf.constant(
            possible_prefix_extensions, dtype=tf.string),
        round_num=tf.constant(0, dtype=tf.int32),
        accumulated_votes=initial_votes)

    sub_round_votes = tf.constant(
        [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
        dtype=tf.int32)

    server_state = triehh_tf.accumulate_server_votes(server_state,
                                                     sub_round_votes)
    expected_accumulated_votes = tf.constant(
        [[2, 4, 2, 0, 0], [2, 4, 2, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
        dtype=tf.int32)

    self.assertAllEqual(server_state.accumulated_votes,
                        expected_accumulated_votes)
Esempio n. 2
0
    def test_accumulate_server_votes_works_as_expected(self):
        discovered_prefixes = ['a', 'b']
        discovered_heavy_hitters = []
        heavy_hitter_frequencies = []
        initial_votes = tf.constant(
            [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)
        initial_weights = tf.constant(10, dtype=tf.int32)

        server_state = triehh_tf.ServerState(
            discovered_heavy_hitters=tf.constant(discovered_heavy_hitters,
                                                 dtype=tf.string),
            heavy_hitter_frequencies=tf.constant(heavy_hitter_frequencies,
                                                 dtype=tf.float64),
            discovered_prefixes=tf.constant(discovered_prefixes,
                                            dtype=tf.string),
            round_num=tf.constant(0, dtype=tf.int32),
            accumulated_votes=initial_votes,
            accumulated_weights=initial_weights)

        sub_round_votes = tf.constant(
            [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)
        sub_round_weight = tf.constant(10, dtype=tf.int32)

        server_state = triehh_tf.accumulate_server_votes(
            server_state, sub_round_votes, sub_round_weight)
        expected_accumulated_votes = tf.constant(
            [[2, 4, 2, 0, 0], [2, 4, 2, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)
        expected_accumulated_weights = tf.constant(20, dtype=tf.int32)

        self.assertAllEqual(server_state.accumulated_votes,
                            expected_accumulated_votes)
        self.assertEqual(server_state.accumulated_weights,
                         expected_accumulated_weights)