def test_contains_static_aggregation(self, dtype): """Tests that built computation contains at least one secure sum call.""" # Bounds provided as Python constants. @computations.federated_computation( computation_types.FederatedType((dtype, (2, )), placements.CLIENTS)) def comp_py_bounds(value): return federated_aggregations.secure_quantized_sum( value, np.array(-1.0, dtype.as_numpy_dtype), np.array(1.0, dtype.as_numpy_dtype)) static_assert.assert_contains_secure_aggregation(comp_py_bounds) # Bounds provided as tff values. @computations.federated_computation( computation_types.FederatedType((dtype, (2, )), placements.CLIENTS), computation_types.FederatedType(dtype, placements.SERVER), computation_types.FederatedType(dtype, placements.SERVER)) def comp_tff_bounds(value, upper_bound, lower_bound): return federated_aggregations.secure_quantized_sum( value, upper_bound, lower_bound) static_assert.assert_contains_secure_aggregation(comp_tff_bounds)
def test_passes_on_bothagg(self): static_assert.assert_contains_secure_aggregation( secure_and_unsecure_aggregation)
def test_fails_on_unsecagg(self): with self.assertRaises(AssertionError): static_assert.assert_contains_secure_aggregation( unsecure_aggregation)