def _checked_stamp_parameter(self, name, spec, graph=None): """Returns object stamped in the graph after verifying its bindings.""" if graph is None: graph = tf.get_default_graph() val, binding = graph_utils.stamp_parameter_in_graph(name, spec, graph) self._assert_binding_matches_type_and_value( binding, computation_types.to_type(spec), val, graph) return val
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack): """Serializes the 'target' as a TF computation with a given parameter type. See also `serialize_tf2_as_tf_computation` for TensorFlow 2 serialization. Args: target: The entity to convert into and serialize as a TF computation. This can currently only be a Python function. In the future, we will add here support for serializing the various kinds of non-eager and eager functions, and eventually aim at full support for and compliance with TF 2.0. This function is currently required to declare either zero parameters if `parameter_type` is `None`, or exactly one parameter if it's not `None`. The nested structure of this parameter must correspond to the structure of the 'parameter_type'. In the future, we may support targets with multiple args/keyword args (to be documented in the API and referenced from here). parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `types.Type`, or something that's convertible to it by `types.to_type()`. context_stack: The context stack to use. Returns: A tuple of (`pb.Computation`, `tff.Type`), where the computation contains the instance with the `pb.TensorFlow` variant set, and the type is an instance of `tff.Type`, potentially including Python container annotations, for use by TensorFlow computation wrappers. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ # TODO(b/113112108): Support a greater variety of target type signatures, # with keyword args or multiple args corresponding to elements of a tuple. # Document all accepted forms with examples in the API, and point to there # from here. py_typecheck.check_type(target, types.FunctionType) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) parameter_type = computation_types.to_type(parameter_type) argspec = inspect.getargspec(target) # pylint: disable=deprecated-method with tf.Graph().as_default() as graph: args = [] if parameter_type is not None: if len(argspec.args) != 1: raise ValueError( 'Expected the target to declare exactly one parameter, ' 'found {}.'.format(repr(argspec.args))) parameter_name = argspec.args[0] parameter_value, parameter_binding = graph_utils.stamp_parameter_in_graph( parameter_name, parameter_type, graph) args.append(parameter_value) else: if argspec.args: raise ValueError( 'Expected the target to declare no parameters, found {}.'. format(repr(argspec.args))) parameter_binding = None context = tf_computation_context.TensorFlowComputationContext(graph) with context_stack.install(context): result = target(*args) # TODO(b/122081673): This needs to change for TF 2.0. We may also # want to allow the person creating a tff.tf_computation to specify # a different initializer; e.g., if it is known that certain # variables will be assigned immediately to arguments of the function, # then it is wasteful to initialize them before this. # # The following is a bit of a work around: the collections below may # contain variables more than once, hence we throw into a set. TFF needs # to ensure all variables are initialized, but not all variables are # always in the collections we expect. tff.learning._KerasModel tries to # pull Keras variables (that may or may not be in GLOBAL_VARIABLES) into # TFF_MODEL_VARIABLES for now. all_variables = set( tf.compat.v1.global_variables() + tf.compat.v1.local_variables() + tf.compat.v1.get_collection( graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE)) if all_variables: # Use a readable but not-too-long name for the init_op. name = 'init_op_for_' + '_'.join( [v.name.replace(':0', '') for v in all_variables]) if len(name) > 50: name = 'init_op_for_{}_variables'.format( len(all_variables)) with tf.control_dependencies(context.init_ops): # Before running the main new init op, run any initializers for sub- # computations from context.init_ops. Variables from import_graph_def # will not make it into the global collections, and so will not be # initialized without this code path. init_op_name = tf.compat.v1.initializers.variables( all_variables, name=name).name elif context.init_ops: init_op_name = tf.group(*context.init_ops, name='subcomputation_init_ops').name else: init_op_name = None result_type, result_binding = graph_utils.capture_result_from_graph( result, graph) annotated_type = computation_types.FunctionType(parameter_type, result_type) return pb.Computation(type=pb.Type(function=pb.FunctionType( parameter=type_serialization.serialize_type(parameter_type), result=type_serialization.serialize_type(result_type))), tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding, initialize_op=init_op_name)), annotated_type
def pad_graph_inputs_to_match_type(comp, type_signature): r"""Pads the parameter bindings of `comp` to match `type_signature`. The padded parameters here are in effect dummy bindings--they are not plugged in elsewhere in `comp`. This pattern is necessary to transform TFF expressions of the form: Lambda(arg) | Call / \ CompiledComputation Tuple | Selection[i] | Ref(arg) into the form: CompiledComputation in the case where arg in the above picture represents an n-tuple, where n > 1. Notice that some type manipulation must take place to execute the transformation outlined above, or anything similar to it, since the Lambda we are looking to replace accepts a parameter of an n-tuple, whereas the `CompiledComputation` represented above accepts only a 1-tuple. `pad_graph_inputs_to_match_type` is intended as an intermediate transform in the transformation outlined above, since there may also need to be some parameter permutation via `permute_graph_inputs`. Notice also that the existing parameter bindings of `comp` must match the first elements of `type_signature`. This is to ensure that we are attempting to pad only compatible `CompiledComputation`s to a given type signature. Args: comp: Instance of `computation_building_blocks.CompiledComputation` representing the graph whose inputs we want to pad to match `type_signature`. type_signature: Instance of `computation_types.NamedTupleType` representing the type signature we wish to pad `comp` to accept as a parameter. Returns: A transformed version of `comp`, instance of `computation_building_blocks.CompiledComputation` which takes an argument of type `type_signature` and executes the same logic as `comp`. In particular, this transformed version will have the same return type as the original `comp`. Raises: TypeError: If the proto underlying `comp` has a parameter type which is not of `NamedTupleType`, the `type_signature` argument is not of type `NamedTupleType`, or there is a type mismatch between the declared parameters of `comp` and the requested `type_signature`. ValueError: If the requested `type_signature` is shorter than the parameter type signature declared by `comp`. """ py_typecheck.check_type(type_signature, computation_types.NamedTupleType) py_typecheck.check_type(comp, computation_building_blocks.CompiledComputation) proto = comp.proto graph_def = proto.tensorflow.graph_def graph_parameter_binding = proto.tensorflow.parameter proto_type = type_serialization.deserialize_type(proto.type) binding_oneof = graph_parameter_binding.WhichOneof('binding') if binding_oneof != 'tuple': raise TypeError( 'Can only pad inputs of a CompiledComputation with parameter type ' 'tuple; you have attempted to pad a CompiledComputation ' 'with parameter type {}'.format(binding_oneof)) # This line provides protection against an improperly serialized proto py_typecheck.check_type(proto_type.parameter, computation_types.NamedTupleType) parameter_bindings = [x for x in graph_parameter_binding.tuple.element] parameter_type_elements = anonymous_tuple.to_elements(proto_type.parameter) type_signature_elements = anonymous_tuple.to_elements(type_signature) if len(parameter_bindings) > len(type_signature): raise ValueError( 'We can only pad graph input bindings, never mask them. ' 'This means that a proposed type signature passed to ' '`pad_graph_inputs_to_match_type` must have more elements ' 'than the existing type signature of the compiled ' 'computation. You have proposed a type signature of ' 'length {} be assigned to a computation with parameter ' 'type signature of length {}.'.format(len(type_signature), len(parameter_bindings))) if any(x != type_signature_elements[idx] for idx, x in enumerate(parameter_type_elements)): raise TypeError( 'The existing elements of the parameter type signature ' 'of the compiled computation in `pad_graph_inputs_to_match_type` ' 'must match the beginning of the proposed new type signature; ' 'you have proposed a parameter type of {} for a computation ' 'with existing parameter type {}.'.format(type_signature, proto_type.parameter)) g = tf.Graph() with g.as_default(): tf.graph_util.import_graph_def( serialization_utils.unpack_graph_def(graph_def), name='') elems_to_stamp = anonymous_tuple.to_elements( type_signature)[len(parameter_bindings):] for name, type_spec in elems_to_stamp: if name is None: stamp_name = 'name' else: stamp_name = name _, stamped_binding = graph_utils.stamp_parameter_in_graph( stamp_name, type_spec, g) parameter_bindings.append(stamped_binding) parameter_type_elements.append((name, type_spec)) new_parameter_binding = pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding(element=parameter_bindings)) new_graph_def = g.as_graph_def() new_function_type = computation_types.FunctionType(parameter_type_elements, proto_type.result) serialized_type = type_serialization.serialize_type(new_function_type) input_padded_proto = pb.Computation( type=serialized_type, tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def(new_graph_def), initialize_op=proto.tensorflow.initialize_op, parameter=new_parameter_binding, result=proto.tensorflow.result)) return computation_building_blocks.CompiledComputation(input_padded_proto)