def federated_aggregate(self, value, zero, accumulate, merge, report): """Implements `federated_aggregate` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. zero: As in `api/intrinsics.py`. accumulate: As in `api/intrinsics.py`. merge: As in `api/intrinsics.py`. report: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) value_utils.check_federated_value_placement(value, placements.CLIENTS, 'value to be aggregated') zero = value_impl.to_value(zero, None, self._context_stack) py_typecheck.check_type(zero, value_base.Value) # TODO(b/113112108): We need a check here that zero does not have federated # constituents. accumulate = value_impl.to_value(accumulate, None, self._context_stack) merge = value_impl.to_value(merge, None, self._context_stack) report = value_impl.to_value(report, None, self._context_stack) for op in [accumulate, merge, report]: py_typecheck.check_type(op, value_base.Value) py_typecheck.check_type(op.type_signature, computation_types.FunctionType) accumulate_type_expected = type_factory.reduction_op( zero.type_signature, value.type_signature.member) merge_type_expected = type_factory.reduction_op( zero.type_signature, zero.type_signature) report_type_expected = computation_types.FunctionType( zero.type_signature, report.type_signature.result) for op_name, op, type_expected in [ ('accumulate', accumulate, accumulate_type_expected), ('merge', merge, merge_type_expected), ('report', report, report_type_expected) ]: if not type_utils.is_assignable_from(type_expected, op.type_signature): raise TypeError( 'Expected parameter `{}` to be of type {}, but received {} instead.' .format(op_name, type_expected, op.type_signature)) value = value_impl.ValueImpl.get_comp(value) zero = value_impl.ValueImpl.get_comp(zero) accumulate = value_impl.ValueImpl.get_comp(accumulate) merge = value_impl.ValueImpl.get_comp(merge) report = value_impl.ValueImpl.get_comp(report) comp = building_block_factory.create_federated_aggregate( value, zero, accumulate, merge, report) return value_impl.ValueImpl(comp, self._context_stack)
def federated_aggregate(self, value, zero, accumulate, merge, report): """Implements `federated_aggregate` as defined in `api/intrinsics.py`.""" value = value_impl.to_value(value, None, self._context_stack) value = value_utils.ensure_federated_value(value, placements.CLIENTS, 'value to be aggregated') zero = value_impl.to_value(zero, None, self._context_stack) py_typecheck.check_type(zero, value_base.Value) accumulate = value_impl.to_value(accumulate, None, self._context_stack) merge = value_impl.to_value(merge, None, self._context_stack) report = value_impl.to_value(report, None, self._context_stack) for op in [accumulate, merge, report]: py_typecheck.check_type(op, value_base.Value) py_typecheck.check_type(op.type_signature, computation_types.FunctionType) if not type_utils.is_assignable_from( accumulate.type_signature.parameter[0], zero.type_signature): raise TypeError('Expected `zero` to be assignable to type {}, ' 'but was of incompatible type {}.'.format( accumulate.type_signature.parameter[0], zero.type_signature)) accumulate_type_expected = type_factory.reduction_op( accumulate.type_signature.result, value.type_signature.member) merge_type_expected = type_factory.reduction_op( accumulate.type_signature.result, accumulate.type_signature.result) report_type_expected = computation_types.FunctionType( merge.type_signature.result, report.type_signature.result) for op_name, op, type_expected in [ ('accumulate', accumulate, accumulate_type_expected), ('merge', merge, merge_type_expected), ('report', report, report_type_expected) ]: if not type_utils.is_assignable_from(type_expected, op.type_signature): raise TypeError( 'Expected parameter `{}` to be of type {}, but received {} instead.' .format(op_name, type_expected, op.type_signature)) value = value_impl.ValueImpl.get_comp(value) zero = value_impl.ValueImpl.get_comp(zero) accumulate = value_impl.ValueImpl.get_comp(accumulate) merge = value_impl.ValueImpl.get_comp(merge) report = value_impl.ValueImpl.get_comp(report) comp = building_block_factory.create_federated_aggregate( value, zero, accumulate, merge, report) return value_impl.ValueImpl(comp, self._context_stack)
def federated_reduce(self, value, zero, op): """Implements `federated_reduce` as defined in `api/intrinsics.py`.""" # TODO(b/113112108): Since in most cases, it can be assumed that CLIENTS is # a non-empty collective (or else, the computation fails), specifying zero # at this level of the API should probably be optional. TBD. value = value_impl.to_value(value, None, self._context_stack) value = value_utils.ensure_federated_value(value, placements.CLIENTS, 'value to be reduced') zero = value_impl.to_value(zero, None, self._context_stack) py_typecheck.check_type(zero, value_base.Value) # TODO(b/113112108): We need a check here that zero does not have federated # constituents. op = value_impl.to_value(op, None, self._context_stack) py_typecheck.check_type(op, value_base.Value) py_typecheck.check_type(op.type_signature, computation_types.FunctionType) op_type_expected = type_factory.reduction_op(zero.type_signature, value.type_signature.member) if not type_utils.is_assignable_from(op_type_expected, op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( op_type_expected, op.type_signature)) value = value_impl.ValueImpl.get_comp(value) zero = value_impl.ValueImpl.get_comp(zero) op = value_impl.ValueImpl.get_comp(op) comp = building_block_factory.create_federated_reduce(value, zero, op) return value_impl.ValueImpl(comp, self._context_stack)
def parse_federated_aggregate_argument_types(type_spec): """Verifies and parses `type_spec` into constituents. Args: type_spec: An instance of `computation_types.NamedTupleType`. Returns: A tuple of (value_type, zero_type, accumulate_type, merge_type, report_type) for the 5 type constituents. """ py_typecheck.check_type(type_spec, computation_types.NamedTupleType) py_typecheck.check_len(type_spec, 5) value_type = type_spec[0] py_typecheck.check_type(value_type, computation_types.FederatedType) item_type = value_type.member zero_type = type_spec[1] accumulate_type = type_spec[2] type_utils.check_equivalent_types( accumulate_type, type_factory.reduction_op(zero_type, item_type)) merge_type = type_spec[3] type_utils.check_equivalent_types(merge_type, type_factory.binary_op(zero_type)) report_type = type_spec[4] py_typecheck.check_type(report_type, computation_types.FunctionType) type_utils.check_equivalent_types(report_type.parameter, zero_type) return value_type, zero_type, accumulate_type, merge_type, report_type
def sequence_reduce(self, value, zero, op): """Implements `sequence_reduce` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. zero: As in `api/intrinsics.py`. op: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) zero = value_impl.to_value(zero, None, self._context_stack) op = value_impl.to_value(op, None, self._context_stack) if isinstance(value.type_signature, computation_types.SequenceType): element_type = value.type_signature.element else: py_typecheck.check_type(value.type_signature, computation_types.FederatedType) py_typecheck.check_type(value.type_signature.member, computation_types.SequenceType) element_type = value.type_signature.member.element op_type_expected = type_factory.reduction_op(zero.type_signature, element_type) if not type_utils.is_assignable_from(op_type_expected, op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( op_type_expected, op.type_signature)) value = value_impl.ValueImpl.get_comp(value) zero = value_impl.ValueImpl.get_comp(zero) op = value_impl.ValueImpl.get_comp(op) if isinstance(value.type_signature, computation_types.SequenceType): return building_block_factory.create_sequence_reduce( value, zero, op) else: value_type = computation_types.SequenceType(element_type) intrinsic_type = computation_types.FunctionType(( value_type, zero.type_signature, op.type_signature, ), op.type_signature.result) intrinsic = building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_REDUCE.uri, intrinsic_type) ref = building_blocks.Reference('arg', value_type) tup = building_blocks.Tuple((ref, zero, op)) call = building_blocks.Call(intrinsic, tup) fn = building_blocks.Lambda(ref.name, ref.type_signature, call) fn_impl = value_impl.ValueImpl(fn, self._context_stack) if value.type_signature.placement is placements.SERVER: return self.federated_apply(fn_impl, value) elif value.type_signature.placement is placements.CLIENTS: return self.federated_map(fn_impl, value) else: raise TypeError('Unsupported placement {}.'.format( value.type_signature.placement))
def create_dummy_intrinsic_def_federated_reduce(): value = intrinsic_defs.FEDERATED_REDUCE type_signature = computation_types.FunctionType([ type_factory.at_clients(tf.float32), tf.float32, type_factory.reduction_op(tf.float32, tf.float32), ], type_factory.at_server(tf.float32)) return value, type_signature
async def _compute_intrinsic_federated_aggregate(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.NamedTupleType) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) if len(arg.internal_representation) != 5: raise ValueError( 'Expected 5 elements in the `federated_aggregate()` argument tuple, ' 'found {}.'.format(len(arg.internal_representation))) val_type = arg.type_signature[0] py_typecheck.check_type(val_type, computation_types.FederatedType) item_type = val_type.member zero_type = arg.type_signature[1] accumulate_type = arg.type_signature[2] type_utils.check_equivalent_types( accumulate_type, type_factory.reduction_op(zero_type, item_type)) merge_type = arg.type_signature[3] type_utils.check_equivalent_types(merge_type, type_factory.binary_op(zero_type)) report_type = arg.type_signature[4] py_typecheck.check_type(report_type, computation_types.FunctionType) type_utils.check_equivalent_types(report_type.parameter, zero_type) # NOTE: This is a simple initial implementation that simply forwards this # to `federated_reduce()`. The more complete implementation would be able # to take advantage of the parallelism afforded by `merge` to reduce the # cost from liner (with respect to the number of clients) to sub-linear. # TODO(b/134543154): Expand this implementation to take advantage of the # parallelism afforded by `merge`. val = arg.internal_representation[0] zero = arg.internal_representation[1] accumulate = arg.internal_representation[2] pre_report = await self._compute_intrinsic_federated_reduce( FederatedExecutorValue( anonymous_tuple.AnonymousTuple([(None, val), (None, zero), (None, accumulate)]), computation_types.NamedTupleType( [val_type, zero_type, accumulate_type]))) py_typecheck.check_type(pre_report.type_signature, computation_types.FederatedType) type_utils.check_equivalent_types(pre_report.type_signature.member, report_type.parameter) report = arg.internal_representation[4] return await self._compute_intrinsic_federated_apply( FederatedExecutorValue( anonymous_tuple.AnonymousTuple([ (None, report), (None, pre_report.internal_representation) ]), computation_types.NamedTupleType( [report_type, pre_report.type_signature])))
async def _compute_intrinsic_federated_reduce(self, arg): py_typecheck.check_type(arg.type_signature, computation_types.NamedTupleType) py_typecheck.check_type(arg.internal_representation, anonymous_tuple.AnonymousTuple) if len(arg.internal_representation) != 3: raise ValueError( 'Expected 3 elements in the `federated_reduce()` argument tuple, ' 'found {}.'.format(len(arg.internal_representation))) val_type = arg.type_signature[0] py_typecheck.check_type(val_type, computation_types.FederatedType) item_type = val_type.member zero_type = arg.type_signature[1] op_type = arg.type_signature[2] type_utils.check_equivalent_types( op_type, type_factory.reduction_op(zero_type, item_type)) val = arg.internal_representation[0] py_typecheck.check_type(val, list) child = self._target_executors[placement_literals.SERVER][0] async def _move(v): return await child.create_value(await v.compute(), item_type) items = await asyncio.gather(*[_move(v) for v in val]) zero = await child.create_value( await (await self.create_selection(arg, index=1)).compute(), zero_type) op = await child.create_value(arg.internal_representation[2], op_type) result = zero for item in items: result = await child.create_call( op, await child.create_tuple( anonymous_tuple.AnonymousTuple([(None, result), (None, item)]))) return FederatedExecutorValue([result], computation_types.FederatedType( result.type_signature, placement_literals.SERVER, all_equal=True))
# @federated_computation # def federated_aggregate(x, zero, accumulate, merge, report): # a = generic_partial_reduce(x, zero, accumulate, INTERMEDIATE_AGGREGATORS) # b = generic_reduce(a, zero, merge, SERVER) # c = generic_map(report, b) # return c # # Actual implementations might vary. # # Type signature: <{T}@CLIENTS,U,(<U,T>->U),(<U,U>->U),(U->R)> -> R@SERVER FEDERATED_AGGREGATE = IntrinsicDef( 'FEDERATED_AGGREGATE', 'federated_aggregate', computation_types.FunctionType(parameter=[ type_factory.at_clients(computation_types.AbstractType('T')), computation_types.AbstractType('U'), type_factory.reduction_op(computation_types.AbstractType('U'), computation_types.AbstractType('T')), type_factory.binary_op(computation_types.AbstractType('U')), computation_types.FunctionType(computation_types.AbstractType('U'), computation_types.AbstractType('R')) ], result=type_factory.at_server( computation_types.AbstractType('R')))) # Applies a given function to a value on the server. # # Type signature: <(T->U),T@SERVER> -> U@SERVER FEDERATED_APPLY = IntrinsicDef( 'FEDERATED_APPLY', 'federated_apply', computation_types.FunctionType(parameter=[ computation_types.FunctionType(computation_types.AbstractType('T'), computation_types.AbstractType('U')),
def test_reduction_op(self): self.assertEqual(str(type_factory.reduction_op(tf.float32, tf.int32)), '(<float32,int32> -> float32)')