コード例 #1
0
ファイル: rotation_test.py プロジェクト: tensorflow/federated
  def test_type_properties(self, name, value_type):
    factory = _hadamard_sum() if name == 'hd' else _dft_sum()
    value_type = computation_types.to_type(value_type)
    process = factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    server_state_type = computation_types.at_server(
        ((), rotation.SEED_TFF_TYPE))

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    type_test_utils.assert_types_equivalent(process.initialize.type_signature,
                                            expected_initialize_type)

    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict([(name, ())]))
    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=expected_measurements_type))
    type_test_utils.assert_types_equivalent(process.next.type_signature,
                                            expected_next_type)
コード例 #2
0
  def test_clip(self, clip_mechanism):
    clip_factory = clipping_factory.HistogramClippingSumFactory(
        clip_mechanism, 1)
    self.assertIsInstance(clip_factory, factory.UnweightedAggregationFactory)
    value_type = computation_types.to_type((tf.int32, (2,)))
    process = clip_factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.FederatedType(value_type,
                                                       placements.CLIENTS)
    result_value_type = computation_types.FederatedType(value_type,
                                                        placements.SERVER)
    expected_state_type = computation_types.at_server(())
    expected_measurements_type = computation_types.at_server(())

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type, value=param_value_type),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #3
0
    def test_type_properties(self, value_type):
        factory = _discretization_sum()
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        server_state_type = computation_types.at_server(
            collections.OrderedDict(scale_factor=tf.float32,
                                    prior_norm_bound=tf.float32,
                                    inner_agg_process=()))

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        type_test_utils.assert_types_equivalent(
            process.initialize.type_signature, expected_initialize_type)

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(discretize=()))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        type_test_utils.assert_types_equivalent(process.next.type_signature,
                                                expected_next_type)
コード例 #4
0
  def test_type_properties_simple(self):
    value_type = computation_types.to_type((tf.int32, (2,)))
    agg_factory = modular_clipping_factory.ModularClippingSumFactory(
        clip_range_lower=-2, clip_range_upper=2)
    process = agg_factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    # Inner SumFactory has no state.
    server_state_type = computation_types.at_server(())

    expected_init_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(expected_init_type))

    expected_measurements_type = collections.OrderedDict(modclip=())

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=computation_types.at_server(
                expected_measurements_type)))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #5
0
    def test_clip_type_properties_weighted(self, value_type, weight_type):
        factory = _concat_mean()
        value_type = computation_types.to_type(value_type)
        weight_type = computation_types.to_type(weight_type)
        process = factory.create(value_type, weight_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # State comes from the inner MeanFactory.
        server_state_type = computation_types.at_server(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()))

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        type_test_utils.assert_types_equivalent(
            process.initialize.type_signature, expected_initialize_type)

        # Measurements come from the inner mean factory.
        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(mean_value=(), mean_weight=()))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type),
                weight=computation_types.at_clients(weight_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        type_test_utils.assert_types_equivalent(process.next.type_signature,
                                                expected_next_type)
コード例 #6
0
    def test_raises_with_closure(self):
        eager_ex = eager_tf_executor.EagerTFExecutor()
        factory = federated_resolving_strategy.FederatedResolvingStrategy.factory(
            {
                placements.SERVER: eager_ex,
            })
        federated_ex = federating_executor.FederatingExecutor(
            factory, eager_ex)
        ex = reference_resolving_executor.ReferenceResolvingExecutor(
            federated_ex)

        @federated_computation.federated_computation(
            tf.int32, computation_types.at_server(tf.int32))
        def foo(x, y):
            @federated_computation.federated_computation(tf.int32)
            def bar(z):
                del z
                return x

            return intrinsics.federated_map(bar, y)

        v1 = asyncio.run(ex.create_value(foo))
        v2 = asyncio.run(
            ex.create_value(structure.Struct([
                ('x', 0), ('y', 0)
            ]), [tf.int32, computation_types.at_server(tf.int32)]))
        with self.assertRaisesRegex(
                RuntimeError,
                'lambda passed to intrinsic contains references to captured variables'
        ):
            asyncio.run(ex.create_call(v1, v2))
コード例 #7
0
    def test_type_properties(self, modulus, value_type, symmetric_range):
        factory_ = secure.SecureModularSumFactory(
            modulus=modulus, symmetric_range=symmetric_range)
        self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = factory_.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        expected_state_type = computation_types.at_server(
            computation_types.to_type(()))
        expected_measurements_type = expected_state_type

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
        try:
            static_assert.assert_not_contains_unsecure_aggregation(
                process.next)
        except:  # pylint: disable=bare-except
            self.fail('Factory returned an AggregationProcess containing '
                      'non-secure aggregation.')
コード例 #8
0
    def test_type_properties(self):
        mw_type = computation_types.to_type(
            model_utils.ModelWeights(trainable=(tf.float32, tf.float32),
                                     non_trainable=tf.float32))

        finalizer = finalizers.build_apply_optimizer_finalizer(
            sgdm.build_sgdm(1.0), mw_type)
        self.assertIsInstance(finalizer, finalizers.FinalizerProcess)

        expected_param_weights_type = computation_types.at_server(mw_type)
        expected_param_update_type = computation_types.at_server(
            mw_type.trainable)
        expected_result_type = computation_types.at_server(mw_type)
        expected_state_type = computation_types.at_server(())
        expected_measurements_type = computation_types.at_server(())

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        expected_initialize_type.check_equivalent_to(
            finalizer.initialize.type_signature)

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                weights=expected_param_weights_type,
                update=expected_param_update_type),
            result=MeasuredProcessOutput(expected_state_type,
                                         expected_result_type,
                                         expected_measurements_type))
        expected_next_type.check_equivalent_to(finalizer.next.type_signature)
