def test_are_equivalent_types(self): t1 = computation_types.TensorType(tf.int32, [None]) t2 = computation_types.TensorType(tf.int32, [10]) t3 = computation_types.TensorType(tf.int32, [10]) self.assertTrue(type_utils.are_equivalent_types(t1, t1)) self.assertTrue(type_utils.are_equivalent_types(t2, t3)) self.assertTrue(type_utils.are_equivalent_types(t3, t2)) self.assertFalse(type_utils.are_equivalent_types(t1, t2)) self.assertFalse(type_utils.are_equivalent_types(t2, t1))
def test_unpack_args_from_tuple_type(self, tuple_with_args, expected_args, expected_kwargs): args, kwargs = function_utils.unpack_args_from_tuple(tuple_with_args) self.assertEqual(len(args), len(expected_args)) for idx, arg in enumerate(args): self.assertTrue( type_utils.are_equivalent_types( arg, computation_types.to_type(expected_args[idx]))) self.assertEqual(set(kwargs.keys()), set(expected_kwargs.keys())) for k, v in six.iteritems(kwargs): self.assertTrue( type_utils.are_equivalent_types(computation_types.to_type(v), expected_kwargs[k]))
def from_proto(cls, computation_proto): """Returns an instance of a derived class based on 'computation_proto'. Args: computation_proto: An instance of pb.Computation. Returns: An instance of a class that implements 'ComputationBuildingBlock' and that contains the deserialized logic from in 'computation_proto'. Raises: NotImplementedError: if computation_proto contains a kind of computation for which deserialization has not been implemented yet. ValueError: if deserialization failed due to the argument being invalid. """ py_typecheck.check_type(computation_proto, pb.Computation) computation_oneof = computation_proto.WhichOneof('computation') deserializer = cls._deserializer_dict.get(computation_oneof) if deserializer is not None: deserialized = deserializer(computation_proto) type_spec = type_serialization.deserialize_type( computation_proto.type) if not type_utils.are_equivalent_types(deserialized.type_signature, type_spec): raise ValueError( 'The type {} derived from the computation structure does not ' 'match the type {} declared in its signature'.format( str(deserialized.type_signature), str(type_spec))) return deserialized else: raise NotImplementedError( 'Deserialization for computations of type {} has not been ' 'implemented yet.'.format(computation_oneof))
def construct_binary_operator_with_upcast(type_signature, operator): """Constructs lambda upcasting its argument and applying `operator`. The concept of upcasting is explained further in the docstring for `apply_binary_operator_with_upcast`. Notice that since we are constructing a function here, e.g. for the body of an intrinsic, the function we are constructing must be reducible to TensorFlow. Therefore `type_signature` can only have named tuple or tensor type elements; that is, we cannot handle federated types here in a generic way. Args: type_signature: Value convertible to `computation_types.NamedTupleType`, with two elements, both of the same type or the second able to be upcast to the first, as explained in `apply_binary_operator_with_upcast`, and both containing only tuples and tensors in their type tree. operator: Callable defining the operator. Returns: A `computation_building_blocks.Lambda` encapsulating a function which upcasts the second element of its argument and applies the binary operator. """ py_typecheck.check_callable(operator) type_signature = computation_types.to_type(type_signature) _check_generic_operator_type(type_signature) ref_to_arg = computation_building_blocks.Reference('binary_operator_arg', type_signature) def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if isinstance(type_spec, computation_types.NamedTupleType): elems = anonymous_tuple.to_elements(type_spec) packed_elems = [(elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elems] return computation_building_blocks.Tuple(packed_elems) elif isinstance(type_spec, computation_types.TensorType): expand_fn = computation_constructing_utils.construct_tensorflow_to_broadcast_scalar( to_pack.type_signature.dtype, type_spec.shape) return computation_building_blocks.Call(expand_fn, to_pack) y_ref = computation_building_blocks.Selection(ref_to_arg, index=1) first_arg = computation_building_blocks.Selection(ref_to_arg, index=0) if type_utils.are_equivalent_types(first_arg.type_signature, y_ref.type_signature): second_arg = y_ref else: second_arg = _pack_into_type(y_ref, first_arg.type_signature) fn = computation_constructing_utils.construct_tensorflow_binary_operator( first_arg.type_signature, operator) packed = computation_building_blocks.Tuple([first_arg, second_arg]) operated = computation_building_blocks.Call(fn, packed) lambda_encapsulating_op = computation_building_blocks.Lambda( ref_to_arg.name, ref_to_arg.type_signature, operated) return lambda_encapsulating_op
async def create_value(self, value, type_spec=None): type_spec = computation_types.to_type(type_spec) if isinstance(value, computation_impl.ComputationImpl): return await self.create_value( computation_impl.ComputationImpl.get_proto(value), type_utils.reconcile_value_with_type_spec(value, type_spec)) py_typecheck.check_type(type_spec, computation_types.Type) hashable_key = _get_hashable_key(value, type_spec) try: identifier = self._cache[hashable_key] except KeyError: identifier = None except TypeError as err: raise RuntimeError( 'Failed to perform a hash table lookup with a value of Python ' 'type {} and TFF type {}, and payload {}: {}'.format( py_typecheck.type_string(type(value)), type_spec, value, err)) if isinstance(identifier, CachedValueIdentifier): try: cached_value = self._cache[identifier] except KeyError: cached_value = None # If may be that the same payload appeared with a mismatching type spec, # which may be a legitimate use case if (as it happens) the payload alone # does not uniquely determine the type, so we simply opt not to reuse the # cache value and fallback on the regular behavior. if (cached_value is not None and type_spec is not None and not type_utils.are_equivalent_types( cached_value.type_signature, type_spec)): identifier = None else: identifier = None if identifier is None: self._num_values_created = self._num_values_created + 1 identifier = CachedValueIdentifier(str(self._num_values_created)) self._cache[hashable_key] = identifier target_future = asyncio.ensure_future( self._target_executor.create_value(value, type_spec)) cached_value = None if cached_value is None: cached_value = CachedValue(identifier, hashable_key, type_spec, target_future) self._cache[identifier] = cached_value try: await cached_value.target_future except Exception as e: # Invalidate the entire cache in the inner executor had an exception. # TODO(b/145514490): This is a bit heavy handed, there maybe caches where # only the current cache item needs to be invalidated; however this # currently only occurs when an inner RemoteExecutor has the backend go # down. self._cache = {} raise e # No type check is necessary here; we have either checked # `type_utils.are_equivalent_types` or just constructed `target_value` # explicitly with `type_spec`. return cached_value
def test_construct_setattr_named_tuple_type_leaves_type_signature_unchanged( self): good_type = computation_types.NamedTupleType([('a', tf.int32), (None, tf.float32), ('b', tf.bool)]) value_comp = computation_building_blocks.Data('x', tf.int32) lam = computation_constructing_utils.construct_named_tuple_setattr_lambda( good_type, 'a', value_comp) self.assertTrue( type_utils.are_equivalent_types(lam.type_signature.parameter, lam.type_signature.result))
def test_federated_setattr_call_leaves_type_signatures_alone(self, placement): named_tuple_type = computation_types.NamedTupleType([('a', tf.int32), (None, tf.float32), ('b', tf.bool)]) good_type = computation_types.FederatedType(named_tuple_type, placement) federated_comp = computation_building_blocks.Data('federated_comp', good_type) value_comp = computation_building_blocks.Data('x', tf.int32) federated_setattr = computation_constructing_utils.construct_federated_setattr_call( federated_comp, 'a', value_comp) self.assertTrue( type_utils.are_equivalent_types(federated_setattr.type_signature, federated_comp.type_signature))
def _serialize_deserialize_roundtrip_test(self, type_list): """Performs roundtrip serialization/deserialization of computation_types. Args: type_list: A list of instances of computation_types.Type or things convertible to it. """ for t in type_list: t1 = computation_types.to_type(t) p1 = type_serialization.serialize_type(t1) t2 = type_serialization.deserialize_type(p1) p2 = type_serialization.serialize_type(t2) self.assertEqual(repr(t1), repr(t2)) self.assertEqual(repr(p1), repr(p2)) self.assertTrue(type_utils.are_equivalent_types(t1, t2))
def __add__(self, other): other = to_value(other, None, self._context_stack) if not type_utils.are_equivalent_types(self.type_signature, other.type_signature): raise TypeError('Cannot add {} and {}.'.format( str(self.type_signature), str(other.type_signature))) return ValueImpl( computation_building_blocks.Call( computation_building_blocks.Intrinsic( intrinsic_defs.GENERIC_PLUS.uri, computation_types.FunctionType( [self.type_signature, self.type_signature], self.type_signature)), ValueImpl.get_comp( to_value([self, other], None, self._context_stack))), self._context_stack)
async def create_value(self, value, type_spec=None): type_spec = computation_types.to_type(type_spec) if isinstance(value, computation_impl.ComputationImpl): return await self.create_value( computation_impl.ComputationImpl.get_proto(value), type_utils.reconcile_value_with_type_spec(value, type_spec)) py_typecheck.check_type(type_spec, computation_types.Type) hashable_key = _get_hashable_key(value, type_spec) try: identifier = self._cache[hashable_key] except KeyError: identifier = None except TypeError as err: raise RuntimeError( 'Failed to perform a has table lookup with a value of Python ' 'type {} and TFF type {}, and payload {}: {}'.format( py_typecheck.type_string(type(value)), type_spec, value, err)) if isinstance(identifier, CachedValueIdentifier): try: cached_value = self._cache[identifier] except KeyError: cached_value = None # If may be that the same payload appeared with a mismatching type spec, # which may be a legitimate use case if (as it happens) the payload alone # does not uniquely determine the type, so we simply opt not to reuse the # cache value and fallback on the regular behavior. if type_spec is not None and not type_utils.are_equivalent_types( cached_value.type_signature, type_spec): identifier = None else: identifier = None if identifier is None: self._num_values_created = self._num_values_created + 1 identifier = CachedValueIdentifier(str(self._num_values_created)) self._cache[hashable_key] = identifier target_future = asyncio.ensure_future( self._target_executor.create_value(value, type_spec)) cached_value = None if cached_value is None: cached_value = CachedValue(identifier, hashable_key, type_spec, target_future) self._cache[identifier] = cached_value target_value = await cached_value.target_future type_utils.check_assignable_from(type_spec, target_value.type_signature) return cached_value
def __eq__(self, other): """Base class equality checks names and values equal.""" # TODO(b/130890785): Delegate value-checking to # `building_blocks.ComputationBuildingBlock`. if self is other: return True if not isinstance(other, BoundVariableTracker): return NotImplemented if self.name != other.name: return False if (isinstance(self.value, building_blocks.ComputationBuildingBlock) and isinstance(other.value, building_blocks.ComputationBuildingBlock)): return (self.value.compact_representation() == other.value.compact_representation() and type_utils.are_equivalent_types(self.value.type_signature, other.value.type_signature)) return self.value is other.value
def federated_secure_sum(self, value, bitwidth): """Implements `federated_secure_sum` as defined in `api/intrinsics.py`.""" value = value_impl.to_value(value, None, self._context_stack) value = value_utils.ensure_federated_value(value, placements.CLIENTS, 'value to be summed') type_utils.check_is_structure_of_integers(value.type_signature) bitwidth = value_impl.to_value(bitwidth, None, self._context_stack) value_member_ty = value.type_signature.member bitwidth_ty = bitwidth.type_signature if not type_utils.are_equivalent_types(value_member_ty, bitwidth_ty): raise TypeError( 'Expected `federated_secure_sum` parameters `value` and `bitwidth` ' 'to have the same structure. Found `value` of `{}` and `bitwidth` of `{}`' .format(value_member_ty, bitwidth_ty)) value = value_impl.ValueImpl.get_comp(value) bitwidth = value_impl.ValueImpl.get_comp(bitwidth) comp = building_block_factory.create_federated_secure_sum( value, bitwidth) return value_impl.ValueImpl(comp, self._context_stack)
def _wrap(func, parameter_type, wrapper_fn): """Wrap a given `func` with a given `parameter_type` using `wrapper_fn`. This method does not handle the multiple modes of usage as wrapper/decorator, as those are handled by ComputationWrapper below. It focused on the simple case with a function/defun (always present) and either a valid parameter type or an indication that there's no parameter (None). The only ambiguity left to resolve is whether `func` should be immediately wrapped, or treated as a polymorphic callable to be wrapped upon invocation based on actual parameter types. The determination is based on the presence or absence of parameters in the declaration of `func`. In order to be treated as a concrete no-argument computation, `func` shouldn't declare any arguments (even with default values). Args: func: The function or defun to wrap as a computation. parameter_type: The parameter type accepted by the computation, or None if there is no parameter. wrapper_fn: The Python callable that performs actual wrapping. It must accept two arguments, and optional third `name`. The first argument will be a Python function that takes either zero parameters if the computation is to be a no-parameter computation, or exactly one parameter if the computation does have a parameter. The second argument will be either None for a no-parameter computation, or the type of the computation's parameter (an instance of types.Type) if the computation has one. The third, optional parameter `name` is the optional name of the function that is being wrapped (only for debugging purposes). The object to be returned by this function should be an instance of a ConcreteFunction. Returns: Either the result of wrapping (an object that represents the computation), or a polymorphic callable that performs wrapping upon invocation based on argument types. Raises: TypeError: if the arguments are of the wrong types, or the `wrapper_fn` constructs something that isn't a ConcreteFunction. """ try: func_name = func.__name__ except AttributeError: func_name = None argspec = func_utils.get_argspec(func) parameter_type = computation_types.to_type(parameter_type) if not parameter_type: if argspec.args or argspec.varargs or argspec.keywords: # There is no TFF type specification, and the function/defun declares # parameters. Create a polymorphic template. def _wrap_polymorphic(wrapper_fn, func, parameter_type, name=func_name): return wrapper_fn(func_utils.wrap_as_zero_or_one_arg_callable( func, parameter_type, unpack=True), parameter_type, name=name) polymorphic_fn = func_utils.PolymorphicFunction( lambda pt: _wrap_polymorphic(wrapper_fn, func, pt)) # When applying a decorator, the __doc__ attribute with the documentation # in triple-quotes is not automatically transferred from the function on # which it was applied to the wrapped object, so we must transfer it here # explicitly. polymorphic_fn.__doc__ = getattr(func, '__doc__', None) return polymorphic_fn concrete_fn = wrapper_fn(func_utils.wrap_as_zero_or_one_arg_callable( func, parameter_type), parameter_type, name=func_name) py_typecheck.check_type(concrete_fn, func_utils.ConcreteFunction, 'value returned by the wrapper') if not type_utils.are_equivalent_types( concrete_fn.type_signature.parameter, parameter_type): raise TypeError( 'Expected a concrete function that takes parameter {}, got one ' 'that takes {}.'.format(str(parameter_type), str(concrete_fn.type_signature.parameter))) # When applying a decorator, the __doc__ attribute with the documentation # in triple-quotes is not automatically transferred from the function on concrete_fn.__doc__ = getattr(func, '__doc__', None) return concrete_fn
def embed_tensorflow_computation(comp, type_spec=None, device=None): """Embeds a TensorFlow computation for use in the eager context. Args: comp: An instance of `pb.Computation`. type_spec: An optional `tff.Type` instance or something convertible to it. device: An optional device name. Returns: Either a one-argument or a zero-argument callable that executes the computation in eager mode. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ # TODO(b/134543154): Decide whether this belongs in `graph_utils.py` since # it deals exclusively with eager mode. Incubate here, and potentially move # there, once stable. if device is not None: raise NotImplementedError( 'Unable to embed TF code on a specific device.') py_typecheck.check_type(comp, pb.Computation) comp_type = type_serialization.deserialize_type(comp.type) type_spec = computation_types.to_type(type_spec) if type_spec is not None: if not type_utils.are_equivalent_types(type_spec, comp_type): raise TypeError( 'Expected a computation of type {}, got {}.'.format( str(type_spec), str(comp_type))) else: type_spec = comp_type which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': raise TypeError('Expected a TensorFlow computation, found {}.'.format( which_computation)) if isinstance(type_spec, computation_types.FunctionType): param_type = type_spec.parameter result_type = type_spec.result else: param_type = None result_type = type_spec if param_type is not None: input_tensor_names = graph_utils.extract_tensor_names_from_binding( comp.tensorflow.parameter) else: input_tensor_names = [] output_tensor_names = graph_utils.extract_tensor_names_from_binding( comp.tensorflow.result) def function_to_wrap(*args): # pylint: disable=missing-docstring if len(args) != len(input_tensor_names): raise RuntimeError('Expected {} arguments, found {}.'.format( str(len(input_tensor_names)), str(len(args)))) graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) init_op = comp.tensorflow.initialize_op init_names = [init_op] if init_op else [] returned_elements = tf.import_graph_def( graph_merge.uniquify_shared_names(graph_def), input_map=dict(zip(input_tensor_names, args)), return_elements=output_tensor_names + init_names) if init_names: with tf.control_dependencies([returned_elements[-1]]): return [tf.identity(x) for x in returned_elements[0:-1]] else: return returned_elements signature = [] param_fns = [] if param_type is not None: for spec in anonymous_tuple.flatten(type_spec.parameter): if isinstance(spec, computation_types.TensorType): signature.append(tf.TensorSpec(spec.shape, spec.dtype)) param_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) signature.append(tf.TensorSpec([], tf.variant)) param_fns.append(tf.data.experimental.to_variant) wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature) result_fns = [] for spec in anonymous_tuple.flatten(result_type): if isinstance(spec, computation_types.TensorType): result_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) structure = type_utils.type_to_tf_structure(spec.element) def fn(x, structure=structure): return tf.data.experimental.from_variant(x, structure) result_fns.append(fn) def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( str(len(param_fns)), str(len(arg_parts)))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements) fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return( arg, p, w) if param_type is not None: return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda else: return lambda: fn_to_return(None)
def to_representation_for_type(value, type_spec=None, device=None): """Verifies or converts the `value` to an eager objct matching `type_spec`. WARNING: This function is only partially implemented. It does not support data sets at this point. The output of this function is always an eager tensor, eager dataset, a representation of a TensorFlow computtion, or a nested structure of those that matches `type_spec`, and when `device` has been specified, everything is placed on that device on a best-effort basis. TensorFlow computations are represented here as zero- or one-argument Python callables that accept their entire argument bundle as a single Python object. Args: value: The raw representation of a value to compare against `type_spec` and potentially to be converted. type_spec: An instance of `tff.Type`, can be `None` for values that derive from `typed_object.TypedObject`. device: The optional device to place the value on (for tensor-level values). Returns: Either `value` itself, or a modified version of it. Raises: TypeError: If the `value` is not compatible with `type_spec`. """ if device is not None: py_typecheck.check_type(device, six.string_types) with tf.device(device): return to_representation_for_type(value, type_spec=type_spec, device=None) type_spec = computation_types.to_type(type_spec) if isinstance(value, typed_object.TypedObject): if type_spec is not None: if not type_utils.are_equivalent_types(value.type_signature, type_spec): raise TypeError( 'Expected a value of type {}, found {}.'.format( str(type_spec), str(value.type_signature))) else: type_spec = value.type_signature if type_spec is None: raise ValueError( 'Cannot derive an eager representation for a value of an unknown type.' ) if isinstance(value, EagerValue): return value.internal_representation if isinstance(value, executor_value_base.ExecutorValue): raise TypeError( 'Cannot accept a value embedded within a non-eager executor.') if isinstance(value, computation_base.Computation): return to_representation_for_type( computation_impl.ComputationImpl.get_proto(value), type_spec, device) if isinstance(value, pb.Computation): return embed_tensorflow_computation(value, type_spec, device) if isinstance(type_spec, computation_types.TensorType): if not isinstance(value, tf.Tensor): value = tf.constant(value, type_spec.dtype, type_spec.shape) value_type = (computation_types.TensorType(value.dtype.base_dtype, value.shape)) if not type_utils.is_assignable_from(type_spec, value_type): raise TypeError( 'The apparent type {} of a tensor {} does not match the expected ' 'type {}.'.format(str(value_type), str(value), str(type_spec))) return value elif isinstance(type_spec, computation_types.NamedTupleType): type_elem = anonymous_tuple.to_elements(type_spec) value_elem = (anonymous_tuple.to_elements( anonymous_tuple.from_container(value))) result_elem = [] if len(type_elem) != len(value_elem): raise TypeError( 'Expected a {}-element tuple, found {} elements.'.format( str(len(type_elem)), str(len(value_elem)))) for (t_name, el_type), (v_name, el_val) in zip(type_elem, value_elem): if t_name != v_name: raise TypeError( 'Mismatching element names in type vs. value: {} vs. {}.'. format(t_name, v_name)) el_repr = to_representation_for_type(el_val, el_type, device) result_elem.append((t_name, el_repr)) return anonymous_tuple.AnonymousTuple(result_elem) elif isinstance(type_spec, computation_types.SequenceType): py_typecheck.check_type(value, (tf.data.Dataset, tf.compat.v1.data.Dataset, tf.compat.v2.data.Dataset)) element_type = type_utils.tf_dtypes_and_shapes_to_type( tf.compat.v1.data.get_output_types(value), tf.compat.v1.data.get_output_shapes(value)) value_type = computation_types.SequenceType(element_type) if not type_utils.are_equivalent_types(value_type, type_spec): raise TypeError('Expected a value of type {}, found {}.'.format( str(type_spec), str(value_type))) return value else: raise TypeError('Unexpected type {}.'.format(str(type_spec)))
def embed_tensorflow_computation(comp, type_spec=None, device=None): """Embeds a TensorFlow computation for use in the eager context. Args: comp: An instance of `pb.Computation`. type_spec: An optional `tff.Type` instance or something convertible to it. device: An optional device name. Returns: Either a one-argument or a zero-argument callable that executes the computation in eager mode. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py` # since it deals exclusively with eager mode. Incubate here, and potentially # move there, once stable. py_typecheck.check_type(comp, pb.Computation) comp_type = type_serialization.deserialize_type(comp.type) type_spec = computation_types.to_type(type_spec) if type_spec is not None: if not type_utils.are_equivalent_types(type_spec, comp_type): raise TypeError( 'Expected a computation of type {}, got {}.'.format( type_spec, comp_type)) else: type_spec = comp_type which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': raise TypeError('Expected a TensorFlow computation, found {}.'.format( which_computation)) if isinstance(type_spec, computation_types.FunctionType): param_type = type_spec.parameter result_type = type_spec.result else: param_type = None result_type = type_spec if param_type is not None: input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding( comp.tensorflow.parameter) else: input_tensor_names = [] output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding( comp.tensorflow.result) def function_to_wrap(*args): # pylint: disable=missing-docstring if len(args) != len(input_tensor_names): raise RuntimeError('Expected {} arguments, found {}.'.format( len(input_tensor_names), len(args))) graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) init_op = comp.tensorflow.initialize_op if init_op: graph_def = tensorflow_utils.add_control_deps_for_init_op( graph_def, init_op) def _import_fn(): return tf.import_graph_def( graph_merge.uniquify_shared_names(graph_def), input_map=dict(list(zip(input_tensor_names, args))), return_elements=output_tensor_names) if device is not None: with tf.device(device): return _import_fn() else: return _import_fn() signature = [] param_fns = [] if param_type is not None: for spec in anonymous_tuple.flatten(type_spec.parameter): if isinstance(spec, computation_types.TensorType): signature.append(tf.TensorSpec(spec.shape, spec.dtype)) param_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) signature.append(tf.TensorSpec([], tf.variant)) param_fns.append(tf.data.experimental.to_variant) wrapped_fn = tf.compat.v1.wrap_function(function_to_wrap, signature) result_fns = [] for spec in anonymous_tuple.flatten(result_type): if isinstance(spec, computation_types.TensorType): result_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) structure = type_utils.type_to_tf_structure(spec.element) def fn(x, structure=structure): return tf.data.experimental.from_variant(x, structure) result_fns.append(fn) def _fn_to_return(arg, param_fns, wrapped_fn): # pylint:disable=missing-docstring param_elements = [] if arg is not None: arg_parts = anonymous_tuple.flatten(arg) if len(arg_parts) != len(param_fns): raise RuntimeError('Expected {} arguments, found {}.'.format( len(param_fns), len(arg_parts))) for arg_part, param_fn in zip(arg_parts, param_fns): param_elements.append(param_fn(arg_part)) result_parts = wrapped_fn(*param_elements) # There is a tf.wrap_function(...) issue b/144127474 that variables created # from tf.import_graph_def(...) inside tf.wrap_function(...) is not # destroyed. So get all the variables from `wrapped_fn` and destroy # manually. # TODO(b/144127474): Remove this manual cleanup once tf.wrap_function(...) # is fixed. resources = [] for op in wrapped_fn.graph.get_operations(): if op.type == 'VarHandleOp': resources += op.outputs if resources: for resource in wrapped_fn.prune(feeds={}, fetches=resources)(): tf.raw_ops.DestroyResourceOp(resource=resource) result_elements = [] for result_part, result_fn in zip(result_parts, result_fns): result_elements.append(result_fn(result_part)) return anonymous_tuple.pack_sequence_as(result_type, result_elements) fn_to_return = lambda arg, p=param_fns, w=wrapped_fn: _fn_to_return( arg, p, w) if device is not None: old_fn_to_return = fn_to_return # pylint: disable=function-redefined def fn_to_return(x): with tf.device(device): return old_fn_to_return(x) # pylint: enable=function-redefined if param_type is not None: return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda else: return lambda: fn_to_return(None)
def _wrap(fn, parameter_type, wrapper_fn): """Wraps a possibly-polymorphic `fn` in `wrapper_fn`. If `parameter_type` is `None` and `fn` takes any arguments (even with default values), `fn` is inferred to be polymorphic and won't be passed to `wrapper_fn` until invocation time (when concrete parameter types are available). `wrapper_fn` must accept three positional arguments and one defaulted argument `name`: * `target_fn`, the Python function to be wrapped. * `parameter_type`, the optional type of the computation's parameter (an instance of `computation_types.Type`). * `unpack`, an argument which will be passed on to `function_utils.wrap_as_zero_or_one_arg_callable` when wrapping `target_fn`. See that function for details. * Optional `name`, the name of the function that is being wrapped (only for debugging purposes). Args: fn: The function or defun to wrap as a computation. parameter_type: Optional type of any arguments to `fn`. wrapper_fn: The Python callable that performs actual wrapping. The object to be returned by this function should be an instance of a `ConcreteFunction`. Returns: Either the result of wrapping (an object that represents the computation), or a polymorphic callable that performs wrapping upon invocation based on argument types. The returned function still may accept multiple arguments (it has not yet had `function_uils.wrap_as_zero_or_one_arg_callable` applied to it). Raises: TypeError: if the arguments are of the wrong types, or the `wrapper_fn` constructs something that isn't a ConcreteFunction. """ try: fn_name = fn.__name__ except AttributeError: fn_name = None signature = function_utils.get_signature(fn) parameter_type = computation_types.to_type(parameter_type) if parameter_type is None and signature.parameters: # There is no TFF type specification, and the function/defun declares # parameters. Create a polymorphic template. def _wrap_polymorphic(wrapper_fn, fn, parameter_type, name=fn_name): return wrapper_fn(fn, parameter_type, unpack=True, name=name) polymorphic_fn = function_utils.PolymorphicFunction( lambda pt: _wrap_polymorphic(wrapper_fn, fn, pt)) # When applying a decorator, the __doc__ attribute with the documentation # in triple-quotes is not automatically transferred from the function on # which it was applied to the wrapped object, so we must transfer it here # explicitly. polymorphic_fn.__doc__ = getattr(fn, '__doc__', None) return polymorphic_fn # Either we have a concrete parameter type, or this is no-arg function. concrete_fn = wrapper_fn(fn, parameter_type, unpack=None) py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction, 'value returned by the wrapper') if not type_utils.are_equivalent_types( concrete_fn.type_signature.parameter, parameter_type): raise TypeError( 'Expected a concrete function that takes parameter {}, got one ' 'that takes {}.'.format(str(parameter_type), str(concrete_fn.type_signature.parameter))) # When applying a decorator, the __doc__ attribute with the documentation # in triple-quotes is not automatically transferred from the function on concrete_fn.__doc__ = getattr(fn, '__doc__', None) return concrete_fn