def _pack_binary_operator_args(x, y): """Packs arguments to binary operator into a single arg.""" def _only_tuple_or_tensor(value): return type_utils.type_tree_contains_only( value.type_signature, (computation_types.NamedTupleType, computation_types.TensorType)) if _only_tuple_or_tensor(x) and _only_tuple_or_tensor(y): arg = value_impl.ValueImpl( computation_building_blocks.Tuple([ value_impl.ValueImpl.get_comp(x), value_impl.ValueImpl.get_comp(y) ]), context_stack) elif (isinstance(x.type_signature, computation_types.FederatedType) and isinstance(y.type_signature, computation_types.FederatedType) and x.type_signature.placement == y.type_signature.placement): if not type_utils.is_binary_op_with_upcast_compatible_pair( x.type_signature.member, y.type_signature.member): raise TypeError( 'The members of the federated types {} and {} are not division ' 'compatible; see `type_utils.is_binary_op_with_upcast_compatible_pair` ' 'for more details.'.format(x.type_signature, y.type_signature)) packed_arg = value_impl.ValueImpl( computation_building_blocks.Tuple([ value_impl.ValueImpl.get_comp(x), value_impl.ValueImpl.get_comp(y) ]), context_stack) arg = intrinsics.federated_zip(packed_arg) else: raise TypeError return arg
def test_basic_functionality_of_tuple_class(self): x = computation_building_blocks.Reference('foo', tf.int32) y = computation_building_blocks.Reference('bar', tf.bool) z = computation_building_blocks.Tuple([x, ('y', y)]) with self.assertRaises(ValueError): _ = computation_building_blocks.Tuple([('', y)]) self.assertIsInstance(z, anonymous_tuple.AnonymousTuple) self.assertEqual(str(z.type_signature), '<int32,y=bool>') self.assertEqual( repr(z), 'Tuple([(None, Reference(\'foo\', TensorType(tf.int32))), (\'y\', ' 'Reference(\'bar\', TensorType(tf.bool)))])') self.assertEqual(z.tff_repr, '<foo,y=bar>') self.assertEqual(dir(z), ['y']) self.assertIs(z.y, y) self.assertLen(z, 2) self.assertIs(z[0], x) self.assertIs(z[1], y) self.assertEqual(','.join(e.tff_repr for e in iter(z)), 'foo,bar') z_proto = z.proto self.assertEqual(type_serialization.deserialize_type(z_proto.type), z.type_signature) self.assertEqual(z_proto.WhichOneof('computation'), 'tuple') self.assertEqual([e.name for e in z_proto.tuple.element], ['', 'y']) self._serialize_deserialize_roundtrip_test(z)
def create_computation_appending(comp1, comp2): r"""Returns a block appending `comp2` to `comp1`. Block / \ [comps=Tuple] Tuple | | [Comp, Comp] [Sel(0), ..., Sel(0), Sel(1)] \ \ \ Sel(0) Sel(n) Ref(comps) \ \ Ref(comps) Ref(comps) Args: comp1: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_type.NamedTupleType`. comp2: A `computation_building_blocks.ComputationBuildingBlock` or a named computation (a tuple pair of name, computation) representing a single element of an `anonymous_tuple.AnonymousTuple`. Returns: A `computation_building_blocks.Block`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( comp1, computation_building_blocks.ComputationBuildingBlock) if isinstance(comp2, computation_building_blocks.ComputationBuildingBlock): name2 = None elif py_typecheck.is_name_value_pair( comp2, name_required=False, value_type=computation_building_blocks.ComputationBuildingBlock): name2, comp2 = comp2 else: raise TypeError('Unexpected tuple element: {}.'.format(comp2)) comps = computation_building_blocks.Tuple((comp1, comp2)) ref = computation_building_blocks.Reference('comps', comps.type_signature) sel_0 = computation_building_blocks.Selection(ref, index=0) elements = [] named_type_signatures = anonymous_tuple.to_elements(comp1.type_signature) for index, (name, _) in enumerate(named_type_signatures): sel = computation_building_blocks.Selection(sel_0, index=index) elements.append((name, sel)) sel_1 = computation_building_blocks.Selection(ref, index=1) elements.append((name2, sel_1)) result = computation_building_blocks.Tuple(elements) symbols = ((ref.name, comps), ) return computation_building_blocks.Block(symbols, result)
def _transform_functional_args(comps): r"""Transforms the functional computations `comps`. Given a computation containing `n` called intrinsics with `m` arguments, this function constructs the following computation from the functional arguments of the called intrinsic: Block / \ [fn=Tuple] Lambda(arg) | \ [Comp(f1), Comp(f2), ...] Tuple | [Call, Call, ...] / \ / \ Sel(0) Sel(0) Sel(1) Sel(1) / / / / Ref(fn) Ref(arg) Ref(fn) Ref(arg) with one `computation_building_blocks.Call` for each `n`. This computation represents one of `m` arguments that should be passed to the call of the transformed computation. Args: comps: a Python list of computations. Returns: A `computation_building_blocks.Block`. """ functions = computation_building_blocks.Tuple(comps) functions_name = six.next(name_generator) functions_ref = computation_building_blocks.Reference( functions_name, functions.type_signature) arg_name = six.next(name_generator) arg_type = [element.type_signature.parameter for element in comps] arg_ref = computation_building_blocks.Reference(arg_name, arg_type) elements = [] for index in range(len(comps)): sel_fn = computation_building_blocks.Selection(functions_ref, index=index) sel_arg = computation_building_blocks.Selection(arg_ref, index=index) call = computation_building_blocks.Call(sel_fn, sel_arg) elements.append(call) calls = computation_building_blocks.Tuple(elements) fn = computation_building_blocks.Lambda(arg_ref.name, arg_ref.type_signature, calls) return computation_building_blocks.Block( ((functions_ref.name, functions), ), fn)
def __setattr__(self, name, value): py_typecheck.check_type(name, six.string_types) if not isinstance(self._comp.type_signature, computation_types.NamedTupleType): raise TypeError( 'Operator setattr() is only supported for named tuples, but the ' 'object on which it has been invoked is of type {}.'.format( str(self._comp.type_signature))) if name not in dir(self._comp.type_signature): raise AttributeError( 'There is no such attribute as \'{}\' in this tuple. ' 'TFF does not allow for assigning to a nonexistent attribute. ' 'If you want to assign to \'{}\', you must create a new named tuple ' 'containing this attribute.'.format(name, name)) elem_array = [] type_signature_elements = anonymous_tuple.to_elements( self._comp.type_signature) for k, v in type_signature_elements: if k == name: try: value = to_value(value, v, self._context_stack) except TypeError: raise TypeError( 'Setattr has attempted to set element {} of type {} ' 'with incompatible item {}.'.format(k, v, value)) elem_array.append((k, ValueImpl.get_comp(value))) else: elem_array.append( (k, computation_building_blocks.Selection(self._comp, name=k))) new_comp = computation_building_blocks.Tuple([(k, v) for k, v in elem_array]) super(ValueImpl, self).__setattr__('_comp', new_comp)
def construct_federated_getitem_call(arg, idx): """Calls intrinsic `ValueImpl`, passing getitem to a federated value. The main piece of orchestration plugging __getitem__ call together with a federated value. Args: arg: Instance of `computation_building_blocks.ComputationBuildingBlock` of `computation_types.FederatedType` with member of type `computation_types.NamedTupleType` from which we wish to pick out item `idx`. idx: Index, instance of `int` or `slice` used to address the `computation_types.NamedTupleType` underlying `arg`. Returns: Returns an instance of `ValueImpl` of type `computation_types.FederatedType` of same placement as `arg`, the result of applying or mapping the appropriate `__getitem__` function, as defined by `idx`. """ py_typecheck.check_type(arg, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(idx, (int, slice)) py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) py_typecheck.check_type(arg.type_signature.member, computation_types.NamedTupleType) getitem_comp = construct_federated_getitem_comp(arg, idx) intrinsic = construct_map_or_apply(getitem_comp, arg) call = computation_building_blocks.Call( intrinsic, computation_building_blocks.Tuple([getitem_comp, arg])) return call
def _create_transformed_args_from_comps(call_names, elements): """Constructs a Python list of transformed computations. Given the "original" computation containing `n` called intrinsics with `m` arguments, this function constructs the following Python list of computations: [Block, Tuple, ...] with one `computation_building_blocks.Block` for each functional computation in `m` and one `computation_building_blocks.Tuple` for each non-functional computation in `m`. This list of computations represent the arguments that should be passed to the `computation_building_blocks.Call` of the "transformed" computation. Args: call_names: a Python list of names. elements: A 2 dimentional Python list of computations. Returns: A Python list of computations. """ args = [] for comps in elements: if isinstance(comps[0].type_signature, computation_types.FunctionType): arg = _create_block_to_calls(call_names, comps) else: arg = computation_building_blocks.Tuple( zip(call_names, comps)) args.append(arg) return args
def test_returns_string_for_comp_with_right_overhang(self): ref = computation_building_blocks.Reference('a', tf.int32) data = computation_building_blocks.Data('data', tf.int32) tup = computation_building_blocks.Tuple([ref, data, data, data, data]) sel = computation_building_blocks.Selection(tup, index=0) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, sel) comp = computation_building_blocks.Call(fn, data) compact_string = comp.compact_representation() self.assertEqual(compact_string, '(a -> <a,data,data,data,data>[0])(data)') formatted_string = comp.formatted_representation() # pyformat: disable self.assertEqual( formatted_string, '(a -> <\n' ' a,\n' ' data,\n' ' data,\n' ' data,\n' ' data\n' '>[0])(data)') # pyformat: enable structural_string = comp.structural_representation() # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'Lambda(a) data\n' '|\n' 'Sel(0)\n' '|\n' 'Tuple\n' '|\n' '[Ref(a), data, data, data, data]')
def construct_map_or_apply(fn, arg): """Injects intrinsic to allow application of `fn` to federated `arg`. Args: fn: An instance of `computation_building_blocks.ComputationBuildingBlock` of functional type to be wrapped with intrinsic in order to call on `arg`. arg: `computation_building_blocks.ComputationBuildingBlock` instance of federated type for which to construct intrinsic in order to call `fn` on `arg`. `member` of `type_signature` of `arg` must be assignable to `parameter` of `type_signature` of `fn`. Returns: Returns a `computation_building_blocks.Intrinsic` which can call `fn` on `arg`. """ py_typecheck.check_type(fn, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(arg, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) type_utils.check_assignable_from(fn.type_signature.parameter, arg.type_signature.member) if arg.type_signature.placement == placement_literals.SERVER: result_type = computation_types.FederatedType(fn.type_signature.result, arg.type_signature.placement, arg.type_signature.all_equal) intrinsic_type = computation_types.FunctionType( [fn.type_signature, arg.type_signature], result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_APPLY.uri, intrinsic_type) tup = computation_building_blocks.Tuple((fn, arg)) return computation_building_blocks.Call(intrinsic, tup) elif arg.type_signature.placement == placement_literals.CLIENTS: return create_federated_map(fn, arg)
def test_execute_with_nested_lambda(self): int32_add = computation_building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto( computations.tf_computation(tf.add, [tf.int32, tf.int32]))) curried_int32_add = computation_building_blocks.Lambda( 'x', tf.int32, computation_building_blocks.Lambda( 'y', tf.int32, computation_building_blocks.Call( int32_add, computation_building_blocks.Tuple( [(None, computation_building_blocks.Reference( 'x', tf.int32)), (None, computation_building_blocks.Reference( 'y', tf.int32))])))) make_10 = computation_building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto( computations.tf_computation(lambda: tf.constant(10)))) add_10 = computation_building_blocks.Call( curried_int32_add, computation_building_blocks.Call(make_10)) add_10_computation = computation_impl.ComputationImpl( add_10.proto, context_stack_impl.context_stack) self.assertEqual(add_10_computation(5), 15)
def test_federated_zip(self): loop = asyncio.get_event_loop() ex = _make_test_executor(3) @computations.federated_computation def ten_on_server(): return intrinsics.federated_value(10, placements.SERVER) @computations.federated_computation def ten_on_clients(): return intrinsics.federated_value(10, placements.CLIENTS) for ten, type_string, cardinality in [ (ten_on_server, '<int32,int32>@SERVER', 1), (ten_on_clients, '{<int32,int32>}@CLIENTS', 3) ]: comp = computation_constructing_utils.create_zip_two_values( computation_building_blocks.Tuple([ computation_building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto(ten)) ] * 2)) val = loop.run_until_complete( ex.create_value(comp.proto, comp.type_signature)) self.assertIsInstance(val, federated_executor.FederatedExecutorValue) self.assertEqual(str(val.type_signature), type_string) self.assertIsInstance(val.internal_representation, list) self.assertLen(val.internal_representation, cardinality) for v in val.internal_representation: self.assertIsInstance(v, anonymous_tuple.AnonymousTuple) self.assertLen(v, 2) for x in v: self.assertIsInstance(x, eager_executor.EagerValue) self.assertEqual(x.internal_representation.numpy(), 10)
def get_curried(func): """Returns a curried version of function `func` that takes a parameter tuple. For functions `func` of types <T1,T2,....,Tn> -> U, the result is a function of the form T1 -> (T2 -> (T3 -> .... (Tn -> U) ... )). NOTE: No attempt is made at avoiding naming conflicts in cases where `func` contains references. The arguments of the curriend function are named `argN` with `N` starting at 0. Args: func: A value of a functional TFF type. Returns: A value that represents the curried form of `func`. """ py_typecheck.check_type(func, value_base.Value) py_typecheck.check_type(func.type_signature, computation_types.FunctionType) py_typecheck.check_type(func.type_signature.parameter, computation_types.NamedTupleType) param_elements = anonymous_tuple.to_elements(func.type_signature.parameter) references = [] for idx, (_, elem_type) in enumerate(param_elements): references.append( computation_building_blocks.Reference('arg{}'.format(idx), elem_type)) result = computation_building_blocks.Call( value_impl.ValueImpl.get_comp(func), computation_building_blocks.Tuple(references)) for ref in references[::-1]: result = computation_building_blocks.Lambda(ref.name, ref.type_signature, result) return value_impl.ValueImpl(result, value_impl.ValueImpl.get_context_stack(func))
def _create_call_to_federated_map(fn, arg): r"""Creates a computation to call a federated map. Call / \ Intrinsic Tuple / \ Computation Computation Args: fn: An instance of a functional `computation_building_blocks.ComputationBuildingBlock` to use as the map function. arg: An instance of `computation_building_blocks.ComputationBuildingBlock` to use as the map argument. Returns: An instance of `computation_building_blocks.Call` wrapping the federated map computation. """ py_typecheck.check_type( fn, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type( arg, computation_building_blocks.ComputationBuildingBlock) federated_type = computation_types.FederatedType(fn.type_signature.result, placements.CLIENTS) function_type = computation_types.FunctionType( [fn.type_signature, arg.type_signature], federated_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, function_type) tup = computation_building_blocks.Tuple((fn, arg)) return computation_building_blocks.Call(intrinsic, tup)
def create_federated_map(fn, arg): r"""Creates a called federated map. Call / \ Intrinsic Tuple | [Comp, Comp] Args: fn: A `computation_building_blocks.ComputationBuildingBlock` to use as the function. arg: A `computation_building_blocks.ComputationBuildingBlock` to use as the argument. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( fn, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type( arg, computation_building_blocks.ComputationBuildingBlock) result_type = computation_types.FederatedType(fn.type_signature.result, placement_literals.CLIENTS, False) intrinsic_type = computation_types.FunctionType( (fn.type_signature, arg.type_signature), result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, intrinsic_type) values = computation_building_blocks.Tuple((fn, arg)) return computation_building_blocks.Call(intrinsic, values)
def construct_federated_getattr_call(arg, name): """Constructs computation building block passing getattr to federated value. Args: arg: Instance of `computation_building_blocks.ComputationBuildingBlock` of `computation_types.FederatedType` with member of type `computation_types.NamedTupleType` from which we wish to pick out item `name`. name: String name to address the `computation_types.NamedTupleType` underlying `arg`. Returns: Returns a `computation_building_blocks.Call` with type signature `computation_types.FederatedType` of same placement as `arg`, the result of applying or mapping the appropriate `__getattr__` function, as defined by `name`. """ py_typecheck.check_type( arg, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type(name, six.string_types) py_typecheck.check_type(arg.type_signature, computation_types.FederatedType) py_typecheck.check_type(arg.type_signature.member, computation_types.NamedTupleType) getattr_comp = construct_federated_getattr_comp(arg, name) intrinsic = construct_map_or_apply(getattr_comp, arg) call = computation_building_blocks.Call( intrinsic, computation_building_blocks.Tuple([getattr_comp, arg])) return call
def _transform_non_functional_args(comps): r"""Transforms the non-functional computations `comps`. Given a computation containing `n` called intrinsics with `m` arguments, this function constructs the following computation from the non-functional arguments of the called intrinsic: federated_zip(Tuple) | [Comp, Comp, ...] or Tuple | [Comp, Comp, ...] with one `computation_building_blocks.ComputationBuildignBlock` for each `n`. This computation represents one of `m` arguments that should be passed to the call of the transformed computation. Args: comps: A Python list of computations. Returns: A `computation_building_blocks.Block`. """ values = computation_building_blocks.Tuple(comps) first_comp = comps[0] if isinstance(first_comp.type_signature, computation_types.FederatedType): return computation_constructing_utils.create_federated_zip(values) else: return values
def sequence_reduce(self, value, zero, op): """Implements `sequence_reduce` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. zero: As in `api/intrinsics.py`. op: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) zero = value_impl.to_value(zero, None, self._context_stack) op = value_impl.to_value(op, None, self._context_stack) if isinstance(value.type_signature, computation_types.SequenceType): element_type = value.type_signature.element else: py_typecheck.check_type(value.type_signature, computation_types.FederatedType) py_typecheck.check_type(value.type_signature.member, computation_types.SequenceType) element_type = value.type_signature.member.element op_type_expected = type_constructors.reduction_op( zero.type_signature, element_type) if not type_utils.is_assignable_from(op_type_expected, op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( op_type_expected, op.type_signature)) value = value_impl.ValueImpl.get_comp(value) zero = value_impl.ValueImpl.get_comp(zero) op = value_impl.ValueImpl.get_comp(op) if isinstance(value.type_signature, computation_types.SequenceType): return computation_constructing_utils.create_sequence_reduce( value, zero, op) else: value_type = computation_types.SequenceType(element_type) intrinsic_type = computation_types.FunctionType(( value_type, zero.type_signature, op.type_signature, ), op.type_signature.result) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_REDUCE.uri, intrinsic_type) ref = computation_building_blocks.Reference('arg', value_type) tup = computation_building_blocks.Tuple((ref, zero, op)) call = computation_building_blocks.Call(intrinsic, tup) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, call) fn_impl = value_impl.ValueImpl(fn, self._context_stack) if value.type_signature.placement is placements.SERVER: return self.federated_apply(fn_impl, value) elif value.type_signature.placement is placements.CLIENTS: return self.federated_map(fn_impl, value) else: raise TypeError('Unsupported placement {}.'.format( value.type_signature.placement))
def sequence_reduce(self, value, zero, op): """Implements `sequence_reduce` as defined in `api/intrinsics.py`. Args: value: As in `api/intrinsics.py`. zero: As in `api/intrinsics.py`. op: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ value = value_impl.to_value(value, None, self._context_stack) zero = value_impl.to_value(zero, None, self._context_stack) op = value_impl.to_value(op, None, self._context_stack) if isinstance(value.type_signature, computation_types.SequenceType): element_type = value.type_signature.element else: py_typecheck.check_type(value.type_signature, computation_types.FederatedType) py_typecheck.check_type(value.type_signature.member, computation_types.SequenceType) element_type = value.type_signature.member.element op_type_expected = type_constructors.reduction_op(zero.type_signature, element_type) if not type_utils.is_assignable_from(op_type_expected, op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( str(op_type_expected), str(op.type_signature))) sequence_reduce_building_block = computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_REDUCE.uri, computation_types.FunctionType([ computation_types.SequenceType(element_type), zero.type_signature, op.type_signature ], zero.type_signature)) if isinstance(value.type_signature, computation_types.SequenceType): sequence_reduce_intrinsic = value_impl.ValueImpl( sequence_reduce_building_block, self._context_stack) return sequence_reduce_intrinsic(value, zero, op) else: federated_mapping_fn_building_block = computation_building_blocks.Lambda( 'arg', computation_types.SequenceType(element_type), computation_building_blocks.Call( sequence_reduce_building_block, computation_building_blocks.Tuple([ computation_building_blocks.Reference( 'arg', computation_types.SequenceType(element_type)), value_impl.ValueImpl.get_comp(zero), value_impl.ValueImpl.get_comp(op) ]))) federated_mapping_fn = value_impl.ValueImpl( federated_mapping_fn_building_block, self._context_stack) if value.type_signature.placement is placements.SERVER: return self.federated_apply(federated_mapping_fn, value) elif value.type_signature.placement is placements.CLIENTS: return self.federated_map(federated_mapping_fn, value) else: raise TypeError('Unsupported placement {}.'.format( str(value.type_signature.placement)))
def create_sequence_map(fn, arg): r"""Creates a called sequence map. Call / \ Intrinsic Tuple | [Comp, Comp] Args: fn: A `computation_building_blocks.ComputationBuildingBlock` to use as the function. arg: A `computation_building_blocks.ComputationBuildingBlock` to use as the argument. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( fn, computation_building_blocks.ComputationBuildingBlock) py_typecheck.check_type( arg, computation_building_blocks.ComputationBuildingBlock) result_type = computation_types.SequenceType(fn.type_signature.result) intrinsic_type = computation_types.FunctionType( (fn.type_signature, arg.type_signature), result_type) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type) values = computation_building_blocks.Tuple((fn, arg)) return computation_building_blocks.Call(intrinsic, values)
def test_replace_chained_federated_maps_does_not_replace_unchained_federated_maps( self): map_arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) map_arg = computation_building_blocks.Reference('arg', map_arg_type) inner_lambda = _create_lambda_to_add_one(map_arg.type_signature.member) inner_call = _create_call_to_federated_map(inner_lambda, map_arg) dummy_tuple = computation_building_blocks.Tuple([inner_call]) dummy_selection = computation_building_blocks.Selection(dummy_tuple, index=0) outer_lambda = _create_lambda_to_add_one( inner_call.function.type_signature.result.member) outer_call = _create_call_to_federated_map(outer_lambda, dummy_selection) map_lambda = computation_building_blocks.Lambda( map_arg.name, map_arg.type_signature, outer_call) comp = map_lambda uri = intrinsic_defs.FEDERATED_MAP.uri self.assertEqual(_get_number_of_intrinsics(comp, uri), 2) comp_impl = _to_comp(comp) self.assertEqual(comp_impl([(1)]), [3]) transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 2) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl([(1)]), [3])
def _create_lambda_to_add_one(dtype): r"""Creates a computation to add `1` to an argument. Lambda \ Call / \ Intrinsic Tuple / \ Reference Computation Args: dtype: The type of the argument. Returns: An instance of `computation_building_blocks.Lambda` wrapping a function that adds 1 to an argument. """ if isinstance(dtype, computation_types.TensorType): dtype = dtype.dtype py_typecheck.check_type(dtype, tf.dtypes.DType) function_type = computation_types.FunctionType([dtype, dtype], dtype) intrinsic = computation_building_blocks.Intrinsic( intrinsic_defs.GENERIC_PLUS.uri, function_type) arg = computation_building_blocks.Reference('arg', dtype) constant = _create_call_to_py_fn(lambda: tf.cast(tf.constant(1), dtype)) tup = computation_building_blocks.Tuple([arg, constant]) call = computation_building_blocks.Call(intrinsic, tup) return computation_building_blocks.Lambda(arg.name, arg.type_signature, call)
def test_scope_snapshot_called_lambdas(self): comp = computation_building_blocks.Tuple( [computation_building_blocks.Data('test', tf.int32)]) input1 = computation_building_blocks.Reference('input1', comp.type_signature) first_level_call = computation_building_blocks.Call( computation_building_blocks.Lambda('input1', input1.type_signature, input1), comp) input2 = computation_building_blocks.Reference( 'input2', first_level_call.type_signature) second_level_call = computation_building_blocks.Call( computation_building_blocks.Lambda('input2', input2.type_signature, input2), first_level_call) self.assertEqual(str(second_level_call), '(input2 -> input2)((input1 -> input1)(<test>))') global_snapshot = transformations.scope_count_snapshot( second_level_call) self.assertEqual( global_snapshot, { '(input2 -> input2)': { 'input2': 1 }, '(input1 -> input1)': { 'input1': 1 } })
def _create_block_to_calls(call_names, comps): r"""Constructs a transformed block computation from `comps`. Given the "original" computation containing `n` called intrinsics with `m` arguments, this function constructs the following computation: Block / \ fn=Tuple Lambda(arg) | \ [Comp(f1), Comp(f2), ...] Tuple | [Call, Call, ...] / \ / \ Sel(0) Sel(0) Sel(1) Sel(1) / / / / Ref(fn) Ref(arg) Ref(fn) Ref(arg) with one `computation_building_blocks.Call` for each `n`. This computation represents one of `m` arguments that should be passed to the call of the "transformed" computation. Args: call_names: a Python list of names. comps: a Python list of computations. Returns: A `computation_building_blocks.Block`. """ functions = computation_building_blocks.Tuple( zip(call_names, comps)) fn = computation_building_blocks.Reference( 'fn', functions.type_signature) arg_type = [element.type_signature.parameter for element in comps] arg = computation_building_blocks.Reference('arg', arg_type) elements = [] for index, name in enumerate(call_names): sel_fn = computation_building_blocks.Selection(fn, index=index) sel_arg = computation_building_blocks.Selection(arg, index=index) call = computation_building_blocks.Call(sel_fn, sel_arg) elements.append((name, call)) calls = computation_building_blocks.Tuple(elements) lam = computation_building_blocks.Lambda(arg.name, arg.type_signature, calls) return computation_building_blocks.Block([('fn', functions)], lam)
def _transform(comp): """Returns a new transformed computation or `comp`.""" if not _should_transform(comp): return comp, False def _create_block_to_chained_calls(comps): r"""Constructs a transformed block computation from `comps`. Block / \ [fn=Tuple] Lambda(arg) | \ [Comp(y), Comp(x)] Call / \ Sel(1) Call / / \ Ref(fn) Sel(0) Ref(arg) / Ref(fn) (let fn=<y, x> in (arg -> fn[1](fn[0](arg))) Args: comps: a Python list of computations. Returns: A `computation_building_blocks.Block`. """ functions = computation_building_blocks.Tuple(comps) fn_ref = computation_building_blocks.Reference( 'fn', functions.type_signature) arg_type = comps[0].type_signature.parameter arg_ref = computation_building_blocks.Reference('arg', arg_type) arg = arg_ref for index, _ in enumerate(comps): fn_sel = computation_building_blocks.Selection(fn_ref, index=index) call = computation_building_blocks.Call(fn_sel, arg) arg = call lam = computation_building_blocks.Lambda(arg_ref.name, arg_ref.type_signature, call) return computation_building_blocks.Block([('fn', functions)], lam) block = _create_block_to_chained_calls(( comp.argument[1].argument[0], comp.argument[0], )) arg = computation_building_blocks.Tuple([ block, comp.argument[1].argument[1], ]) intrinsic_type = computation_types.FunctionType( arg.type_signature, comp.function.type_signature.result) intrinsic = computation_building_blocks.Intrinsic( comp.function.uri, intrinsic_type) transformed_comp = computation_building_blocks.Call(intrinsic, arg) return transformed_comp, True
def construct_binary_operator_with_upcast(type_signature, operator): """Constructs lambda upcasting its argument and applying `operator`. The concept of upcasting is explained further in the docstring for `apply_binary_operator_with_upcast`. Notice that since we are constructing a function here, e.g. for the body of an intrinsic, the function we are constructing must be reducible to TensorFlow. Therefore `type_signature` can only have named tuple or tensor type elements; that is, we cannot handle federated types here in a generic way. Args: type_signature: Value convertible to `computation_types.NamedTupleType`, with two elements, both of the same type or the second able to be upcast to the first, as explained in `apply_binary_operator_with_upcast`, and both containing only tuples and tensors in their type tree. operator: Callable defining the operator. Returns: A `computation_building_blocks.Lambda` encapsulating a function which upcasts the second element of its argument and applies the binary operator. """ py_typecheck.check_callable(operator) type_signature = computation_types.to_type(type_signature) _check_generic_operator_type(type_signature) ref_to_arg = computation_building_blocks.Reference('binary_operator_arg', type_signature) def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if isinstance(type_spec, computation_types.NamedTupleType): elems = anonymous_tuple.to_elements(type_spec) packed_elems = [(elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elems] return computation_building_blocks.Tuple(packed_elems) elif isinstance(type_spec, computation_types.TensorType): expand_fn = computation_constructing_utils.construct_tensorflow_to_broadcast_scalar( to_pack.type_signature.dtype, type_spec.shape) return computation_building_blocks.Call(expand_fn, to_pack) y_ref = computation_building_blocks.Selection(ref_to_arg, index=1) first_arg = computation_building_blocks.Selection(ref_to_arg, index=0) if type_utils.are_equivalent_types(first_arg.type_signature, y_ref.type_signature): second_arg = y_ref else: second_arg = _pack_into_type(y_ref, first_arg.type_signature) fn = computation_constructing_utils.construct_tensorflow_binary_operator( first_arg.type_signature, operator) packed = computation_building_blocks.Tuple([first_arg, second_arg]) operated = computation_building_blocks.Call(fn, packed) lambda_encapsulating_op = computation_building_blocks.Lambda( ref_to_arg.name, ref_to_arg.type_signature, operated) return lambda_encapsulating_op
def _create_chain_zipped_values(value): r"""Creates a chain of called federated zip with two values. Block-------- / \ [value=Tuple] Call | / \ [Comp1, Intrinsic Tuple Comp2, | ...] [Call, Sel(n)] / \ \ Intrinsic Tuple Ref(value) | [Sel(0), Sel(1)] \ \ Ref(value) Ref(value) NOTE: This function is intended to be used in conjunction with `_create_fn_to_append_chain_zipped_values` and will drop the tuple names. The names will be added back to the resulting computation when the zipped values are mapped to a function that flattens the chain. This nested zip -> flatten structure must be used since length of a named tuple type in the TFF type system is an element of the type proper. That is, a named tuple type of length 2 is a different type than a named tuple type of length 3, they are not simply items with the same type and different values, as would be the case if you were thinking of these as Python `list`s. It may be better to think of named tuple types in TFF as more like `struct`s. Args: value: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_types.NamedTupleType` containing at least two elements. Returns: A `computation_building_blocks.Call`. Raises: TypeError: If any of the types do not match. ValueError: If `value` does not contain at least two elements. """ py_typecheck.check_type( value, computation_building_blocks.ComputationBuildingBlock) named_type_signatures = anonymous_tuple.to_elements(value.type_signature) length = len(named_type_signatures) if length < 2: raise ValueError( 'Expected a value with at least two elements, received {} elements.' .format(named_type_signatures)) ref = computation_building_blocks.Reference('value', value.type_signature) symbols = ((ref.name, value), ) sel_0 = computation_building_blocks.Selection(ref, index=0) result = sel_0 for i in range(1, length): sel = computation_building_blocks.Selection(ref, index=i) values = computation_building_blocks.Tuple((result, sel)) result = _create_zip_two_values(values) return computation_building_blocks.Block(symbols, result)
def test_tensorflow_op_count_doubles_number_of_ops_in_two_tuple(self): two_variable_comp = _create_two_variable_tensorflow() node_tf_variable_count = computation_building_block_utils.count_tensorflow_variables_in( two_variable_comp) tf_tuple = computation_building_blocks.Tuple( [two_variable_comp, two_variable_comp]) tree_tf_variable_count = tree_analysis.count_tensorflow_variables_under( tf_tuple) self.assertEqual(tree_tf_variable_count, 2 * node_tf_variable_count)
def test_propogates_dependence_up_through_tuple(self): dummy_intrinsic = computation_building_blocks.Intrinsic( 'dummy_intrinsic', tf.int32) integer_reference = computation_building_blocks.Reference('int', tf.int32) tup = computation_building_blocks.Tuple( [integer_reference, dummy_intrinsic]) dependent_nodes = tree_analysis.extract_nodes_consuming( tup, dummy_intrinsic_predicate) self.assertIn(tup, dependent_nodes)
def test_simple_reduce_lambda(self): x = computation_building_blocks.Reference('x', [tf.int32]) l = computation_building_blocks.Lambda('x', [tf.int32], x) input_val = computation_building_blocks.Tuple( [computation_building_blocks.Data('test', tf.int32)]) called = computation_building_blocks.Call(l, input_val) self.assertEqual(str(called), '(x -> x)(<test>)') reduced = transformations.replace_called_lambdas_with_block(called) self.assertEqual(str(reduced), '(let x=<test> in x)')
def test_tensorflow_op_count_doubles_number_of_ops_in_two_tuple(self): integer_identity = computation_constructing_utils.create_compiled_identity( tf.int32) node_tf_op_count = computation_building_block_utils.count_tensorflow_ops_in( integer_identity) tf_tuple = computation_building_blocks.Tuple( [integer_identity, integer_identity]) tree_tf_op_count = tree_analysis.count_tensorflow_ops_under(tf_tuple) self.assertEqual(tree_tf_op_count, 2 * node_tf_op_count)