コード例 #9
0
    def test_zero_type_properties_with_zeroed_count_agg_factory(
            self, value_type):
        factory = robust.zeroing_factory(
            zeroing_norm=1.0,
            inner_agg_factory=sum_factory.SumFactory(),
            norm_order=2.0,
            zeroed_count_sum_factory=aggregator_test_utils.SumPlusOneFactory())
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        server_state_type = computation_types.at_server(
            collections.OrderedDict(zeroing_norm=(),
                                    inner_agg=(),
                                    zeroed_count_agg=tf.int32))
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(zeroing=(),
                                    zeroing_norm=robust.NORM_TF_TYPE,
                                    zeroed_count=robust.COUNT_TF_TYPE))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #10
0
    def test_with_federated_map_and_broadcast(self):
        eager_ex = eager_tf_executor.EagerTFExecutor()
        factory = federated_resolving_strategy.FederatedResolvingStrategy.factory(
            {
                placements.SERVER: eager_ex,
                placements.CLIENTS: [eager_ex for _ in range(3)]
            })
        federated_ex = federating_executor.FederatingExecutor(
            factory, eager_ex)
        ex = reference_resolving_executor.ReferenceResolvingExecutor(
            federated_ex)

        @tensorflow_computation.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @federated_computation.federated_computation(
            computation_types.at_server(tf.int32))
        def comp(x):
            return intrinsics.federated_map(add_one,
                                            intrinsics.federated_broadcast(x))

        v1 = asyncio.run(ex.create_value(comp))
        v2 = asyncio.run(
            ex.create_value(10, computation_types.at_server(tf.int32)))
        v3 = asyncio.run(ex.create_call(v1, v2))
        result = asyncio.run(v3.compute())
        self.assertCountEqual([x.numpy() for x in result], [11, 11, 11])
