def test_type_properties(self, name, value_type): factory = _hadamard_sum() if name == 'hd' else _dft_sum() value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( ((), rotation.SEED_TFF_TYPE)) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent(process.initialize.type_signature, expected_initialize_type) expected_measurements_type = computation_types.at_server( collections.OrderedDict([(name, ())])) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def test_clip(self, clip_mechanism): clip_factory = clipping_factory.HistogramClippingSumFactory( clip_mechanism, 1) self.assertIsInstance(clip_factory, factory.UnweightedAggregationFactory) value_type = computation_types.to_type((tf.int32, (2,))) process = clip_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.FederatedType(value_type, placements.CLIENTS) result_value_type = computation_types.FederatedType(value_type, placements.SERVER) expected_state_type = computation_types.at_server(()) expected_measurements_type = computation_types.at_server(()) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_type_properties(self, value_type): factory = _discretization_sum() value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( collections.OrderedDict(scale_factor=tf.float32, prior_norm_bound=tf.float32, inner_agg_process=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type) expected_measurements_type = computation_types.at_server( collections.OrderedDict(discretize=())) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def test_type_properties_simple(self): value_type = computation_types.to_type((tf.int32, (2,))) agg_factory = modular_clipping_factory.ModularClippingSumFactory( clip_range_lower=-2, clip_range_upper=2) process = agg_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # Inner SumFactory has no state. server_state_type = computation_types.at_server(()) expected_init_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to(expected_init_type)) expected_measurements_type = collections.OrderedDict(modclip=()) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=computation_types.at_server( expected_measurements_type))) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_clip_type_properties_weighted(self, value_type, weight_type): factory = _concat_mean() value_type = computation_types.to_type(value_type) weight_type = computation_types.to_type(weight_type) process = factory.create(value_type, weight_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # State comes from the inner MeanFactory. server_state_type = computation_types.at_server( collections.OrderedDict(value_sum_process=(), weight_sum_process=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type) # Measurements come from the inner mean factory. expected_measurements_type = computation_types.at_server( collections.OrderedDict(mean_value=(), mean_weight=())) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type), weight=computation_types.at_clients(weight_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def test_raises_with_closure(self): eager_ex = eager_tf_executor.EagerTFExecutor() factory = federated_resolving_strategy.FederatedResolvingStrategy.factory( { placements.SERVER: eager_ex, }) federated_ex = federating_executor.FederatingExecutor( factory, eager_ex) ex = reference_resolving_executor.ReferenceResolvingExecutor( federated_ex) @federated_computation.federated_computation( tf.int32, computation_types.at_server(tf.int32)) def foo(x, y): @federated_computation.federated_computation(tf.int32) def bar(z): del z return x return intrinsics.federated_map(bar, y) v1 = asyncio.run(ex.create_value(foo)) v2 = asyncio.run( ex.create_value(structure.Struct([ ('x', 0), ('y', 0) ]), [tf.int32, computation_types.at_server(tf.int32)])) with self.assertRaisesRegex( RuntimeError, 'lambda passed to intrinsic contains references to captured variables' ): asyncio.run(ex.create_call(v1, v2))
def test_type_properties(self, modulus, value_type, symmetric_range): factory_ = secure.SecureModularSumFactory( modulus=modulus, symmetric_range=symmetric_range) self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) expected_state_type = computation_types.at_server( computation_types.to_type(())) expected_measurements_type = expected_state_type expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type)) try: static_assert.assert_not_contains_unsecure_aggregation( process.next) except: # pylint: disable=bare-except self.fail('Factory returned an AggregationProcess containing ' 'non-secure aggregation.')
def test_type_properties(self): mw_type = computation_types.to_type( model_utils.ModelWeights(trainable=(tf.float32, tf.float32), non_trainable=tf.float32)) finalizer = finalizers.build_apply_optimizer_finalizer( sgdm.build_sgdm(1.0), mw_type) self.assertIsInstance(finalizer, finalizers.FinalizerProcess) expected_param_weights_type = computation_types.at_server(mw_type) expected_param_update_type = computation_types.at_server( mw_type.trainable) expected_result_type = computation_types.at_server(mw_type) expected_state_type = computation_types.at_server(()) expected_measurements_type = computation_types.at_server(()) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) expected_initialize_type.check_equivalent_to( finalizer.initialize.type_signature) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, weights=expected_param_weights_type, update=expected_param_update_type), result=MeasuredProcessOutput(expected_state_type, expected_result_type, expected_measurements_type)) expected_next_type.check_equivalent_to(finalizer.next.type_signature)
def test_zero_type_properties_with_zeroed_count_agg_factory( self, value_type): factory = robust.zeroing_factory( zeroing_norm=1.0, inner_agg_factory=sum_factory.SumFactory(), norm_order=2.0, zeroed_count_sum_factory=aggregator_test_utils.SumPlusOneFactory()) value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( collections.OrderedDict(zeroing_norm=(), inner_agg=(), zeroed_count_agg=tf.int32)) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(zeroing=(), zeroing_norm=robust.NORM_TF_TYPE, zeroed_count=robust.COUNT_TF_TYPE)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_with_federated_map_and_broadcast(self): eager_ex = eager_tf_executor.EagerTFExecutor() factory = federated_resolving_strategy.FederatedResolvingStrategy.factory( { placements.SERVER: eager_ex, placements.CLIENTS: [eager_ex for _ in range(3)] }) federated_ex = federating_executor.FederatingExecutor( factory, eager_ex) ex = reference_resolving_executor.ReferenceResolvingExecutor( federated_ex) @tensorflow_computation.tf_computation(tf.int32) def add_one(x): return x + 1 @federated_computation.federated_computation( computation_types.at_server(tf.int32)) def comp(x): return intrinsics.federated_map(add_one, intrinsics.federated_broadcast(x)) v1 = asyncio.run(ex.create_value(comp)) v2 = asyncio.run( ex.create_value(10, computation_types.at_server(tf.int32))) v3 = asyncio.run(ex.create_call(v1, v2)) result = asyncio.run(v3.compute()) self.assertCountEqual([x.numpy() for x in result], [11, 11, 11])
def test_federated_evaluation(self): evaluate = federated_evaluation.build_federated_evaluation(TestModel) model_weights_type = model_utils.weights_type_from_model(TestModel) type_test_utils.assert_types_equivalent( evaluate.type_signature, FunctionType( parameter=StructType([ ('server_model_weights', computation_types.at_server(model_weights_type)), ('federated_dataset', computation_types.at_clients( SequenceType( StructType([('temp', TensorType(dtype=tf.float32, shape=[None]))])))), ]), result=computation_types.at_server( collections.OrderedDict(eval=collections.OrderedDict( num_over=tf.float32))))) def _temp_dict(temps): return {'temp': np.array(temps, dtype=np.float32)} result = evaluate( collections.OrderedDict(trainable=[5.0], non_trainable=[]), [ [_temp_dict([1.0, 10.0, 2.0, 7.0]), _temp_dict([6.0, 11.0])], [_temp_dict([9.0, 12.0, 13.0])], [_temp_dict([1.0]), _temp_dict([22.0, 23.0])], ]) self.assertEqual( result, collections.OrderedDict( eval=collections.OrderedDict(num_over=9.0), ))
def test_federated_map_injected_zip_with_server_int(self): computation = _create_computation_greater_than_10_with_unused_parameter( ) x = _mock_data_of_type(computation_types.at_server(tf.int32)) y = _mock_data_of_type(computation_types.at_server(tf.int32)) value = intrinsics.federated_map(computation, [x, y]) self.assert_value(value, 'bool@SERVER')
def create_whimsy_intrinsic_def_federated_zip_at_server(): value = intrinsic_defs.FEDERATED_ZIP_AT_SERVER type_signature = computation_types.FunctionType([ computation_types.at_server(tf.float32), computation_types.at_server(tf.float32) ], computation_types.at_server([tf.float32, tf.float32])) return value, type_signature
def test_type_properties(self, value_type): factory = stochastic_discretization.StochasticDiscretizationFactory( step_size=0.1, inner_agg_factory=_measurement_aggregator, distortion_aggregation_factory=mean.UnweightedMeanFactory()) value_type = computation_types.to_type(value_type) quantize_type = type_conversions.structure_from_tensor_type_tree( lambda x: (tf.int32, x.shape), value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.StructType([('step_size', tf.float32), ('inner_agg_process', ()) ]) server_state_type = computation_types.at_server(server_state_type) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent(process.initialize.type_signature, expected_initialize_type) expected_measurements_type = computation_types.StructType([ ('stochastic_discretization', quantize_type), ('distortion', tf.float32) ]) expected_measurements_type = computation_types.at_server( expected_measurements_type) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def test_type_properties_unweighted(self, value_type): value_type = computation_types.to_type(value_type) factory_ = mean.UnweightedMeanFactory() self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.at_clients(value_type) result_value_type = computation_types.at_server(value_type) expected_state_type = computation_types.at_server(((), ())) expected_measurements_type = computation_types.at_server( collections.OrderedDict(mean_value=(), mean_count=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_concat_type_properties_unweighted(self, value_type): factory = _concat_sum() value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # Inner SumFactory has no state. server_state_type = computation_types.at_server(()) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) type_test_utils.assert_types_equivalent( process.initialize.type_signature, expected_initialize_type) # Inner SumFactory has no measurements. expected_measurements_type = computation_types.at_server(()) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) type_test_utils.assert_types_equivalent(process.next.type_signature, expected_next_type)
def test_with_federated_map(self): eager_ex = eager_tf_executor.EagerTFExecutor() factory = federated_resolving_strategy.FederatedResolvingStrategy.factory( {placements.SERVER: eager_ex}) federated_ex = federating_executor.FederatingExecutor( factory, eager_ex) ex = reference_resolving_executor.ReferenceResolvingExecutor( federated_ex) loop = asyncio.get_event_loop() @computations.tf_computation(tf.int32) def add_one(x): return x + 1 @computations.federated_computation( computation_types.at_server(tf.int32)) def comp(x): return intrinsics.federated_map(add_one, x) v1 = loop.run_until_complete(ex.create_value(comp)) v2 = loop.run_until_complete( ex.create_value(10, computation_types.at_server(tf.int32))) v3 = loop.run_until_complete(ex.create_call(v1, v2)) result = loop.run_until_complete(v3.compute()) self.assertEqual(result.numpy(), 11)
def test_type_properties_simple(self, value_type, estimate_stddev): factory = _test_factory(estimate_stddev=estimate_stddev) process = factory.create(computation_types.to_type(value_type)) self.assertIsInstance(process, aggregation_process.AggregationProcess) # Inner SumFactory has no state. server_state_type = computation_types.at_server(()) expected_init_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_init_type)) expected_measurements_type = collections.OrderedDict(modclip=()) if estimate_stddev: expected_measurements_type['estimated_stddev'] = tf.float32 expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=computation_types.at_server( expected_measurements_type))) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_clip_type_properties_simple(self, value_type): factory = _clipped_sum() value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( collections.OrderedDict(clipping_norm=(), inner_agg=(), clipped_count_agg=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(clipping=(), clipping_norm=robust.NORM_TF_TYPE, clipped_count=robust.COUNT_TF_TYPE)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_type_properties_constant_bounds(self, value_type, upper_bound, lower_bound, measurements_dtype): secure_sum_f = secure.SecureSumFactory( upper_bound_threshold=upper_bound, lower_bound_threshold=lower_bound) self.assertIsInstance(secure_sum_f, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = secure_sum_f.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) expected_state_type = computation_types.at_server( computation_types.to_type(())) expected_measurements_type = _measurements_type(measurements_dtype) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def create_whimsy_intrinsic_def_federated_apply(): value = intrinsic_defs.FEDERATED_APPLY type_signature = computation_types.FunctionType([ type_factory.unary_op(tf.float32), computation_types.at_server(tf.float32), ], computation_types.at_server(tf.float32)) return value, type_signature
def test_federated_secure_select(self): uri = intrinsic_defs.FEDERATED_SECURE_SELECT.uri comp = building_blocks.Intrinsic( uri, computation_types.FunctionType( [ computation_types.at_clients(tf.int32), # client_keys computation_types.at_server(tf.int32), # max_key computation_types.at_server(tf.float32), # server_state computation_types.FunctionType([tf.float32, tf.int32], tf.float32) # select_fn ], computation_types.at_clients( computation_types.SequenceType(tf.float32)))) self.assertGreater(_count_intrinsics(comp, uri), 0) # First without secure intrinsics shouldn't modify anything. reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies( comp) self.assertFalse(modified) self.assertGreater(_count_intrinsics(comp, uri), 0) self.assert_types_identical(comp.type_signature, reduced.type_signature) # Now replace bodies including secure intrinsics. reduced, modified = intrinsic_reductions.replace_secure_intrinsics_with_insecure_bodies( comp) self.assertTrue(modified) self.assert_types_identical(comp.type_signature, reduced.type_signature) self.assertGreater( _count_intrinsics(reduced, intrinsic_defs.FEDERATED_SELECT.uri), 0)
def _run_in_federated_computation(optimizer, spec): weights = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec) gradients = tf.nest.map_structure(lambda s: tf.ones(s.shape, s.dtype), spec) @federated_computation.federated_computation() def init_fn(): return intrinsics.federated_eval( tensorflow_computation.tf_computation( lambda: optimizer.initialize(spec)), placements.SERVER) @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_server(computation_types.to_type(spec)), computation_types.at_server(computation_types.to_type(spec))) def next_fn(state, weights, gradients): return intrinsics.federated_map( tensorflow_computation.tf_computation(optimizer.next), (state, weights, gradients)) state = init_fn() state_history = [state] weights_history = [weights] for _ in range(3): state, weights = next_fn(state, weights, gradients) state_history.append(state) weights_history.append(weights) return state_history, weights_history
def test_composition_type_properties(self, last_process): state_type = tf.float32 values_type = tf.int32 last_process = last_process(state_type, 0.0, values_type) composite_process = measured_process.chain_measured_processes( collections.OrderedDict( double=_create_test_measured_process_double(state_type, 1.0, values_type), last_process=last_process)) self.assertIsInstance(composite_process, measured_process.MeasuredProcess) expected_state_type = computation_types.at_server( collections.OrderedDict(double=state_type, last_process=state_type)) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( composite_process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) param_value_type = computation_types.at_clients(values_type) result_value_type = computation_types.at_server(values_type) expected_measurements_type = computation_types.at_server( collections.OrderedDict( double=collections.OrderedDict(a=tf.int32), last_process=last_process.next.type_signature.result.measurements .member)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, values=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( composite_process.next.type_signature.is_equivalent_to( expected_next_type))
async def compute_federated_aggregate( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: value_type, zero_type, accumulate_type, merge_type, report_type = ( executor_utils.parse_federated_aggregate_argument_types( arg.type_signature)) py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 5) val = arg.internal_representation[0] py_typecheck.check_type(val, list) py_typecheck.check_len(val, len(self._target_executors)) identity_report, identity_report_type = tensorflow_computation_factory.create_identity( zero_type) aggr_type = computation_types.FunctionType( computation_types.StructType([ value_type, zero_type, accumulate_type, merge_type, identity_report_type ]), computation_types.at_server(zero_type)) aggr_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_AGGREGATE, aggr_type) zero = await (await self._executor.create_selection(arg, 1)).compute() accumulate = arg.internal_representation[2] merge = arg.internal_representation[3] report = arg.internal_representation[4] async def _child_fn(ex, v): py_typecheck.check_type(v, executor_value_base.ExecutorValue) arg_values = [ ex.create_value(zero, zero_type), ex.create_value(accumulate, accumulate_type), ex.create_value(merge, merge_type), ex.create_value(identity_report, identity_report_type) ] aggr_func, aggr_args = await asyncio.gather( ex.create_value(aggr_comp, aggr_type), ex.create_struct([v] + list(await asyncio.gather(*arg_values)))) child_result = await (await ex.create_call(aggr_func, aggr_args)).compute() result_at_server = await self._server_executor.create_value( child_result, zero_type) return result_at_server val_futures = asyncio.as_completed( [_child_fn(c, v) for c, v in zip(self._target_executors, val)]) parent_merge, parent_report = await asyncio.gather( self._server_executor.create_value(merge, merge_type), self._server_executor.create_value(report, report_type)) merge_result = await next(val_futures) for next_val_future in val_futures: next_val = await next_val_future merge_arg = await self._server_executor.create_struct( [merge_result, next_val]) merge_result = await self._server_executor.create_call( parent_merge, merge_arg) report_result = await self._server_executor.create_call( parent_report, merge_result) return FederatedComposingStrategyValue( report_result, computation_types.at_server(report_type.result))
def test_serialize_deserialize_federated_at_server(self): x = 10 x_type = computation_types.at_server(tf.int32) value_proto, value_type = value_serialization.serialize_value(x, x_type) type_test_utils.assert_types_identical( value_type, computation_types.at_server(tf.int32)) y, type_spec = value_serialization.deserialize_value(value_proto) type_test_utils.assert_types_identical(type_spec, x_type) self.assertEqual(y, 10)
def test_serialize_deserialize_federated_at_server(self): x = 10 x_type = computation_types.at_server(tf.int32) value_proto, value_type = executor_serialization.serialize_value( x, x_type) self.assertIsInstance(value_proto, executor_pb2.Value) self.assert_types_identical(value_type, computation_types.at_server(tf.int32)) y, type_spec = executor_serialization.deserialize_value(value_proto) self.assert_types_identical(type_spec, x_type) self.assertEqual(y, 10)
def _build_expected_test_quant_model_eval_signature(): """Returns signature for build_federated_evaluation using TestModelQuant.""" weights_parameter_type = computation_types.at_server( model_utils.weights_type_from_model(TestModelQuant)) data_parameter_type = computation_types.at_clients( computation_types.SequenceType( collections.OrderedDict(temp=computation_types.TensorType( shape=(None, ), dtype=tf.float32)))) return_type = collections.OrderedDict( num_same=computation_types.at_server(tf.float32)) return computation_types.FunctionType(parameter=collections.OrderedDict( server_model_weights=weights_parameter_type, federated_dataset=data_parameter_type), result=return_type)
def test_central_aggregation_with_secure_sum(self, value_shape, arity, l1_bound): value_type = computation_types.to_type((tf.float32, (value_shape,))) factory_ = hihi_factory.create_central_hierarchical_histogram_factory( arity=arity, secure_sum=True) self.assertIsInstance(factory_, differential_privacy.DifferentiallyPrivateFactory) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) query_state = _test_central_dp_query.initial_global_state() query_state_type = type_conversions.type_from_tensors(query_state) query_metrics_type = type_conversions.type_from_tensors( _test_central_dp_query.derive_metrics(query_state)) server_state_type = computation_types.at_server((query_state_type, ())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) secure_dp_type = collections.OrderedDict( secure_upper_clipped_count=tf.int32, secure_lower_clipped_count=tf.int32, secure_upper_threshold=tf.float32, secure_lower_threshold=tf.float32) expected_measurements_type = computation_types.at_server( collections.OrderedDict( dp_query_metrics=query_metrics_type, dp=secure_dp_type)) result_value_type = computation_types.to_type( collections.OrderedDict([ ('flat_values', computation_types.TensorType(tf.float32, tf.TensorShape(None))), ('nested_row_splits', [(tf.int64, (None,))]) ])) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(result_value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_finalizer(): @federated_computation.federated_computation( empty_init_fn.type_signature.result, computation_types.at_server(MODEL_WEIGHTS_TYPE), computation_types.at_server(FLOAT_TYPE)) def next_fn(state, weights, updates): new_weights = intrinsics.federated_map( tensorflow_computation.tf_computation(lambda x, y: x + y), (weights.trainable, updates)) new_weights = intrinsics.federated_zip( model_utils.ModelWeights(new_weights, ())) return measured_process.MeasuredProcessOutput(state, new_weights, empty_at_server()) return finalizers.FinalizerProcess(empty_init_fn, next_fn)