def federated_map_all_equal(self, fn, arg): """`federated_map` with the `all_equal` bit set in the `arg` and return.""" # TODO(b/113112108): Possibly lift the restriction that the mapped value # must be placed at the clients after adding support for placement labels # in the federated types, and expanding the type specification of the # intrinsic this is based on to work with federated values of arbitrary # placement. arg = value_impl.to_value(arg, None, self._context_stack) arg = value_utils.ensure_federated_value(arg, placements.CLIENTS, 'value to be mapped') # TODO(b/113112108): Add support for polymorphic templates auto-instantiated # here based on the actual type of the argument. fn = value_impl.to_value(fn, None, self._context_stack) py_typecheck.check_type(fn, value_base.Value) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) if not type_utils.is_assignable_from(fn.type_signature.parameter, arg.type_signature.member): raise TypeError( 'The mapping function expects a parameter of type {}, but member ' 'constituents of the mapped value are of incompatible type {}.' .format(fn.type_signature.parameter, arg.type_signature.member)) fn = value_impl.ValueImpl.get_comp(fn) arg = value_impl.ValueImpl.get_comp(arg) comp = building_block_factory.create_federated_map_all_equal(fn, arg) return value_impl.ValueImpl(comp, self._context_stack)
def federated_map_all_equal(fn, arg): """`federated_map` with the `all_equal` bit set in the `arg` and return.""" # TODO(b/113112108): Possibly lift the restriction that the mapped value # must be placed at the clients after adding support for placement labels # in the federated types, and expanding the type specification of the # intrinsic this is based on to work with federated values of arbitrary # placement. arg = value_impl.to_value(arg, None) arg = value_utils.ensure_federated_value(arg, placements.CLIENTS, 'value to be mapped') fn = value_impl.to_value(fn, None, parameter_type_hint=arg.type_signature.member) py_typecheck.check_type(fn, value_impl.Value) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) if not fn.type_signature.parameter.is_assignable_from( arg.type_signature.member): raise TypeError( 'The mapping function expects a parameter of type {}, but member ' 'constituents of the mapped value are of incompatible type {}.'. format(fn.type_signature.parameter, arg.type_signature.member)) comp = building_block_factory.create_federated_map_all_equal( fn.comp, arg.comp) comp = _bind_comp_as_reference(comp) return value_impl.Value(comp)
def federated_map_all_equal(self, fn, arg): """Implements `federated_map` as defined in `api/intrinsic.py`. Implements `federated_map` as defined in `api/intrinsic.py` with an argument with the `all_equal` bit set. Args: fn: As in `api/intrinsics.py`. arg: As in `api/intrinsics.py`, with the `all_equal` bit set. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ # TODO(b/113112108): Possibly lift the restriction that the mapped value # must be placed at the clients after adding support for placement labels # in the federated types, and expanding the type specification of the # intrinsic this is based on to work with federated values of arbitrary # placement. arg = value_impl.to_value(arg, None, self._context_stack) if isinstance(arg.type_signature, computation_types.NamedTupleType): if len(anonymous_tuple.to_elements(arg.type_signature)) >= 2: # We've been passed a value which the user expects to be zipped. arg = self.federated_zip(arg) value_utils.check_federated_value_placement(arg, placements.CLIENTS, 'value to be mapped') # TODO(b/113112108): Add support for polymorphic templates auto-instantiated # here based on the actual type of the argument. fn = value_impl.to_value(fn, None, self._context_stack) py_typecheck.check_type(fn, value_base.Value) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) if not type_utils.is_assignable_from(fn.type_signature.parameter, arg.type_signature.member): raise TypeError( 'The mapping function expects a parameter of type {}, but member ' 'constituents of the mapped value are of incompatible type {}.' .format(fn.type_signature.parameter, arg.type_signature.member)) fn = value_impl.ValueImpl.get_comp(fn) arg = value_impl.ValueImpl.get_comp(arg) comp = building_block_factory.create_federated_map_all_equal(fn, arg) return value_impl.ValueImpl(comp, self._context_stack)
def test_converts_federated_map_all_equal_to_federated_map(self): fed_type_all_equal = computation_types.FederatedType( tf.int32, placements.CLIENTS, all_equal=True) normalized_fed_type = computation_types.FederatedType( tf.int32, placements.CLIENTS) int_ref = building_blocks.Reference('x', tf.int32) int_identity = building_blocks.Lambda('x', tf.int32, int_ref) federated_int_ref = building_blocks.Reference('y', fed_type_all_equal) called_federated_map_all_equal = building_block_factory.create_federated_map_all_equal( int_identity, federated_int_ref) normalized_federated_map = mapreduce_transformations.normalize_all_equal_bit( called_federated_map_all_equal) self.assertEqual(called_federated_map_all_equal.function.uri, intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri) self.assertIsInstance(normalized_federated_map, building_blocks.Call) self.assertIsInstance(normalized_federated_map.function, building_blocks.Intrinsic) self.assertEqual(normalized_federated_map.function.uri, intrinsic_defs.FEDERATED_MAP.uri) self.assertEqual(normalized_federated_map.type_signature, normalized_fed_type)
def create_dummy_called_federated_map_all_equal(parameter_name, parameter_type=tf.int32): r"""Returns a dummy called federated map. Call / \ federated_map_all_equal Tuple | [Lambda(x), data] | Ref(x) Args: parameter_name: The name of the parameter. parameter_type: The type of the parameter. """ fn = create_identity_function(parameter_name, parameter_type) arg_type = computation_types.FederatedType( parameter_type, placements.CLIENTS, all_equal=True) arg = building_blocks.Data('data', arg_type) return building_block_factory.create_federated_map_all_equal(fn, arg)