コード例 #11
0
    def test_federated_evaluation(self):
        evaluate = federated_evaluation.build_federated_evaluation(TestModel)
        model_weights_type = model_utils.weights_type_from_model(TestModel)
        type_test_utils.assert_types_equivalent(
            evaluate.type_signature,
            FunctionType(
                parameter=StructType([
                    ('server_model_weights',
                     computation_types.at_server(model_weights_type)),
                    ('federated_dataset',
                     computation_types.at_clients(
                         SequenceType(
                             StructType([('temp',
                                          TensorType(dtype=tf.float32,
                                                     shape=[None]))])))),
                ]),
                result=computation_types.at_server(
                    collections.OrderedDict(eval=collections.OrderedDict(
                        num_over=tf.float32)))))

        def _temp_dict(temps):
            return {'temp': np.array(temps, dtype=np.float32)}

        result = evaluate(
            collections.OrderedDict(trainable=[5.0], non_trainable=[]), [
                [_temp_dict([1.0, 10.0, 2.0, 7.0]),
                 _temp_dict([6.0, 11.0])],
                [_temp_dict([9.0, 12.0, 13.0])],
                [_temp_dict([1.0]),
                 _temp_dict([22.0, 23.0])],
            ])
        self.assertEqual(
            result,
            collections.OrderedDict(
                eval=collections.OrderedDict(num_over=9.0), ))
コード例 #12
0
 def test_federated_map_injected_zip_with_server_int(self):
     computation = _create_computation_greater_than_10_with_unused_parameter(
     )
     x = _mock_data_of_type(computation_types.at_server(tf.int32))
     y = _mock_data_of_type(computation_types.at_server(tf.int32))
     value = intrinsics.federated_map(computation, [x, y])
     self.assert_value(value, 'bool@SERVER')
コード例 #13
0
def create_whimsy_intrinsic_def_federated_zip_at_server():
    value = intrinsic_defs.FEDERATED_ZIP_AT_SERVER
    type_signature = computation_types.FunctionType([
        computation_types.at_server(tf.float32),
        computation_types.at_server(tf.float32)
    ], computation_types.at_server([tf.float32, tf.float32]))
    return value, type_signature
コード例 #14
0
  def test_type_properties(self, value_type):
    factory = stochastic_discretization.StochasticDiscretizationFactory(
        step_size=0.1,
        inner_agg_factory=_measurement_aggregator,
        distortion_aggregation_factory=mean.UnweightedMeanFactory())
    value_type = computation_types.to_type(value_type)
    quantize_type = type_conversions.structure_from_tensor_type_tree(
        lambda x: (tf.int32, x.shape), value_type)
    process = factory.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    server_state_type = computation_types.StructType([('step_size', tf.float32),
                                                      ('inner_agg_process', ())
                                                     ])
    server_state_type = computation_types.at_server(server_state_type)
    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    type_test_utils.assert_types_equivalent(process.initialize.type_signature,
                                            expected_initialize_type)

    expected_measurements_type = computation_types.StructType([
        ('stochastic_discretization', quantize_type), ('distortion', tf.float32)
    ])
    expected_measurements_type = computation_types.at_server(
        expected_measurements_type)
    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=expected_measurements_type))
    type_test_utils.assert_types_equivalent(process.next.type_signature,
                                            expected_next_type)
コード例 #15
0
  def test_type_properties_unweighted(self, value_type):
    value_type = computation_types.to_type(value_type)

    factory_ = mean.UnweightedMeanFactory()
    self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
    process = factory_.create(value_type)

    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.at_clients(value_type)
    result_value_type = computation_types.at_server(value_type)

    expected_state_type = computation_types.at_server(((), ()))
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(mean_value=(), mean_count=()))

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type, value=param_value_type),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #16
0
    def test_concat_type_properties_unweighted(self, value_type):
        factory = _concat_sum()
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # Inner SumFactory has no state.
        server_state_type = computation_types.at_server(())

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        type_test_utils.assert_types_equivalent(
            process.initialize.type_signature, expected_initialize_type)

        # Inner SumFactory has no measurements.
        expected_measurements_type = computation_types.at_server(())
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        type_test_utils.assert_types_equivalent(process.next.type_signature,
                                                expected_next_type)
