コード例 #1
0
    def test_with_federated_map(self):
        eager_ex = eager_tf_executor.EagerTFExecutor()
        factory = federated_resolving_strategy.FederatedResolvingStrategy.factory(
            {placements.SERVER: eager_ex})
        federated_ex = federating_executor.FederatingExecutor(
            factory, eager_ex)
        ex = reference_resolving_executor.ReferenceResolvingExecutor(
            federated_ex)
        loop = asyncio.get_event_loop()

        @computations.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @computations.federated_computation(
            computation_types.at_server(tf.int32))
        def comp(x):
            return intrinsics.federated_map(add_one, x)

        v1 = loop.run_until_complete(ex.create_value(comp))
        v2 = loop.run_until_complete(
            ex.create_value(10, computation_types.at_server(tf.int32)))
        v3 = loop.run_until_complete(ex.create_call(v1, v2))
        result = loop.run_until_complete(v3.compute())
        self.assertEqual(result.numpy(), 11)
コード例 #2
0
    def test_type_properties(self, value_type, factory_cons):
        factory = factory_cons()
        value_type = computation_types.to_type(value_type)
        process = factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        server_state_type = computation_types.at_server(
            ((),
             collections.OrderedDict(value_sum_process=(),
                                     weight_sum_process=())))
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type),
                weight=computation_types.at_clients(tf.float32)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #3
0
    def test_raises_with_closure(self):
        eager_ex = eager_tf_executor.EagerTFExecutor()
        factory = federated_resolving_strategy.FederatedResolvingStrategy.factory(
            {
                placement_literals.SERVER: eager_ex,
            })
        federated_ex = federating_executor.FederatingExecutor(
            factory, eager_ex)
        ex = reference_resolving_executor.ReferenceResolvingExecutor(
            federated_ex)
        loop = asyncio.get_event_loop()

        @computations.federated_computation(tf.int32,
                                            computation_types.at_server(
                                                tf.int32))
        def foo(x, y):
            @computations.federated_computation(tf.int32)
            def bar(z):
                del z
                return x

            return intrinsics.federated_map(bar, y)

        v1 = loop.run_until_complete(ex.create_value(foo))
        v2 = loop.run_until_complete(
            ex.create_value(structure.Struct([
                ('x', 0), ('y', 0)
            ]), [tf.int32, computation_types.at_server(tf.int32)]))
        with self.assertRaisesRegex(
                RuntimeError,
                'lambda passed to intrinsic contains references to captured variables'
        ):
            loop.run_until_complete(ex.create_call(v1, v2))
