コード例 #1
0
ファイル: cond_v2_impl.py プロジェクト: zyx910/tensorflow
def _create_new_tf_function(func_graph):
    """Converts func_graph to a TF_Function and adds it to the current graph.

  Args:
    func_graph: function._FuncGraph

  Returns:
    The name of the new TF_Function.
  """
    c_func = c_api.TF_GraphToFunction_wrapper(
        func_graph._c_graph,
        compat.as_str(func_graph.name),
        False,  # append_hash_to_fn_name
        None,  # opers
        [t._as_tf_output() for t in func_graph.inputs],
        [t._as_tf_output() for t in func_graph.outputs],
        [],
        None,  # opts
        None)  # description
    _ = c_api_util.ScopedTFFunction(c_func)

    # TODO(b/109833212): this sucks, we're serializing the TF_Function*,
    # deserializing it into a Python FunctionDef, then reserializing it to create
    # a new TF_Function that we add to the graph.
    fdef = _function.function_def_from_tf_function(c_func)
    defined_func = _function._from_definition(fdef)
    defined_func._sub_functions = func_graph._functions
    defined_func.add_to_graph(func_graph._outer_graph)

    return func_graph.name
コード例 #2
0
def make_function_def(name, graph, operations, inputs, outputs):
    """Makes FunctionDef proto and defined function.

  Args:
    name: the function name
    graph: the graph from which to build the function
    operations: the operations in the function body
    inputs: tensors to be used as function arguments
    outputs: tensors to be returned from the function

  Returns:
   fdef: a FunctionDef protocol buffer for the function
   fn: a wrapped TF_Function for the function
  """
    with errors.raise_exception_on_not_ok_status() as status:
        fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
            graph._c_graph,  # pylint: disable=protected-access
            compat.as_str(name),
            False,
            [o._c_op for o in operations],  # pylint: disable=protected-access
            [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
            [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
            [],
            None,
            compat.as_str(""),
            status)
    # TODO(apassos) avoid creating a FunctionDef (specially to grab the signature,
    # but also in general it's nice not to depend on it.
    with c_api_util.tf_buffer() as buffer_:
        with errors.raise_exception_on_not_ok_status() as status:
            pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
        proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
    fdef = function_pb2.FunctionDef()
    fdef.ParseFromString(compat.as_bytes(proto_data))
    return fdef, fn
コード例 #3
0
ファイル: function.py プロジェクト: sgcm520/tensorflow2
    def __init__(self, name, graph, operations, inputs, outputs, attrs):
        """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
      attrs: dict mapping names of attributes to their AttrValue values
    """
        fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
            graph._c_graph,  # pylint: disable=protected-access
            compat.as_str(name),
            False,
            [o._c_op for o in operations],  # pylint: disable=protected-access
            [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
            [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
            [],
            None,
            compat.as_str(""))

        for name, attr_value in attrs.items():
            serialized = attr_value.SerializeToString()
            # TODO(iga): this creates and deletes a new TF_Status for every attr.
            # It might be worth creating a convenient way to re-use status.
            pywrap_tensorflow.TF_FunctionSetAttrValueProto(
                fn, compat.as_str(name), serialized)

        # TODO(apassos) avoid creating a FunctionDef (specially to grab the
        # signature, but also in general it's nice not to depend on it.
        with c_api_util.tf_buffer() as buffer_:
            pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
            proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
        function_def = function_pb2.FunctionDef()
        function_def.ParseFromString(compat.as_bytes(proto_data))
        if context.executing_eagerly():
            _register(fn)
        self.definition = function_def
        self.name = function_def.signature.name
        self.signature = function_def.signature
        self.grad_func_name = None
        self.python_grad_func = None
        self._c_func = c_api_util.ScopedTFFunction(fn)
        self._grad_func = None
コード例 #4
0
ファイル: function.py プロジェクト: tejamukka/tensorflow-1
    def __init__(self, name, graph, operations, inputs, outputs):
        """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
    """
        with errors.raise_exception_on_not_ok_status() as status:
            fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
                graph._c_graph,  # pylint: disable=protected-access
                compat.as_str(name),
                False,
                [o._c_op for o in operations],  # pylint: disable=protected-access
                [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
                [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
                [],
                None,
                compat.as_str(""),
                status)
        # TODO(apassos) avoid creating a FunctionDef (specially to grab the
        # signature, but also in general it's nice not to depend on it.
        with c_api_util.tf_buffer() as buffer_:
            with errors.raise_exception_on_not_ok_status() as status:
                pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
            proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
        function_def = function_pb2.FunctionDef()
        function_def.ParseFromString(compat.as_bytes(proto_data))
        if context.in_eager_mode():
            _register(fn)
        self.definition = function_def
        self.name = function_def.signature.name
        self.signature = function_def.signature
        self.grad_func_name = None
        self.python_grad_func = None
        self._c_func = fn
        self._grad_func = None
コード例 #5
0
def _create_new_tf_function(func_graph):
    """Converts func_graph to a TF_Function and adds it to the current graph.

  Args:
    func_graph: function._FuncGraph

  Returns:
    The name of the new TF_Function.
  """
    func_graph.name = "%s_" % func_graph.name
    c_func = c_api.TF_GraphToFunction_wrapper(
        func_graph._c_graph,
        func_graph.name,
        False,  # append_hash_to_fn_name
        None,  # opers
        [t._as_tf_output() for t in func_graph.inputs],
        [t._as_tf_output() for t in func_graph.outputs],
        [],
        None,  # opts
        None)  # description
    c_func = c_api_util.ScopedTFFunction(c_func)
    c_api.TF_GraphCopyFunction(ops.get_default_graph()._c_graph, c_func.func,
                               None)
    return func_graph.name
コード例 #6
0
ファイル: function.py プロジェクト: ychen404/TensorFlowPlus
  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None:
      return

    # Create the func_def object.
    temp_graph = _FuncGraph(capture_by_value=self._capture_by_value)
    with temp_graph.as_default():
      # List of placeholders for the function_def.
      inputs = []
      for (argname, argtype) in self._args:
        argholder = array_ops.placeholder(argtype, name=argname)
        inputs.append(argholder)
      # Call func and gather the output tensors.
      with vs.variable_scope("", custom_getter=temp_graph.getvar):
        outputs = self._func(*inputs)
      # If func only returned one value, make it a tuple.
      if not isinstance(outputs, (list, tuple)):
        outputs = (outputs,)
      if any([_ is None for _ in outputs]):
        raise ValueError("Function can not return None.")
      # Ensures each output is a Tensor.
      outputs = [ops.convert_to_tensor(_) for _ in outputs]
    self._extra_inputs = temp_graph.extra_inputs
    inputs.extend(temp_graph.extra_args)
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Build the FunctionDef
    self._definition = graph_to_function_def.graph_to_function_def(
        temp_graph,
        temp_graph.get_operations(),
        inputs,
        outputs,
        out_names=self._out_names)

    # Extra kwargs are treated as attrs on the function def.
    sig_pre_func_name = self._func_name or _get_func_name(self._func)
    kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name,
                                         **self._extra_kwargs)
    for k in kwargs_attr:
      self._definition.attr[k].CopyFrom(kwargs_attr[k])

    # Hash the definition and its dependencies.
    self._hash_str = self._create_hash_str(
        self._definition.signature.input_arg,
        self._definition.signature.output_arg, self._definition.node_def)

    # Finally, we decide the function name to use.  If not specified,
    # make up something which is almost certainly unique (but deterministic).
    if not self._func_name:
      self._func_name = "_".join([_get_func_name(self._func), self._hash_str])
    self._definition.signature.name = self._func_name
    if self._func.__doc__:
      self._definition.signature.description = self._func.__doc__

    # pylint: disable=protected-access
    if temp_graph._c_graph:
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      with errors.raise_exception_on_not_ok_status() as status:
        self._c_func = c_api.TF_GraphToFunction_wrapper(
            temp_graph._c_graph,
            self._func_name,
            False,  # append_hash_to_fn_name
            None,  # opers
            [t._as_tf_output() for t in inputs],
            [t._as_tf_output() for t in outputs],
            output_names,
            None,  # opts
            description,
            status)
      self._set_c_attrs(kwargs_attr)
コード例 #7
0
ファイル: function.py プロジェクト: zuozi2810/tensorflow
    def _create_definition_if_needed_impl(self):
        """This is not what you want, see _create_definition_if_needed."""
        if self._definition is not None or self._c_func is not None:
            return

        temp_graph = func_graph_from_py_func(
            self._func,
            self._arg_names,
            self._arg_types,
            self._func_name,
            self._capture_by_value,
            self._caller_device,
            whitelisted_stateful_ops=self._whitelisted_stateful_ops,
            capture_resource_var_by_value=self._capture_resource_var_by_value)

        self._extra_inputs = temp_graph.extra_inputs
        # pylint: disable=protected-access
        self._sub_functions = temp_graph._functions
        # pylint: enable=protected-access

        # Extra kwargs are treated as attrs on the function def.
        if self._func_name:
            base_func_name = self._func_name
        else:
            base_func_name = function_utils.get_func_name(self._func)
            if self._grad_func:
                base_func_name += ("_%s" % self._grad_func.name)
        kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
                                             **self._extra_kwargs)

        if not temp_graph._c_graph:  # pylint: disable=protected-access
            # Build the FunctionDef
            self._definition = graph_to_function_def.graph_to_function_def(
                temp_graph,
                temp_graph.get_operations(),
                temp_graph.inputs,
                temp_graph.outputs,
                out_names=self._out_names)

            for k in kwargs_attr:
                self._definition.attr[k].CopyFrom(kwargs_attr[k])

            # Hash the definition and its dependencies.
            self._hash_str = self._create_hash_str(
                self._definition.signature.input_arg,
                self._definition.signature.output_arg,
                self._definition.node_def)

            # Finally, we decide the function name to use.  If not specified,
            # make up something which is almost certainly unique (but deterministic).
            if not self._func_name:
                self._func_name = "_".join([base_func_name, self._hash_str])
            self._definition.signature.name = self._func_name
            if self._func.__doc__:
                self._definition.signature.description = self._func.__doc__

            self._op_def = self._definition.signature
        else:  # C API is enabled
            output_names = ([compat.as_bytes(x) for x in self._out_names]
                            if self._out_names else [])
            description = self._func.__doc__ or None
            # pylint: disable=protected-access
            c_func = c_api.TF_GraphToFunction_wrapper(
                temp_graph._c_graph,
                base_func_name,
                self._func_name is None,  # append_hash_to_fn_name
                None,  # opers
                [t._as_tf_output() for t in temp_graph.inputs],
                [t._as_tf_output() for t in temp_graph.outputs],
                output_names,
                [],  # control_outputs
                [],  # control_output_names
                None,  # opts
                description)
            self._c_func = c_api_util.ScopedTFFunction(c_func)
            # pylint: enable=protected-access
            self._set_c_attrs(kwargs_attr)

            # Set cached fields: _op_def and _func_name (if not already set)
            self._op_def = self.definition.signature
            if self._func_name:
                assert self._func_name == self._op_def.name
            else:
                self._func_name = compat.as_str(self._op_def.name)

        self._stateful_ops = [(op.name, op.type)
                              for op in temp_graph.get_operations()
                              if op.op_def.is_stateful]
コード例 #8
0
ファイル: function.py プロジェクト: Odegaard11/publicBicycle
  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    # Create the func_def object.
    temp_graph = _FuncGraph(capture_by_value=self._capture_by_value)
    with temp_graph.as_default():
      # List of placeholders for the function_def.
      inputs = []
      for (argname, argtype) in self._args:
        argholder = array_ops.placeholder(argtype, name=argname)
        inputs.append(argholder)
      # Call func and gather the output tensors.
      with vs.variable_scope("", custom_getter=temp_graph.getvar):
        outputs = self._func(*inputs)

      # There is no way of distinguishing between a function not returning
      # anything and a function returning None in Python.
      # We need to allow the former and ideally want to forbid the latter as
      # it is most likely user error.
      # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
      # allow users to explicitly mark the function as not returning anything.
      # For now, we allow a single None return and interpret it as a function
      # with no output.
      if outputs is None:
        outputs = []
      else:
        # If func only returned one value, make it a tuple.
        if not isinstance(outputs, (list, tuple)):
          outputs = (outputs,)
        if any([_ is None for _ in outputs]):
          raise ValueError("Function can not return None.")
      # Ensures each output is a Tensor.
      outputs = [ops.convert_to_tensor(_) for _ in outputs]
    self._extra_inputs = temp_graph.extra_inputs
    inputs.extend(temp_graph.extra_args)
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    base_func_name = self._func_name or _get_func_name(self._func)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
                                         **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          inputs,
          outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      with errors.raise_exception_on_not_ok_status() as status:
        self._c_func = c_api.TF_GraphToFunction_wrapper(
            temp_graph._c_graph,
            base_func_name,
            self._func_name is None,  # append_hash_to_fn_name
            None,  # opers
            [t._as_tf_output() for t in inputs],
            [t._as_tf_output() for t in outputs],
            output_names,
            None,  # opts
            description,
            status)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)