def federated_map(self, fn, arg): """Implements `federated_map` as defined in `api/intrinsics.py`.""" # TODO(b/113112108): Possibly lift the restriction that the mapped value # must be placed at the server or clients. Would occur after adding support # for placement labels in the federated types, and expanding the type # specification of the intrinsic this is based on to work with federated # values of arbitrary placement. arg = value_impl.to_value(arg, None, self._context_stack) arg = value_utils.ensure_federated_value(arg, label='value to be mapped') fn = value_impl.to_value(fn, None, self._context_stack, parameter_type_hint=arg.type_signature.member) py_typecheck.check_type(fn, value_base.Value) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) if not type_utils.is_assignable_from(fn.type_signature.parameter, arg.type_signature.member): raise TypeError( 'The mapping function expects a parameter of type {}, but member ' 'constituents of the mapped value are of incompatible type {}.' .format(fn.type_signature.parameter, arg.type_signature.member)) # TODO(b/144384398): Change structure to one that maps the placement type # to the building_block function that fits it, in a way that allows the # appropriate type checks. if arg.type_signature.placement is placements.SERVER: if not arg.type_signature.all_equal: raise TypeError( 'Arguments placed at {} should be equal at all locations.'. format(placements.SERVER)) fn = value_impl.ValueImpl.get_comp(fn) arg = value_impl.ValueImpl.get_comp(arg) comp = building_block_factory.create_federated_apply(fn, arg) elif arg.type_signature.placement is placements.CLIENTS: fn = value_impl.ValueImpl.get_comp(fn) arg = value_impl.ValueImpl.get_comp(arg) comp = building_block_factory.create_federated_map(fn, arg) else: raise TypeError( 'The argument should be placed at {} or {}, placed at {} instead.' .format(placements.SERVER, placements.CLIENTS, arg.type_signature.placement)) return value_impl.ValueImpl(comp, self._context_stack)
def federated_aggregate(self, value, zero, accumulate, merge, report): """Implements `federated_aggregate` as defined in `api/intrinsics.py`.""" value = value_impl.to_value(value, None, self._context_stack) value = value_utils.ensure_federated_value(value, placements.CLIENTS, 'value to be aggregated') zero = value_impl.to_value(zero, None, self._context_stack) py_typecheck.check_type(zero, value_base.Value) accumulate = value_impl.to_value(accumulate, None, self._context_stack) merge = value_impl.to_value(merge, None, self._context_stack) report = value_impl.to_value(report, None, self._context_stack) for op in [accumulate, merge, report]: py_typecheck.check_type(op, value_base.Value) py_typecheck.check_type(op.type_signature, computation_types.FunctionType) if not type_utils.is_assignable_from( accumulate.type_signature.parameter[0], zero.type_signature): raise TypeError('Expected `zero` to be assignable to type {}, ' 'but was of incompatible type {}.'.format( accumulate.type_signature.parameter[0], zero.type_signature)) accumulate_type_expected = type_factory.reduction_op( accumulate.type_signature.result, value.type_signature.member) merge_type_expected = type_factory.reduction_op( accumulate.type_signature.result, accumulate.type_signature.result) report_type_expected = computation_types.FunctionType( merge.type_signature.result, report.type_signature.result) for op_name, op, type_expected in [ ('accumulate', accumulate, accumulate_type_expected), ('merge', merge, merge_type_expected), ('report', report, report_type_expected) ]: if not type_utils.is_assignable_from(type_expected, op.type_signature): raise TypeError( 'Expected parameter `{}` to be of type {}, but received {} instead.' .format(op_name, type_expected, op.type_signature)) value = value_impl.ValueImpl.get_comp(value) zero = value_impl.ValueImpl.get_comp(zero) accumulate = value_impl.ValueImpl.get_comp(accumulate) merge = value_impl.ValueImpl.get_comp(merge) report = value_impl.ValueImpl.get_comp(report) comp = building_block_factory.create_federated_aggregate( value, zero, accumulate, merge, report) return value_impl.ValueImpl(comp, self._context_stack)
def federated_broadcast(self, value): """Implements `federated_broadcast` as defined in `api/intrinsics.py`.""" value = value_impl.to_value(value, None, self._context_stack) value = value_utils.ensure_federated_value(value, placement_literals.SERVER, 'value to be broadcasted') if not value.type_signature.all_equal: raise TypeError( 'The broadcasted value should be equal at all locations.') value = value_impl.ValueImpl.get_comp(value) comp = building_block_factory.create_federated_broadcast(value) comp = self._bind_comp_as_reference(comp) return value_impl.ValueImpl(comp, self._context_stack)
def federated_reduce(self, value, zero, op): """Implements `federated_reduce` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. zero: As in `api/intrinsics.py`. op: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ # TODO(b/113112108): Since in most cases, it can be assumed that CLIENTS is # a non-empty collective (or else, the computation fails), specifying zero # at this level of the API should probably be optional. TBD. value = value_impl.to_value(value, None, self._context_stack) value = value_utils.ensure_federated_value(value, placements.CLIENTS, 'value to be reduced') zero = value_impl.to_value(zero, None, self._context_stack) py_typecheck.check_type(zero, value_base.Value) # TODO(b/113112108): We need a check here that zero does not have federated # constituents. op = value_impl.to_value(op, None, self._context_stack) py_typecheck.check_type(op, value_base.Value) py_typecheck.check_type(op.type_signature, computation_types.FunctionType) op_type_expected = type_factory.reduction_op( zero.type_signature, value.type_signature.member) if not type_utils.is_assignable_from(op_type_expected, op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( op_type_expected, op.type_signature)) value = value_impl.ValueImpl.get_comp(value) zero = value_impl.ValueImpl.get_comp(zero) op = value_impl.ValueImpl.get_comp(op) comp = building_block_factory.create_federated_reduce(value, zero, op) return value_impl.ValueImpl(comp, self._context_stack)
def federated_map_all_equal(self, fn, arg): """Implements `federated_map` as defined in `api/intrinsic.py`. Implements `federated_map` as defined in `api/intrinsic.py` with an argument with the `all_equal` bit set. Args: fn: As in `api/intrinsics.py`. arg: As in `api/intrinsics.py`, with the `all_equal` bit set. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ # TODO(b/113112108): Possibly lift the restriction that the mapped value # must be placed at the clients after adding support for placement labels # in the federated types, and expanding the type specification of the # intrinsic this is based on to work with federated values of arbitrary # placement. arg = value_impl.to_value(arg, None, self._context_stack) arg = value_utils.ensure_federated_value(arg, placements.CLIENTS, 'value to be mapped') # TODO(b/113112108): Add support for polymorphic templates auto-instantiated # here based on the actual type of the argument. fn = value_impl.to_value(fn, None, self._context_stack) py_typecheck.check_type(fn, value_base.Value) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) if not type_utils.is_assignable_from(fn.type_signature.parameter, arg.type_signature.member): raise TypeError( 'The mapping function expects a parameter of type {}, but member ' 'constituents of the mapped value are of incompatible type {}.' .format(fn.type_signature.parameter, arg.type_signature.member)) fn = value_impl.ValueImpl.get_comp(fn) arg = value_impl.ValueImpl.get_comp(arg) comp = building_block_factory.create_federated_map_all_equal(fn, arg) return value_impl.ValueImpl(comp, self._context_stack)
def federated_collect(self, value): """Implements `federated_collect` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) value = value_utils.ensure_federated_value(value, placements.CLIENTS, 'value to be collected') value = value_impl.ValueImpl.get_comp(value) comp = building_block_factory.create_federated_collect(value) return value_impl.ValueImpl(comp, self._context_stack)
def federated_secure_sum(self, value, bitwidth): """Implements `federated_secure_sum` as defined in `api/intrinsics.py`.""" value = value_impl.to_value(value, None, self._context_stack) value = value_utils.ensure_federated_value(value, placements.CLIENTS, 'value to be summed') type_utils.check_is_structure_of_integers(value.type_signature) bitwidth = value_impl.to_value(bitwidth, None, self._context_stack) value_member_ty = value.type_signature.member bitwidth_ty = bitwidth.type_signature if not type_utils.are_equivalent_types(value_member_ty, bitwidth_ty): raise TypeError( 'Expected `federated_secure_sum` parameters `value` and `bitwidth` ' 'to have the same structure. Found `value` of `{}` and `bitwidth` of `{}`' .format(value_member_ty, bitwidth_ty)) value = value_impl.ValueImpl.get_comp(value) bitwidth = value_impl.ValueImpl.get_comp(bitwidth) comp = building_block_factory.create_federated_secure_sum( value, bitwidth) return value_impl.ValueImpl(comp, self._context_stack)
def federated_secure_sum(self, value, bitwidth): """Implements `federated_secure_sum` as defined in `api/intrinsics.py`.""" value = value_impl.to_value(value, None, self._context_stack) value = value_utils.ensure_federated_value(value, placements.CLIENTS, 'value to be summed') type_utils.check_is_structure_of_integers(value.type_signature) bitwidth = value_impl.to_value(bitwidth, None, self._context_stack) value_member_type = value.type_signature.member bitwidth_type = bitwidth.type_signature if not type_utils.is_valid_bitwidth_type_for_value_type( bitwidth_type, value_member_type): raise TypeError( 'Expected `federated_secure_sum` parameter `bitwidth` to match ' 'the structure of `value`, with one integer bitwidth per tensor in ' '`value`. Found `value` of `{}` and `bitwidth` of `{}`.'.format( value_member_type, bitwidth_type)) value = value_impl.ValueImpl.get_comp(value) bitwidth = value_impl.ValueImpl.get_comp(bitwidth) comp = building_block_factory.create_federated_secure_sum(value, bitwidth) return value_impl.ValueImpl(comp, self._context_stack)
def federated_reduce(self, value, zero, op): """Implements `federated_reduce` as defined in `api/intrinsics.py`.""" value = value_impl.to_value(value, None, self._context_stack) value = value_utils.ensure_federated_value(value, placement_literals.CLIENTS, 'value to be reduced') zero = value_impl.to_value(zero, None, self._context_stack) if type_analysis.contains_federated_types(zero.type_signature): raise TypeError('`zero` may not contain a federated type, found type:\n' + str(zero.type_signature)) op = value_impl.to_value( op, None, self._context_stack, parameter_type_hint=computation_types.StructType( [zero.type_signature, value.type_signature.member])) op.type_signature.check_function() if not op.type_signature.result.is_assignable_from(zero.type_signature): raise TypeError( '`zero` must be assignable to the result type from `op`:\n', computation_types.type_mismatch_error_message( zero.type_signature, op.type_signature.result, computation_types.TypeRelation.ASSIGNABLE)) op_type_expected = type_factory.reduction_op(op.type_signature.result, value.type_signature.member) if not op_type_expected.is_assignable_from(op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( op_type_expected, op.type_signature)) value = value_impl.ValueImpl.get_comp(value) zero = value_impl.ValueImpl.get_comp(zero) op = value_impl.ValueImpl.get_comp(op) comp = building_block_factory.create_federated_reduce(value, zero, op) comp = self._bind_comp_as_reference(comp) return value_impl.ValueImpl(comp, self._context_stack)
def federated_sum(self, value): """Implements `federated_sum` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) value = value_utils.ensure_federated_value(value, placements.CLIENTS, 'value to be summed') if not type_utils.is_sum_compatible(value.type_signature): raise TypeError( 'The value type {} is not compatible with the sum operator.'.format( value.type_signature)) value = value_impl.ValueImpl.get_comp(value) comp = building_block_factory.create_federated_sum(value) return value_impl.ValueImpl(comp, self._context_stack)
def _(x): x = value_impl.to_value(x, None, _context_stack) with self.assertRaises(TypeError): value_utils.ensure_federated_value(x) return x
def _(x): x = value_impl.to_value(x, None, _context_stack) value_utils.ensure_federated_value(x) return x
def _(x): x = value_impl.to_value(x, None, _context_stack) value_utils.ensure_federated_value(x, placement_literals.CLIENTS) return x