コード例 #4
0
    def test_zero_type_properties_weighted(self, value_type, weight_type):
        factory = _zeroed_mean()
        value_type = computation_types.to_type(value_type)
        weight_type = computation_types.to_type(weight_type)
        process = factory.create(value_type, weight_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        mean_state_type = collections.OrderedDict(value_sum_process=(),
                                                  weight_sum_process=())
        server_state_type = computation_types.at_server(
            collections.OrderedDict(zeroing_norm=(),
                                    inner_agg=mean_state_type,
                                    zeroed_count_agg=()))
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(zeroing=collections.OrderedDict(
                mean_value=(), mean_weight=()),
                                    zeroing_norm=robust_factory.NORM_TF_TYPE,
                                    zeroed_count=robust_factory.COUNT_TF_TYPE))
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type),
                weight=computation_types.at_clients(weight_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #5
0
    def test_type_properties_constant_bounds(self, value_type, upper_bound,
                                             lower_bound, measurements_dtype):
        secure_sum_f = secure_factory.SecureSumFactory(
            upper_bound_threshold=upper_bound,
            lower_bound_threshold=lower_bound)
        self.assertIsInstance(secure_sum_f,
                              factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = secure_sum_f.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        expected_state_type = computation_types.at_server(
            computation_types.to_type(()))
        expected_measurements_type = _measurements_type(measurements_dtype)

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #6
0
  def test_clip_type_properties_simple(self, value_type):
    factory = _clipped_sum()
    value_type = computation_types.to_type(value_type)
    process = factory.create_unweighted(value_type)
    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    server_state_type = computation_types.at_server(
        collections.OrderedDict(
            clipping_norm=(), inner_agg=(), clipped_count_agg=()))
    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=server_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(
            agg_process=(),
            clipping_norm=clipping_factory.NORM_TF_TYPE,
            clipped_count=clipping_factory.COUNT_TF_TYPE))
    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=server_state_type,
            value=computation_types.at_clients(value_type)),
        result=measured_process.MeasuredProcessOutput(
            state=server_state_type,
            result=computation_types.at_server(value_type),
            measurements=expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #7
0
    async def compute_federated_aggregate(
        self, arg: FederatedComposingStrategyValue
    ) -> FederatedComposingStrategyValue:
        value_type, zero_type, accumulate_type, merge_type, report_type = (
            executor_utils.parse_federated_aggregate_argument_types(
                arg.type_signature))
        py_typecheck.check_type(arg.internal_representation, structure.Struct)
        py_typecheck.check_len(arg.internal_representation, 5)
        val = arg.internal_representation[0]
        py_typecheck.check_type(val, list)
        py_typecheck.check_len(val, len(self._target_executors))
        identity_report, identity_report_type = tensorflow_computation_factory.create_identity(
            zero_type)
        aggr_type = computation_types.FunctionType(
            computation_types.StructType([
                value_type, zero_type, accumulate_type, merge_type,
                identity_report_type
            ]), computation_types.at_server(zero_type))
        aggr_comp = executor_utils.create_intrinsic_comp(
            intrinsic_defs.FEDERATED_AGGREGATE, aggr_type)
        zero = await (await
                      self._executor.create_selection(arg, index=1)).compute()
        accumulate = arg.internal_representation[2]
        merge = arg.internal_representation[3]
        report = arg.internal_representation[4]

        async def _child_fn(ex, v):
            py_typecheck.check_type(v, executor_value_base.ExecutorValue)
            arg_values = [
                ex.create_value(zero, zero_type),
                ex.create_value(accumulate, accumulate_type),
                ex.create_value(merge, merge_type),
                ex.create_value(identity_report, identity_report_type)
            ]
            aggr_func, aggr_args = await asyncio.gather(
                ex.create_value(aggr_comp, aggr_type),
                ex.create_struct([v] +
                                 list(await asyncio.gather(*arg_values))))
            child_result = await (await ex.create_call(aggr_func,
                                                       aggr_args)).compute()
            result_at_server = await self._server_executor.create_value(
                child_result, zero_type)
            return result_at_server

        val_futures = asyncio.as_completed(
            [_child_fn(c, v) for c, v in zip(self._target_executors, val)])
        parent_merge, parent_report = await asyncio.gather(
            self._server_executor.create_value(merge, merge_type),
            self._server_executor.create_value(report, report_type))
        merge_result = await next(val_futures)
        for next_val_future in val_futures:
            next_val = await next_val_future
            merge_arg = await self._server_executor.create_struct(
                [merge_result, next_val])
            merge_result = await self._server_executor.create_call(
                parent_merge, merge_arg)
        report_result = await self._server_executor.create_call(
            parent_report, merge_result)
        return FederatedComposingStrategyValue(
            report_result, computation_types.at_server(report_type.result))
コード例 #8
0
def create_dummy_intrinsic_def_federated_zip_at_server():
    value = intrinsic_defs.FEDERATED_ZIP_AT_SERVER
    type_signature = computation_types.FunctionType([
        computation_types.at_server(tf.float32),
        computation_types.at_server(tf.float32)
    ], computation_types.at_server([tf.float32, tf.float32]))
    return value, type_signature
コード例 #9
0
    def test_with_federated_map_and_broadcast(self):
        eager_ex = eager_tf_executor.EagerTFExecutor()
        factory = federated_resolving_strategy.FederatedResolvingStrategy.factory(
            {
                placement_literals.SERVER: eager_ex,
                placement_literals.CLIENTS: [eager_ex for _ in range(3)]
            })
        federated_ex = federating_executor.FederatingExecutor(
            factory, eager_ex)
        ex = reference_resolving_executor.ReferenceResolvingExecutor(
            federated_ex)
        loop = asyncio.get_event_loop()

        @computations.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        @computations.federated_computation(
            computation_types.at_server(tf.int32))
        def comp(x):
            return intrinsics.federated_map(add_one,
                                            intrinsics.federated_broadcast(x))

        v1 = loop.run_until_complete(ex.create_value(comp))
        v2 = loop.run_until_complete(
            ex.create_value(10, computation_types.at_server(tf.int32)))
        v3 = loop.run_until_complete(ex.create_call(v1, v2))
        result = loop.run_until_complete(v3.compute())
        self.assertCountEqual([x.numpy() for x in result], [11, 11, 11])
コード例 #10
0
  def test_type_properties_unweighted(self, value_type):
    value_type = computation_types.to_type(value_type)

    factory_ = mean.UnweightedMeanFactory()
    self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
    process = factory_.create(value_type)

    self.assertIsInstance(process, aggregation_process.AggregationProcess)

    param_value_type = computation_types.at_clients(value_type)
    result_value_type = computation_types.at_server(value_type)

    expected_state_type = computation_types.at_server(())
    expected_measurements_type = computation_types.at_server(
        collections.OrderedDict(mean_value=()))

    expected_initialize_type = computation_types.FunctionType(
        parameter=None, result=expected_state_type)
    self.assertTrue(
        process.initialize.type_signature.is_equivalent_to(
            expected_initialize_type))

    expected_next_type = computation_types.FunctionType(
        parameter=collections.OrderedDict(
            state=expected_state_type, value=param_value_type),
        result=measured_process.MeasuredProcessOutput(
            expected_state_type, result_value_type, expected_measurements_type))
    self.assertTrue(
        process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #11
0
    def test_type_properties(self, value_type, weight_type):
        mean_f = mean_factory.MeanFactory()
        self.assertIsInstance(mean_f, factory.WeightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        weight_type = computation_types.to_type(weight_type)
        process = mean_f.create_weighted(value_type, weight_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        param_value_type = computation_types.FederatedType(
            value_type, placements.CLIENTS)
        result_value_type = computation_types.FederatedType(
            value_type, placements.SERVER)
        expected_state_type = computation_types.at_server(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()))
        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(value_sum_process=(),
                                    weight_sum_process=()))

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=param_value_type,
                weight=computation_types.at_clients(weight_type)),
            result=measured_process.MeasuredProcessOutput(
                expected_state_type, result_value_type,
                expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
コード例 #12
0
def create_dummy_intrinsic_def_federated_apply():
    value = intrinsic_defs.FEDERATED_APPLY
    type_signature = computation_types.FunctionType([
        type_factory.unary_op(tf.float32),
        computation_types.at_server(tf.float32),
    ], computation_types.at_server(tf.float32))
    return value, type_signature
コード例 #13
0
 def test_serialize_deserialize_federated_at_server(self):
   x = 10
   x_type = computation_types.at_server(tf.int32)
   value_proto, value_type = executor_serialization.serialize_value(x, x_type)
   self.assertIsInstance(value_proto, executor_pb2.Value)
   self.assert_types_identical(value_type,
                               computation_types.at_server(tf.int32))
   y, type_spec = executor_serialization.deserialize_value(value_proto)
   self.assert_types_identical(type_spec, x_type)
   self.assertEqual(y, 10)
コード例 #14
0
def _build_expected_test_quant_model_eval_signature():
    """Returns signature for build_federated_evaluation using TestModelQuant."""
    weights_parameter_type = computation_types.at_server(
        model_utils.weights_type_from_model(TestModelQuant))
    data_parameter_type = computation_types.at_clients(
        computation_types.SequenceType(
            collections.OrderedDict(temp=computation_types.TensorType(
                shape=(None, ), dtype=tf.float32))))
    return_type = collections.OrderedDict(
        num_same=computation_types.at_server(tf.float32))
    return computation_types.FunctionType(parameter=collections.OrderedDict(
        server_model_weights=weights_parameter_type,
        federated_dataset=data_parameter_type),
                                          result=return_type)
コード例 #15
0
def _temperature_sensor_example_next_fn():
    @computations.tf_computation(computation_types.SequenceType(tf.float32),
                                 tf.float32)
    def count_over(ds, t):
        return ds.reduce(
            np.float32(0),
            lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32))

    @computations.tf_computation(computation_types.SequenceType(tf.float32))
    def count_total(ds):
        return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0)

    @computations.federated_computation(
        computation_types.at_clients(computation_types.SequenceType(
            tf.float32)), computation_types.at_server(tf.float32))
    def comp(temperatures, threshold):
        return intrinsics.federated_mean(
            intrinsics.federated_map(
                count_over,
                intrinsics.federated_zip(
                    [temperatures,
                     intrinsics.federated_broadcast(threshold)])),
            intrinsics.federated_map(count_total, temperatures))

    return comp
