def construct_map_or_apply(func, arg): """Injects intrinsic to allow application of `func` to federated `arg`. Args: func: `value_base.Value` instance of non-federated type to be wrapped with intrinsic in order to call on `arg`. arg: `computation_building_blocks.ComputationBuildingBlock` instance of federated type for which to construct intrinsic in order to call `func` on `value`. Returns: Returns `value_base.Value` instance wrapping `computation_building_blocks.Intrinsic` which can call `func` on `arg`. """ py_typecheck.check_type(func, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(arg, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) result_type = computation_types.FederatedType(func.type_signature.result, arg.type_signature.placement, arg.type_signature.all_equal) if arg.type_signature.placement == placement_literals.SERVER: intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_APPLY.uri, computation_types.FunctionType( [func.type_signature, arg.type_signature], result_type)) elif arg.type_signature.placement == placement_literals.CLIENTS: intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, computation_types.FunctionType( [func.type_signature, arg.type_signature], result_type)) return intrinsic
def construct_map_or_apply(fn, arg): """Injects intrinsic to allow application of `fn` to federated `arg`. Args: fn: An instance of `computation_building_blocks.ComputationBuildingBlock` of functional type to be wrapped with intrinsic in order to call on `arg`. arg: `computation_building_blocks.ComputationBuildingBlock` instance of federated type for which to construct intrinsic in order to call `fn` on `arg`. `member` of `type_signature` of `arg` must be assignable to `parameter` of `type_signature` of `fn`. Returns: Returns a `computation_building_blocks.Intrinsic` which can call `fn` on `arg`. """ py_typecheck.check_type(fn, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(arg, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) type_utils.check_assignable_from(fn.type_signature.parameter, arg.type_signature.member) if arg.type_signature.placement == placement_literals.SERVER: result_type = computation_types.FederatedType(fn.type_signature.result, arg.type_signature.placement, arg.type_signature.all_equal) intrinsic_type = computation_types.FunctionType( [fn.type_signature, arg.type_signature], result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_APPLY.uri, intrinsic_type) tup = computation_building_blocks.Tuple((fn, arg)) return computation_building_blocks.Call(intrinsic, tup) elif arg.type_signature.placement == placement_literals.CLIENTS: return create_federated_map(fn, arg)
def test_propogates_dependence_up_through_lambda(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', tf.int32) lam = computation_building_blocks.Lambda('x', tf.int32, dummy_intrinsic) dependent_nodes = tree_analysis.extract_nodes_consuming( lam, dummy_intrinsic_predicate) self.assertIn(lam, dependent_nodes)
def test_propogates_dependence_up_through_selection(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', [tf.int32]) selection = computation_building_blocks.Selection(dummy_intrinsic, index=0) dependent_nodes = tree_analysis.extract_nodes_consuming( selection, dummy_intrinsic_predicate) self.assertIn(selection, dependent_nodes)
def test_intrinsic_class_succeeds_simple_federated_map(self): simple_function = computation_types.FunctionType(tf.int32, tf.float32) federated_arg = computation_types.FederatedType( simple_function.parameter, placements.CLIENTS) federated_result = computation_types.FederatedType( simple_function.result, placements.CLIENTS) federated_map_concrete_type = computation_types.FunctionType( [simple_function, federated_arg], federated_result) concrete_federated_map = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, federated_map_concrete_type) self.assertIsInstance(concrete_federated_map, computation_building_blocks.Intrinsic) self.assertEqual( str(concrete_federated_map.type_signature), '(<(int32 -> float32),{int32}@CLIENTS> -> {float32}@CLIENTS)') self.assertEqual(concrete_federated_map.uri, 'federated_map') self.assertEqual(concrete_federated_map.compact_representation(), 'federated_map') concrete_federated_map_proto = concrete_federated_map.proto self.assertEqual( type_serialization.deserialize_type( concrete_federated_map_proto.type), concrete_federated_map.type_signature) self.assertEqual( concrete_federated_map_proto.WhichOneof('computation'), 'intrinsic') self.assertEqual(concrete_federated_map_proto.intrinsic.uri, concrete_federated_map.uri) self._serialize_deserialize_roundtrip_test(concrete_federated_map)
def _create_call_to_federated_map(fn, arg): r"""Creates a computation to call a federated map. Call / \ Intrinsic Tuple / \ Computation Computation Args: fn: An instance of a functional `computation_building_blocks.ComputationBuildingBlock` to use as the map function. arg: An instance of `computation_building_blocks.ComputationBuildingBlock` to use as the map argument. Returns: An instance of `computation_building_blocks.Call` wrapping the federated map computation. """ py_typecheck.check_type( fn, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type( arg, computation_building_blocks.ComputationBuildingBlock) federated_type = computation_types.FederatedType(fn.type_signature.result, placements.CLIENTS) function_type = computation_types.FunctionType( [fn.type_signature, arg.type_signature], federated_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, function_type) tup = computation_building_blocks.Tuple((fn, arg)) return computation_building_blocks.Call(intrinsic, tup)
def create_sequence_map(fn, arg): r"""Creates a called sequence map. Call / \ Intrinsic Tuple | [Comp, Comp] Args: fn: A `computation_building_blocks.ComputationBuildingBlock` to use as the function. arg: A `computation_building_blocks.ComputationBuildingBlock` to use as the argument. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( fn, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type( arg, computation_building_blocks.ComputationBuildingBlock) result_type = computation_types.SequenceType(fn.type_signature.result) intrinsic_type = computation_types.FunctionType( (fn.type_signature, arg.type_signature), result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type) values = computation_building_blocks.Tuple((fn, arg)) return computation_building_blocks.Call(intrinsic, values)
def federated_sum(self, value): """Implements `federated_sum` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.CLIENTS, 'value to be summed') if not type_utils.is_sum_compatible(value.type_signature): raise TypeError( 'The value type {} is not compatible with the sum operator.'.format( str(value.type_signature))) # TODO(b/113112108): Replace this as noted above. result_type = computation_types.FederatedType(value.type_signature.member, placements.SERVER, True) intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_SUM.uri, computation_types.FunctionType(value.type_signature, result_type)), self._context_stack) return intrinsic(value)
def sequence_reduce(self, value, zero, op): """Implements `sequence_reduce` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. zero: As in `api/intrinsics.py`. op: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) zero = value_impl.to_value(zero, None, self._context_stack) op = value_impl.to_value(op, None, self._context_stack) if isinstance(value.type_signature, computation_types.SequenceType): element_type = value.type_signature.element else: py_typecheck.check_type(value.type_signature, computation_types.FederatedType) py_typecheck.check_type(value.type_signature.member, computation_types.SequenceType) element_type = value.type_signature.member.element op_type_expected = type_constructors.reduction_op( zero.type_signature, element_type) if not type_utils.is_assignable_from(op_type_expected, op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( op_type_expected, op.type_signature)) value = value_impl.ValueImpl.get_comp(value) zero = value_impl.ValueImpl.get_comp(zero) op = value_impl.ValueImpl.get_comp(op) if isinstance(value.type_signature, computation_types.SequenceType): return computation_constructing_utils.create_sequence_reduce( value, zero, op) else: value_type = computation_types.SequenceType(element_type) intrinsic_type = computation_types.FunctionType(( value_type, zero.type_signature, op.type_signature, ), op.type_signature.result) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_REDUCE.uri, intrinsic_type) ref = computation_building_blocks.Reference('arg', value_type) tup = computation_building_blocks.Tuple((ref, zero, op)) call = computation_building_blocks.Call(intrinsic, tup) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, call) fn_impl = value_impl.ValueImpl(fn, self._context_stack) if value.type_signature.placement is placements.SERVER: return self.federated_apply(fn_impl, value) elif value.type_signature.placement is placements.CLIENTS: return self.federated_map(fn_impl, value) else: raise TypeError('Unsupported placement {}.'.format( value.type_signature.placement))
def sequence_reduce(self, value, zero, op): """Implements `sequence_reduce` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. zero: As in `api/intrinsics.py`. op: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) zero = value_impl.to_value(zero, None, self._context_stack) op = value_impl.to_value(op, None, self._context_stack) if isinstance(value.type_signature, computation_types.SequenceType): element_type = value.type_signature.element else: py_typecheck.check_type(value.type_signature, computation_types.FederatedType) py_typecheck.check_type(value.type_signature.member, computation_types.SequenceType) element_type = value.type_signature.member.element op_type_expected = type_constructors.reduction_op(zero.type_signature, element_type) if not type_utils.is_assignable_from(op_type_expected, op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( str(op_type_expected), str(op.type_signature))) sequence_reduce_building_block = computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_REDUCE.uri, computation_types.FunctionType([ computation_types.SequenceType(element_type), zero.type_signature, op.type_signature ], zero.type_signature)) if isinstance(value.type_signature, computation_types.SequenceType): sequence_reduce_intrinsic = value_impl.ValueImpl( sequence_reduce_building_block, self._context_stack) return sequence_reduce_intrinsic(value, zero, op) else: federated_mapping_fn_building_block = computation_building_blocks.Lambda( 'arg', computation_types.SequenceType(element_type), computation_building_blocks.Call( sequence_reduce_building_block, computation_building_blocks.Tuple([ computation_building_blocks.Reference( 'arg', computation_types.SequenceType(element_type)), value_impl.ValueImpl.get_comp(zero), value_impl.ValueImpl.get_comp(op) ]))) federated_mapping_fn = value_impl.ValueImpl( federated_mapping_fn_building_block, self._context_stack) if value.type_signature.placement is placements.SERVER: return self.federated_apply(federated_mapping_fn, value) elif value.type_signature.placement is placements.CLIENTS: return self.federated_map(federated_mapping_fn, value) else: raise TypeError('Unsupported placement {}.'.format( str(value.type_signature.placement)))
def federated_broadcast(self, value): """Implements `federated_broadcast` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.SERVER, 'value to be broadcasted') if not value.type_signature.all_equal: raise TypeError('The broadcasted value should be equal at all locations.') # TODO(b/113112108): Replace this hand-crafted logic here and below with # a call to a helper function that handles it in a uniform manner after # implementing support for correctly typechecking federated template types # and instantiating template types on concrete arguments. result_type = computation_types.FederatedType(value.type_signature.member, placements.CLIENTS, True) intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_BROADCAST.uri, computation_types.FunctionType(value.type_signature, result_type)), self._context_stack) return intrinsic(value)
def _create_lambda_to_add_one(dtype): r"""Creates a computation to add `1` to an argument. Lambda \ Call / \ Intrinsic Tuple / \ Reference Computation Args: dtype: The type of the argument. Returns: An instance of `computation_building_blocks.Lambda` wrapping a function that adds 1 to an argument. """ if isinstance(dtype, computation_types.TensorType): dtype = dtype.dtype py_typecheck.check_type(dtype, tf.dtypes.DType) function_type = computation_types.FunctionType([dtype, dtype], dtype) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.GENERIC_PLUS.uri, function_type) arg = computation_building_blocks.Reference('arg', dtype) constant = _create_call_to_py_fn(lambda: tf.cast(tf.constant(1), dtype)) tup = computation_building_blocks.Tuple([arg, constant]) call = computation_building_blocks.Call(intrinsic, tup) return computation_building_blocks.Lambda(arg.name, arg.type_signature, call)
def create_federated_map(fn, arg): r"""Creates a called federated map. Call / \ Intrinsic Tuple | [Comp, Comp] Args: fn: A `computation_building_blocks.ComputationBuildingBlock` to use as the function. arg: A `computation_building_blocks.ComputationBuildingBlock` to use as the argument. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( fn, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type( arg, computation_building_blocks.ComputationBuildingBlock) result_type = computation_types.FederatedType(fn.type_signature.result, placement_literals.CLIENTS, False) intrinsic_type = computation_types.FunctionType( (fn.type_signature, arg.type_signature), result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, intrinsic_type) values = computation_building_blocks.Tuple((fn, arg)) return computation_building_blocks.Call(intrinsic, values)
def create_federated_sum(value): r"""Creates a called federated sum. Call / \ Intrinsic Comp Args: value: A `computation_building_blocks.ComputationBuildingBlock` to use as the value. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) result_type = computation_types.FederatedType(value.type_signature.member, placement_literals.SERVER, True) intrinsic_type = computation_types.FunctionType(value.type_signature, result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_SUM.uri, intrinsic_type) return computation_building_blocks.Call(intrinsic, value)
def _create_lambda_to_dummy_intrinsic(type_spec, uri='dummy'): r"""Creates a lambda to call a dummy intrinsic. Lambda \ Call / \ Intrinsic Ref(arg) Args: type_spec: The type of the argument. uri: The URI of the intrinsic. Returns: A `computation_building_blocks.Lambda`. Raises: TypeError: If `type_spec` is not a `tf.dtypes.DType`. """ py_typecheck.check_type(type_spec, tf.dtypes.DType) intrinsic_type = computation_types.FunctionType(type_spec, type_spec) intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type) arg = computation_building_blocks.Reference('arg', type_spec) call = computation_building_blocks.Call(intrinsic, arg) return computation_building_blocks.Lambda(arg.name, arg.type_signature, call)
def create_federated_value(value, placement): r"""Creates a called federated value. Call / \ Intrinsic Comp Args: value: A `computation_building_blocks.ComputationBuildingBlock` to use as the value. placement: A `placement_literals.PlacementLiteral` to use as the placement. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) if placement is placement_literals.CLIENTS: uri = intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri elif placement is placement_literals.SERVER: uri = intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri else: raise TypeError('Unsupported placement {}.'.format(placement)) result_type = computation_types.FederatedType(value.type_signature, placement, True) intrinsic_type = computation_types.FunctionType(value.type_signature, result_type) intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type) return computation_building_blocks.Call(intrinsic, value)
def federated_value(self, value, placement): """Implements `federated_value` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. placement: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) # TODO(b/113112108): Verify that neither the value, nor any of its parts # are of a federated type. if placement is placements.CLIENTS: uri = intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS.uri elif placement is placements.SERVER: uri = intrinsic_defs.FEDERATED_VALUE_AT_SERVER.uri else: raise TypeError('The placement must be either CLIENTS or SERVER.') result_type = computation_types.FederatedType(value.type_signature, placement, True) intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( uri, computation_types.FunctionType(value.type_signature, result_type)), self._context_stack) return intrinsic(value)
def federated_collect(self, value): """Implements `federated_collect` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) type_utils.check_federated_value_placement(value, placements.CLIENTS, 'value to be collected') result_type = computation_types.FederatedType( computation_types.SequenceType(value.type_signature.member), placements.SERVER, True) intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_COLLECT.uri, computation_types.FunctionType(value.type_signature, result_type)), self._context_stack) return intrinsic(value)
def _create_zip_two_values(value): r"""Creates a called federated zip with two values. Call / \ Intrinsic Tuple | [Comp1, Comp2] Notice that this function will drop any names associated to the two-tuple it is processing. This is necessary due to the type signature of the underlying federated zip intrinsic, `<T@P,U@P>-><T,U>@P`. Keeping names here would violate this type signature. The names are cached at a higher level than this function, and appended to the resulting tuple in a single call to `federated_map` or `federated_apply` before the resulting structure is sent back to the caller. Args: value: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_types.NamedTupleType` containing exactly two elements. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. ValueError: If `value` does not contain exactly two elements. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) named_type_signatures = anonymous_tuple.to_elements(value.type_signature) length = len(named_type_signatures) if length != 2: raise ValueError( 'Expected a value with exactly two elements, received {} elements.' .format(named_type_signatures)) placement = value[0].type_signature.placement if placement is placement_literals.CLIENTS: uri = intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri all_equal = False elif placement is placement_literals.SERVER: uri = intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri all_equal = True else: raise TypeError('Unsupported placement {}.'.format(placement)) elements = [] for _, type_signature in named_type_signatures: federated_type = computation_types.FederatedType( type_signature.member, placement, all_equal) elements.append((None, federated_type)) parameter_type = computation_types.NamedTupleType(elements) result_type = computation_types.FederatedType( [(None, e.member) for _, e in named_type_signatures], placement, all_equal) intrinsic_type = computation_types.FunctionType(parameter_type, result_type) intrinsic = computation_building_blocks.Intrinsic(uri, intrinsic_type) return computation_building_blocks.Call(intrinsic, value)
def zip_two_tuple(input_val, context_stack): """Helper function to perform 2-tuple at a time zipping. Takes 2-tuple of federated values and returns federated 2-tuple of values. Args: input_val: 2-tuple TFF `Value` of `NamedTuple` type, whose elements must be `FederatedTypes` with the same placement. context_stack: The context stack to use, as in `impl.value_impl.to_value`. Returns: TFF `Value` of `FederatedType` with member of 2-tuple `NamedTuple` type. """ py_typecheck.check_type(input_val, value_base.Value) py_typecheck.check_type(input_val.type_signature, computation_types.NamedTupleType) py_typecheck.check_type(input_val[0].type_signature, computation_types.FederatedType) zip_uris = { placements.CLIENTS: intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS.uri, placements.SERVER: intrinsic_defs.FEDERATED_ZIP_AT_SERVER.uri, } zip_all_equal = { placements.CLIENTS: False, placements.SERVER: True, } output_placement = input_val[0].type_signature.placement if output_placement not in zip_uris: raise TypeError( 'The argument must have components placed at SERVER or ' 'CLIENTS') output_all_equal_bit = zip_all_equal[output_placement] for elem in input_val: type_utils.check_federated_value_placement(elem, output_placement) num_elements = len(anonymous_tuple.to_elements(input_val.type_signature)) if num_elements != 2: raise ValueError('The argument of zip_two_tuple must be a 2-tuple, ' 'not an {}-tuple'.format(num_elements)) result_type = computation_types.FederatedType( [(name, e.member) for name, e in anonymous_tuple.to_elements(input_val.type_signature)], output_placement, output_all_equal_bit) def _adjust_all_equal_bit(x): return computation_types.FederatedType(x.member, x.placement, output_all_equal_bit) adjusted_input_type = computation_types.NamedTupleType([ (k, _adjust_all_equal_bit(v)) if k else _adjust_all_equal_bit(v) for k, v in anonymous_tuple.to_elements(input_val.type_signature) ]) intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( zip_uris[output_placement], computation_types.FunctionType(adjusted_input_type, result_type)), context_stack) return intrinsic(input_val)
def test_passes_with_federated_map(self): intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, computation_types.FunctionType([ computation_types.FunctionType(tf.int32, tf.float32), computation_types.FederatedType(tf.int32, placements.CLIENTS) ], computation_types.FederatedType(tf.float32, placements.CLIENTS))) tree_analysis.check_intrinsics_whitelisted_for_reduction(intrinsic)
def test_returns_string_for_intrinsic(self): comp = computation_building_blocks.Intrinsic('intrinsic', tf.int32) compact_string = comp.compact_representation() self.assertEqual(compact_string, 'intrinsic') formatted_string = comp.formatted_representation() self.assertEqual(formatted_string, 'intrinsic') structural_string = comp.structural_representation() self.assertEqual(structural_string, 'intrinsic')
def _transform(comp): """Returns a new transformed computation or `comp`.""" if not _should_transform(comp): return comp, False def _create_block_to_chained_calls(comps): r"""Constructs a transformed block computation from `comps`. Block / \ [fn=Tuple] Lambda(arg) | \ [Comp(y), Comp(x)] Call / \ Sel(1) Call / / \ Ref(fn) Sel(0) Ref(arg) / Ref(fn) (let fn=<y, x> in (arg -> fn[1](fn[0](arg))) Args: comps: a Python list of computations. Returns: A `computation_building_blocks.Block`. """ functions = computation_building_blocks.Tuple(comps) fn_ref = computation_building_blocks.Reference( 'fn', functions.type_signature) arg_type = comps[0].type_signature.parameter arg_ref = computation_building_blocks.Reference('arg', arg_type) arg = arg_ref for index, _ in enumerate(comps): fn_sel = computation_building_blocks.Selection(fn_ref, index=index) call = computation_building_blocks.Call(fn_sel, arg) arg = call lam = computation_building_blocks.Lambda(arg_ref.name, arg_ref.type_signature, call) return computation_building_blocks.Block([('fn', functions)], lam) block = _create_block_to_chained_calls(( comp.argument[1].argument[0], comp.argument[0], )) arg = computation_building_blocks.Tuple([ block, comp.argument[1].argument[1], ]) intrinsic_type = computation_types.FunctionType( arg.type_signature, comp.function.type_signature.result) intrinsic = computation_building_blocks.Intrinsic( comp.function.uri, intrinsic_type) transformed_comp = computation_building_blocks.Call(intrinsic, arg) return transformed_comp, True
def test_propogates_dependence_up_through_block_locals(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', tf.int32) integer_reference = computation_building_blocks.Reference('int', tf.int32) block = computation_building_blocks.Block([('x', dummy_intrinsic)], integer_reference) dependent_nodes = tree_analysis.extract_nodes_consuming( block, dummy_intrinsic_predicate) self.assertIn(block, dependent_nodes)
def test_propogates_dependence_up_through_tuple(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', tf.int32) integer_reference = computation_building_blocks.Reference('int', tf.int32) tup = computation_building_blocks.Tuple( [integer_reference, dummy_intrinsic]) dependent_nodes = tree_analysis.extract_nodes_consuming( tup, dummy_intrinsic_predicate) self.assertIn(tup, dependent_nodes)
def test_raises_with_federated_mean(self): intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MEAN.uri, computation_types.FunctionType( computation_types.FederatedType(tf.int32, placements.CLIENTS), computation_types.FederatedType(tf.int32, placements.SERVER))) with self.assertRaisesRegex(ValueError, intrinsic.compact_representation()): tree_analysis.check_intrinsics_whitelisted_for_reduction(intrinsic)
def federated_map(self, mapping_fn, value): """Implements `federated_map` as defined in `api/intrinsics.py`. Args: mapping_fn: As in `api/intrinsics.py`. value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ # TODO(b/113112108): Possibly lift the restriction that the mapped value # must be placed at the clients after adding support for placement labels # in the federated types, and expanding the type specification of the # intrinsic this is based on to work with federated values of arbitrary # placement. value = value_impl.to_value(value, None, self._context_stack) if isinstance(value.type_signature, computation_types.NamedTupleType): if len(anonymous_tuple.to_elements(value.type_signature)) >= 2: # We've been passed a value which the user expects to be zipped. value = self.federated_zip(value) type_utils.check_federated_value_placement(value, placements.CLIENTS, 'value to be mapped') # TODO(b/113112108): Add support for polymorphic templates auto-instantiated # here based on the actual type of the argument. mapping_fn = value_impl.to_value(mapping_fn, None, self._context_stack) py_typecheck.check_type(mapping_fn, value_base.Value) py_typecheck.check_type(mapping_fn.type_signature, computation_types.FunctionType) if not type_utils.is_assignable_from(mapping_fn.type_signature.parameter, value.type_signature.member): raise TypeError( 'The mapping function expects a parameter of type {}, but member ' 'constituents of the mapped value are of incompatible type {}.' .format( str(mapping_fn.type_signature.parameter), str(value.type_signature.member))) # TODO(b/113112108): Replace this as noted above. result_type = computation_types.FederatedType( mapping_fn.type_signature.result, placements.CLIENTS, value.type_signature.all_equal) intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, computation_types.FunctionType( [mapping_fn.type_signature, value.type_signature], result_type)), self._context_stack) return intrinsic(mapping_fn, value)
def create_federated_mean(value, weight): r"""Creates a called federated mean. Call / \ Intrinsic Tuple | [Comp, Comp] Args: value: A `computation_building_blocks.ComputationBuildingBlock` to use as the value. weight: A `computation_building_blocks.ComputationBuildingBlock` to use as the weight or `None`. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) if weight is not None: py_typecheck.check_type( weight, computation_building_blocks.ComputationBuildingBlock) result_type = computation_types.FederatedType(value.type_signature.member, placement_literals.SERVER, True) if weight is not None: intrinsic_type = computation_types.FunctionType( (value.type_signature, weight.type_signature), result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_WEIGHTED_MEAN.uri, intrinsic_type) values = computation_building_blocks.Tuple((value, weight)) return computation_building_blocks.Call(intrinsic, values) else: intrinsic_type = computation_types.FunctionType( value.type_signature, result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MEAN.uri, intrinsic_type) return computation_building_blocks.Call(intrinsic, value)
def test_propogates_dependence_up_through_call(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', tf.int32) ref_to_x = computation_building_blocks.Reference('x', tf.int32) identity_lambda = computation_building_blocks.Lambda( 'x', tf.int32, ref_to_x) called_lambda = computation_building_blocks.Call( identity_lambda, dummy_intrinsic) dependent_nodes = tree_analysis.extract_nodes_consuming( called_lambda, dummy_intrinsic_predicate) self.assertIn(called_lambda, dependent_nodes)
def _make_sequence_sum_for(type_spec): py_typecheck.check_type(type_spec, computation_types.SequenceType) if not type_utils.is_sum_compatible(type_spec.element): raise TypeError( 'The value type {} is not compatible with the sum operator.'.format( str(type_spec))) return value_impl.ValueImpl( computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_SUM.uri, computation_types.FunctionType(type_spec, type_spec.element)), self._context_stack)