def _create_new_tf_function(func_graph): """Converts func_graph to a TF_Function and adds it to the current graph. Args: func_graph: function._FuncGraph Returns: The name of the new TF_Function. """ c_func = c_api.TF_GraphToFunction_wrapper( func_graph._c_graph, compat.as_str(func_graph.name), False, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in func_graph.inputs], [t._as_tf_output() for t in func_graph.outputs], [], None, # opts None) # description _ = c_api_util.ScopedTFFunction(c_func) # TODO(b/109833212): this sucks, we're serializing the TF_Function*, # deserializing it into a Python FunctionDef, then reserializing it to create # a new TF_Function that we add to the graph. fdef = _function.function_def_from_tf_function(c_func) defined_func = _function._from_definition(fdef) defined_func._sub_functions = func_graph._functions defined_func.add_to_graph(func_graph._outer_graph) return func_graph.name
def make_function_def(name, graph, operations, inputs, outputs): """Makes FunctionDef proto and defined function. Args: name: the function name graph: the graph from which to build the function operations: the operations in the function body inputs: tensors to be used as function arguments outputs: tensors to be returned from the function Returns: fdef: a FunctionDef protocol buffer for the function fn: a wrapped TF_Function for the function """ with errors.raise_exception_on_not_ok_status() as status: fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str(""), status) # TODO(apassos) avoid creating a FunctionDef (specially to grab the signature, # but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) fdef = function_pb2.FunctionDef() fdef.ParseFromString(compat.as_bytes(proto_data)) return fdef, fn
def __init__(self, name, graph, operations, inputs, outputs, attrs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function attrs: dict mapping names of attributes to their AttrValue values """ fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str("")) for name, attr_value in attrs.items(): serialized = attr_value.SerializeToString() # TODO(iga): this creates and deletes a new TF_Status for every attr. # It might be worth creating a convenient way to re-use status. pywrap_tensorflow.TF_FunctionSetAttrValueProto( fn, compat.as_str(name), serialized) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.executing_eagerly(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) self._grad_func = None
def __init__(self, name, graph, operations, inputs, outputs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function """ with errors.raise_exception_on_not_ok_status() as status: fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str(""), status) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.in_eager_mode(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = fn self._grad_func = None
def _create_new_tf_function(func_graph): """Converts func_graph to a TF_Function and adds it to the current graph. Args: func_graph: function._FuncGraph Returns: The name of the new TF_Function. """ func_graph.name = "%s_" % func_graph.name c_func = c_api.TF_GraphToFunction_wrapper( func_graph._c_graph, func_graph.name, False, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in func_graph.inputs], [t._as_tf_output() for t in func_graph.outputs], [], None, # opts None) # description c_func = c_api_util.ScopedTFFunction(c_func) c_api.TF_GraphCopyFunction(ops.get_default_graph()._c_graph, c_func.func, None) return func_graph.name
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None: return # Create the func_def object. temp_graph = _FuncGraph(capture_by_value=self._capture_by_value) with temp_graph.as_default(): # List of placeholders for the function_def. inputs = [] for (argname, argtype) in self._args: argholder = array_ops.placeholder(argtype, name=argname) inputs.append(argholder) # Call func and gather the output tensors. with vs.variable_scope("", custom_getter=temp_graph.getvar): outputs = self._func(*inputs) # If func only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) if any([_ is None for _ in outputs]): raise ValueError("Function can not return None.") # Ensures each output is a Tensor. outputs = [ops.convert_to_tensor(_) for _ in outputs] self._extra_inputs = temp_graph.extra_inputs inputs.extend(temp_graph.extra_args) # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), inputs, outputs, out_names=self._out_names) # Extra kwargs are treated as attrs on the function def. sig_pre_func_name = self._func_name or _get_func_name(self._func) kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name, **self._extra_kwargs) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([_get_func_name(self._func), self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ # pylint: disable=protected-access if temp_graph._c_graph: output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None with errors.raise_exception_on_not_ok_status() as status: self._c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, self._func_name, False, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in inputs], [t._as_tf_output() for t in outputs], output_names, None, # opts description, status) self._set_c_attrs(kwargs_attr)
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None or self._c_func is not None: return temp_graph = func_graph_from_py_func( self._func, self._arg_names, self._arg_types, self._func_name, self._capture_by_value, self._caller_device, whitelisted_stateful_ops=self._whitelisted_stateful_ops, capture_resource_var_by_value=self._capture_resource_var_by_value) self._extra_inputs = temp_graph.extra_inputs # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Extra kwargs are treated as attrs on the function def. if self._func_name: base_func_name = self._func_name else: base_func_name = function_utils.get_func_name(self._func) if self._grad_func: base_func_name += ("_%s" % self._grad_func.name) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) if not temp_graph._c_graph: # pylint: disable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), temp_graph.inputs, temp_graph.outputs, out_names=self._out_names) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([base_func_name, self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ self._op_def = self._definition.signature else: # C API is enabled output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in temp_graph.inputs], [t._as_tf_output() for t in temp_graph.outputs], output_names, [], # control_outputs [], # control_output_names None, # opts description) self._c_func = c_api_util.ScopedTFFunction(c_func) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) # Set cached fields: _op_def and _func_name (if not already set) self._op_def = self.definition.signature if self._func_name: assert self._func_name == self._op_def.name else: self._func_name = compat.as_str(self._op_def.name) self._stateful_ops = [(op.name, op.type) for op in temp_graph.get_operations() if op.op_def.is_stateful]
def _create_definition_if_needed_impl(self): """This is not what you want, see _create_definition_if_needed.""" if self._definition is not None or self._c_func is not None: return # Create the func_def object. temp_graph = _FuncGraph(capture_by_value=self._capture_by_value) with temp_graph.as_default(): # List of placeholders for the function_def. inputs = [] for (argname, argtype) in self._args: argholder = array_ops.placeholder(argtype, name=argname) inputs.append(argholder) # Call func and gather the output tensors. with vs.variable_scope("", custom_getter=temp_graph.getvar): outputs = self._func(*inputs) # There is no way of distinguishing between a function not returning # anything and a function returning None in Python. # We need to allow the former and ideally want to forbid the latter as # it is most likely user error. # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to # allow users to explicitly mark the function as not returning anything. # For now, we allow a single None return and interpret it as a function # with no output. if outputs is None: outputs = [] else: # If func only returned one value, make it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) if any([_ is None for _ in outputs]): raise ValueError("Function can not return None.") # Ensures each output is a Tensor. outputs = [ops.convert_to_tensor(_) for _ in outputs] self._extra_inputs = temp_graph.extra_inputs inputs.extend(temp_graph.extra_args) # pylint: disable=protected-access self._sub_functions = temp_graph._functions # pylint: enable=protected-access # Extra kwargs are treated as attrs on the function def. base_func_name = self._func_name or _get_func_name(self._func) kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) if not temp_graph._c_graph: # pylint: disable=protected-access # Build the FunctionDef self._definition = graph_to_function_def.graph_to_function_def( temp_graph, temp_graph.get_operations(), inputs, outputs, out_names=self._out_names) for k in kwargs_attr: self._definition.attr[k].CopyFrom(kwargs_attr[k]) # Hash the definition and its dependencies. self._hash_str = self._create_hash_str( self._definition.signature.input_arg, self._definition.signature.output_arg, self._definition.node_def) # Finally, we decide the function name to use. If not specified, # make up something which is almost certainly unique (but deterministic). if not self._func_name: self._func_name = "_".join([base_func_name, self._hash_str]) self._definition.signature.name = self._func_name if self._func.__doc__: self._definition.signature.description = self._func.__doc__ self._op_def = self._definition.signature else: # C API is enabled output_names = ([compat.as_bytes(x) for x in self._out_names] if self._out_names else []) description = self._func.__doc__ or None # pylint: disable=protected-access with errors.raise_exception_on_not_ok_status() as status: self._c_func = c_api.TF_GraphToFunction_wrapper( temp_graph._c_graph, base_func_name, self._func_name is None, # append_hash_to_fn_name None, # opers [t._as_tf_output() for t in inputs], [t._as_tf_output() for t in outputs], output_names, None, # opts description, status) # pylint: enable=protected-access self._set_c_attrs(kwargs_attr) # Set cached fields: _op_def and _func_name (if not already set) self._op_def = self.definition.signature if self._func_name: assert self._func_name == self._op_def.name else: self._func_name = compat.as_str(self._op_def.name)