def test_with_block(self): ex = lambda_executor.LambdaExecutor(eager_executor.EagerExecutor()) loop = asyncio.get_event_loop() f_type = computation_types.FunctionType(tf.int32, tf.int32) a = computation_building_blocks.Reference( 'a', computation_types.NamedTupleType([('f', f_type), ('x', tf.int32)])) ret = computation_building_blocks.Block( [('f', computation_building_blocks.Selection(a, name='f')), ('x', computation_building_blocks.Selection(a, name='x'))], computation_building_blocks.Call( computation_building_blocks.Reference('f', f_type), computation_building_blocks.Call( computation_building_blocks.Reference('f', f_type), computation_building_blocks.Reference('x', tf.int32)))) comp = computation_building_blocks.Lambda(a.name, a.type_signature, ret) @computations.tf_computation(tf.int32) def add_one(x): return x + 1 v1 = loop.run_until_complete( ex.create_value(comp.proto, comp.type_signature)) v2 = loop.run_until_complete(ex.create_value(add_one)) v3 = loop.run_until_complete(ex.create_value(10, tf.int32)) v4 = loop.run_until_complete( ex.create_tuple( anonymous_tuple.AnonymousTuple([('f', v2), ('x', v3)]))) v5 = loop.run_until_complete(ex.create_call(v1, v4)) result = loop.run_until_complete(v5.compute()) self.assertEqual(result.numpy(), 12)
def test_basic_functionality_of_lambda_class(self): arg_name = 'arg' arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)), ('x', tf.int32)] arg = computation_building_blocks.Reference(arg_name, arg_type) arg_f = computation_building_blocks.Selection(arg, name='f') arg_x = computation_building_blocks.Selection(arg, name='x') x = computation_building_blocks.Lambda( arg_name, arg_type, computation_building_blocks.Call( arg_f, computation_building_blocks.Call(arg_f, arg_x))) self.assertEqual(str(x.type_signature), '(<f=(int32 -> int32),x=int32> -> int32)') self.assertEqual(x.parameter_name, arg_name) self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>') self.assertEqual(x.result.tff_repr, 'arg.f(arg.f(arg.x))') arg_type_repr = ( 'NamedTupleType([' '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), ' '(\'x\', TensorType(tf.int32))])') self.assertEqual( repr(x), 'Lambda(\'arg\', {0}, ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format( arg_type_repr)) self.assertEqual(x.tff_repr, '(arg -> arg.f(arg.f(arg.x)))') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'lambda') self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name) self.assertEqual(str(getattr(x_proto, 'lambda').result), str(x.result.proto)) self._serialize_deserialize_roundtrip_test(x)
def construct_binary_operator_with_upcast(type_signature, operator): """Constructs lambda upcasting its argument and applying `operator`. The concept of upcasting is explained further in the docstring for `apply_binary_operator_with_upcast`. Notice that since we are constructing a function here, e.g. for the body of an intrinsic, the function we are constructing must be reducible to TensorFlow. Therefore `type_signature` can only have named tuple or tensor type elements; that is, we cannot handle federated types here in a generic way. Args: type_signature: Value convertible to `computation_types.NamedTupleType`, with two elements, both of the same type or the second able to be upcast to the first, as explained in `apply_binary_operator_with_upcast`, and both containing only tuples and tensors in their type tree. operator: Callable defining the operator. Returns: A `computation_building_blocks.Lambda` encapsulating a function which upcasts the second element of its argument and applies the binary operator. """ py_typecheck.check_callable(operator) type_signature = computation_types.to_type(type_signature) _check_generic_operator_type(type_signature) ref_to_arg = computation_building_blocks.Reference('binary_operator_arg', type_signature) def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if isinstance(type_spec, computation_types.NamedTupleType): elems = anonymous_tuple.to_elements(type_spec) packed_elems = [(elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elems] return computation_building_blocks.Tuple(packed_elems) elif isinstance(type_spec, computation_types.TensorType): expand_fn = computation_constructing_utils.construct_tensorflow_to_broadcast_scalar( to_pack.type_signature.dtype, type_spec.shape) return computation_building_blocks.Call(expand_fn, to_pack) y_ref = computation_building_blocks.Selection(ref_to_arg, index=1) first_arg = computation_building_blocks.Selection(ref_to_arg, index=0) if type_utils.are_equivalent_types(first_arg.type_signature, y_ref.type_signature): second_arg = y_ref else: second_arg = _pack_into_type(y_ref, first_arg.type_signature) fn = computation_constructing_utils.construct_tensorflow_binary_operator( first_arg.type_signature, operator) packed = computation_building_blocks.Tuple([first_arg, second_arg]) operated = computation_building_blocks.Call(fn, packed) lambda_encapsulating_op = computation_building_blocks.Lambda( ref_to_arg.name, ref_to_arg.type_signature, operated) return lambda_encapsulating_op
def _create_chain_zipped_values(value): r"""Creates a chain of called federated zip with two values. Block-------- / \ [value=Tuple] Call | / \ [Comp1, Intrinsic Tuple Comp2, | ...] [Call, Sel(n)] / \ \ Intrinsic Tuple Ref(value) | [Sel(0), Sel(1)] \ \ Ref(value) Ref(value) NOTE: This function is intended to be used in conjunction with `_create_fn_to_append_chain_zipped_values` and will drop the tuple names. The names will be added back to the resulting computation when the zipped values are mapped to a function that flattens the chain. This nested zip -> flatten structure must be used since length of a named tuple type in the TFF type system is an element of the type proper. That is, a named tuple type of length 2 is a different type than a named tuple type of length 3, they are not simply items with the same type and different values, as would be the case if you were thinking of these as Python `list`s. It may be better to think of named tuple types in TFF as more like `struct`s. Args: value: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_types.NamedTupleType` containing at least two elements. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. ValueError: If `value` does not contain at least two elements. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) named_type_signatures = anonymous_tuple.to_elements(value.type_signature) length = len(named_type_signatures) if length < 2: raise ValueError( 'Expected a value with at least two elements, received {} elements.' .format(named_type_signatures)) ref = computation_building_blocks.Reference('value', value.type_signature) symbols = ((ref.name, value), ) sel_0 = computation_building_blocks.Selection(ref, index=0) result = sel_0 for i in range(1, length): sel = computation_building_blocks.Selection(ref, index=i) values = computation_building_blocks.Tuple((result, sel)) result = _create_zip_two_values(values) return computation_building_blocks.Block(symbols, result)
def _create_chain_zipped_values(value): r"""Creates a chain of called federated zip with two values. Block-------- / \ [value=Tuple] Call | / \ [Comp1, Intrinsic Tuple Comp2, | ...] [Call, Sel(n)] / \ \ Intrinsic Tuple Ref(value) | [Sel(0), Sel(1)] \ \ Ref(value) Ref(value) NOTE: This function is intended to be used in conjunction with `_create_fn_to_append_chain_zipped_values` and will drop the tuple names. The names will be added back to the resulting computation when the zipped values are mapped to a function that flattens the chain. Args: value: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_types.NamedTupleType` containing at least two elements. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. ValueError: If `value` does not contain at least two elements. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) named_type_signatures = anonymous_tuple.to_elements(value.type_signature) length = len(named_type_signatures) if length < 2: raise ValueError( 'Expected a value with at least two elements, received {} elements.' .format(named_type_signatures)) first_name, _ = named_type_signatures[0] ref = computation_building_blocks.Reference('value', value.type_signature) symbols = ((ref.name, value), ) sel_0 = computation_building_blocks.Selection(ref, index=0) result = (first_name, sel_0) for i in range(1, length): name, _ = named_type_signatures[i] sel = computation_building_blocks.Selection(ref, index=i) values = computation_building_blocks.Tuple((result, (name, sel))) result = _create_zip_two_values(values) return computation_building_blocks.Block(symbols, result)
def create_computation_appending(comp1, comp2): r"""Returns a block appending `comp2` to `comp1`. Block / \ [comps=Tuple] Tuple | | [Comp, Comp] [Sel(0), ..., Sel(0), Sel(1)] \ \ \ Sel(0) Sel(n) Ref(comps) \ \ Ref(comps) Ref(comps) Args: comp1: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_type.NamedTupleType`. comp2: A `computation_building_blocks.ComputationBuildingBlock` or a named computation (a tuple pair of name, computation) representing a single element of an `anonymous_tuple.AnonymousTuple`. Returns: A `computation_building_blocks.Block`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( comp1, computation_building_blocks.ComputationBuildingBlock) if isinstance(comp2, computation_building_blocks.ComputationBuildingBlock): name2 = None elif py_typecheck.is_name_value_pair( comp2, name_required=False, value_type=computation_building_blocks.ComputationBuildingBlock): name2, comp2 = comp2 else: raise TypeError('Unexpected tuple element: {}.'.format(comp2)) comps = computation_building_blocks.Tuple((comp1, comp2)) ref = computation_building_blocks.Reference('comps', comps.type_signature) sel_0 = computation_building_blocks.Selection(ref, index=0) elements = [] named_type_signatures = anonymous_tuple.to_elements(comp1.type_signature) for index, (name, _) in enumerate(named_type_signatures): sel = computation_building_blocks.Selection(sel_0, index=index) elements.append((name, sel)) sel_1 = computation_building_blocks.Selection(ref, index=1) elements.append((name2, sel_1)) result = computation_building_blocks.Tuple(elements) symbols = ((ref.name, comps), ) return computation_building_blocks.Block(symbols, result)
def _transform_functional_args(comps): r"""Transforms the functional computations `comps`. Given a computation containing `n` called intrinsics with `m` arguments, this function constructs the following computation from the functional arguments of the called intrinsic: Block / \ [fn=Tuple] Lambda(arg) | \ [Comp(f1), Comp(f2), ...] Tuple | [Call, Call, ...] / \ / \ Sel(0) Sel(0) Sel(1) Sel(1) / / / / Ref(fn) Ref(arg) Ref(fn) Ref(arg) with one `computation_building_blocks.Call` for each `n`. This computation represents one of `m` arguments that should be passed to the call of the transformed computation. Args: comps: a Python list of computations. Returns: A `computation_building_blocks.Block`. """ functions = computation_building_blocks.Tuple(comps) functions_name = six.next(name_generator) functions_ref = computation_building_blocks.Reference( functions_name, functions.type_signature) arg_name = six.next(name_generator) arg_type = [element.type_signature.parameter for element in comps] arg_ref = computation_building_blocks.Reference(arg_name, arg_type) elements = [] for index in range(len(comps)): sel_fn = computation_building_blocks.Selection(functions_ref, index=index) sel_arg = computation_building_blocks.Selection(arg_ref, index=index) call = computation_building_blocks.Call(sel_fn, sel_arg) elements.append(call) calls = computation_building_blocks.Tuple(elements) fn = computation_building_blocks.Lambda(arg_ref.name, arg_ref.type_signature, calls) return computation_building_blocks.Block( ((functions_ref.name, functions), ), fn)
def __setattr__(self, name, value): py_typecheck.check_type(name, six.string_types) if not isinstance(self._comp.type_signature, computation_types.NamedTupleType): raise TypeError( 'Operator setattr() is only supported for named tuples, but the ' 'object on which it has been invoked is of type {}.'.format( str(self._comp.type_signature))) if name not in dir(self._comp.type_signature): raise AttributeError( 'There is no such attribute as \'{}\' in this tuple. ' 'TFF does not allow for assigning to a nonexistent attribute. ' 'If you want to assign to \'{}\', you must create a new named tuple ' 'containing this attribute.'.format(name, name)) elem_array = [] type_signature_elements = anonymous_tuple.to_elements( self._comp.type_signature) for k, v in type_signature_elements: if k == name: try: value = to_value(value, v, self._context_stack) except TypeError: raise TypeError( 'Setattr has attempted to set element {} of type {} ' 'with incompatible item {}.'.format(k, v, value)) elem_array.append((k, ValueImpl.get_comp(value))) else: elem_array.append( (k, computation_building_blocks.Selection(self._comp, name=k))) new_comp = computation_building_blocks.Tuple([(k, v) for k, v in elem_array]) super(ValueImpl, self).__setattr__('_comp', new_comp)
def test_returns_string_for_comp_with_right_overhang(self): ref = computation_building_blocks.Reference('a', tf.int32) data = computation_building_blocks.Data('data', tf.int32) tup = computation_building_blocks.Tuple([ref, data, data, data, data]) sel = computation_building_blocks.Selection(tup, index=0) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, sel) comp = computation_building_blocks.Call(fn, data) compact_string = comp.compact_representation() self.assertEqual(compact_string, '(a -> <a,data,data,data,data>[0])(data)') formatted_string = comp.formatted_representation() # pyformat: disable self.assertEqual( formatted_string, '(a -> <\n' ' a,\n' ' data,\n' ' data,\n' ' data,\n' ' data\n' '>[0])(data)') # pyformat: enable structural_string = comp.structural_representation() # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'Lambda(a) data\n' '|\n' 'Sel(0)\n' '|\n' 'Tuple\n' '|\n' '[Ref(a), data, data, data, data]')
def test_propogates_dependence_up_through_selection(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', [tf.int32]) selection = computation_building_blocks.Selection(dummy_intrinsic, index=0) dependent_nodes = tree_analysis.extract_nodes_consuming( selection, dummy_intrinsic_predicate) self.assertIn(selection, dependent_nodes)
def test_replace_chained_federated_maps_does_not_replace_unchained_federated_maps( self): map_arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) map_arg = computation_building_blocks.Reference('arg', map_arg_type) inner_lambda = _create_lambda_to_add_one(map_arg.type_signature.member) inner_call = _create_call_to_federated_map(inner_lambda, map_arg) dummy_tuple = computation_building_blocks.Tuple([inner_call]) dummy_selection = computation_building_blocks.Selection(dummy_tuple, index=0) outer_lambda = _create_lambda_to_add_one( inner_call.function.type_signature.result.member) outer_call = _create_call_to_federated_map(outer_lambda, dummy_selection) map_lambda = computation_building_blocks.Lambda( map_arg.name, map_arg.type_signature, outer_call) comp = map_lambda uri = intrinsic_defs.FEDERATED_MAP.uri self.assertEqual(_get_number_of_intrinsics(comp, uri), 2) comp_impl = _to_comp(comp) self.assertEqual(comp_impl([(1)]), [3]) transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 2) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl([(1)]), [3])
def construct_federated_getattr_comp(comp, name): """Function to construct computation for `federated_apply` of `__getattr__`. Constructs a `computation_building_blocks.ComputationBuildingBlock` which selects `name` from its argument, of type `comp.type_signature.member`, an instance of `computation_types.NamedTupleType`. Args: comp: Instance of `ValueImpl` or `computation_building_blocks.ComputationBuildingBlock` with type signature `computation_types.FederatedType` whose `member` attribute is of type `computation_types.NamedTupleType`. name: String name of attribute to grab. Returns: Instance of `computation_building_blocks.Lambda` which grabs attribute according to `name` of its argument. """ py_typecheck.check_type(comp.type_signature, computation_types.FederatedType) py_typecheck.check_type(comp.type_signature.member, computation_types.NamedTupleType) element_names = [ x for x, _ in anonymous_tuple.to_elements(comp.type_signature.member) ] if name not in element_names: raise ValueError('The federated value {} has no element of name {}'.format( comp, name)) apply_input = computation_building_blocks.Reference( 'x', comp.type_signature.member) selected = computation_building_blocks.Selection(apply_input, name=name) apply_lambda = computation_building_blocks.Lambda( 'x', apply_input.type_signature, selected) return apply_lambda
def __getattr__(self, name): py_typecheck.check_type(name, six.string_types) if (isinstance(self._comp.type_signature, computation_types.FederatedType) and isinstance(self._comp.type_signature.member, computation_types.NamedTupleType)): return ValueImpl( computation_constructing_utils. construct_federated_getattr_call(self._comp, name), self._context_stack) elif not isinstance(self._comp.type_signature, computation_types.NamedTupleType): raise TypeError( 'Operator getattr() is only supported for named tuples, but the ' 'object on which it has been invoked is of type {}.'.format( str(self._comp.type_signature))) if name not in dir(self._comp.type_signature): raise AttributeError( 'There is no such attribute as \'{}\' in this tuple.'.format( name)) if isinstance(self._comp, computation_building_blocks.Tuple): return ValueImpl(getattr(self._comp, name), self._context_stack) return ValueImpl( computation_building_blocks.Selection(self._comp, name=name), self._context_stack)
def __getitem__(self, key): py_typecheck.check_type(key, (int, slice)) if (isinstance(self._comp.type_signature, computation_types.FederatedType) and isinstance(self._comp.type_signature.member, computation_types.NamedTupleType)): return ValueImpl( computation_constructing_utils.construct_federated_getitem_call( self._comp, key), self._context_stack) if not isinstance(self._comp.type_signature, computation_types.NamedTupleType): raise TypeError( 'Operator getitem() is only supported for named tuples, but the ' 'object on which it has been invoked is of type {}.'.format( str(self._comp.type_signature))) elem_length = len(self._comp.type_signature) if isinstance(key, int): if key < 0 or key >= elem_length: raise IndexError( 'The index of the selected element {} is out of range.'.format(key)) if isinstance(self._comp, computation_building_blocks.Tuple): return ValueImpl(self._comp[key], self._context_stack) else: return ValueImpl( computation_building_blocks.Selection(self._comp, index=key), self._context_stack) elif isinstance(key, slice): index_range = range(*key.indices(elem_length)) if not index_range: raise IndexError('Attempted to slice 0 elements, which is not ' 'currently supported.') return to_value([self[k] for k in index_range], None, self._context_stack)
def test_basic_functionality_of_block_class(self): x = computation_building_blocks.Block([ ('x', computation_building_blocks.Reference('arg', (tf.int32, tf.int32))), ('y', computation_building_blocks.Selection( computation_building_blocks.Reference('x', (tf.int32, tf.int32)), index=0)) ], computation_building_blocks.Reference('y', tf.int32)) self.assertEqual(str(x.type_signature), 'int32') self.assertEqual([(k, v.tff_repr) for k, v in x.locals], [('x', 'arg'), ('y', 'x[0]')]) self.assertEqual(x.result.tff_repr, 'y') self.assertEqual( repr(x), 'Block([(\'x\', Reference(\'arg\', ' 'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)]))), ' '(\'y\', Selection(Reference(\'x\', ' 'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)])), ' 'index=0))], ' 'Reference(\'y\', TensorType(tf.int32)))') self.assertEqual(x.tff_repr, '(let x=arg,y=x[0] in y)') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'block') self.assertEqual(str(x_proto.block.result), str(x.result.proto)) for idx, loc_proto in enumerate(x_proto.block.local): loc_name, loc_value = x.locals[idx] self.assertEqual(loc_proto.name, loc_name) self.assertEqual(str(loc_proto.value), str(loc_value.proto)) self._serialize_deserialize_roundtrip_test(x)
def _create_block_to_calls(call_names, comps): r"""Constructs a transformed block computation from `comps`. Given the "original" computation containing `n` called intrinsics with `m` arguments, this function constructs the following computation: Block / \ fn=Tuple Lambda(arg) | \ [Comp(f1), Comp(f2), ...] Tuple | [Call, Call, ...] / \ / \ Sel(0) Sel(0) Sel(1) Sel(1) / / / / Ref(fn) Ref(arg) Ref(fn) Ref(arg) with one `computation_building_blocks.Call` for each `n`. This computation represents one of `m` arguments that should be passed to the call of the "transformed" computation. Args: call_names: a Python list of names. comps: a Python list of computations. Returns: A `computation_building_blocks.Block`. """ functions = computation_building_blocks.Tuple( zip(call_names, comps)) fn = computation_building_blocks.Reference( 'fn', functions.type_signature) arg_type = [element.type_signature.parameter for element in comps] arg = computation_building_blocks.Reference('arg', arg_type) elements = [] for index, name in enumerate(call_names): sel_fn = computation_building_blocks.Selection(fn, index=index) sel_arg = computation_building_blocks.Selection(arg, index=index) call = computation_building_blocks.Call(sel_fn, sel_arg) elements.append((name, call)) calls = computation_building_blocks.Tuple(elements) lam = computation_building_blocks.Lambda(arg.name, arg.type_signature, calls) return computation_building_blocks.Block([('fn', functions)], lam)
def _traverse_selection(comp, transform, context_tree, identifier_seq): """Helper function holding traversal logic for selection nodes.""" _ = six.next(identifier_seq) transformed_source = _transform_postorder_with_symbol_bindings_switch( comp.source, transform, context_tree, identifier_seq) transformed_comp = transform( computation_building_blocks.Selection(transformed_source, comp.name, comp.index), context_tree) return transformed_comp
def construct_federated_getitem_comp(comp, key): """Function to construct computation for `federated_apply` of `__getitem__`. Constructs a `computation_building_blocks.ComputationBuildingBlock` which selects `key` from its argument, of type `comp.type_signature.member`, of type `computation_types.NamedTupleType`. Args: comp: Instance of `computation_building_blocks.ComputationBuildingBlock` with type signature `computation_types.FederatedType` whose `member` attribute is of type `computation_types.NamedTupleType`. key: Instance of `int` or `slice`, key used to grab elements from the member of `comp`. implementation of slicing for `ValueImpl` objects with `type_signature` `computation_types.NamedTupleType`. Returns: Instance of `computation_building_blocks.Lambda` which grabs slice according to `key` of its argument. """ py_typecheck.check_type( comp, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(comp.type_signature, computation_types.FederatedType) py_typecheck.check_type(comp.type_signature.member, computation_types.NamedTupleType) py_typecheck.check_type(key, (int, slice)) apply_input = computation_building_blocks.Reference( 'x', comp.type_signature.member) if isinstance(key, int): selected = computation_building_blocks.Selection(apply_input, index=key) else: elems = anonymous_tuple.to_elements(comp.type_signature.member) index_range = range(*key.indices(len(elems))) elem_list = [] for k in index_range: elem_list.append( (elems[k][0], computation_building_blocks.Selection(apply_input, index=k))) selected = computation_building_blocks.Tuple(elem_list) apply_lambda = computation_building_blocks.Lambda( 'x', apply_input.type_signature, selected) return apply_lambda
def _traverse_selection(comp, transform, context_tree, identifier_seq): """Helper function holding traversal logic for selection nodes.""" _ = six.next(identifier_seq) source, source_modified = _transform_postorder_with_symbol_bindings_switch( comp.source, transform, context_tree, identifier_seq) if source_modified: comp = computation_building_blocks.Selection( source, comp.name, comp.index) comp, comp_modified = transform(comp, context_tree) return comp, comp_modified or source_modified
def create_federated_unzip(value): r"""Creates a tuple of called federated maps or applies. Block / \ [value=Comp] Tuple | [Call, Call, ...] / \ / \ Intrinsic Tuple Intrinsic Tuple | | [Lambda(arg), Ref(value)] [Lambda(arg), Ref(value)] \ \ Sel(0) Sel(1) \ \ Ref(arg) Ref(arg) This function returns a tuple of federated values given a `value` with a federated tuple type signature. Args: value: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_types.NamedTupleType` containing at least one element. Returns: A `computation_building_blocks.Block`. Raises: TypeError: If any of the types do not match. ValueError: If `value` does not contain any elements. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) named_type_signatures = anonymous_tuple.to_elements( value.type_signature.member) length = len(named_type_signatures) if length == 0: raise ValueError( 'federated_zip is only supported on non-empty tuples.') value_ref = computation_building_blocks.Reference('value', value.type_signature) elements = [] fn_ref = computation_building_blocks.Reference('arg', named_type_signatures) for index, (name, _) in enumerate(named_type_signatures): sel = computation_building_blocks.Selection(fn_ref, index=index) fn = computation_building_blocks.Lambda(fn_ref.name, fn_ref.type_signature, sel) intrinsic = create_federated_map_or_apply(fn, value_ref) elements.append((name, intrinsic)) result = computation_building_blocks.Tuple(elements) symbols = ((value_ref.name, value), ) return computation_building_blocks.Block(symbols, result)
def test_returns_string_for_selection_with_index(self): ref = computation_building_blocks.Reference('a', (('b', tf.int32), ('c', tf.bool))) comp = computation_building_blocks.Selection(ref, index=0) compact_string = comp.compact_representation() self.assertEqual(compact_string, 'a[0]') formatted_string = comp.formatted_representation() self.assertEqual(formatted_string, 'a[0]') structural_string = comp.structural_representation() # pyformat: disable self.assertEqual(structural_string, 'Sel(0)\n' '|\n' 'Ref(a)')
def create_federated_zip(value): r"""Creates a called federated zip. Call / \ Intrinsic Tuple | [Comp, Comp] This function returns a federated tuple given a `value` with a tuple of federated values type signature. Args: value: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_types.NamedTupleType` containing at least one element. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. ValueError: If `value` does not contain any elements. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) named_type_signatures = anonymous_tuple.to_elements(value.type_signature) names_to_add = [name for name, _ in named_type_signatures] length = len(named_type_signatures) if length == 0: raise ValueError( 'federated_zip is only supported on non-empty tuples.') first_name, first_type_signature = named_type_signatures[0] if first_type_signature.placement == placement_literals.CLIENTS: map_fn = create_federated_map elif first_type_signature.placement == placement_literals.SERVER: map_fn = create_federated_apply else: raise TypeError('Unsupported placement {}.'.format( first_type_signature.placement)) if length == 1: ref = computation_building_blocks.Reference( 'arg', first_type_signature.member) values = computation_building_blocks.Tuple(((first_name, ref), )) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, values) sel = computation_building_blocks.Selection(value, index=0) return map_fn(fn, sel) else: zipped_args = _create_chain_zipped_values(value) append_fn = _create_fn_to_append_chain_zipped_values(value) unnamed_zip = map_fn(append_fn, zipped_args) return construct_named_federated_tuple(unnamed_zip, names_to_add)
def test_intrinsic_construction_clients(self): federated_comp = computation_building_blocks.Reference( 'test', computation_types.FederatedType([('a', tf.int32), ('b', tf.bool)], placement_literals.CLIENTS, True)) arg_ref = computation_building_blocks.Reference('x', [('a', tf.int32), ('b', tf.bool)]) return_val = computation_building_blocks.Selection(arg_ref, name='a') non_federated_fn = computation_building_blocks.Lambda( 'x', arg_ref.type_signature, return_val) intrinsic = computation_constructing_utils.construct_map_or_apply( non_federated_fn, federated_comp) self.assertEqual(str(intrinsic), 'federated_map')
def test_inline_conflicting_locals(self): arg_comp = computation_building_blocks.Reference( 'arg', [tf.int32, tf.int32]) selected = computation_building_blocks.Selection(arg_comp, index=0) internal_arg = computation_building_blocks.Reference('arg', tf.int32) block = computation_building_blocks.Block([('arg', selected)], internal_arg) lam = computation_building_blocks.Lambda('arg', arg_comp.type_signature, block) self.assertEqual(str(lam), '(arg -> (let arg=arg[0] in arg))') inlined = transformations.inline_blocks_with_n_referenced_locals(lam) self.assertEqual(str(inlined), '(arg -> (let in arg[0]))')
def _extract_from_selection(comp): """Returns a new computation with all intrinsics extracted.""" if _is_called_intrinsic(comp.source): called_intrinsic = comp.source name = six.next(name_generator) variables = ((name, called_intrinsic),) result = computation_building_blocks.Reference( name, called_intrinsic.type_signature) else: block = comp.source variables = block.locals result = block.result selection = computation_building_blocks.Selection( result, name=comp.name, index=comp.index) block = computation_building_blocks.Block(variables, selection) return _extract_from_block(block)
def test_basic_functionality_of_selection_class(self): x = computation_building_blocks.Reference('foo', [('bar', tf.int32), ('baz', tf.bool)]) y = computation_building_blocks.Selection(x, name='bar') self.assertEqual(y.name, 'bar') self.assertEqual(y.index, None) self.assertEqual(str(y.type_signature), 'int32') self.assertEqual( repr(y), 'Selection(Reference(\'foo\', NamedTupleType([' '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))' ', name=\'bar\')') self.assertEqual(computation_building_blocks.compact_representation(y), 'foo.bar') z = computation_building_blocks.Selection(x, name='baz') self.assertEqual(str(z.type_signature), 'bool') self.assertEqual(computation_building_blocks.compact_representation(z), 'foo.baz') with self.assertRaises(ValueError): _ = computation_building_blocks.Selection(x, name='bak') x0 = computation_building_blocks.Selection(x, index=0) self.assertEqual(x0.name, None) self.assertEqual(x0.index, 0) self.assertEqual(str(x0.type_signature), 'int32') self.assertEqual( repr(x0), 'Selection(Reference(\'foo\', NamedTupleType([' '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))' ', index=0)') self.assertEqual( computation_building_blocks.compact_representation(x0), 'foo[0]') x1 = computation_building_blocks.Selection(x, index=1) self.assertEqual(str(x1.type_signature), 'bool') self.assertEqual( computation_building_blocks.compact_representation(x1), 'foo[1]') with self.assertRaises(ValueError): _ = computation_building_blocks.Selection(x, index=2) with self.assertRaises(ValueError): _ = computation_building_blocks.Selection(x, index=-1) y_proto = y.proto self.assertEqual(type_serialization.deserialize_type(y_proto.type), y.type_signature) self.assertEqual(y_proto.WhichOneof('computation'), 'selection') self.assertEqual(str(y_proto.selection.source), str(x.proto)) self.assertEqual(y_proto.selection.name, 'bar') self._serialize_deserialize_roundtrip_test(y) self._serialize_deserialize_roundtrip_test(z) self._serialize_deserialize_roundtrip_test(x0) self._serialize_deserialize_roundtrip_test(x1)
def _create_block_to_chained_calls(comps): r"""Constructs a transformed block computation from `comps`. Block / \ [fn=Tuple] Lambda(arg) | \ [Comp(y), Comp(x)] Call / \ Sel(1) Call / / \ Ref(fn) Sel(0) Ref(arg) / Ref(fn) (let fn=<y, x> in (arg -> fn[1](fn[0](arg))) Args: comps: A Python list of computations. Returns: A `computation_building_blocks.Block`. """ functions = computation_building_blocks.Tuple(comps) functions_name = six.next(name_generator) functions_ref = computation_building_blocks.Reference( functions_name, functions.type_signature) arg_name = six.next(name_generator) arg_type = comps[0].type_signature.parameter arg_ref = computation_building_blocks.Reference(arg_name, arg_type) arg = arg_ref for index, _ in enumerate(comps): fn_sel = computation_building_blocks.Selection(functions_ref, index=index) call = computation_building_blocks.Call(fn_sel, arg) arg = call fn = computation_building_blocks.Lambda(arg_ref.name, arg_ref.type_signature, call) return computation_building_blocks.Block( ((functions_ref.name, functions), ), fn)
def _create_fn_to_append_chain_zipped_values(value): r"""Creates a function to append a chain of zipped values. Lambda(arg3) \ append([Call, Sel(1)]) / \ \ Lambda(arg2) Sel(0) Ref(arg3) \ \ \ Ref(arg3) \ append([Call, Sel(1)]) / \ \ Lambda(arg1) Sel(0) Ref(arg2) \ \ \ Ref(arg2) \ Ref(arg1) NOTE: This function is intended to be used in conjunction with `_create_chain_zipped_values` add will add back the names that were dropped when zipping the values. Args: value: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_types.NamedTupleType` containing at least two elements. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) named_type_signatures = anonymous_tuple.to_elements(value.type_signature) length = len(named_type_signatures) if length < 2: raise ValueError( 'Expected a value with at least two elements, received {} elements.' .format(named_type_signatures)) first_name, first_type_signature = named_type_signatures[0] second_name, second_type_signature = named_type_signatures[1] ref_type = computation_types.NamedTupleType(( (first_name, first_type_signature.member), (second_name, second_type_signature.member), )) ref = computation_building_blocks.Reference('arg', ref_type) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref) for name, type_signature in named_type_signatures[2:]: ref_type = computation_types.NamedTupleType(( fn.type_signature.parameter, (name, type_signature.member), )) ref = computation_building_blocks.Reference('arg', ref_type) sel_0 = computation_building_blocks.Selection(ref, index=0) call = computation_building_blocks.Call(fn, sel_0) sel_1 = computation_building_blocks.Selection(ref, index=1) result = create_computation_appending(call, (name, sel_1)) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, result) return fn
def _create_fn_to_append_chain_zipped_values(value): r"""Creates a function to append a chain of zipped values. Lambda(arg3) \ append([Call, Sel(1)]) / \ \ Lambda(arg2) Sel(0) Ref(arg3) \ \ \ Ref(arg3) \ append([Call, Sel(1)]) / \ \ Lambda(arg1) Sel(0) Ref(arg2) \ \ \ Ref(arg2) \ Ref(arg1) Note that this function will not respect any names it is passed; names for tuples will be cached at a higher level than this function and added back in a single call to federated map or federated apply. Args: value: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_types.NamedTupleType` containing at least two elements. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) named_type_signatures = anonymous_tuple.to_elements(value.type_signature) length = len(named_type_signatures) if length < 2: raise ValueError( 'Expected a value with at least two elements, received {} elements.' .format(named_type_signatures)) _, first_type_signature = named_type_signatures[0] _, second_type_signature = named_type_signatures[1] ref_type = computation_types.NamedTupleType(( first_type_signature.member, second_type_signature.member, )) ref = computation_building_blocks.Reference('arg', ref_type) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref) for _, type_signature in named_type_signatures[2:]: ref_type = computation_types.NamedTupleType(( fn.type_signature.parameter, type_signature.member, )) ref = computation_building_blocks.Reference('arg', ref_type) sel_0 = computation_building_blocks.Selection(ref, index=0) call = computation_building_blocks.Call(fn, sel_0) sel_1 = computation_building_blocks.Selection(ref, index=1) result = create_computation_appending(call, sel_1) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, result) return fn
def construct_named_tuple_setattr_lambda(named_tuple_signature, name, value_comp): """Constructs a building block for replacing one attribute in a named tuple. Returns an instance of `computation_building_blocks.Lambda` which takes an argument of type `computation_types.NamedTupleType` and returns a `computation_building_blocks.Tuple` which contains all the same elements as the argument, except the attribute `name` now has value `value_comp`. The Lambda constructed is the analogue of Python's `setattr` for the concrete type `named_tuple_signature`. Args: named_tuple_signature: Instance of `computation_types.NamedTupleType`, the type of the argument to the constructed `computation_building_blocks.Lambda`. name: String name of the attribute in the `named_tuple_signature` to replace with `value_comp`. Must be present as a name in `named_tuple_signature; otherwise we will raise an `AttributeError`. value_comp: Instance of `computation_building_blocks.ComputationBuildingBlock`, the value to place as attribute `name` in the argument of the returned function. Returns: An instance of `computation_building_blocks.Block` of functional type representing setting attribute `name` to value `value_comp` in its argument of type `named_tuple_signature`. Raises: TypeError: If the types of the arguments don't match the assumptions above. AttributeError: If `name` is not present as a named element in `named_tuple_signature` """ py_typecheck.check_type(named_tuple_signature, computation_types.NamedTupleType) py_typecheck.check_type(name, six.string_types) py_typecheck.check_type( value_comp, computation_building_blocks.ComputationBuildingBlock) value_comp_placeholder = computation_building_blocks.Reference( 'value_comp_placeholder', value_comp.type_signature) lambda_arg = computation_building_blocks.Reference('lambda_arg', named_tuple_signature) if name not in dir(named_tuple_signature): raise AttributeError( 'There is no such attribute as \'{}\' in this federated tuple. ' 'TFF does not allow for assigning to a nonexistent attribute. ' 'If you want to assign to \'{}\', you must create a new named tuple ' 'containing this attribute.'.format(name, name)) elements = [] for idx, (key, element_type) in enumerate( anonymous_tuple.to_elements(named_tuple_signature)): if key == name: if not type_utils.is_assignable_from(element_type, value_comp.type_signature): raise TypeError( '`setattr` has attempted to set element {} of type {} with incompatible type {}' .format(key, element_type, value_comp.type_signature)) elements.append((key, value_comp_placeholder)) else: elements.append((key, computation_building_blocks.Selection(lambda_arg, index=idx))) return_tuple = computation_building_blocks.Tuple(elements) lambda_to_return = computation_building_blocks.Lambda( lambda_arg.name, named_tuple_signature, return_tuple) symbols = ((value_comp_placeholder.name, value_comp), ) return computation_building_blocks.Block(symbols, lambda_to_return)