コード例 #16
0
def _measurements_type(bound_type):
    return computation_types.at_server(
        collections.OrderedDict(
            secure_upper_clipped_count=secure_factory.COUNT_TF_TYPE,
            secure_lower_clipped_count=secure_factory.COUNT_TF_TYPE,
            secure_upper_threshold=bound_type,
            secure_lower_threshold=bound_type))
コード例 #17
0
    async def compute_federated_mean(
        self, arg: FederatedComposingStrategyValue
    ) -> FederatedComposingStrategyValue:
        type_analysis.check_federated_type(
            arg.type_signature, placement=placement_literals.CLIENTS)
        member_type = arg.type_signature.member

        async def _create_total():
            total = await self.compute_federated_sum(arg)
            total = await total.compute()
            return await self._server_executor.create_value(total, member_type)

        async def _create_factor():
            cardinalities = await self._get_cardinalities()
            count = sum(cardinalities)
            return await executor_utils.embed_tf_constant(
                self._server_executor, member_type, float(1.0 / count))

        async def _create_multiply_arg():
            total, factor = await asyncio.gather(_create_total(),
                                                 _create_factor())
            return await self._server_executor.create_struct([total, factor])

        multiply_fn, multiply_arg = await asyncio.gather(
            executor_utils.embed_tf_binary_operator(self._server_executor,
                                                    member_type, tf.multiply),
            _create_multiply_arg())
        result = await self._server_executor.create_call(
            multiply_fn, multiply_arg)
        type_signature = computation_types.at_server(member_type)
        return FederatedComposingStrategyValue(result, type_signature)
