Esempio n. 1
0
    def test_contains_static_aggregation(self, dtype):
        """Tests that built computation contains at least one secure sum call."""

        # Bounds provided as Python constants.
        @computations.federated_computation(
            computation_types.FederatedType((dtype, (2, )),
                                            placements.CLIENTS))
        def comp_py_bounds(value):
            return federated_aggregations.secure_quantized_sum(
                value, np.array(-1.0, dtype.as_numpy_dtype),
                np.array(1.0, dtype.as_numpy_dtype))

        static_assert.assert_contains_secure_aggregation(comp_py_bounds)

        # Bounds provided as tff values.
        @computations.federated_computation(
            computation_types.FederatedType((dtype, (2, )),
                                            placements.CLIENTS),
            computation_types.FederatedType(dtype, placements.SERVER),
            computation_types.FederatedType(dtype, placements.SERVER))
        def comp_tff_bounds(value, upper_bound, lower_bound):
            return federated_aggregations.secure_quantized_sum(
                value, upper_bound, lower_bound)

        static_assert.assert_contains_secure_aggregation(comp_tff_bounds)
 def test_passes_on_bothagg(self):
     static_assert.assert_contains_secure_aggregation(
         secure_and_unsecure_aggregation)
 def test_fails_on_unsecagg(self):
     with self.assertRaises(AssertionError):
         static_assert.assert_contains_secure_aggregation(
             unsecure_aggregation)