def test_with_federated_map(self): eager_ex = eager_tf_executor.EagerTFExecutor() factory = federated_resolving_strategy.FederatedResolvingStrategy.factory( {placements.SERVER: eager_ex}) federated_ex = federating_executor.FederatingExecutor( factory, eager_ex) ex = reference_resolving_executor.ReferenceResolvingExecutor( federated_ex) loop = asyncio.get_event_loop() @computations.tf_computation(tf.int32) def add_one(x): return x + 1 @computations.federated_computation( computation_types.at_server(tf.int32)) def comp(x): return intrinsics.federated_map(add_one, x) v1 = loop.run_until_complete(ex.create_value(comp)) v2 = loop.run_until_complete( ex.create_value(10, computation_types.at_server(tf.int32))) v3 = loop.run_until_complete(ex.create_call(v1, v2)) result = loop.run_until_complete(v3.compute()) self.assertEqual(result.numpy(), 11)
def test_type_properties(self, value_type, factory_cons): factory = factory_cons() value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( ((), collections.OrderedDict(value_sum_process=(), weight_sum_process=()))) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(value_sum_process=(), weight_sum_process=())) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type), weight=computation_types.at_clients(tf.float32)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_raises_with_closure(self): eager_ex = eager_tf_executor.EagerTFExecutor() factory = federated_resolving_strategy.FederatedResolvingStrategy.factory( { placement_literals.SERVER: eager_ex, }) federated_ex = federating_executor.FederatingExecutor( factory, eager_ex) ex = reference_resolving_executor.ReferenceResolvingExecutor( federated_ex) loop = asyncio.get_event_loop() @computations.federated_computation(tf.int32, computation_types.at_server( tf.int32)) def foo(x, y): @computations.federated_computation(tf.int32) def bar(z): del z return x return intrinsics.federated_map(bar, y) v1 = loop.run_until_complete(ex.create_value(foo)) v2 = loop.run_until_complete( ex.create_value(structure.Struct([ ('x', 0), ('y', 0) ]), [tf.int32, computation_types.at_server(tf.int32)])) with self.assertRaisesRegex( RuntimeError, 'lambda passed to intrinsic contains references to captured variables' ): loop.run_until_complete(ex.create_call(v1, v2))
def test_zero_type_properties_weighted(self, value_type, weight_type): factory = _zeroed_mean() value_type = computation_types.to_type(value_type) weight_type = computation_types.to_type(weight_type) process = factory.create(value_type, weight_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) mean_state_type = collections.OrderedDict(value_sum_process=(), weight_sum_process=()) server_state_type = computation_types.at_server( collections.OrderedDict(zeroing_norm=(), inner_agg=mean_state_type, zeroed_count_agg=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(zeroing=collections.OrderedDict( mean_value=(), mean_weight=()), zeroing_norm=robust_factory.NORM_TF_TYPE, zeroed_count=robust_factory.COUNT_TF_TYPE)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type), weight=computation_types.at_clients(weight_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_type_properties_constant_bounds(self, value_type, upper_bound, lower_bound, measurements_dtype): secure_sum_f = secure_factory.SecureSumFactory( upper_bound_threshold=upper_bound, lower_bound_threshold=lower_bound) self.assertIsInstance(secure_sum_f, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = secure_sum_f.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) expected_state_type = computation_types.at_server( computation_types.to_type(())) expected_measurements_type = _measurements_type(measurements_dtype) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_clip_type_properties_simple(self, value_type): factory = _clipped_sum() value_type = computation_types.to_type(value_type) process = factory.create_unweighted(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( collections.OrderedDict( clipping_norm=(), inner_agg=(), clipped_count_agg=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict( agg_process=(), clipping_norm=clipping_factory.NORM_TF_TYPE, clipped_count=clipping_factory.COUNT_TF_TYPE)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
async def compute_federated_aggregate( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: value_type, zero_type, accumulate_type, merge_type, report_type = ( executor_utils.parse_federated_aggregate_argument_types( arg.type_signature)) py_typecheck.check_type(arg.internal_representation, structure.Struct) py_typecheck.check_len(arg.internal_representation, 5) val = arg.internal_representation[0] py_typecheck.check_type(val, list) py_typecheck.check_len(val, len(self._target_executors)) identity_report, identity_report_type = tensorflow_computation_factory.create_identity( zero_type) aggr_type = computation_types.FunctionType( computation_types.StructType([ value_type, zero_type, accumulate_type, merge_type, identity_report_type ]), computation_types.at_server(zero_type)) aggr_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_AGGREGATE, aggr_type) zero = await (await self._executor.create_selection(arg, index=1)).compute() accumulate = arg.internal_representation[2] merge = arg.internal_representation[3] report = arg.internal_representation[4] async def _child_fn(ex, v): py_typecheck.check_type(v, executor_value_base.ExecutorValue) arg_values = [ ex.create_value(zero, zero_type), ex.create_value(accumulate, accumulate_type), ex.create_value(merge, merge_type), ex.create_value(identity_report, identity_report_type) ] aggr_func, aggr_args = await asyncio.gather( ex.create_value(aggr_comp, aggr_type), ex.create_struct([v] + list(await asyncio.gather(*arg_values)))) child_result = await (await ex.create_call(aggr_func, aggr_args)).compute() result_at_server = await self._server_executor.create_value( child_result, zero_type) return result_at_server val_futures = asyncio.as_completed( [_child_fn(c, v) for c, v in zip(self._target_executors, val)]) parent_merge, parent_report = await asyncio.gather( self._server_executor.create_value(merge, merge_type), self._server_executor.create_value(report, report_type)) merge_result = await next(val_futures) for next_val_future in val_futures: next_val = await next_val_future merge_arg = await self._server_executor.create_struct( [merge_result, next_val]) merge_result = await self._server_executor.create_call( parent_merge, merge_arg) report_result = await self._server_executor.create_call( parent_report, merge_result) return FederatedComposingStrategyValue( report_result, computation_types.at_server(report_type.result))
def create_dummy_intrinsic_def_federated_zip_at_server(): value = intrinsic_defs.FEDERATED_ZIP_AT_SERVER type_signature = computation_types.FunctionType([ computation_types.at_server(tf.float32), computation_types.at_server(tf.float32) ], computation_types.at_server([tf.float32, tf.float32])) return value, type_signature
def test_with_federated_map_and_broadcast(self): eager_ex = eager_tf_executor.EagerTFExecutor() factory = federated_resolving_strategy.FederatedResolvingStrategy.factory( { placement_literals.SERVER: eager_ex, placement_literals.CLIENTS: [eager_ex for _ in range(3)] }) federated_ex = federating_executor.FederatingExecutor( factory, eager_ex) ex = reference_resolving_executor.ReferenceResolvingExecutor( federated_ex) loop = asyncio.get_event_loop() @computations.tf_computation(tf.int32) def add_one(x): return x + 1 @computations.federated_computation( computation_types.at_server(tf.int32)) def comp(x): return intrinsics.federated_map(add_one, intrinsics.federated_broadcast(x)) v1 = loop.run_until_complete(ex.create_value(comp)) v2 = loop.run_until_complete( ex.create_value(10, computation_types.at_server(tf.int32))) v3 = loop.run_until_complete(ex.create_call(v1, v2)) result = loop.run_until_complete(v3.compute()) self.assertCountEqual([x.numpy() for x in result], [11, 11, 11])
def test_type_properties_unweighted(self, value_type): value_type = computation_types.to_type(value_type) factory_ = mean.UnweightedMeanFactory() self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.at_clients(value_type) result_value_type = computation_types.at_server(value_type) expected_state_type = computation_types.at_server(()) expected_measurements_type = computation_types.at_server( collections.OrderedDict(mean_value=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_type_properties(self, value_type, weight_type): mean_f = mean_factory.MeanFactory() self.assertIsInstance(mean_f, factory.WeightedAggregationFactory) value_type = computation_types.to_type(value_type) weight_type = computation_types.to_type(weight_type) process = mean_f.create_weighted(value_type, weight_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) param_value_type = computation_types.FederatedType( value_type, placements.CLIENTS) result_value_type = computation_types.FederatedType( value_type, placements.SERVER) expected_state_type = computation_types.at_server( collections.OrderedDict(value_sum_process=(), weight_sum_process=())) expected_measurements_type = computation_types.at_server( collections.OrderedDict(value_sum_process=(), weight_sum_process=())) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=param_value_type, weight=computation_types.at_clients(weight_type)), result=measured_process.MeasuredProcessOutput( expected_state_type, result_value_type, expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def create_dummy_intrinsic_def_federated_apply(): value = intrinsic_defs.FEDERATED_APPLY type_signature = computation_types.FunctionType([ type_factory.unary_op(tf.float32), computation_types.at_server(tf.float32), ], computation_types.at_server(tf.float32)) return value, type_signature
def test_serialize_deserialize_federated_at_server(self): x = 10 x_type = computation_types.at_server(tf.int32) value_proto, value_type = executor_serialization.serialize_value(x, x_type) self.assertIsInstance(value_proto, executor_pb2.Value) self.assert_types_identical(value_type, computation_types.at_server(tf.int32)) y, type_spec = executor_serialization.deserialize_value(value_proto) self.assert_types_identical(type_spec, x_type) self.assertEqual(y, 10)
def _build_expected_test_quant_model_eval_signature(): """Returns signature for build_federated_evaluation using TestModelQuant.""" weights_parameter_type = computation_types.at_server( model_utils.weights_type_from_model(TestModelQuant)) data_parameter_type = computation_types.at_clients( computation_types.SequenceType( collections.OrderedDict(temp=computation_types.TensorType( shape=(None, ), dtype=tf.float32)))) return_type = collections.OrderedDict( num_same=computation_types.at_server(tf.float32)) return computation_types.FunctionType(parameter=collections.OrderedDict( server_model_weights=weights_parameter_type, federated_dataset=data_parameter_type), result=return_type)
def _temperature_sensor_example_next_fn(): @computations.tf_computation(computation_types.SequenceType(tf.float32), tf.float32) def count_over(ds, t): return ds.reduce( np.float32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32)) @computations.tf_computation(computation_types.SequenceType(tf.float32)) def count_total(ds): return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0) @computations.federated_computation( computation_types.at_clients(computation_types.SequenceType( tf.float32)), computation_types.at_server(tf.float32)) def comp(temperatures, threshold): return intrinsics.federated_mean( intrinsics.federated_map( count_over, intrinsics.federated_zip( [temperatures, intrinsics.federated_broadcast(threshold)])), intrinsics.federated_map(count_total, temperatures)) return comp
def _measurements_type(bound_type): return computation_types.at_server( collections.OrderedDict( secure_upper_clipped_count=secure_factory.COUNT_TF_TYPE, secure_lower_clipped_count=secure_factory.COUNT_TF_TYPE, secure_upper_threshold=bound_type, secure_lower_threshold=bound_type))
async def compute_federated_mean( self, arg: FederatedComposingStrategyValue ) -> FederatedComposingStrategyValue: type_analysis.check_federated_type( arg.type_signature, placement=placement_literals.CLIENTS) member_type = arg.type_signature.member async def _create_total(): total = await self.compute_federated_sum(arg) total = await total.compute() return await self._server_executor.create_value(total, member_type) async def _create_factor(): cardinalities = await self._get_cardinalities() count = sum(cardinalities) return await executor_utils.embed_tf_constant( self._server_executor, member_type, float(1.0 / count)) async def _create_multiply_arg(): total, factor = await asyncio.gather(_create_total(), _create_factor()) return await self._server_executor.create_struct([total, factor]) multiply_fn, multiply_arg = await asyncio.gather( executor_utils.embed_tf_binary_operator(self._server_executor, member_type, tf.multiply), _create_multiply_arg()) result = await self._server_executor.create_call( multiply_fn, multiply_arg) type_signature = computation_types.at_server(member_type) return FederatedComposingStrategyValue(result, type_signature)
def create_dummy_intrinsic_def_federated_weighted_mean(): value = intrinsic_defs.FEDERATED_WEIGHTED_MEAN type_signature = computation_types.FunctionType([ computation_types.at_clients(tf.float32), computation_types.at_clients(tf.float32), ], computation_types.at_server(tf.float32)) return value, type_signature
def test_allows_assignable_but_not_equal_zero_and_reduction_types(self): factory = intrinsic_factory.IntrinsicFactory( context_stack_impl.context_stack) element_type = tf.string zero_type = computation_types.TensorType(tf.string, [1]) reduced_type = computation_types.TensorType(tf.string, [None]) @computations.tf_computation(reduced_type, element_type) @computations.check_returns_type(reduced_type) def append(accumulator, element): return tf.concat([accumulator, [element]], 0) @computations.tf_computation @computations.check_returns_type(zero_type) def zero(): return tf.convert_to_tensor(['The beginning']) @computations.federated_computation( computation_types.at_clients(element_type)) @computations.check_returns_type( computation_types.at_server(reduced_type)) def collect(client_values): return factory.federated_reduce(client_values, zero(), append) self.assertEqual(collect.type_signature.compact_representation(), '({string}@CLIENTS -> string[?]@SERVER)')
def create( self, value_type: factory.ValueType ) -> aggregation_process.AggregationProcess: @computations.federated_computation() def init_fn(): # Empty/null state, nothing is tracked across invocations. return intrinsics.federated_value((), placements.SERVER) @computations.federated_computation(computation_types.at_server( ()), computation_types.at_clients(value_type)) def next_fn(unused_state, value): # Empty tuple is the `None` of TFF. empty_tuple = intrinsics.federated_value((), placements.SERVER) initial_reservoir = _build_initial_sample_reservoir(value_type) sample_value = _build_sample_value_computation( value_type, self._sample_size) merge_samples = _build_merge_samples_computation( value_type, self._sample_size) finalize_sample = _build_finalize_sample_computation(value_type) samples = intrinsics.federated_aggregate(value, zero=initial_reservoir, accumulate=sample_value, merge=merge_samples, report=finalize_sample) return measured_process.MeasuredProcessOutput( state=empty_tuple, result=samples, measurements=empty_tuple) return aggregation_process.AggregationProcess(init_fn, next_fn)
def create_dummy_intrinsic_def_federated_collect(): value = intrinsic_defs.FEDERATED_COLLECT type_signature = computation_types.FunctionType( computation_types.at_clients(tf.float32), computation_types.at_server(computation_types.SequenceType( tf.float32))) return value, type_signature
class CreateIdentityTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters( ('int', computation_types.TensorType(tf.int32), 10), ('unnamed_tuple', computation_types.StructType([tf.int32, tf.float32]), structure.Struct([(None, 10), (None, 10.0)])), ('named_tuple', computation_types.StructType([ ('a', tf.int32), ('b', tf.float32) ]), structure.Struct([('a', 10), ('b', 10.0)])), ('sequence', computation_types.SequenceType(tf.int32), [10] * 3), ) # pyformat: enable def test_returns_computation(self, type_signature, value): proto, _ = tensorflow_computation_factory.create_identity( type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = type_factory.unary_op(type_signature) self.assertEqual(actual_type, expected_type) actual_result = test_utils.run_tensorflow(proto, value) self.assertEqual(actual_result, value) @parameterized.named_parameters( ('none', None), ('federated_type', computation_types.at_server(tf.int32)), ) def test_raises_type_error(self, type_signature): with self.assertRaises(TypeError): tensorflow_computation_factory.create_identity(type_signature)
def create_dummy_intrinsic_def_federated_secure_sum(): value = intrinsic_defs.FEDERATED_SECURE_SUM type_signature = computation_types.FunctionType([ computation_types.at_clients(tf.int32), tf.int32, ], computation_types.at_server(tf.int32)) return value, type_signature
class CreateReplicateInputTest(parameterized.TestCase): @parameterized.named_parameters( ('int', computation_types.TensorType(tf.int32), 3, 10), ('float', computation_types.TensorType(tf.float32), 3, 10.0), ('unnamed_tuple', computation_types.StructType([tf.int32, tf.float32]), 3, structure.Struct([(None, 10), (None, 10.0)])), ('named_tuple', computation_types.StructType([ ('a', tf.int32), ('b', tf.float32) ]), 3, structure.Struct([('a', 10), ('b', 10.0)])), ('sequence', computation_types.SequenceType(tf.int32), 3, [10] * 3), ) def test_returns_computation(self, type_signature, count, value): proto, _ = tensorflow_computation_factory.create_replicate_input( type_signature, count) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType( type_signature, [type_signature] * count) expected_type.check_assignable_from(actual_type) actual_result = test_utils.run_tensorflow(proto, value) expected_result = structure.Struct([(None, value)] * count) self.assertEqual(actual_result, expected_result) @parameterized.named_parameters( ('none_type', None, 3), ('none_count', computation_types.TensorType(tf.int32), None), ('federated_type', computation_types.at_server(tf.int32), 3), ) def test_raises_type_error(self, type_signature, count): with self.assertRaises(TypeError): tensorflow_computation_factory.create_replicate_input( type_signature, count)
def test_federated_zip_with_single_named_bool_server(self): x = _mock_data_of_type( computation_types.StructType([ ('a', computation_types.at_server(tf.bool)) ])) val = intrinsics.federated_zip(x) self.assert_value(val, '<a=bool>@SERVER')
def create_intrinsic_def_federated_secure_sum(value_type, bitwidth_type): value = intrinsic_defs.FEDERATED_SECURE_SUM type_signature = computation_types.FunctionType([ computation_types.at_clients(value_type), bitwidth_type, ], computation_types.at_server(value_type)) return value, type_signature
def _build_expected_broadcaster_next_signature(): """Returns signature of the broadcaster used in multiple tests below.""" state_type = computation_types.at_server( computation_types.StructType([('trainable', [ (), ]), ('non_trainable', [])])) value_type = computation_types.at_server( model_utils.weights_type_from_model(TestModelQuant)) result_type = computation_types.at_clients( model_utils.weights_type_from_model(TestModelQuant)) measurements_type = computation_types.at_server(()) return computation_types.FunctionType( parameter=collections.OrderedDict(state=state_type, value=value_type), result=collections.OrderedDict(state=state_type, result=result_type, measurements=measurements_type))
async def _compute_apply_fn(): apply_type = computation_types.FunctionType( computation_types.StructType( [divide_blk.type_signature, zip2_type.result]), computation_types.at_server(divide_blk.type_signature.result)) apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY, apply_type) return await executor.create_value(apply_comp, apply_type)
def create_dummy_intrinsic_def_federated_reduce(): value = intrinsic_defs.FEDERATED_REDUCE type_signature = computation_types.FunctionType([ computation_types.at_clients(tf.float32), tf.float32, type_factory.reduction_op(tf.float32, tf.float32), ], computation_types.at_server(tf.float32)) return value, type_signature
def test_returns_value_with_federated_type_at_server(self): value = [eager_tf_executor.EagerValue(10.0, None, tf.float32)] type_signature = computation_types.at_server(tf.float32) value = federated_resolving_strategy.FederatedResolvingStrategyValue( value, type_signature) result = self.run_sync(value.compute()) self.assertEqual(result, 10.0)