コード例 #18
0
def create_dummy_intrinsic_def_federated_weighted_mean():
    value = intrinsic_defs.FEDERATED_WEIGHTED_MEAN
    type_signature = computation_types.FunctionType([
        computation_types.at_clients(tf.float32),
        computation_types.at_clients(tf.float32),
    ], computation_types.at_server(tf.float32))
    return value, type_signature
コード例 #19
0
    def test_allows_assignable_but_not_equal_zero_and_reduction_types(self):
        factory = intrinsic_factory.IntrinsicFactory(
            context_stack_impl.context_stack)

        element_type = tf.string
        zero_type = computation_types.TensorType(tf.string, [1])
        reduced_type = computation_types.TensorType(tf.string, [None])

        @computations.tf_computation(reduced_type, element_type)
        @computations.check_returns_type(reduced_type)
        def append(accumulator, element):
            return tf.concat([accumulator, [element]], 0)

        @computations.tf_computation
        @computations.check_returns_type(zero_type)
        def zero():
            return tf.convert_to_tensor(['The beginning'])

        @computations.federated_computation(
            computation_types.at_clients(element_type))
        @computations.check_returns_type(
            computation_types.at_server(reduced_type))
        def collect(client_values):
            return factory.federated_reduce(client_values, zero(), append)

        self.assertEqual(collect.type_signature.compact_representation(),
                         '({string}@CLIENTS -> string[?]@SERVER)')
