async def _get_cardinalities_helper(): """Helper function which does the actual work of fetching cardinalities.""" one_type = type_factory.at_clients(tf.int32, all_equal=True) sum_type = computation_types.FunctionType( type_factory.at_clients(tf.int32), type_factory.at_server(tf.int32)) sum_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_SUM, sum_type) async def _count_leaf_executors(ex): """Counts the total number of leaf executors under `ex`.""" one_fut = ex.create_value(1, one_type) sum_comp_fut = ex.create_value(sum_comp, sum_type) one_val, sum_comp_val = tuple(await asyncio.gather( one_fut, sum_comp_fut)) sum_result = await (await ex.create_call(sum_comp_val, one_val)).compute() if isinstance(sum_result, tf.Tensor): return sum_result.numpy() else: return sum_result return await asyncio.gather( *[_count_leaf_executors(c) for c in self._child_executors])
def test_federated_weighted_mean_with_floats(self): loop, ex = _make_test_runtime(num_clients=4, use_lambda_executor=True) v1 = loop.run_until_complete( ex.create_value([1.0, 2.0, 3.0, 4.0], type_factory.at_clients(tf.float32))) self.assertEqual(str(v1.type_signature), '{float32}@CLIENTS') v2 = loop.run_until_complete( ex.create_value([5.0, 10.0, 3.0, 2.0], type_factory.at_clients(tf.float32))) self.assertEqual(str(v2.type_signature), '{float32}@CLIENTS') v3 = loop.run_until_complete( ex.create_tuple( anonymous_tuple.AnonymousTuple([(None, v1), (None, v2)]))) self.assertEqual(str(v3.type_signature), '<{float32}@CLIENTS,{float32}@CLIENTS>') v4 = loop.run_until_complete( ex.create_value( intrinsic_defs.FEDERATED_WEIGHTED_MEAN, computation_types.FunctionType([ type_factory.at_clients(tf.float32), type_factory.at_clients(tf.float32) ], type_factory.at_server(tf.float32)))) self.assertEqual( str(v4.type_signature), '(<{float32}@CLIENTS,{float32}@CLIENTS> -> float32@SERVER)') v5 = loop.run_until_complete(ex.create_call(v4, v3)) self.assertEqual(str(v5.type_signature), 'float32@SERVER') result = loop.run_until_complete(v5.compute()) self.assertAlmostEqual(result.numpy(), 2.1, places=3)
def test_create_selection_by_index_anonymous_tuple_backed(self): loop = asyncio.get_event_loop() ex = _make_test_executor(num_clients=4) v1 = loop.run_until_complete( ex.create_value([1.0, 2.0, 3.0, 4.0], type_factory.at_clients(tf.float32))) self.assertEqual(str(v1.type_signature), '{float32}@CLIENTS') v2 = loop.run_until_complete( ex.create_value([5.0, 10.0, 3.0, 2.0], type_factory.at_clients(tf.float32))) self.assertEqual(str(v2.type_signature), '{float32}@CLIENTS') v3 = loop.run_until_complete( ex.create_tuple( anonymous_tuple.AnonymousTuple([(None, v1), (None, v2)]))) self.assertEqual(str(v3.type_signature), '<{float32}@CLIENTS,{float32}@CLIENTS>') v4 = loop.run_until_complete(ex.create_selection(v3, index=0)) self.assertEqual(str(v4.type_signature), '{float32}@CLIENTS') result = tf.nest.map_structure(lambda x: x.numpy(), loop.run_until_complete(v4.compute())) self.assertCountEqual(result, [1, 2, 3, 4])
async def _compute_intrinsic_federated_map(self, arg): py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) fn_type = arg.type_signature[0] py_typecheck.check_type(fn_type, computation_types.FunctionType) val_type = arg.type_signature[1] type_utils.check_federated_type(val_type, fn_type.parameter, placement_literals.CLIENTS) fn = arg.internal_representation[0] val = arg.internal_representation[1] py_typecheck.check_type(fn, pb.Computation) py_typecheck.check_type(val, list) map_type = computation_types.FunctionType( [fn_type, type_factory.at_clients(fn_type.parameter)], type_factory.at_clients(fn_type.result)) map_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_MAP, map_type) async def _child_fn(ex, v): py_typecheck.check_type(v, executor_value_base.ExecutorValue) fn_val = await ex.create_value(fn, fn_type) map_val, map_arg = tuple(await asyncio.gather( ex.create_value(map_comp, map_type), ex.create_tuple([fn_val, v]))) return await ex.create_call(map_val, map_arg) result_vals = await asyncio.gather( *[_child_fn(c, v) for c, v in zip(self._child_executors, val)]) return CompositeValue(result_vals, type_factory.at_clients(fn_type.result))
def create_dummy_intrinsic_def_federated_map(): value = intrinsic_defs.FEDERATED_MAP type_signature = computation_types.FunctionType([ type_factory.unary_op(tf.float32), type_factory.at_clients(tf.float32), ], type_factory.at_clients(tf.float32)) return value, type_signature
def create_dummy_intrinsic_def_federated_map_all_equal(): value = intrinsic_defs.FEDERATED_MAP_ALL_EQUAL type_signature = computation_types.FunctionType([ type_factory.unary_op(tf.float32), type_factory.at_clients(tf.float32, all_equal=True), ], type_factory.at_clients(tf.float32, all_equal=True)) return value, type_signature
def create_dummy_intrinsic_def_federated_zip_at_clients(): value = intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS type_signature = computation_types.FunctionType([ type_factory.at_clients(tf.float32), type_factory.at_clients(tf.float32) ], type_factory.at_clients([tf.float32, tf.float32])) return value, type_signature
def create_dummy_intrinsic_def_federated_weighted_mean(): value = intrinsic_defs.FEDERATED_WEIGHTED_MEAN type_signature = computation_types.FunctionType([ type_factory.at_clients(tf.float32), type_factory.at_clients(tf.float32), ], type_factory.at_server(tf.float32)) return value, type_signature
def test_federated_weighted_mean(self): @computations.federated_computation(type_factory.at_clients( tf.float32), type_factory.at_clients(tf.float32)) def comp(x, y): return intrinsics.federated_mean(x, y) result = comp([float(x + 1) for x in range(12)], [1.0, 2.0, 3.0] * 4) self.assertAlmostEqual(result, 6.83333333333, places=3)
def test_federated_weighted_mean(self): @computations.federated_computation(type_factory.at_clients( tf.float32), type_factory.at_clients(tf.float32)) def comp(x, y): return intrinsics.federated_mean(x, y) executor, num_clients = _create_test_executor() arg = ([float(x + 1) for x in range(num_clients)], [1.0, 2.0, 3.0] * 4) result = _invoke(executor, comp, arg) self.assertAlmostEqual(result, 6.83333333333, places=3)
def test_conflicting_cardinalities_within_call(self): @computations.federated_computation( [type_factory.at_clients(tf.int32), type_factory.at_clients(tf.int32)]) def comp(x): return x five_ints = list(range(5)) ten_ints = list(range(10)) with context_stack_impl.context_stack.install(_test_ctx()): with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'): comp([five_ints, ten_ints])
def test_with_temperature_sensor_example(self, executor): @computations.tf_computation(computation_types.SequenceType( tf.float32), tf.float32) def count_over(ds, t): return ds.reduce( np.float32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32)) @computations.tf_computation(computation_types.SequenceType(tf.float32) ) def count_total(ds): return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0) @computations.federated_computation( type_factory.at_clients(computation_types.SequenceType( tf.float32)), type_factory.at_server(tf.float32)) def comp(temperatures, threshold): return intrinsics.federated_mean( intrinsics.federated_map( count_over, intrinsics.federated_zip([ temperatures, intrinsics.federated_broadcast(threshold) ])), intrinsics.federated_map(count_total, temperatures)) with executor_test_utils.install_executor(executor): to_float = lambda x: tf.cast(x, tf.float32) temperatures = [ tf.data.Dataset.range(10).map(to_float), tf.data.Dataset.range(20).map(to_float), tf.data.Dataset.range(30).map(to_float), ] threshold = 15.0 result = comp(temperatures, threshold) self.assertAlmostEqual(result, 8.333, places=3)
def create_dummy_intrinsic_def_federated_secure_sum(): value = intrinsic_defs.FEDERATED_SECURE_SUM type_signature = computation_types.FunctionType([ type_factory.at_clients(tf.float32), tf.float32, ], type_factory.at_server(tf.float32)) return value, type_signature
def test_conflicting_cardinalities_within_call(self): @computations.federated_computation([ type_factory.at_clients(tf.int32), type_factory.at_clients(tf.int32), ]) def comp(x): return x five_ints = list(range(5)) ten_ints = list(range(10)) executor = executor_stacks.local_executor_factory() with executor_test_utils.install_executor(executor): with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'): comp([five_ints, ten_ints])
def test_get_size_info(self, num_clients): @computations.federated_computation( type_factory.at_clients(computation_types.SequenceType(tf.float32)), type_factory.at_server(tf.float32)) def comp(temperatures, threshold): client_data = [temperatures, intrinsics.federated_broadcast(threshold)] result_map = intrinsics.federated_map( count_over, intrinsics.federated_zip(client_data)) count_map = intrinsics.federated_map(count_total, temperatures) return intrinsics.federated_mean(result_map, count_map) factory = executor_stacks.sizing_executor_factory(num_clients=num_clients) default_executor.set_default_executor(factory) to_float = lambda x: tf.cast(x, tf.float32) temperatures = [tf.data.Dataset.range(10).map(to_float)] * num_clients threshold = 15.0 comp(temperatures, threshold) # Each client receives a tf.float32 and uploads two tf.float32 values. expected_broadcast_bits = num_clients * 32 expected_aggregate_bits = expected_broadcast_bits * 2 expected = ({ (('CLIENTS', num_clients),): [[1, tf.float32]] * num_clients }, { (('CLIENTS', num_clients),): [[1, tf.float32]] * num_clients * 2 }, [expected_broadcast_bits], [expected_aggregate_bits]) self.assertEqual(expected, factory.get_size_info())
def test_raises_value_error_with_unexpected_federated_type_clients(self): executor = create_test_executor() value = [10, 20] type_signature = type_factory.at_clients(tf.int32) with self.assertRaises(ValueError): self.run_sync(executor.create_value(value, type_signature))
def _temperature_sensor_example_next_fn(): @computations.tf_computation(computation_types.SequenceType(tf.float32), tf.float32) def count_over(ds, t): return ds.reduce( np.float32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32)) @computations.tf_computation(computation_types.SequenceType(tf.float32)) def count_total(ds): return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0) @computations.federated_computation( type_factory.at_clients(computation_types.SequenceType(tf.float32)), type_factory.at_server(tf.float32)) def comp(temperatures, threshold): return intrinsics.federated_mean( intrinsics.federated_map( count_over, intrinsics.federated_zip( [temperatures, intrinsics.federated_broadcast(threshold)])), intrinsics.federated_map(count_total, temperatures)) return comp
def test_federated_mean(self): @computations.federated_computation(type_factory.at_clients(tf.float32) ) def comp(x): return intrinsics.federated_mean(x) self.assertEqual(comp([float(x + 1) for x in range(12)]), 6.5)
def test_raises_type_error_with_unembedded_federated_type(self): value = [10.0, 11.0, 12.0] type_signature = type_factory.at_clients(tf.float32) value = federating_executor.FederatingExecutorValue( value, type_signature) with self.assertRaises(TypeError): self.run_sync(value.compute())
def create_dummy_intrinsic_def_federated_reduce(): value = intrinsic_defs.FEDERATED_REDUCE type_signature = computation_types.FunctionType([ type_factory.at_clients(tf.float32), tf.float32, type_factory.reduction_op(tf.float32, tf.float32), ], type_factory.at_server(tf.float32)) return value, type_signature
def test_execution_with_inferred_clients_larger_than_fanout( self, executor_factory_fn): @computations.federated_computation(type_factory.at_clients(tf.int32)) def foo(x): return intrinsics.federated_sum(x) with _execution_context(executor_factory_fn(max_fanout=3)): self.assertEqual(foo([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 55)
def test_returns_value_with_federated_type_clients_all_equal(self): value = [eager_tf_executor.EagerValue(10.0, None, tf.float32)] type_signature = type_factory.at_clients(tf.float32, all_equal=True) value = federating_executor.FederatingExecutorValue(value, type_signature) result = self.run_sync(value.compute()) self.assertEqual(result, 10.0)
def test_with_federated_value_as_a_non_py_list(self, val): loop, ex = _make_test_runtime(num_clients=4) v = loop.run_until_complete( ex.create_value(val, type_factory.at_clients(tf.int32))) self.assertEqual(str(v.type_signature), '{int32}@CLIENTS') result = tf.nest.map_structure(lambda x: x.numpy(), loop.run_until_complete(v.compute())) self.assertCountEqual(result, [1, 2, 3, 4])
def test_with_num_clients_larger_than_fanout(self): set_default_executor.set_default_executor( executor_stacks.create_local_executor(max_fanout=3)) @computations.federated_computation(type_factory.at_clients(tf.int32)) def foo(x): return intrinsics.federated_sum(x) self.assertEqual(foo([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), 55)
async def _get_cardinalities(self): one_type = type_factory.at_clients(tf.int32, all_equal=True) sum_type = computation_types.FunctionType( type_factory.at_clients(tf.int32), type_factory.at_server(tf.int32)) sum_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_SUM, sum_type) async def _child_fn(ex): return await (await ex.create_call(*(await asyncio.gather( ex.create_value(sum_comp, sum_type), ex.create_value(1, one_type)))) ).compute() def _materialize(v): return v.numpy() if isinstance(v, tf.Tensor) else v return [ _materialize(x) for x in (await asyncio.gather( *[_child_fn(c) for c in self._child_executors])) ]
def test_federated_mean(self): @computations.federated_computation(type_factory.at_clients(tf.float32) ) def comp(x): return intrinsics.federated_mean(x) executor, num_clients = _create_test_executor() arg = [float(x + 1) for x in range(num_clients)] result = _invoke(executor, comp, arg) self.assertEqual(result, 6.5)
def test_serialize_deserialize_federated_at_clients(self): x = [10, 20] x_type = type_factory.at_clients(tf.int32) value_proto, value_type = executor_service_utils.serialize_value( x, x_type) self.assertIsInstance(value_proto, executor_pb2.Value) self.assertEqual(str(value_type), '{int32}@CLIENTS') y, type_spec = executor_service_utils.deserialize_value(value_proto) self.assertEqual(str(type_spec), str(x_type)) self.assertEqual(y, [10, 20])
async def _compute_intrinsic_federated_zip_at_clients(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.NamedTupleType) py_typecheck.check_len(arg.type_signature, 2) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) keys = [k for k, _ in anonymous_tuple.to_elements(arg.type_signature)] vals = [arg.internal_representation[n] for n in [0, 1]] types = [arg.type_signature[n] for n in [0, 1]] for n in [0, 1]: type_utils.check_federated_type( types[n], placement=placement_literals.CLIENTS) types[n] = type_factory.at_clients(types[n].member) py_typecheck.check_type(vals[n], list) py_typecheck.check_len(vals[n], len(self._child_executors)) item_type = computation_types.NamedTupleType([ ((keys[n], types[n].member) if keys[n] else types[n].member) for n in [0, 1] ]) result_type = type_factory.at_clients(item_type) zip_type = computation_types.FunctionType( computation_types.NamedTupleType([ ((keys[n], types[n]) if keys[n] else types[n]) for n in [0, 1] ]), result_type) zip_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip_type) async def _child_fn(ex, x, y): py_typecheck.check_type(x, executor_value_base.ExecutorValue) py_typecheck.check_type(y, executor_value_base.ExecutorValue) return await ex.create_call( await ex.create_value(zip_comp, zip_type), await ex.create_tuple( anonymous_tuple.AnonymousTuple([(keys[0], x), (keys[1], y)]))) result = await asyncio.gather(*[ _child_fn(c, x, y) for c, x, y in zip(self._child_executors, vals[0], vals[1]) ]) return CompositeValue(result, result_type)
def _federated_weighted_mean(self, arg): type_utils.check_valid_federated_weighted_mean_argument_tuple_type( arg.type_signature) v_type = arg.type_signature[0].member total = sum(arg.value[1]) products_val = [ multiply_by_scalar(ComputedValue(v, v_type), w / total).value for v, w in zip(arg.value[0], arg.value[1]) ] return self._federated_sum( ComputedValue(products_val, type_factory.at_clients(v_type)))
def test_execution_with_inferred_clients_larger_than_fanout( self, executor_factory_fn): @computations.federated_computation(type_factory.at_clients(tf.int32)) def foo(x): return intrinsics.federated_sum(x) executor = executor_factory_fn(max_fanout=3) with executor_test_utils.install_executor(executor): result = foo([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) self.assertEqual(result, 55)