コード例 #17
0
    def test_with_federated_map(self):
        eager_ex = eager_tf_executor.EagerTFExecutor()
        factory = federated_resolving_strategy.FederatedResolvingStrategy.factory(
            {placements.SERVER: eager_ex})
        federated_ex = federating_executor.FederatingExecutor(
            factory, eager_ex)
        ex = reference_resolving_executor.ReferenceResolvingExecutor(
            federated_ex)
        loop = asyncio.get_event_loop()

        @computations.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @computations.federated_computation(
            computation_types.at_server(tf.int32))
        def comp(x):
            return intrinsics.federated_map(add_one, x)

        v1 = loop.run_until_complete(ex.create_value(comp))
        v2 = loop.run_until_complete(
            ex.create_value(10, computation_types.at_server(tf.int32)))
        v3 = loop.run_until_complete(ex.create_call(v1, v2))
        result = loop.run_until_complete(v3.compute())
        self.assertEqual(result.numpy(), 11)
コード例 #18
0
    def test_type_properties_simple(self, value_type, estimate_stddev):
        factory = _test_factory(estimate_stddev=estimate_stddev)
        process = factory.create(computation_types.to_type(value_type))
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # Inner SumFactory has no state.
        server_state_type = computation_types.at_server(())

        expected_init_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_init_type))

        expected_measurements_type = collections.OrderedDict(modclip=())
        if estimate_stddev:
            expected_measurements_type['estimated_stddev'] = tf.float32

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=computation_types.at_server(
                    expected_measurements_type)))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #19
0
    def test_clip_type_properties_simple(self, value_type):
        factory = _clipped_sum()
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        server_state_type = computation_types.at_server(
            collections.OrderedDict(clipping_norm=(),
                                    inner_agg=(),
                                    clipped_count_agg=()))
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(clipping=(),
                                    clipping_norm=robust.NORM_TF_TYPE,
                                    clipped_count=robust.COUNT_TF_TYPE))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #20
0
ファイル: secure_test.py プロジェクト: xingzhis/federated
    def test_type_properties_constant_bounds(self, value_type, upper_bound,
                                             lower_bound, measurements_dtype):
        secure_sum_f = secure.SecureSumFactory(
            upper_bound_threshold=upper_bound,
            lower_bound_threshold=lower_bound)
        self.assertIsInstance(secure_sum_f,
                              factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = secure_sum_f.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        expected_state_type = computation_types.at_server(
            computation_types.to_type(()))
        expected_measurements_type = _measurements_type(measurements_dtype)

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #21
0
def create_whimsy_intrinsic_def_federated_apply():
    value = intrinsic_defs.FEDERATED_APPLY
    type_signature = computation_types.FunctionType([
        type_factory.unary_op(tf.float32),
        computation_types.at_server(tf.float32),
    ], computation_types.at_server(tf.float32))
    return value, type_signature
コード例 #22
0
 def test_federated_secure_select(self):
     uri = intrinsic_defs.FEDERATED_SECURE_SELECT.uri
     comp = building_blocks.Intrinsic(
         uri,
         computation_types.FunctionType(
             [
                 computation_types.at_clients(tf.int32),  # client_keys
                 computation_types.at_server(tf.int32),  # max_key
                 computation_types.at_server(tf.float32),  # server_state
                 computation_types.FunctionType([tf.float32, tf.int32],
                                                tf.float32)  # select_fn
             ],
             computation_types.at_clients(
                 computation_types.SequenceType(tf.float32))))
     self.assertGreater(_count_intrinsics(comp, uri), 0)
     # First without secure intrinsics shouldn't modify anything.
     reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies(
         comp)
     self.assertFalse(modified)
     self.assertGreater(_count_intrinsics(comp, uri), 0)
     self.assert_types_identical(comp.type_signature,
                                 reduced.type_signature)
     # Now replace bodies including secure intrinsics.
     reduced, modified = intrinsic_reductions.replace_secure_intrinsics_with_insecure_bodies(
         comp)
     self.assertTrue(modified)
     self.assert_types_identical(comp.type_signature,
                                 reduced.type_signature)
     self.assertGreater(
         _count_intrinsics(reduced, intrinsic_defs.FEDERATED_SELECT.uri), 0)
コード例 #23
0
def _run_in_federated_computation(optimizer, spec):
    weights = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec)
    gradients = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype),
                                      spec)

    @federated_computation.federated_computation()
    def init_fn():
        return intrinsics.federated_eval(
            tensorflow_computation.tf_computation(
                lambda: optimizer.initialize(spec)), placements.SERVER)

    @federated_computation.federated_computation(
        init_fn.type_signature.result,
        computation_types.at_server(computation_types.to_type(spec)),
        computation_types.at_server(computation_types.to_type(spec)))
    def next_fn(state, weights, gradients):
        return intrinsics.federated_map(
            tensorflow_computation.tf_computation(optimizer.next),
            (state, weights, gradients))

    state = init_fn()
    state_history = [state]
    weights_history = [weights]
    for _ in range(3):
        state, weights = next_fn(state, weights, gradients)
        state_history.append(state)
        weights_history.append(weights)

    return state_history, weights_history