コード例 #20
0
ファイル: sampling.py プロジェクト: pribanacek/federated
    def create(
        self, value_type: factory.ValueType
    ) -> aggregation_process.AggregationProcess:
        @computations.federated_computation()
        def init_fn():
            # Empty/null state, nothing is tracked across invocations.
            return intrinsics.federated_value((), placements.SERVER)

        @computations.federated_computation(computation_types.at_server(
            ()), computation_types.at_clients(value_type))
        def next_fn(unused_state, value):
            # Empty tuple is the `None` of TFF.
            empty_tuple = intrinsics.federated_value((), placements.SERVER)
            initial_reservoir = _build_initial_sample_reservoir(value_type)
            sample_value = _build_sample_value_computation(
                value_type, self._sample_size)
            merge_samples = _build_merge_samples_computation(
                value_type, self._sample_size)
            finalize_sample = _build_finalize_sample_computation(value_type)
            samples = intrinsics.federated_aggregate(value,
                                                     zero=initial_reservoir,
                                                     accumulate=sample_value,
                                                     merge=merge_samples,
                                                     report=finalize_sample)
            return measured_process.MeasuredProcessOutput(
                state=empty_tuple, result=samples, measurements=empty_tuple)

        return aggregation_process.AggregationProcess(init_fn, next_fn)
コード例 #21
0
def create_dummy_intrinsic_def_federated_collect():
    value = intrinsic_defs.FEDERATED_COLLECT
    type_signature = computation_types.FunctionType(
        computation_types.at_clients(tf.float32),
        computation_types.at_server(computation_types.SequenceType(
            tf.float32)))
    return value, type_signature
