예제 #1
0
 def test_string_max_length_validation(self):
     with self.assertRaisesRegex(ValueError, 'string_max_length'):
         iblt_factory.IbltFactory(string_max_length=0,
                                  capacity=10,
                                  repetitions=3,
                                  seed=0)
     with self.assertRaisesRegex(ValueError, 'string_max_length'):
         iblt_factory.IbltFactory(string_max_length=-1,
                                  capacity=10,
                                  repetitions=3,
                                  seed=0)
     # Should not raise
     iblt_factory.IbltFactory(string_max_length=1,
                              capacity=10,
                              repetitions=3,
                              seed=0)
예제 #2
0
    def test_iblt_aggregation_as_expected(
        self,
        capacity: int,
        string_max_length: int,
        repetitions: int,
        seed: int,
        sketch_agg_factory: Optional[
            factory.UnweightedAggregationFactory] = None,
        value_tensor_agg_factory: Optional[
            factory.UnweightedAggregationFactory] = None):
        iblt_agg_factory = iblt_factory.IbltFactory(
            sketch_agg_factory=sketch_agg_factory,
            value_tensor_agg_factory=value_tensor_agg_factory,
            capacity=capacity,
            string_max_length=string_max_length,
            repetitions=repetitions,
            seed=seed)
        iblt_agg_process = iblt_agg_factory.create(VALUE_TYPE)
        process_output = iblt_agg_process.next(iblt_agg_process.initialize(),
                                               CLIENT_DATA)
        output_strings = [
            s.decode('utf-8') for s in process_output.result.output_strings
        ]
        string_values = process_output.result.string_values
        result = dict(zip(output_strings, string_values))

        self.assertCountEqual(result, AGGREGATED_DATA)

        expected_measurements = collections.OrderedDict([('num_not_decoded',
                                                          0), ('sketch', ()),
                                                         ('value_tensor', ())])
        self.assertCountEqual(process_output.measurements,
                              expected_measurements)
예제 #3
0
 def test_value_type_validation(self, value_type):
     iblt_agg_factory = iblt_factory.IbltFactory(capacity=10,
                                                 string_max_length=5,
                                                 repetitions=3,
                                                 seed=0)
     with self.assertRaises(ValueError):
         iblt_agg_factory.create(value_type)
예제 #4
0
 def test_repetitions_validation(self):
     with self.assertRaisesRegex(ValueError, 'repetitions'):
         iblt_factory.IbltFactory(repetitions=0,
                                  capacity=10,
                                  string_max_length=10,
                                  seed=0)
     with self.assertRaisesRegex(ValueError, 'repetitions'):
         iblt_factory.IbltFactory(repetitions=2,
                                  capacity=10,
                                  string_max_length=10,
                                  seed=0)
     # Should not raise
     iblt_factory.IbltFactory(repetitions=3,
                              capacity=10,
                              string_max_length=10,
                              seed=0)
예제 #5
0
 def test_string_max_length_error(self):
     client = collections.OrderedDict([
         (iblt_factory.DATASET_KEY,
          tf.constant(['thisisalongword'], dtype=tf.string)),
         (iblt_factory.DATASET_VALUE, tf.constant([[1]], dtype=tf.int64)),
     ])
     value_type = computation_types.SequenceType(
         collections.OrderedDict(key=tf.string,
                                 value=computation_types.TensorType(
                                     shape=(1, ), dtype=tf.int64)))
     client_data = [tf.data.Dataset.from_tensor_slices(client)]
     iblt_agg_factory = iblt_factory.IbltFactory(capacity=10,
                                                 string_max_length=5,
                                                 repetitions=3,
                                                 seed=0)
     iblt_agg_process = iblt_agg_factory.create(value_type)
     with self.assertRaises(tf.errors.InvalidArgumentError):
         iblt_agg_process.next(iblt_agg_process.initialize(), client_data)