コード例 #24
0
  def test_composition_type_properties(self, last_process):
    state_type = tf.float32
    values_type = tf.int32
    last_process = last_process(state_type, 0.0, values_type)
    composite_process = measured_process.chain_measured_processes(
        collections.OrderedDict(
            double=_create_test_measured_process_double(state_type, 1.0,
                                                        values_type),
            last_process=last_process))
    self.assertIsInstance(composite_process, measured_process.MeasuredProcess)

    expected_state_type = computation_types.at_server(
        collections.OrderedDict(double=state_type, last_process=state_type))
    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        composite_process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    param_value_type = computation_types.at_clients(values_type)
    result_value_type = computation_types.at_server(values_type)
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(
            double=collections.OrderedDict(a=tf.int32),
            last_process=last_process.next.type_signature.result.measurements
            .member))
    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type, values=param_value_type),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        composite_process.next.type_signature.is_equivalent_to(
            expected_next_type))
コード例 #25
0
    async def compute_federated_aggregate(
        self, arg: FederatedComposingStrategyValue
    ) -> FederatedComposingStrategyValue:
        value_type, zero_type, accumulate_type, merge_type, report_type = (
            executor_utils.parse_federated_aggregate_argument_types(
                arg.type_signature))
        py_typecheck.check_type(arg.internal_representation, structure.Struct)
        py_typecheck.check_len(arg.internal_representation, 5)
        val = arg.internal_representation[0]
        py_typecheck.check_type(val, list)
        py_typecheck.check_len(val, len(self._target_executors))
        identity_report, identity_report_type = tensorflow_computation_factory.create_identity(
            zero_type)
        aggr_type = computation_types.FunctionType(
            computation_types.StructType([
                value_type, zero_type, accumulate_type, merge_type,
                identity_report_type
            ]), computation_types.at_server(zero_type))
        aggr_comp = executor_utils.create_intrinsic_comp(
            intrinsic_defs.FEDERATED_AGGREGATE, aggr_type)
        zero = await (await self._executor.create_selection(arg, 1)).compute()
        accumulate = arg.internal_representation[2]
        merge = arg.internal_representation[3]
        report = arg.internal_representation[4]

        async def _child_fn(ex, v):
            py_typecheck.check_type(v, executor_value_base.ExecutorValue)
            arg_values = [
                ex.create_value(zero, zero_type),
                ex.create_value(accumulate, accumulate_type),
                ex.create_value(merge, merge_type),
                ex.create_value(identity_report, identity_report_type)
            ]
            aggr_func, aggr_args = await asyncio.gather(
                ex.create_value(aggr_comp, aggr_type),
                ex.create_struct([v] +
                                 list(await asyncio.gather(*arg_values))))
            child_result = await (await ex.create_call(aggr_func,
                                                       aggr_args)).compute()
            result_at_server = await self._server_executor.create_value(
                child_result, zero_type)
            return result_at_server

        val_futures = asyncio.as_completed(
            [_child_fn(c, v) for c, v in zip(self._target_executors, val)])
        parent_merge, parent_report = await asyncio.gather(
            self._server_executor.create_value(merge, merge_type),
            self._server_executor.create_value(report, report_type))
        merge_result = await next(val_futures)
        for next_val_future in val_futures:
            next_val = await next_val_future
            merge_arg = await self._server_executor.create_struct(
                [merge_result, next_val])
            merge_result = await self._server_executor.create_call(
                parent_merge, merge_arg)
        report_result = await self._server_executor.create_call(
            parent_report, merge_result)
        return FederatedComposingStrategyValue(
            report_result, computation_types.at_server(report_type.result))
