Пример #1
0
def make_function_def(name, graph, operations, inputs, outputs):
    """Makes FunctionDef proto and defined function.

  Args:
    name: the function name
    graph: the graph from which to build the function
    operations: the operations in the function body
    inputs: tensors to be used as function arguments
    outputs: tensors to be returned from the function

  Returns:
   fdef: a FunctionDef protocol buffer for the function
   fn: a wrapped TF_Function for the function
  """
    with errors.raise_exception_on_not_ok_status() as status:
        fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
            graph._c_graph,  # pylint: disable=protected-access
            compat.as_str(name),
            False,
            [o._c_op for o in operations],  # pylint: disable=protected-access
            [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
            [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
            [],
            None,
            compat.as_str(""),
            status)
    # TODO(apassos) avoid creating a FunctionDef (specially to grab the signature,
    # but also in general it's nice not to depend on it.
    with c_api_util.tf_buffer() as buffer_:
        with errors.raise_exception_on_not_ok_status() as status:
            pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
        proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
    fdef = function_pb2.FunctionDef()
    fdef.ParseFromString(compat.as_bytes(proto_data))
    return fdef, fn
Пример #2
0
def function_def_from_tf_function(c_func):
    """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto."""
    with c_api_util.tf_buffer() as buf:
        c_api.TF_FunctionToFunctionDef(c_func, buf)
        data = c_api.TF_GetBuffer(buf)
    fdef = function_pb2.FunctionDef()
    fdef.ParseFromString(compat.as_bytes(data))
    return fdef
Пример #3
0
 def definition(self):
     """Function definition proto."""
     self._create_definition_if_needed()
     if self._c_func:
         with c_api_util.tf_buffer() as buf:
             c_api.TF_FunctionToFunctionDef(self._c_func.func, buf)
             fdef = function_pb2.FunctionDef()
             proto_data = c_api.TF_GetBuffer(buf)
             fdef.ParseFromString(compat.as_bytes(proto_data))
         return fdef
     return self._definition
Пример #4
0
 def definition(self):
   """Function definition proto."""
   self._create_definition_if_needed()
   if self._c_func:
     with c_api_util.tf_buffer() as buf:
       with errors.raise_exception_on_not_ok_status() as status:
         c_api.TF_FunctionToFunctionDef(self._c_func, buf, status)
       fdef = function_pb2.FunctionDef()
       proto_data = c_api.TF_GetBuffer(buf)
       fdef.ParseFromString(compat.as_bytes(proto_data))
     return fdef
   return self._definition
Пример #5
0
 def definition(self):
   """Function definition proto."""
   self._create_definition_if_needed()
   if self._c_func:
     with c_api_util.tf_buffer() as buf:
       c_api.TF_FunctionToFunctionDef(self._c_func.func, buf)
       fdef = function_pb2.FunctionDef()
       proto_data = c_api.TF_GetBuffer(buf)
       fdef.ParseFromString(compat.as_bytes(proto_data))
       with ops.init_scope():
         if context.executing_eagerly():
           context.add_function(self._c_func.func)
           self._function_deleter = _DefinedFunctionDeleter(
               fdef.signature.name)
     return fdef
   return self._definition
Пример #6
0
    def __init__(self, name, graph, operations, inputs, outputs, attrs):
        """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
      attrs: dict mapping names of attributes to their AttrValue values
    """
        fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
            graph._c_graph,  # pylint: disable=protected-access
            compat.as_str(name),
            False,
            [o._c_op for o in operations],  # pylint: disable=protected-access
            [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
            [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
            [],
            None,
            compat.as_str(""))

        for name, attr_value in attrs.items():
            serialized = attr_value.SerializeToString()
            # TODO(iga): this creates and deletes a new TF_Status for every attr.
            # It might be worth creating a convenient way to re-use status.
            pywrap_tensorflow.TF_FunctionSetAttrValueProto(
                fn, compat.as_str(name), serialized)

        # TODO(apassos) avoid creating a FunctionDef (specially to grab the
        # signature, but also in general it's nice not to depend on it.
        with c_api_util.tf_buffer() as buffer_:
            pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
            proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
        function_def = function_pb2.FunctionDef()
        function_def.ParseFromString(compat.as_bytes(proto_data))
        if context.executing_eagerly():
            _register(fn)
        self.definition = function_def
        self.name = function_def.signature.name
        self.signature = function_def.signature
        self.grad_func_name = None
        self.python_grad_func = None
        self._c_func = c_api_util.ScopedTFFunction(fn)
        self._grad_func = None
Пример #7
0
    def __init__(self, name, graph, operations, inputs, outputs):
        """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
    """
        with errors.raise_exception_on_not_ok_status() as status:
            fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
                graph._c_graph,  # pylint: disable=protected-access
                compat.as_str(name),
                False,
                [o._c_op for o in operations],  # pylint: disable=protected-access
                [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
                [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
                [],
                None,
                compat.as_str(""),
                status)
        # TODO(apassos) avoid creating a FunctionDef (specially to grab the
        # signature, but also in general it's nice not to depend on it.
        with c_api_util.tf_buffer() as buffer_:
            with errors.raise_exception_on_not_ok_status() as status:
                pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
            proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
        function_def = function_pb2.FunctionDef()
        function_def.ParseFromString(compat.as_bytes(proto_data))
        if context.in_eager_mode():
            _register(fn)
        self.definition = function_def
        self.name = function_def.signature.name
        self.signature = function_def.signature
        self.grad_func_name = None
        self.python_grad_func = None
        self._c_func = fn
        self._grad_func = None