def _from_definition(fdef, grad_func=None): """Creates a _DefinedFunction initialized from a FunctionDef proto. Args: fdef: a FunctionDef grad_func: a _DefinedFunction or None Returns: A _DefinedFunction representing fdef """ # TODO(iga): This method does major surgery on _DefinedFunction. # Make it a named constructor using @classmethod of _DefinedFunction. # The Python callable is only needed to create a FunctionDef. Since we have # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we # have access to such a callable here). func = None argnames = [arg.name for arg in fdef.signature.input_arg] input_types = tuple( dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg) func_name = fdef.signature.name # Note: FunctionDefs do not include python gradient functions, so if the # original _DefinedFunction included one it will not be reflected here. python_grad_func = None out_names = [arg.name for arg in fdef.signature.output_arg] result = _DefinedFunction(func, argnames, input_types, func_name, grad_func, python_grad_func, out_names) # pylint: disable=protected-access serialized = fdef.SerializeToString() c_func = c_api.TF_FunctionImportFunctionDef(serialized) result._c_func = c_api_util.ScopedTFFunction(c_func) result._extra_inputs = [] # pylint: enable=protected-access return result
def _create_new_tf_function(func_graph): """Converts func_graph to a TF_Function and adds it to the current graph. Args: func_graph: function._FuncGraph Returns: The name of the new TF_Function. """ c_func = c_api.TF_GraphToFunction_wrapper( func_graph._c_graph, compat.as_str(func_graph.name), False, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in func_graph.inputs], [t._as_tf_output() for t in func_graph.outputs], [], None, # opts None) # description _ = c_api_util.ScopedTFFunction(c_func) # TODO(b/109833212): this sucks, we're serializing the TF_Function*, # deserializing it into a Python FunctionDef, then reserializing it to create # a new TF_Function that we add to the graph. fdef = _function.function_def_from_tf_function(c_func) defined_func = _function._from_definition(fdef) defined_func._sub_functions = func_graph._functions defined_func.add_to_graph(func_graph._outer_graph) return func_graph.name
def __init__(self, name, graph, operations, inputs, outputs, attrs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function attrs: dict mapping names of attributes to their AttrValue values """ fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str("")) for name, attr_value in attrs.items(): serialized = attr_value.SerializeToString() # TODO(iga): this creates and deletes a new TF_Status for every attr. # It might be worth creating a convenient way to re-use status. pywrap_tensorflow.TF_FunctionSetAttrValueProto( fn, compat.as_str(name), serialized) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.executing_eagerly(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) self._grad_func = None
def __init__(self, name, graph, operations, inputs, outputs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function """ fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str("")) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.executing_eagerly(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) self._grad_func = None
def _create_new_tf_function(func_graph): """Converts func_graph to a TF_Function and adds it to the current graph. Args: func_graph: function._FuncGraph Returns: The name of the new TF_Function. """ func_graph.name = "%s_" % func_graph.name c_func = c_api.TF_GraphToFunction_wrapper( func_graph._c_graph, func_graph.name, False, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in func_graph.inputs], [t._as_tf_output() for t in func_graph.outputs], [], None, # opts None) # description c_func = c_api_util.ScopedTFFunction(c_func) c_api.TF_GraphCopyFunction(ops.get_default_graph()._c_graph, c_func.func, None) return func_graph.name
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None or self._c_func is not None: return # Copy variable collections (by reference) from the parent graph such that # name based variable sharing (e.g. via tf.make_template) works between the # func graph and parent graph. variable_keys = [] variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS) # pylint: disable=protected-access variable_keys.append(vs._VARSTORE_KEY) # pylint: disable=protected-access parent_graph = ops.get_default_graph() collections_ref = { key: parent_graph.get_collection_ref(key) for key in variable_keys} temp_graph = func_graph_from_py_func( self._func, self._arg_names, self._arg_types, self._func_name, self._capture_by_value, self._caller_device, collections_ref=collections_ref, allowlisted_stateful_ops=self._allowlisted_stateful_ops, capture_resource_var_by_value=self._capture_resource_var_by_value) self._extra_inputs = temp_graph.extra_inputs # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Extra kwargs are treated as attrs on the function def. if self._func_name: base_func_name = self._func_name else: base_func_name = function_utils.get_func_name(self._func) if self._grad_func: base_func_name += ("_%s" % self._grad_func.name) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) if not temp_graph._c_graph: # pylint: disable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), temp_graph.inputs, temp_graph.outputs, out_names=self._out_names) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([base_func_name, self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ self._op_def = self._definition.signature else: # C API is enabled output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in temp_graph.inputs], [t._as_tf_output() for t in temp_graph.outputs], output_names, [], # control_outputs [], # control_output_names None, # opts description) self._c_func = c_api_util.ScopedTFFunction(c_func) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) # Set cached fields: _op_def and _func_name (if not already set) self._op_def = self.definition.signature if self._func_name: assert self._func_name == self._op_def.name else: self._func_name = compat.as_str(self._op_def.name) self._stateful_ops = [(op.name, op.type) for op in temp_graph.get_operations() if op._is_stateful] # pylint: disable=protected-access
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None or self._c_func is not None: return temp_graph = func_graph_from_py_func( self._func, self._arg_names, self._arg_types, self._func_name, self._capture_by_value, self._caller_device, whitelisted_stateful_ops=self._whitelisted_stateful_ops) self._extra_inputs = temp_graph.extra_inputs # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Extra kwargs are treated as attrs on the function def. if self._func_name: base_func_name = self._func_name else: base_func_name = function_utils.get_func_name(self._func) if self._grad_func: base_func_name += ("_%s" % self._grad_func.name) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) if not temp_graph._c_graph: # pylint: disable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), temp_graph.inputs, temp_graph.outputs, out_names=self._out_names) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([base_func_name, self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ self._op_def = self._definition.signature else: # C API is enabled output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in temp_graph.inputs], [t._as_tf_output() for t in temp_graph.outputs], output_names, None, # opts description) self._c_func = c_api_util.ScopedTFFunction(c_func) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) # Set cached fields: _op_def and _func_name (if not already set) self._op_def = self.definition.signature if self._func_name: assert self._func_name == self._op_def.name else: self._func_name = compat.as_str(self._op_def.name) self._stateful_ops = [(op.name, op.type) for op in temp_graph.get_operations() if op.op_def.is_stateful]
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None or self._c_func is not None: return # Create the func_def object. temp_graph = _FuncGraph(capture_by_value=self._capture_by_value) with temp_graph.as_default(): # List of placeholders for the function_def. inputs = [] for (argname, argtype) in self._args: argholder = array_ops.placeholder(argtype, name=argname) inputs.append(argholder) # Call func and gather the output tensors. with vs.variable_scope("", custom_getter=temp_graph.getvar): outputs = self._func(*inputs) # There is no way of distinguishing between a function not returning # anything and a function returning None in Python. # We need to allow the former and ideally want to forbid the latter as # it is most likely user error. # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to # allow users to explicitly mark the function as not returning anything. # For now, we allow a single None return and interpret it as a function # with no output. if outputs is None: outputs = [] else: # If func only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs, ) if any([_ is None for _ in outputs]): raise ValueError("Function can not return None.") # Ensures each output is a Tensor in the function graph. outputs = [ops.convert_to_tensor(t) for t in outputs] outputs = [ temp_graph.capture(t) if t.graph is not temp_graph else t for t in outputs ] self._extra_inputs = temp_graph.extra_inputs inputs.extend(temp_graph.extra_args) # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Extra kwargs are treated as attrs on the function def. if self._func_name: base_func_name = self._func_name else: base_func_name = _get_func_name(self._func) if self._grad_func: base_func_name += ("_%s" % self._grad_func.name) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) if not temp_graph._c_graph: # pylint: disable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), inputs, outputs, out_names=self._out_names) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([base_func_name, self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ self._op_def = self._definition.signature else: # C API is enabled output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in inputs], [t._as_tf_output() for t in outputs], output_names, None, # opts description) self._c_func = c_api_util.ScopedTFFunction(c_func) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) # Set cached fields: _op_def and _func_name (if not already set) self._op_def = self.definition.signature if self._func_name: assert self._func_name == self._op_def.name else: self._func_name = compat.as_str(self._op_def.name)