コード例 #26
0
 def test_serialize_deserialize_federated_at_server(self):
   x = 10
   x_type = computation_types.at_server(tf.int32)
   value_proto, value_type = value_serialization.serialize_value(x, x_type)
   type_test_utils.assert_types_identical(
       value_type, computation_types.at_server(tf.int32))
   y, type_spec = value_serialization.deserialize_value(value_proto)
   type_test_utils.assert_types_identical(type_spec, x_type)
   self.assertEqual(y, 10)
コード例 #27
0
 def test_serialize_deserialize_federated_at_server(self):
     x = 10
     x_type = computation_types.at_server(tf.int32)
     value_proto, value_type = executor_serialization.serialize_value(
         x, x_type)
     self.assertIsInstance(value_proto, executor_pb2.Value)
     self.assert_types_identical(value_type,
                                 computation_types.at_server(tf.int32))
     y, type_spec = executor_serialization.deserialize_value(value_proto)
     self.assert_types_identical(type_spec, x_type)
     self.assertEqual(y, 10)
コード例 #28
0
def _build_expected_test_quant_model_eval_signature():
    """Returns signature for build_federated_evaluation using TestModelQuant."""
    weights_parameter_type = computation_types.at_server(
        model_utils.weights_type_from_model(TestModelQuant))
    data_parameter_type = computation_types.at_clients(
        computation_types.SequenceType(
            collections.OrderedDict(temp=computation_types.TensorType(
                shape=(None, ), dtype=tf.float32))))
    return_type = collections.OrderedDict(
        num_same=computation_types.at_server(tf.float32))
    return computation_types.FunctionType(parameter=collections.OrderedDict(
        server_model_weights=weights_parameter_type,
        federated_dataset=data_parameter_type),
                                          result=return_type)
コード例 #29
0
  def test_central_aggregation_with_secure_sum(self, value_shape, arity,
                                               l1_bound):

    value_type = computation_types.to_type((tf.float32, (value_shape,)))

    factory_ = hihi_factory.create_central_hierarchical_histogram_factory(
        arity=arity, secure_sum=True)
    self.assertIsInstance(factory_,
                          differential_privacy.DifferentiallyPrivateFactory)

    process = factory_.create(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    query_state = _test_central_dp_query.initial_global_state()
    query_state_type = type_conversions.type_from_tensors(query_state)
    query_metrics_type = type_conversions.type_from_tensors(
        _test_central_dp_query.derive_metrics(query_state))

    server_state_type = computation_types.at_server((query_state_type, ()))
    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    secure_dp_type = collections.OrderedDict(
        secure_upper_clipped_count=tf.int32,
        secure_lower_clipped_count=tf.int32,
        secure_upper_threshold=tf.float32,
        secure_lower_threshold=tf.float32)
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(
            dp_query_metrics=query_metrics_type, dp=secure_dp_type))
    result_value_type = computation_types.to_type(
        collections.OrderedDict([
            ('flat_values',
             computation_types.TensorType(tf.float32, tf.TensorShape(None))),
            ('nested_row_splits', [(tf.int64, (None,))])
        ]))
    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(result_value_type),
            measurements=expected_measurements_type))

    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #30
0
def test_finalizer():
    @federated_computation.federated_computation(
        empty_init_fn.type_signature.result,
        computation_types.at_server(MODEL_WEIGHTS_TYPE),
        computation_types.at_server(FLOAT_TYPE))
    def next_fn(state, weights, updates):
        new_weights = intrinsics.federated_map(
            tensorflow_computation.tf_computation(lambda x, y: x + y),
            (weights.trainable, updates))
        new_weights = intrinsics.federated_zip(
            model_utils.ModelWeights(new_weights, ()))
        return measured_process.MeasuredProcessOutput(state, new_weights,
                                                      empty_at_server())

    return finalizers.FinalizerProcess(empty_init_fn, next_fn)