def test_type_properties(self, value_type): sum_f = aggregator_test_utils.SumPlusOneFactory() self.assertIsInstance(sum_f, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = sum_f.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.FederatedType( value_type, placements.CLIENTS) result_value_type = computation_types.FederatedType( value_type, placements.SERVER) expected_state_type = computation_types.FederatedType( tf.int32, placements.SERVER) expected_measurements_type = computation_types.FederatedType( tf.int32, placements.SERVER) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict(state=expected_state_type, value=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_inner_value_and_weight_sum_factory(self): sum_factory = aggregator_test_utils.SumPlusOneFactory() factory_ = mean.MeanFactory( value_sum_factory=sum_factory, weight_sum_factory=sum_factory) value_type = computation_types.to_type(tf.float32) weight_type = computation_types.to_type(tf.float32) process = factory_.create(value_type, weight_type) state = process.initialize() self.assertAllEqual( collections.OrderedDict(value_sum_process=0, weight_sum_process=0), state) # Weighted values will be summed to 11.0 and weights will be summed to 7.0. client_data = [1.0, 2.0, 3.0] weights = [3.0, 2.0, 1.0] output = process.next(state, client_data, weights) self.assertAllEqual( collections.OrderedDict(value_sum_process=1, weight_sum_process=1), output.state) self.assertAllClose(11 / 7, output.result) self.assertEqual( collections.OrderedDict(mean_value=M_CONST, mean_weight=M_CONST), output.measurements)
def test_type_properties_with_inner_factory_unweighted(self, value_type): value_type = computation_types.to_type(value_type) sum_factory = aggregator_test_utils.SumPlusOneFactory() factory_ = mean.UnweightedMeanFactory( value_sum_factory=sum_factory, count_sum_factory=sum_factory) self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.at_clients(value_type) result_value_type = computation_types.at_server(value_type) expected_state_type = computation_types.at_server(((tf.int32, tf.int32))) expected_measurements_type = computation_types.at_server( collections.OrderedDict(mean_value=tf.int32, mean_count=tf.int32)) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_clip_type_properties_with_clipped_count_agg_factory( self, value_type): factory = robust.clipping_factory( clipping_norm=1.0, inner_agg_factory=sum_factory.SumFactory(), clipped_count_sum_factory=aggregator_test_utils.SumPlusOneFactory( )) value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( collections.OrderedDict(clipping_norm=(), inner_agg=(), clipped_count_agg=tf.int32)) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(clipping=(), clipping_norm=robust.NORM_TF_TYPE, clipped_count=robust.COUNT_TF_TYPE)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_sum_structure(self): sum_f = aggregator_test_utils.SumPlusOneFactory() value_type = computation_types.to_type(((tf.float32, (2, )), tf.int32)) process = sum_f.create(value_type) state = process.initialize() self.assertEqual(0, state) client_data = [((1.0, 2.0), 3), ((2.0, 5.0), 4), ((3.0, 0.0), 5)] output = process.next(state, client_data) self.assertEqual(1, output.state) self.assertAllClose(((7.0, 8.0), 13), output.result) self.assertEqual(42, output.measurements)
def test_sum_scalar(self): sum_f = aggregator_test_utils.SumPlusOneFactory() value_type = computation_types.to_type(tf.float32) process = sum_f.create(value_type) state = process.initialize() self.assertEqual(0, state) client_data = [1.0, 2.0, 3.0] output = process.next(state, client_data) self.assertEqual(1, output.state) self.assertAllClose(7.0, output.result) self.assertEqual(42, output.measurements)
def test_inner_value_sum_factory_unweighted(self): sum_factory = aggregator_test_utils.SumPlusOneFactory() factory_ = mean.UnweightedMeanFactory( value_sum_factory=sum_factory, count_sum_factory=sum_factory) value_type = computation_types.to_type(tf.float32) process = factory_.create(value_type) state = process.initialize() self.assertAllEqual((0, 0), state) # Values will be summed to 7.0. client_data = [1.0, 2.0, 3.0] output = process.next(state, client_data) self.assertAllEqual((1, 1), output.state) self.assertAllClose(7 / 4, output.result) self.assertEqual( collections.OrderedDict(mean_value=M_CONST, mean_count=M_CONST), output.measurements)
def test_incorrect_value_type_raises(self, bad_value_type): sum_f = aggregator_test_utils.SumPlusOneFactory() with self.assertRaises(TypeError): sum_f.create(bad_value_type)
import tensorflow_privacy as tfp from tensorflow_federated.python.aggregators import aggregator_test_utils from tensorflow_federated.python.aggregators import differential_privacy from tensorflow_federated.python.aggregators import factory from tensorflow_federated.python.core.backends.native import execution_contexts from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.impl.types import type_conversions from tensorflow_federated.python.core.templates import aggregation_process from tensorflow_federated.python.core.templates import measured_process _test_dp_query = tfp.GaussianSumQuery(l2_norm_clip=1.0, stddev=0.0) _test_struct_type = [(tf.float32, (2, )), tf.float32] _test_inner_agg_factory = aggregator_test_utils.SumPlusOneFactory() class DPFactoryComputationTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( ('float_simple', tf.float32, None), ('struct_simple', _test_struct_type, None), ('float_inner', tf.float32, _test_inner_agg_factory), ('struct_inner', _test_struct_type, _test_inner_agg_factory)) def test_type_properties(self, value_type, inner_agg_factory): factory_ = differential_privacy.DifferentiallyPrivateFactory( _test_dp_query, inner_agg_factory) self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess)