def test_client_update_works_as_expected(self):
    max_num_prefixes = tf.constant(10)
    max_user_contribution = tf.constant(10)
    possible_prefix_extensions = tf.constant(
        ['a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR],
        dtype=tf.string)
    discovered_prefixes = tf.constant(['a', 'b', 'c', 'd', 'e'],
                                      dtype=tf.string)
    round_num = tf.constant(1)
    num_sub_rounds = tf.constant(1)
    sample_data = tf.data.Dataset.from_tensor_slices(
        ['a', '', 'abc', 'bac', 'abb', 'aaa', 'acc', 'hi'])
    client_output = triehh_tf.client_update(
        sample_data, discovered_prefixes, possible_prefix_extensions, round_num,
        num_sub_rounds, max_num_prefixes, max_user_contribution,
        tf.constant(triehh_tf.DEFAULT_TERMINATOR, dtype=tf.string))

    # Each string is attached with triehh_tf.DEFAULT_TERMINATOR before the
    # client votes, so 'a$' get a vote here.
    expected_client_votes = tf.constant(
        [[1, 2, 1, 0, 0, 1], [1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0]],
        dtype=tf.int32)
    self.assertAllEqual(client_output.client_votes, expected_client_votes)
  def test_client_update_works_on_empty_local_datasets(self):
    max_num_prefixes = tf.constant(10)
    max_user_contribution = tf.constant(10)
    possible_prefix_extensions = tf.constant(
        ['a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR],
        dtype=tf.string)
    discovered_prefixes = tf.constant(['a', 'b', 'c', 'd', 'e'],
                                      dtype=tf.string)
    round_num = tf.constant(1)
    num_sub_rounds = tf.constant(1)
    # Force an empty dataset that yields tf.string. Using `from_tensor_slices`
    # defaults to yielding tf.int32 values.
    sample_data = tf.data.Dataset.from_generator(
        generator=lambda: iter(()), output_types=tf.string, output_shapes=())
    client_output = triehh_tf.client_update(
        sample_data, discovered_prefixes, possible_prefix_extensions, round_num,
        num_sub_rounds, max_num_prefixes, max_user_contribution,
        tf.constant(triehh_tf.DEFAULT_TERMINATOR, dtype=tf.string))

    expected_client_votes = tf.constant(
        [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0]],
        dtype=tf.int32)
    self.assertAllEqual(client_output.client_votes, expected_client_votes)
  def test_client_update_works_on_empty_discovered_prefixes(self):
    max_num_prefixes = tf.constant(10)
    max_user_contribution = tf.constant(10)
    possible_prefix_extensions = tf.constant(
        ['a', 'b', 'c', 'd', 'e', triehh_tf.DEFAULT_TERMINATOR],
        dtype=tf.string)
    discovered_prefixes = tf.constant([], dtype=tf.string)
    round_num = tf.constant(1)
    num_sub_rounds = tf.constant(1)
    sample_data = tf.data.Dataset.from_tensor_slices(
        ['a', '', 'abc', 'bac', 'abb', 'aaa', 'acc', 'hi'])
    client_output = triehh_tf.client_update(
        sample_data, discovered_prefixes, possible_prefix_extensions, round_num,
        num_sub_rounds, max_num_prefixes, max_user_contribution,
        tf.constant(triehh_tf.DEFAULT_TERMINATOR, dtype=tf.string))

    expected_client_votes = tf.constant(
        [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0]],
        dtype=tf.int32)
    self.assertAllEqual(client_output.client_votes, expected_client_votes)
Пример #4
0
 def client_update_fn(tf_dataset, discovered_prefixes, round_num):
     return client_update(tf_dataset, discovered_prefixes,
                          tf.constant(possible_prefix_extensions),
                          round_num, num_sub_rounds, max_num_prefixes,
                          max_user_contribution,
                          tf.constant(default_terminator, dtype=tf.string))