Пример #1
0
    def test_client_update_works_as_expected(self):
        max_num_heavy_hitters = tf.constant(10)
        max_user_contribution = tf.constant(10)
        possible_prefix_extensions = tf.constant(['a', 'b', 'c', 'd', 'e'],
                                                 dtype=tf.string)
        discovered_prefixes = tf.constant(['a', 'b', 'c', 'd', 'e'],
                                          dtype=tf.string)
        round_num = tf.constant(1)
        num_sub_rounds = tf.constant(1)
        sample_data = tf.data.Dataset.from_tensor_slices(
            ['a', '', 'abc', 'bac', 'abb', 'aaa', 'acc', 'hi'])
        client_output = triehh_tf.client_update(sample_data,
                                                discovered_prefixes,
                                                possible_prefix_extensions,
                                                round_num, num_sub_rounds,
                                                max_num_heavy_hitters,
                                                max_user_contribution)

        expected_client_votes = tf.constant(
            [[1, 2, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [
                0, 0, 0, 0, 0
            ], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
            dtype=tf.int32)
        self.assertAllEqual(client_output.client_votes, expected_client_votes)
Пример #2
0
    def test_client_update_works_as_expected(self):
        max_num_prefixes = tf.constant(10)
        max_user_contribution = tf.constant(10)
        possible_prefix_extensions = tf.constant(
            ['a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR],
            dtype=tf.string)
        discovered_prefixes = tf.constant(['a', 'b', 'c', 'd', 'e'],
                                          dtype=tf.string)
        round_num = tf.constant(1)
        num_sub_rounds = tf.constant(1)
        sample_data = tf.data.Dataset.from_tensor_slices(
            ['a', '', 'abc', 'bac', 'abb', 'aaa', 'acc', 'hi'])
        client_output = triehh_tf.client_update(
            sample_data, discovered_prefixes, possible_prefix_extensions,
            round_num, num_sub_rounds, max_num_prefixes, max_user_contribution)

        # Each string is attached with triehh_tf.DEFAULT_TERMINATOR before the
        # client votes, so 'a$' get a vote here.
        expected_client_votes = tf.constant(
            [[1, 2, 1, 0, 0, 1], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0]],
            dtype=tf.int32)
        self.assertAllEqual(client_output.client_votes, expected_client_votes)
  def test_client_update_works_on_empty_local_datasets(self):
    max_num_prefixes = tf.constant(10)
    max_user_contribution = tf.constant(10)
    possible_prefix_extensions = tf.constant(
        ['a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR],
        dtype=tf.string)
    discovered_prefixes = tf.constant(['a', 'b', 'c', 'd', 'e'],
                                      dtype=tf.string)
    round_num = tf.constant(1)
    num_sub_rounds = tf.constant(1)
    # Force an empty dataset that yields tf.string. Using `from_tensor_slices`
    # defaults to yielding tf.int32 values.
    sample_data = tf.data.Dataset.from_generator(
        generator=lambda: iter(()), output_types=tf.string, output_shapes=())
    client_output = triehh_tf.client_update(sample_data, discovered_prefixes,
                                            possible_prefix_extensions,
                                            round_num, num_sub_rounds,
                                            max_num_prefixes,
                                            max_user_contribution)

    expected_client_votes = tf.constant(
        [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0]],
        dtype=tf.int32)
    self.assertAllEqual(client_output.client_votes, expected_client_votes)
Пример #4
0
    def test_client_update_works_on_empty_local_datasets(self):
        max_num_prefixes = tf.constant(10)
        max_user_contribution = tf.constant(10)
        possible_prefix_extensions = tf.constant(
            ['a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR],
            dtype=tf.string)
        discovered_prefixes = tf.constant(['a', 'b', 'c', 'd', 'e'],
                                          dtype=tf.string)
        round_num = tf.constant(1)
        num_sub_rounds = tf.constant(1)
        sample_data = tf.data.Dataset.from_tensor_slices([])
        client_output = triehh_tf.client_update(
            sample_data, discovered_prefixes, possible_prefix_extensions,
            round_num, num_sub_rounds, max_num_prefixes, max_user_contribution)

        expected_client_votes = tf.constant(
            [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0]],
            dtype=tf.int32)
        self.assertAllEqual(client_output.client_votes, expected_client_votes)
Пример #5
0
 def client_update_fn(tf_dataset, discovered_prefixes, round_num):
     return client_update(tf_dataset, discovered_prefixes,
                          tf.constant(possible_prefix_extensions),
                          round_num, num_sub_rounds, max_num_prefixes,
                          max_user_contribution)
Пример #6
0
 def client_update_fn(tf_dataset, discovered_prefixes, round_num):
     return client_update(tf_dataset, discovered_prefixes,
                          tf.constant(possible_prefix_extensions),
                          round_num, num_sub_rounds, max_num_prefixes,
                          max_user_contribution,
                          tf.constant(default_terminator, dtype=tf.string))