def _federated_zip_at_clients(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.NamedTupleType) py_typecheck.check_type(arg.value, anonymous_tuple.AnonymousTuple) zip_args = [] zip_arg_types = [] for idx in range(len(arg.type_signature)): val = arg.value[idx] py_typecheck.check_type(val, list) zip_args.append(val) val_type = arg.type_signature[idx] type_utils.check_federated_type(val_type, None, placements.CLIENTS, False) zip_arg_types.append(val_type.member) zipped_val = [anonymous_tuple.from_container(x) for x in zip(*zip_args)] return ComputedValue( zipped_val, type_factory.at_clients( computation_types.NamedTupleType(zip_arg_types)))
async def _compute_intrinsic_federated_zip_at_server(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.NamedTupleType) py_typecheck.check_len(arg.type_signature, 2) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) for n in [0, 1]: type_utils.check_federated_type( arg.type_signature[n], placement=placement_literals.SERVER, all_equal=True) return CompositeValue( await self._parent_executor.create_tuple( [arg.internal_representation[n] for n in [0, 1]]), type_factory.at_server( computation_types.NamedTupleType( [arg.type_signature[0].member, arg.type_signature[1].member])))
async def _compute_intrinsic_federated_collect(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) type_utils.check_federated_type(arg.type_signature, placement=placement_literals.CLIENTS) val = arg.internal_representation py_typecheck.check_type(val, list) member_type = arg.type_signature.member child = self._target_executors[placement_literals.SERVER][0] collected_items = await child.create_value( await asyncio.gather(*[v.compute() for v in val]), computation_types.SequenceType(member_type)) return FederatingExecutorValue( [collected_items], computation_types.FederatedType( computation_types.SequenceType(member_type), placement_literals.SERVER, all_equal=True))
async def _compute_intrinsic_federated_apply(self, arg): py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) fn_type = arg.type_signature[0] py_typecheck.check_type(fn_type, computation_types.FunctionType) val_type = arg.type_signature[1] type_utils.check_federated_type(val_type, fn_type.parameter, placement_literals.SERVER, all_equal=True) fn = arg.internal_representation[0] py_typecheck.check_type(fn, pb.Computation) val = arg.internal_representation[1] py_typecheck.check_type(val, executor_value_base.ExecutorValue) return CompositeValue( await self._parent_executor.create_call( await self._parent_executor.create_value(fn, fn_type), val), type_factory.at_server(fn_type.result))
def _federated_reduce(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.NamedTupleType) federated_type = arg.type_signature[0] type_utils.check_federated_type(federated_type, None, placements.CLIENTS, False) zero_type = arg.type_signature[1] op_type = arg.type_signature[2] py_typecheck.check_type(op_type, computation_types.FunctionType) type_utils.check_assignable_from(op_type.parameter, [zero_type, federated_type.member]) total = ComputedValue(arg.value[1], zero_type) reduce_fn = arg.value[2] for v in arg.value[0]: total = reduce_fn( ComputedValue( anonymous_tuple.AnonymousTuple([(None, total.value), (None, v)]), op_type.parameter)) return self._federated_value_at_server(total)
async def _compute_intrinsic_federated_zip_at_clients(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.NamedTupleType) py_typecheck.check_len(arg.type_signature, 2) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) py_typecheck.check_len(arg.internal_representation, 2) keys = [k for k, _ in anonymous_tuple.to_elements(arg.type_signature)] vals = [arg.internal_representation[n] for n in [0, 1]] types = [arg.type_signature[n] for n in [0, 1]] for n in [0, 1]: type_utils.check_federated_type( types[n], placement=placement_literals.CLIENTS) types[n] = type_factory.at_clients(types[n].member) py_typecheck.check_type(vals[n], list) py_typecheck.check_len(vals[n], len(self._child_executors)) item_type = computation_types.NamedTupleType([ ((keys[n], types[n].member) if keys[n] else types[n].member) for n in [0, 1] ]) result_type = type_factory.at_clients(item_type) zip_type = computation_types.FunctionType( computation_types.NamedTupleType([ ((keys[n], types[n]) if keys[n] else types[n]) for n in [0, 1] ]), result_type) zip_comp = executor_utils.create_intrinsic_comp( intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip_type) async def _child_fn(ex, x, y): py_typecheck.check_type(x, executor_value_base.ExecutorValue) py_typecheck.check_type(y, executor_value_base.ExecutorValue) return await ex.create_call( await ex.create_value(zip_comp, zip_type), await ex.create_tuple( anonymous_tuple.AnonymousTuple([(keys[0], x), (keys[1], y)]))) result = await asyncio.gather(*[ _child_fn(c, x, y) for c, x, y in zip(self._child_executors, vals[0], vals[1]) ]) return CompositeValue(result, result_type)
def test_check_federated_type(self): type_spec = computation_types.FederatedType(tf.int32, placements.CLIENTS, False) type_utils.check_federated_type(type_spec, tf.int32, placements.CLIENTS, False) type_utils.check_federated_type(type_spec, tf.int32, None, None) type_utils.check_federated_type(type_spec, None, placements.CLIENTS, None) type_utils.check_federated_type(type_spec, None, None, False) self.assertRaises(TypeError, type_utils.check_federated_type, type_spec, tf.bool, None, None) self.assertRaises(TypeError, type_utils.check_federated_type, type_spec, None, placements.SERVER, None) self.assertRaises(TypeError, type_utils.check_federated_type, type_spec, None, None, True)
def create_federated_map(fn, arg): r"""Creates a called federated map. Call / \ Intrinsic Tuple | [Comp, Comp] Args: fn: A functional `computation_building_blocks.ComputationBuildingBlock` to use as the function. arg: A `computation_building_blocks.ComputationBuildingBlock` to use as the argument. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type(fn, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) py_typecheck.check_type(arg, computation_building_blocks.ComputationBuildingBlock) type_utils.check_federated_type(arg.type_signature) parameter_type = computation_types.FederatedType(fn.type_signature.parameter, placement_literals.CLIENTS, False) result_type = computation_types.FederatedType(fn.type_signature.result, placement_literals.CLIENTS, False) intrinsic_type = computation_types.FunctionType( (fn.type_signature, parameter_type), result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, intrinsic_type) tup = computation_building_blocks.Tuple((fn, arg)) return computation_building_blocks.Call(intrinsic, tup)
async def compute_intrinsic_federated_broadcast( executor: executor_base.Executor, arg: executor_value_base.ExecutorValue ) -> executor_value_base.ExecutorValue: """Computes a federated broadcast on the given `executor`. Args: executor: The executor to use. arg: The value to broadcast. Expected to be embedded in the `executor` and have federated type placed at `tff.SERVER` with all_equal of `True`. Returns: The result embedded in `executor`. Raises: TypeError: If the arguments are of the wrong types. """ py_typecheck.check_type(executor, executor_base.Executor) py_typecheck.check_type(arg, executor_value_base.ExecutorValue) type_utils.check_federated_type( arg.type_signature, placement=placement_literals.SERVER, all_equal=True) value = await arg.compute() type_signature = computation_types.FederatedType( arg.type_signature.member, placement_literals.CLIENTS, all_equal=True) return await executor.create_value(value, type_signature)
def fit_argument(arg, type_spec, context): """Fits the given argument `arg` to match the given parameter `type_spec`. Args: arg: The argument to fit, an instance of `ComputedValue`. type_spec: The type of the parameter to fit to, an instance of `tff.Type` or something convertible to it. context: The context in which to perform the fitting, either an instance of `ComputationContext`, or `None` if unspecified. Returns: An instance of `ComputationValue` with the payload from `arg`, but matching the `type_spec` in the given context. Raises: TypeError: If the types mismatch. ValueError: If the value is invalid or does not fit the requested type. """ py_typecheck.check_type(arg, ComputedValue) type_spec = computation_types.to_type(type_spec) py_typecheck.check_type(type_spec, computation_types.Type) if context is not None: py_typecheck.check_type(context, ComputationContext) type_utils.check_assignable_from(type_spec, arg.type_signature) if arg.type_signature == type_spec: return arg elif isinstance(type_spec, computation_types.NamedTupleType): py_typecheck.check_type(arg.value, anonymous_tuple.AnonymousTuple) result_elements = [] for idx, (elem_name, elem_type) in enumerate(anonymous_tuple.to_elements(type_spec)): elem_val = ComputedValue(arg.value[idx], arg.type_signature[idx]) if elem_val != elem_type: elem_val = fit_argument(elem_val, elem_type, context) result_elements.append((elem_name, elem_val.value)) return ComputedValue( anonymous_tuple.AnonymousTuple(result_elements), type_spec) elif isinstance(type_spec, computation_types.FederatedType): type_utils.check_federated_type( arg.type_signature, placement=type_spec.placement) if arg.type_signature.all_equal: member_val = ComputedValue(arg.value, arg.type_signature.member) if type_spec.member != arg.type_signature.member: member_val = fit_argument(member_val, type_spec.member, context) if type_spec.all_equal: return ComputedValue(member_val.value, type_spec) else: cardinality = context.get_cardinality(type_spec.placement) return ComputedValue([member_val.value for _ in range(cardinality)], type_spec) elif type_spec.all_equal: raise TypeError('Cannot fit a non all-equal {} into all-equal {}.'.format( arg.type_signature, type_spec)) else: py_typecheck.check_type(arg.value, list) def _fit_member_val(x): x_val = ComputedValue(x, arg.type_signature.member) return fit_argument(x_val, type_spec.member, context).value return ComputedValue([_fit_member_val(x) for x in arg.value], type_spec) else: # TODO(b/113123634): Possibly add more conversions, e.g., for tensor types. return arg