def test_type_properties_unweighted(self, value_type): value_type = computation_types.to_type(value_type) factory_ = mean_factory.UnweightedMeanFactory() self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.at_clients(value_type) result_value_type = computation_types.at_server(value_type) expected_state_type = computation_types.at_server(()) expected_measurements_type = computation_types.at_server( collections.OrderedDict(mean_value=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_incorrect_create_type_raises(self, wrong_type): factory_ = mean_factory.MeanFactory() correct_type = computation_types.to_type(tf.float32) with self.assertRaises(TypeError): factory_.create(wrong_type, correct_type) with self.assertRaises(TypeError): factory_.create(correct_type, wrong_type) factory_ = mean_factory.UnweightedMeanFactory() with self.assertRaises(TypeError): factory_.create(wrong_type)
def test_structure_value_unweighted(self): factory_ = mean_factory.UnweightedMeanFactory() value_type = computation_types.to_type(_test_struct_type) process = factory_.create(value_type) expected_state = () expected_measurements = collections.OrderedDict(mean_value=()) state = process.initialize() self.assertAllEqual(expected_state, state) client_data = [((1.0, 2.0), 3.0), ((2.0, 5.0), 4.0), ((3.0, 0.0), 5.0)] output = process.next(state, client_data) self.assertAllEqual(expected_state, output.state) self.assertAllClose(((2.0, 7 / 3), 4.0), output.result) self.assertEqual(expected_measurements, output.measurements)
def test_inner_value_sum_factory_unweighted(self): sum_factory = aggregators_test_utils.SumPlusOneFactory() factory_ = mean_factory.UnweightedMeanFactory(value_sum_factory=sum_factory) value_type = computation_types.to_type(tf.float32) process = factory_.create(value_type) state = process.initialize() self.assertAllEqual(0, state) # Values will be summed to 7.0. client_data = [1.0, 2.0, 3.0] output = process.next(state, client_data) self.assertAllEqual(1, output.state) self.assertAllClose(7 / 3, output.result) self.assertEqual( collections.OrderedDict(mean_value=M_CONST), output.measurements)
def test_scalar_value_unweighted(self): factory_ = mean_factory.UnweightedMeanFactory() value_type = computation_types.to_type(tf.float32) process = factory_.create(value_type) expected_state = () expected_measurements = collections.OrderedDict(mean_value=()) state = process.initialize() self.assertAllEqual(expected_state, state) client_data = [1.0, 2.0, 3.0] output = process.next(state, client_data) self.assertAllClose(2.0, output.result) self.assertAllEqual(expected_state, output.state) self.assertEqual(expected_measurements, output.measurements)