def test_accumulate_server_votes_works_as_expected(self): possible_prefix_extensions = ['a', 'b', 'c', 'd', 'e'] discovered_prefixes = ['a', 'b'] discovered_heavy_hitters = [] initial_votes = tf.constant( [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=tf.int32) server_state = triehh_tf.ServerState( discovered_heavy_hitters=tf.constant( discovered_heavy_hitters, dtype=tf.string), discovered_prefixes=tf.constant(discovered_prefixes, dtype=tf.string), possible_prefix_extensions=tf.constant( possible_prefix_extensions, dtype=tf.string), round_num=tf.constant(0, dtype=tf.int32), accumulated_votes=initial_votes) sub_round_votes = tf.constant( [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=tf.int32) server_state = triehh_tf.accumulate_server_votes(server_state, sub_round_votes) expected_accumulated_votes = tf.constant( [[2, 4, 2, 0, 0], [2, 4, 2, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=tf.int32) self.assertAllEqual(server_state.accumulated_votes, expected_accumulated_votes)
def test_accumulate_server_votes_works_as_expected(self): discovered_prefixes = ['a', 'b'] discovered_heavy_hitters = [] heavy_hitter_frequencies = [] initial_votes = tf.constant( [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0 ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=tf.int32) initial_weights = tf.constant(10, dtype=tf.int32) server_state = triehh_tf.ServerState( discovered_heavy_hitters=tf.constant(discovered_heavy_hitters, dtype=tf.string), heavy_hitter_frequencies=tf.constant(heavy_hitter_frequencies, dtype=tf.float64), discovered_prefixes=tf.constant(discovered_prefixes, dtype=tf.string), round_num=tf.constant(0, dtype=tf.int32), accumulated_votes=initial_votes, accumulated_weights=initial_weights) sub_round_votes = tf.constant( [[1, 2, 1, 0, 0], [1, 2, 1, 0, 0], [0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0 ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=tf.int32) sub_round_weight = tf.constant(10, dtype=tf.int32) server_state = triehh_tf.accumulate_server_votes( server_state, sub_round_votes, sub_round_weight) expected_accumulated_votes = tf.constant( [[2, 4, 2, 0, 0], [2, 4, 2, 0, 0], [0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0 ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], dtype=tf.int32) expected_accumulated_weights = tf.constant(20, dtype=tf.int32) self.assertAllEqual(server_state.accumulated_votes, expected_accumulated_votes) self.assertEqual(server_state.accumulated_weights, expected_accumulated_weights)