def graph_to_function_def(graph, name, inputs, outputs): """Returns `graph` as a `FunctionDef` protocol buffer. This method creates a [`FunctionDef`]( https://www.tensorflow.org/code/tensorflow/core/framework/function.proto) protocol buffer that contains all the ops present in the graph. The graph effectively becomes the body of the function. The arguments `inputs` and `outputs` will be listed as the inputs and outputs tensors of the function. They must be lists of tensors present in the graph. The lists can optionally be empty. The returned protocol buffer can be passed to the [`Graph.add_function()`](#Graph.add_function) method of a different graph to make it available there. Args: graph: Graph. name: string. The name to use for the function. inputs: List of tensors. Inputs to the function. outputs: List of tensors. Outputs of the function. Returns: A FunctionDef protocol buffer. """ func = function_pb2.FunctionDef() func.signature.name = name func.signature.input_arg.extend([_tensor_to_argdef(i) for i in inputs]) func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs]) func_arg_placeholders = set([i.name for i in inputs]) for op in graph.get_operations(): tensor_name = op.values()[0].name if tensor_name not in func_arg_placeholders: _add_op_node(graph, op, func) return func
def testGraphDefsWithPermutedNodesInFunctionsCompareEqual(self): @function.Defun(dtypes.float32) def F1(x): return math_ops.exp(x) - math_ops.exp(-x) f1_def = F1.definition library = function_pb2.FunctionDefLibrary() library.function.extend([f1_def]) graph_def1 = graph_pb2.GraphDef() graph_def1.library.CopyFrom(library) reversed_function = function_pb2.FunctionDef() reversed_function.CopyFrom(f1_def) # Clear the node_def attribute. del reversed_function.node_def[:] reversed_function.node_def.extend(reversed(f1_def.node_def)) reversed_library = function_pb2.FunctionDefLibrary() reversed_library.function.extend([reversed_function]) graph_def2 = graph_pb2.GraphDef() graph_def2.library.CopyFrom(reversed_library) self.assertTrue(graph_util.graph_defs_equal(graph_def1, graph_def2))
def _fix_fdef(orig_fdef, functions, shared_name_suffix): """Fixes a FunctionDef proto to be loaded in current context. In particular, when loading a function library into an eager context, one must rename the functions to avoid conflicts with existent functions. Args: orig_fdef: FunctionDef proto to fix. It is not modified. functions: map from function name to a ConcreteFunction instance. shared_name_suffix: A unique string for this load which helps to avoid `shared_name` collisions across loads. Two functions from the same load using the same `shared_name` still need to share, but functions from different loads with the same `shared_name` should not. Returns: A fixed copy of the original FunctionDef. """ fdef = function_pb2.FunctionDef() fdef.CopyFrom(orig_fdef) contains_custom_gradients = False for node_def in fdef.node_def: fix_node_def(node_def, functions, shared_name_suffix) if not contains_custom_gradients: contains_custom_gradients = _check_op_has_custom_gradients( node_def) if contains_custom_gradients: logging.warning( "Importing a function (%s) with ops with custom gradients. Will likely " "fail if a gradient is requested.", fdef.signature.name) fdef.signature.name = _clean_function_name(fdef.signature.name) return fdef
def _fix_fdef(orig_fdef, functions): """Fixes a FunctionDef proto to be loaded in current context. In particular, when loading a function library into an eager context, one must rename the functions to avoid conflicts with existent functions. Args: orig_fdef: FunctionDef proto to fix. It is not modified. functions: map from function name to a ConcreteFunction instance. Returns: A fixed copy of the original FunctionDef. """ fdef = function_pb2.FunctionDef() fdef.CopyFrom(orig_fdef) for node_def in fdef.node_def: if "_gradient_op_type" in node_def.attr: if node_def.op in ["StatefulPartitionedCall", "PartitionedCall"]: # TODO(andresp): This code assumes that the gradient registered for this # function call is the default gradient for the function and not a # custom one. fname = node_def.attr["f"].func.name node_def.attr["_gradient_op_type"].s = compat.as_bytes( functions[fname]._gradient_name) # pylint: disable=protected-access else: logging.warning("Importing a function (%s) with ops with custom " "gradients. Will likely fail if a gradient is " "requested.", fdef.signature.name) for _, attr_value in node_def.attr.items(): if attr_value.func.name: attr_value.func.name = functions[attr_value.func.name].name fdef.signature.name = _clean_function_name(fdef.signature.name) return fdef
def _graph_to_function_def(graph, inputs, outputs, out_names=None): """Returns `graph` as a `FunctionDef` protocol buffer. This method creates a [`FunctionDef`]( https://www.tensorflow.org/code/tensorflow/core/framework/function.proto) protocol buffer that contains all the ops present in the graph. The graph effectively becomes the body of the function. The arguments `inputs` and `outputs` will be listed as the inputs and outputs tensors of the function. They must be lists of tensors present in the graph. The lists can optionally be empty. Args: graph: Graph. inputs: List of tensors. Inputs to the function. outputs: List of tensors. Outputs of the function. out_names: Optional list of string names for the outputs. Returns: A FunctionDef protocol buffer. Raises: ValueError: if out_names is specified and the wrong length. """ func = function_pb2.FunctionDef() func.signature.name = "_" used_names = set() func.signature.input_arg.extend( [_tensor_to_argdef(i, used_names=used_names) for i in inputs]) if out_names is None: used_names = set() func.signature.output_arg.extend( [_tensor_to_argdef(o, used_names=used_names) for o in outputs]) elif len(outputs) != len(out_names): raise ValueError( "Length of out_names (%d) does not match number of outputs (%d): %s" % (len(out_names), len(outputs), ", ".join(out_names))) elif len(out_names) != len(set(out_names)): raise ValueError("Must not have duplicates in out_names: %s" % ", ".join(out_names)) else: func.signature.output_arg.extend( [_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)]) func_arg_placeholders = set([i.name for i in inputs]) input_dict = _create_input_dict(graph, func_arg_placeholders) for op in graph.get_operations(): if _is_in_placeholders(op, func_arg_placeholders): continue _add_op_node(op, func, input_dict) if out_names is None: for index, o in enumerate(outputs): k = func.signature.output_arg[index].name func.ret[k] = input_dict[o.name] else: for o, n in zip(outputs, out_names): func.ret[n] = input_dict[o.name] return func
def _graph_to_function_def(graph, inputs, outputs): """Returns `graph` as a `FunctionDef` protocol buffer. This method creates a [`FunctionDef`]( https://www.tensorflow.org/code/tensorflow/core/framework/function.proto) protocol buffer that contains all the ops present in the graph. The graph effectively becomes the body of the function. The arguments `inputs` and `outputs` will be listed as the inputs and outputs tensors of the function. They must be lists of tensors present in the graph. The lists can optionally be empty. Args: graph: Graph. inputs: List of tensors. Inputs to the function. outputs: List of tensors. Outputs of the function. Returns: A FunctionDef protocol buffer. """ func = function_pb2.FunctionDef() func.signature.name = "_" func.signature.input_arg.extend([_tensor_to_argdef(i) for i in inputs]) func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs]) func_arg_placeholders = set([i.name for i in inputs]) for op in graph.get_operations(): if op.values() and (op.values()[0].name in func_arg_placeholders): continue _add_op_node(op, func) return func
def _fix_fdef(orig_fdef, functions, shared_name_suffix): """Fixes a FunctionDef proto to be loaded in current context. In particular, when loading a function library into an eager context, one must rename the functions to avoid conflicts with existent functions. Args: orig_fdef: FunctionDef proto to fix. It is not modified. functions: map from function name to a ConcreteFunction instance. shared_name_suffix: A unique string for this load which helps to avoid `shared_name` collisions across loads. Two functions from the same load using the same `shared_name` still need to share, but functions from different loads with the same `shared_name` should not. Returns: A fixed copy of the original FunctionDef. """ fdef = function_pb2.FunctionDef() fdef.CopyFrom(orig_fdef) for node_def in fdef.node_def: fix_node_def(node_def, functions, shared_name_suffix, fdef.signature.name) fdef.signature.name = _clean_function_name(fdef.signature.name) return fdef
def testNoOutputs(self): with session_lib.Session() as sess: # Build a function with a single Const node, whose output is ignored. fdef = function_pb2.FunctionDef() fdef.signature.name = "KernelWithNoOutputs" node = node_def_pb2.NodeDef() node.op = "Const" node.name = "ignored" node.attr["dtype"].type = dtypes.int32.as_datatype_enum tensor = tensor_util.make_tensor_proto([0], dtype=dtypes.int32, shape=[]) node.attr["value"].tensor.CopyFrom(tensor) fdef.node_def.extend([node]) # Check that calling the result as a compiled kernel doesn't crash. @function.Defun(compiled=True) def KernelWithNoOutputs(): return constant_op.constant(100) # Hack to override the definition. By accessing .definition, we # force the _DefinedFunction initialized internally. Then, we # replace it's internal FunctionDef proto. We do this hack here # because one typically can't construct KernelWithNoOutputs # function via Defun decorator directly. _ = KernelWithNoOutputs.definition foo = KernelWithNoOutputs foo._definition = fdef call = KernelWithNoOutputs() sess.run(call, {})
def make_function_def(name, graph, operations, inputs, outputs): """Makes FunctionDef proto and defined function. Args: name: the function name graph: the graph from which to build the function operations: the operations in the function body inputs: tensors to be used as function arguments outputs: tensors to be returned from the function Returns: fdef: a FunctionDef protocol buffer for the function fn: a wrapped TF_Function for the function """ with errors.raise_exception_on_not_ok_status() as status: fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str(""), status) # TODO(apassos) avoid creating a FunctionDef (specially to grab the signature, # but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) fdef = function_pb2.FunctionDef() fdef.ParseFromString(compat.as_bytes(proto_data)) return fdef, fn
def test_create_nullary(self): fndef = text_format.Parse( """ signature { name: 'NullaryFunction' output_arg { name: 'o' type: DT_INT32 } } node_def { name: 'retval' op: 'Const' attr { key: 'dtype' value { type: DT_INT32 } } attr { key: 'value' value { tensor { dtype: DT_INT32 tensor_shape {} int_val: 1 } } } } ret { key: 'o' value: 'retval:output' } """, function_pb2.FunctionDef(), ) ctx = runtime_client.GlobalEagerContext() rt = runtime_client.Runtime(ctx) rt.CreateFunction(fndef)
def function_def_from_tf_function(c_func): """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto.""" with c_api_util.tf_buffer() as buf: c_api.TF_FunctionToFunctionDef(c_func, buf) data = c_api.TF_GetBuffer(buf) fdef = function_pb2.FunctionDef() fdef.ParseFromString(compat.as_bytes(data)) return fdef
def _fix_fdef(orig_fdef, name_map): fdef = function_pb2.FunctionDef() fdef.CopyFrom(orig_fdef) fdef.signature.name = _clean_function_name(fdef.signature.name) for node_def in fdef.node_def: for _, attr_value in node_def.attr.items(): if attr_value.func.name: attr_value.func.name = name_map[attr_value.func.name] return fdef
def definition(self): """Function definition proto.""" self._create_definition_if_needed() if self._c_func: with c_api_util.tf_buffer() as buf: c_api.TF_FunctionToFunctionDef(self._c_func.func, buf) fdef = function_pb2.FunctionDef() proto_data = c_api.TF_GetBuffer(buf) fdef.ParseFromString(compat.as_bytes(proto_data)) return fdef return self._definition
def to_function_graph_def(self, add_shapes=True): # type: (bool) -> function_pb2.FunctionDef """ Args: add_shapes: If True, add the special "_output_shapes" attribute with output shape information from this Node's output metadata. Returns the `function_pb2.FunctionDef` serialization of this function's graph in its current form. """ ret = function_pb2.FunctionDef() ret.CopyFrom(self._func_graph_def) # Leave signature as is, but replace all node_defs del ret.node_def[:] ret.signature.CopyFrom(self._func_graph_def.signature) input_args = [input_arg.name for input_arg in ret.signature.input_arg] for op in self.nodes: if op.op_type == _INPUT_DUMMY_OP_NAME: continue node_def = ret.node_def.add() op.to_node_def(node_def, add_shapes) unique_input_counter = Counter() for i in range(len(op.inputs)): (input_tensor_name, global_input_index_str) = (op.inputs[i].name.split(":")) global_input_index = int(global_input_index_str) if input_tensor_name in input_args: # don't add index for function args node_def.input[i] = input_tensor_name else: input_op_output_args, input_op_output_has_number_attr = ( self._get_op_def_denormalized_outputs(op.inputs[i].op)) if (len(input_op_output_args) == 1 and input_op_output_args[0].type_list_attr): node_def.input[i] = (input_tensor_name + ":" + input_op_output_args[0].name + ":" + str(global_input_index)) else: input_name = ( input_tensor_name + ":" + input_op_output_args[global_input_index].name) node_def.input[i] = ( input_name + ":" + str(unique_input_counter[input_name])) if input_op_output_has_number_attr: # only uniquify input args with var length, # otherwise it should be 0 unique_input_counter[input_name] += 1 return ret
def definition(self): """Function definition proto.""" self._create_definition_if_needed() if self._c_func: with c_api_util.tf_buffer() as buf: with errors.raise_exception_on_not_ok_status() as status: c_api.TF_FunctionToFunctionDef(self._c_func, buf, status) fdef = function_pb2.FunctionDef() proto_data = c_api.TF_GetBuffer(buf) fdef.ParseFromString(compat.as_bytes(proto_data)) return fdef return self._definition
def definition(self): """Function definition proto.""" self._create_definition_if_needed() if self._c_func: with c_api_util.tf_buffer() as buf: c_api.TF_FunctionToFunctionDef(self._c_func.func, buf) fdef = function_pb2.FunctionDef() proto_data = c_api.TF_GetBuffer(buf) fdef.ParseFromString(compat.as_bytes(proto_data)) with ops.init_scope(): if context.executing_eagerly(): context.add_function(self._c_func.func) self._function_deleter = _DefinedFunctionDeleter( fdef.signature.name) return fdef return self._definition
def __init__(self, name, graph, operations, inputs, outputs, attrs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function attrs: dict mapping names of attributes to their AttrValue values """ fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str("")) for name, attr_value in attrs.items(): serialized = attr_value.SerializeToString() # TODO(iga): this creates and deletes a new TF_Status for every attr. # It might be worth creating a convenient way to re-use status. pywrap_tensorflow.TF_FunctionSetAttrValueProto( fn, compat.as_str(name), serialized) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.executing_eagerly(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) self._grad_func = None
def _fix_fdef(orig_fdef, functions, shared_name_suffix): """Fixes a FunctionDef proto to be loaded in current context. In particular, when loading a function library into an eager context, one must rename the functions to avoid conflicts with existent functions. Args: orig_fdef: FunctionDef proto to fix. It is not modified. functions: map from function name to a ConcreteFunction instance. shared_name_suffix: A unique string for this load which helps to avoid `shared_name` collisions across loads. Two functions from the same load using the same `shared_name` still need to share, but functions from different loads with the same `shared_name` should not. Returns: A fixed copy of the original FunctionDef. """ fdef = function_pb2.FunctionDef() fdef.CopyFrom(orig_fdef) for node_def in fdef.node_def: if "_gradient_op_type" in node_def.attr: if node_def.op in ["StatefulPartitionedCall", "PartitionedCall"]: # TODO(andresp): This code assumes that the gradient registered for this # function call is the default gradient for the function and not a # custom one. fname = node_def.attr["f"].func.name node_def.attr["_gradient_op_type"].s = compat.as_bytes( functions[fname]._gradient_name) # pylint: disable=protected-access else: logging.warning( "Importing a function (%s) with ops with custom " "gradients. Will likely fail if a gradient is " "requested.", fdef.signature.name) for _, attr_value in node_def.attr.items(): if attr_value.func.name: attr_value.func.name = functions[attr_value.func.name].name # TODO(b/124205571): Avoid accidental sharing and destruction of restored # resources. For now uniquify "shared_name" when loading functions to avoid # sharing. if "shared_name" in node_def.attr: node_def.attr["shared_name"].s += compat.as_bytes( shared_name_suffix) fdef.signature.name = _clean_function_name(fdef.signature.name) return fdef
def __init__(self, name, graph, operations, inputs, outputs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function """ with errors.raise_exception_on_not_ok_status() as status: fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str(""), status) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.in_eager_mode(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = fn self._grad_func = None
def _fix_fdef(orig_fdef, functions, shared_name_suffix, new_gradient_op_types): """Fixes a FunctionDef proto to be loaded in current context. In particular, when loading a function library into an eager context, one must rename the functions to avoid conflicts with existent functions. Args: orig_fdef: FunctionDef proto to fix. It is not modified. functions: map from function name to a ConcreteFunction instance. shared_name_suffix: A unique string for this load which helps to avoid `shared_name` collisions across loads. Two functions from the same load using the same `shared_name` still need to share, but functions from different loads with the same `shared_name` should not. new_gradient_op_types: map from old gradient op type to newly generated op type. Returns: A fixed copy of the original FunctionDef """ fdef = function_pb2.FunctionDef() fdef.CopyFrom(orig_fdef) contains_unsaved_custom_gradients = False for node_def in fdef.node_def: fix_node_def(node_def, functions, shared_name_suffix) op_type = _get_gradient_op_type(node_def) if op_type is not None: if op_type in new_gradient_op_types: node_def.attr["_gradient_op_type"].s = compat.as_bytes( new_gradient_op_types[op_type]) else: contains_unsaved_custom_gradients = True if contains_unsaved_custom_gradients: logging.warning( "Importing a function (%s) with ops with unsaved custom gradients. Will" " likely fail if a gradient is requested.", fdef.signature.name) fdef.signature.name = _clean_function_name(fdef.signature.name) return fdef
def test_create_function_called_by_py_runtime(self): if not tf2.enabled(): self.skipTest("TF2 test") fndef = text_format.Parse( """ signature { name: 'NullaryFunction' output_arg { name: 'o' type: DT_INT32 } } node_def { name: 'retval' op: 'Const' attr { key: 'dtype' value { type: DT_INT32 } } attr { key: 'value' value { tensor { dtype: DT_INT32 tensor_shape {} int_val: 1 } } } } ret { key: 'o' value: 'retval:output' } """, function_pb2.FunctionDef(), ) ctx = runtime_client.GlobalPythonEagerContext() rt = runtime_client.Runtime(ctx) rt.CreateFunction(fndef) ret, = execute.execute("NullaryFunction", 1, [], (), context.context()) self.assertAllEqual(ret, 1)
def graph_to_function_def(graph, operations, inputs, outputs, out_names=None): """Returns `graph` as a `FunctionDef` protocol buffer. This method creates a [`FunctionDef`]( https://www.tensorflow.org/code/tensorflow/core/framework/function.proto) protocol buffer that contains all the ops in `operations`. The operations become the body of the function. The arguments `inputs` and `outputs` will be listed as the inputs and outputs tensors of the function. They must be lists of tensors present in the graph. The lists can optionally be empty. Args: graph: Graph. operations: the operations to put in the function. Must be a subset of the operations in the graph. inputs: List of tensors. Inputs to the function. outputs: List of tensors. Outputs of the function. out_names: Optional list of string names for the outputs. Returns: A FunctionDef protocol buffer. Raises: ValueError: if out_names is specified and the wrong length. """ func = function_pb2.FunctionDef() func.signature.name = "_" used_names = set() func.signature.input_arg.extend( [_tensor_to_argdef(i, used_names=used_names) for i in inputs]) # Initializes the input map with all placeholder input tensors. initial_dict = {} for o, m in zip(inputs, func.signature.input_arg): initial_dict[o.name] = m.name if out_names is None: used_names = set() func.signature.output_arg.extend( [_tensor_to_argdef(o, used_names=used_names) for o in outputs]) elif len(outputs) != len(out_names): raise errors_impl.InvalidArgumentError( None, None, "output names must be either empty or equal in size to outputs. " "output names size = %d outputs size = %d" % (len(out_names), len(outputs))) elif len(out_names) != len(set(out_names)): raise ValueError("Must not have duplicates in out_names: %s" % ", ".join(out_names)) else: func.signature.output_arg.extend( [_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)]) func_arg_placeholders = set(i.name for i in inputs) input_dict = _create_input_dict(graph, func_arg_placeholders, initial_value=initial_dict) for op in operations: if _is_in_placeholders(op, func_arg_placeholders): continue _add_op_node(op, func, input_dict) if out_names is None: for index, o in enumerate(outputs): k = func.signature.output_arg[index].name func.ret[k] = input_dict[o.name] else: for o, n in zip(outputs, out_names): func.ret[n] = input_dict[o.name] return func