コード例 #22
0
class CreateIdentityTest(parameterized.TestCase):

    # pyformat: disable
    @parameterized.named_parameters(
        ('int', computation_types.TensorType(tf.int32), 10),
        ('unnamed_tuple', computation_types.StructType([tf.int32, tf.float32]),
         structure.Struct([(None, 10), (None, 10.0)])),
        ('named_tuple',
         computation_types.StructType([
             ('a', tf.int32), ('b', tf.float32)
         ]), structure.Struct([('a', 10), ('b', 10.0)])),
        ('sequence', computation_types.SequenceType(tf.int32), [10] * 3),
    )
    # pyformat: enable
    def test_returns_computation(self, type_signature, value):
        proto, _ = tensorflow_computation_factory.create_identity(
            type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = type_factory.unary_op(type_signature)
        self.assertEqual(actual_type, expected_type)
        actual_result = test_utils.run_tensorflow(proto, value)
        self.assertEqual(actual_result, value)

    @parameterized.named_parameters(
        ('none', None),
        ('federated_type', computation_types.at_server(tf.int32)),
    )
    def test_raises_type_error(self, type_signature):
        with self.assertRaises(TypeError):
            tensorflow_computation_factory.create_identity(type_signature)
コード例 #23
0
def create_dummy_intrinsic_def_federated_secure_sum():
    value = intrinsic_defs.FEDERATED_SECURE_SUM
    type_signature = computation_types.FunctionType([
        computation_types.at_clients(tf.int32),
        tf.int32,
    ], computation_types.at_server(tf.int32))
    return value, type_signature
コード例 #24
0
class CreateReplicateInputTest(parameterized.TestCase):
    @parameterized.named_parameters(
        ('int', computation_types.TensorType(tf.int32), 3, 10),
        ('float', computation_types.TensorType(tf.float32), 3, 10.0),
        ('unnamed_tuple', computation_types.StructType([tf.int32, tf.float32]),
         3, structure.Struct([(None, 10), (None, 10.0)])),
        ('named_tuple',
         computation_types.StructType([
             ('a', tf.int32), ('b', tf.float32)
         ]), 3, structure.Struct([('a', 10), ('b', 10.0)])),
        ('sequence', computation_types.SequenceType(tf.int32), 3, [10] * 3),
    )
    def test_returns_computation(self, type_signature, count, value):
        proto, _ = tensorflow_computation_factory.create_replicate_input(
            type_signature, count)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(
            type_signature, [type_signature] * count)
        expected_type.check_assignable_from(actual_type)
        actual_result = test_utils.run_tensorflow(proto, value)
        expected_result = structure.Struct([(None, value)] * count)
        self.assertEqual(actual_result, expected_result)

    @parameterized.named_parameters(
        ('none_type', None, 3),
        ('none_count', computation_types.TensorType(tf.int32), None),
        ('federated_type', computation_types.at_server(tf.int32), 3),
    )
    def test_raises_type_error(self, type_signature, count):
        with self.assertRaises(TypeError):
            tensorflow_computation_factory.create_replicate_input(
                type_signature, count)
コード例 #25
0
 def test_federated_zip_with_single_named_bool_server(self):
     x = _mock_data_of_type(
         computation_types.StructType([
             ('a', computation_types.at_server(tf.bool))
         ]))
     val = intrinsics.federated_zip(x)
     self.assert_value(val, '<a=bool>@SERVER')
コード例 #26
0
def create_intrinsic_def_federated_secure_sum(value_type, bitwidth_type):
    value = intrinsic_defs.FEDERATED_SECURE_SUM
    type_signature = computation_types.FunctionType([
        computation_types.at_clients(value_type),
        bitwidth_type,
    ], computation_types.at_server(value_type))
    return value, type_signature
コード例 #27
0
def _build_expected_broadcaster_next_signature():
    """Returns signature of the broadcaster used in multiple tests below."""
    state_type = computation_types.at_server(
        computation_types.StructType([('trainable', [
            (),
        ]), ('non_trainable', [])]))
    value_type = computation_types.at_server(
        model_utils.weights_type_from_model(TestModelQuant))
    result_type = computation_types.at_clients(
        model_utils.weights_type_from_model(TestModelQuant))
    measurements_type = computation_types.at_server(())
    return computation_types.FunctionType(
        parameter=collections.OrderedDict(state=state_type, value=value_type),
        result=collections.OrderedDict(state=state_type,
                                       result=result_type,
                                       measurements=measurements_type))
コード例 #28
0
 async def _compute_apply_fn():
     apply_type = computation_types.FunctionType(
         computation_types.StructType(
             [divide_blk.type_signature, zip2_type.result]),
         computation_types.at_server(divide_blk.type_signature.result))
     apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY,
                                        apply_type)
     return await executor.create_value(apply_comp, apply_type)
コード例 #29
0
def create_dummy_intrinsic_def_federated_reduce():
    value = intrinsic_defs.FEDERATED_REDUCE
    type_signature = computation_types.FunctionType([
        computation_types.at_clients(tf.float32),
        tf.float32,
        type_factory.reduction_op(tf.float32, tf.float32),
    ], computation_types.at_server(tf.float32))
    return value, type_signature
コード例 #30
0
    def test_returns_value_with_federated_type_at_server(self):
        value = [eager_tf_executor.EagerValue(10.0, None, tf.float32)]
        type_signature = computation_types.at_server(tf.float32)
        value = federated_resolving_strategy.FederatedResolvingStrategyValue(
            value, type_signature)

        result = self.run_sync(value.compute())

        self.assertEqual(result, 10.0)