def make_function_def(name, graph, operations, inputs, outputs): """Makes FunctionDef proto and defined function. Args: name: the function name graph: the graph from which to build the function operations: the operations in the function body inputs: tensors to be used as function arguments outputs: tensors to be returned from the function Returns: fdef: a FunctionDef protocol buffer for the function fn: a wrapped TF_Function for the function """ with errors.raise_exception_on_not_ok_status() as status: fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str(""), status) # TODO(apassos) avoid creating a FunctionDef (specially to grab the signature, # but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) fdef = function_pb2.FunctionDef() fdef.ParseFromString(compat.as_bytes(proto_data)) return fdef, fn
def function_def_from_tf_function(c_func): """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto.""" with c_api_util.tf_buffer() as buf: c_api.TF_FunctionToFunctionDef(c_func, buf) data = c_api.TF_GetBuffer(buf) fdef = function_pb2.FunctionDef() fdef.ParseFromString(compat.as_bytes(data)) return fdef
def definition(self): """Function definition proto.""" self._create_definition_if_needed() if self._c_func: with c_api_util.tf_buffer() as buf: c_api.TF_FunctionToFunctionDef(self._c_func.func, buf) fdef = function_pb2.FunctionDef() proto_data = c_api.TF_GetBuffer(buf) fdef.ParseFromString(compat.as_bytes(proto_data)) return fdef return self._definition
def definition(self): """Function definition proto.""" self._create_definition_if_needed() if self._c_func: with c_api_util.tf_buffer() as buf: with errors.raise_exception_on_not_ok_status() as status: c_api.TF_FunctionToFunctionDef(self._c_func, buf, status) fdef = function_pb2.FunctionDef() proto_data = c_api.TF_GetBuffer(buf) fdef.ParseFromString(compat.as_bytes(proto_data)) return fdef return self._definition
def definition(self): """Function definition proto.""" self._create_definition_if_needed() if self._c_func: with c_api_util.tf_buffer() as buf: c_api.TF_FunctionToFunctionDef(self._c_func.func, buf) fdef = function_pb2.FunctionDef() proto_data = c_api.TF_GetBuffer(buf) fdef.ParseFromString(compat.as_bytes(proto_data)) with ops.init_scope(): if context.executing_eagerly(): context.add_function(self._c_func.func) self._function_deleter = _DefinedFunctionDeleter( fdef.signature.name) return fdef return self._definition
def __init__(self, name, graph, operations, inputs, outputs, attrs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function attrs: dict mapping names of attributes to their AttrValue values """ fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str("")) for name, attr_value in attrs.items(): serialized = attr_value.SerializeToString() # TODO(iga): this creates and deletes a new TF_Status for every attr. # It might be worth creating a convenient way to re-use status. pywrap_tensorflow.TF_FunctionSetAttrValueProto( fn, compat.as_str(name), serialized) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.executing_eagerly(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = c_api_util.ScopedTFFunction(fn) self._grad_func = None
def __init__(self, name, graph, operations, inputs, outputs): """Initializes an eager defined function. Args: name: str, the name for the created function. graph: Graph, the graph containing the operations in the function operations: list of Operation; the subset of operations in the graph which will be in the function inputs: the tensors in the graph to be used as inputs to the function outputs: the tensors in the graph which will be outputs to the function """ with errors.raise_exception_on_not_ok_status() as status: fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( graph._c_graph, # pylint: disable=protected-access compat.as_str(name), False, [o._c_op for o in operations], # pylint: disable=protected-access [t._as_tf_output() for t in inputs], # pylint: disable=protected-access [t._as_tf_output() for t in outputs], # pylint: disable=protected-access [], None, compat.as_str(""), status) # TODO(apassos) avoid creating a FunctionDef (specially to grab the # signature, but also in general it's nice not to depend on it. with c_api_util.tf_buffer() as buffer_: with errors.raise_exception_on_not_ok_status() as status: pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) function_def = function_pb2.FunctionDef() function_def.ParseFromString(compat.as_bytes(proto_data)) if context.in_eager_mode(): _register(fn) self.definition = function_def self.name = function_def.signature.name self.signature = function_def.signature self.grad_func_name = None self.python_grad_func = None self._c_func = fn self._grad_func = None