예제 #1
0
  def test_inner_value_and_weight_sum_factory(self):
    sum_factory = aggregators_test_utils.SumPlusOneFactory()
    factory_ = mean.MeanFactory(
        value_sum_factory=sum_factory, weight_sum_factory=sum_factory)
    value_type = computation_types.to_type(tf.float32)
    weight_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type, weight_type)

    state = process.initialize()
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=0, weight_sum_process=0),
        state)

    # Weighted values will be summed to 11.0 and weights will be summed to 7.0.
    client_data = [1.0, 2.0, 3.0]
    weights = [3.0, 2.0, 1.0]

    output = process.next(state, client_data, weights)
    self.assertAllEqual(
        collections.OrderedDict(value_sum_process=1, weight_sum_process=1),
        output.state)
    self.assertAllClose(11 / 7, output.result)
    self.assertEqual(
        collections.OrderedDict(mean_value=M_CONST, mean_weight=M_CONST),
        output.measurements)
예제 #2
0
    def test_type_properties(self, value_type):
        sum_f = test_utils.SumPlusOneFactory()
        self.assertIsInstance(sum_f, factory.AggregationProcessFactory)
        value_type = computation_types.to_type(value_type)
        process = sum_f.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        param_value_type = computation_types.FederatedType(
            value_type, placements.CLIENTS)
        result_value_type = computation_types.FederatedType(
            value_type, placements.SERVER)
        expected_state_type = computation_types.FederatedType(
            tf.int32, placements.SERVER)
        expected_measurements_type = computation_types.FederatedType(
            tf.int32, placements.SERVER)

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(state=expected_state_type,
                                              value=param_value_type),
            result=measured_process.MeasuredProcessOutput(
                expected_state_type, result_value_type,
                expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
예제 #3
0
  def test_type_properties_with_inner_factory_unweighted(self, value_type):
    value_type = computation_types.to_type(value_type)
    sum_factory = aggregators_test_utils.SumPlusOneFactory()

    factory_ = mean.UnweightedMeanFactory(value_sum_factory=sum_factory)
    self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
    process = factory_.create(value_type)

    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.at_clients(value_type)
    result_value_type = computation_types.at_server(value_type)

    expected_state_type = computation_types.at_server(tf.int32)
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(mean_value=tf.int32))

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type, value=param_value_type),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
예제 #4
0
  def test_type_properties_with_inner_factory_weighted(self, value_type,
                                                       weight_type):
    sum_factory = aggregators_test_utils.SumPlusOneFactory()
    factory_ = mean_factory.MeanFactory(
        value_sum_factory=sum_factory, weight_sum_factory=sum_factory)
    self.assertIsInstance(factory_, factory.WeightedAggregationFactory)
    value_type = computation_types.to_type(value_type)
    weight_type = computation_types.to_type(weight_type)
    process = factory_.create_weighted(value_type, weight_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.FederatedType(value_type,
                                                       placements.CLIENTS)
    result_value_type = computation_types.FederatedType(value_type,
                                                        placements.SERVER)
    expected_state_type = expected_measurements_type = computation_types.FederatedType(
        collections.OrderedDict(
            value_sum_process=tf.int32, weight_sum_process=tf.int32),
        placements.SERVER)

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type,
            value=param_value_type,
            weight=computation_types.at_clients(weight_type)),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
예제 #5
0
    def test_sum_structure(self):
        sum_f = test_utils.SumPlusOneFactory()
        value_type = computation_types.to_type(((tf.float32, (2, )), tf.int32))
        process = sum_f.create(value_type)

        state = process.initialize()
        self.assertEqual(0, state)

        client_data = [((1.0, 2.0), 3), ((2.0, 5.0), 4), ((3.0, 0.0), 5)]
        output = process.next(state, client_data)
        self.assertEqual(1, output.state)
        self.assertAllClose(((7.0, 8.0), 13), output.result)
        self.assertEqual(42, output.measurements)
예제 #6
0
    def test_sum_scalar(self):
        sum_f = test_utils.SumPlusOneFactory()
        value_type = computation_types.to_type(tf.float32)
        process = sum_f.create(value_type)

        state = process.initialize()
        self.assertEqual(0, state)

        client_data = [1.0, 2.0, 3.0]
        output = process.next(state, client_data)
        self.assertEqual(1, output.state)
        self.assertAllClose(7.0, output.result)
        self.assertEqual(42, output.measurements)
예제 #7
0
  def test_inner_value_sum_factory_unweighted(self):
    sum_factory = aggregators_test_utils.SumPlusOneFactory()
    factory_ = mean.UnweightedMeanFactory(value_sum_factory=sum_factory)
    value_type = computation_types.to_type(tf.float32)
    process = factory_.create(value_type)

    state = process.initialize()
    self.assertAllEqual(0, state)

    # Values will be summed to 7.0.
    client_data = [1.0, 2.0, 3.0]

    output = process.next(state, client_data)
    self.assertAllEqual(1, output.state)
    self.assertAllClose(7 / 3, output.result)
    self.assertEqual(
        collections.OrderedDict(mean_value=M_CONST), output.measurements)
예제 #8
0
    def test_inner_weight_sum_factory(self):
        sum_factory = aggregators_test_utils.SumPlusOneFactory()
        mean_f = mean_factory.MeanFactory(weight_sum_factory=sum_factory)
        value_type = computation_types.to_type(tf.float32)
        process = mean_f.create(value_type)

        state = process.initialize()
        self.assertAllEqual(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=0), state)

        client_data = [1.0, 2.0, 3.0]
        weights = [1.0, 1.0, 1.0]
        # Weights will be summed to 4.0.
        output = process.next(state, client_data, weights)
        self.assertAllEqual(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=1), output.state)
        self.assertAllClose(1.5, output.result)
        self.assertEqual(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=M_CONST),
            output.measurements)
예제 #9
0
from tensorflow_federated.python.aggregators import differential_privacy
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.aggregators import test_utils
from tensorflow_federated.python.core.api import computation_types
from tensorflow_federated.python.core.api import placements
from tensorflow_federated.python.core.api import test_case
from tensorflow_federated.python.core.backends.native import execution_contexts
from tensorflow_federated.python.core.impl.types import type_conversions
from tensorflow_federated.python.core.templates import aggregation_process
from tensorflow_federated.python.core.templates import measured_process

_test_dp_query = tfp.GaussianSumQuery(l2_norm_clip=1.0, stddev=0.0)

_test_struct_type = [(tf.float32, (2,)), tf.float32]
_test_inner_agg_factory = test_utils.SumPlusOneFactory()


class DPFactoryComputationTest(test_case.TestCase, parameterized.TestCase):

  @parameterized.named_parameters(
      ('float_simple', tf.float32, None),
      ('struct_simple', _test_struct_type, None),
      ('float_inner', tf.float32, _test_inner_agg_factory),
      ('struct_inner', _test_struct_type, _test_inner_agg_factory))
  def test_type_properties(self, value_type, inner_agg_factory):
    factory_ = differential_privacy.DifferentiallyPrivateFactory(
        _test_dp_query, inner_agg_factory)
    self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
    value_type = computation_types.to_type(value_type)
    process = factory_.create(value_type)
예제 #10
0
 def test_incorrect_value_type_raises(self, bad_value_type):
     sum_f = test_utils.SumPlusOneFactory()
     with self.assertRaises(TypeError):
         sum_f.create(bad_value_type)
예제 #11
0
 def test_incorrect_value_type_raises(self, bad_value_type):
     sum_f = aggregators_test_utils.SumPlusOneFactory()
     with self.assertRaises(TypeError):
         sum_f.create_unweighted(bad_value_type)