Ejemplo n.º 1
0
def _from_definition(fdef, grad_func=None):
    """Creates a _DefinedFunction initialized from a FunctionDef proto.

  Args:
    fdef: a FunctionDef
    grad_func: a _DefinedFunction or None

  Returns:
    A _DefinedFunction representing fdef
  """
    # TODO(iga): This method does major surgery on _DefinedFunction.
    # Make it a named constructor using @classmethod of _DefinedFunction.

    # The Python callable is only needed to create a FunctionDef. Since we have
    # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we
    # have access to such a callable here).
    func = None
    argnames = [arg.name for arg in fdef.signature.input_arg]
    input_types = tuple(
        dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
    func_name = fdef.signature.name
    # Note: FunctionDefs do not include python gradient functions, so if the
    # original _DefinedFunction included one it will not be reflected here.
    python_grad_func = None
    out_names = [arg.name for arg in fdef.signature.output_arg]
    result = _DefinedFunction(func, argnames, input_types, func_name,
                              grad_func, python_grad_func, out_names)
    # pylint: disable=protected-access
    serialized = fdef.SerializeToString()
    c_func = c_api.TF_FunctionImportFunctionDef(serialized)
    result._c_func = c_api_util.ScopedTFFunction(c_func)
    result._extra_inputs = []
    # pylint: enable=protected-access

    return result
Ejemplo n.º 2
0
def _create_new_tf_function(func_graph):
    """Converts func_graph to a TF_Function and adds it to the current graph.

  Args:
    func_graph: function._FuncGraph

  Returns:
    The name of the new TF_Function.
  """
    c_func = c_api.TF_GraphToFunction_wrapper(
        func_graph._c_graph,
        compat.as_str(func_graph.name),
        False,  # append_hash_to_fn_name
        None,  # opers
        [t._as_tf_output() for t in func_graph.inputs],
        [t._as_tf_output() for t in func_graph.outputs],
        [],
        None,  # opts
        None)  # description
    _ = c_api_util.ScopedTFFunction(c_func)

    # TODO(b/109833212): this sucks, we're serializing the TF_Function*,
    # deserializing it into a Python FunctionDef, then reserializing it to create
    # a new TF_Function that we add to the graph.
    fdef = _function.function_def_from_tf_function(c_func)
    defined_func = _function._from_definition(fdef)
    defined_func._sub_functions = func_graph._functions
    defined_func.add_to_graph(func_graph._outer_graph)

    return func_graph.name
Ejemplo n.º 3
0
    def __init__(self, name, graph, operations, inputs, outputs, attrs):
        """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
      attrs: dict mapping names of attributes to their AttrValue values
    """
        fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
            graph._c_graph,  # pylint: disable=protected-access
            compat.as_str(name),
            False,
            [o._c_op for o in operations],  # pylint: disable=protected-access
            [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
            [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
            [],
            None,
            compat.as_str(""))

        for name, attr_value in attrs.items():
            serialized = attr_value.SerializeToString()
            # TODO(iga): this creates and deletes a new TF_Status for every attr.
            # It might be worth creating a convenient way to re-use status.
            pywrap_tensorflow.TF_FunctionSetAttrValueProto(
                fn, compat.as_str(name), serialized)

        # TODO(apassos) avoid creating a FunctionDef (specially to grab the
        # signature, but also in general it's nice not to depend on it.
        with c_api_util.tf_buffer() as buffer_:
            pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
            proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
        function_def = function_pb2.FunctionDef()
        function_def.ParseFromString(compat.as_bytes(proto_data))
        if context.executing_eagerly():
            _register(fn)
        self.definition = function_def
        self.name = function_def.signature.name
        self.signature = function_def.signature
        self.grad_func_name = None
        self.python_grad_func = None
        self._c_func = c_api_util.ScopedTFFunction(fn)
        self._grad_func = None
Ejemplo n.º 4
0
    def __init__(self, name, graph, operations, inputs, outputs):
        """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
    """
        fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
            graph._c_graph,  # pylint: disable=protected-access
            compat.as_str(name),
            False,
            [o._c_op for o in operations],  # pylint: disable=protected-access
            [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
            [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
            [],
            None,
            compat.as_str(""))
        # TODO(apassos) avoid creating a FunctionDef (specially to grab the
        # signature, but also in general it's nice not to depend on it.
        with c_api_util.tf_buffer() as buffer_:
            pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
            proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
        function_def = function_pb2.FunctionDef()
        function_def.ParseFromString(compat.as_bytes(proto_data))
        if context.executing_eagerly():
            _register(fn)
        self.definition = function_def
        self.name = function_def.signature.name
        self.signature = function_def.signature
        self.grad_func_name = None
        self.python_grad_func = None
        self._c_func = c_api_util.ScopedTFFunction(fn)
        self._grad_func = None
Ejemplo n.º 5
0
def _create_new_tf_function(func_graph):
    """Converts func_graph to a TF_Function and adds it to the current graph.

  Args:
    func_graph: function._FuncGraph

  Returns:
    The name of the new TF_Function.
  """
    func_graph.name = "%s_" % func_graph.name
    c_func = c_api.TF_GraphToFunction_wrapper(
        func_graph._c_graph,
        func_graph.name,
        False,  # append_hash_to_fn_name
        None,  # opers
        [t._as_tf_output() for t in func_graph.inputs],
        [t._as_tf_output() for t in func_graph.outputs],
        [],
        None,  # opts
        None)  # description
    c_func = c_api_util.ScopedTFFunction(c_func)
    c_api.TF_GraphCopyFunction(ops.get_default_graph()._c_graph, c_func.func,
                               None)
    return func_graph.name
Ejemplo n.º 6
0
  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    # Copy variable collections (by reference) from the parent graph such that
    # name based variable sharing (e.g. via tf.make_template) works between the
    # func graph and parent graph.
    variable_keys = []
    variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS)  # pylint: disable=protected-access
    variable_keys.append(vs._VARSTORE_KEY)  # pylint: disable=protected-access

    parent_graph = ops.get_default_graph()
    collections_ref = {
        key: parent_graph.get_collection_ref(key) for key in variable_keys}

    temp_graph = func_graph_from_py_func(
        self._func,
        self._arg_names,
        self._arg_types,
        self._func_name,
        self._capture_by_value,
        self._caller_device,
        collections_ref=collections_ref,
        allowlisted_stateful_ops=self._allowlisted_stateful_ops,
        capture_resource_var_by_value=self._capture_resource_var_by_value)

    self._extra_inputs = temp_graph.extra_inputs
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    if self._func_name:
      base_func_name = self._func_name
    else:
      base_func_name = function_utils.get_func_name(self._func)
      if self._grad_func:
        base_func_name += ("_%s" % self._grad_func.name)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          temp_graph.inputs,
          temp_graph.outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      c_func = c_api.TF_GraphToFunction_wrapper(
          temp_graph._c_graph,
          base_func_name,
          self._func_name is None,  # append_hash_to_fn_name
          None,  # opers
          [t._as_tf_output() for t in temp_graph.inputs],
          [t._as_tf_output() for t in temp_graph.outputs],
          output_names,
          [], # control_outputs
          [], # control_output_names
          None,  # opts
          description)
      self._c_func = c_api_util.ScopedTFFunction(c_func)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)

    self._stateful_ops = [(op.name, op.type)
                          for op in temp_graph.get_operations()
                          if op._is_stateful]  # pylint: disable=protected-access
Ejemplo n.º 7
0
    def _create_definition_if_needed_impl(self):
        """This is not what you want, see _create_definition_if_needed."""
        if self._definition is not None or self._c_func is not None:
            return

        temp_graph = func_graph_from_py_func(
            self._func,
            self._arg_names,
            self._arg_types,
            self._func_name,
            self._capture_by_value,
            self._caller_device,
            whitelisted_stateful_ops=self._whitelisted_stateful_ops)

        self._extra_inputs = temp_graph.extra_inputs
        # pylint: disable=protected-access
        self._sub_functions = temp_graph._functions
        # pylint: enable=protected-access

        # Extra kwargs are treated as attrs on the function def.
        if self._func_name:
            base_func_name = self._func_name
        else:
            base_func_name = function_utils.get_func_name(self._func)
            if self._grad_func:
                base_func_name += ("_%s" % self._grad_func.name)
        kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
                                             **self._extra_kwargs)

        if not temp_graph._c_graph:  # pylint: disable=protected-access
            # Build the FunctionDef
            self._definition = graph_to_function_def.graph_to_function_def(
                temp_graph,
                temp_graph.get_operations(),
                temp_graph.inputs,
                temp_graph.outputs,
                out_names=self._out_names)

            for k in kwargs_attr:
                self._definition.attr[k].CopyFrom(kwargs_attr[k])

            # Hash the definition and its dependencies.
            self._hash_str = self._create_hash_str(
                self._definition.signature.input_arg,
                self._definition.signature.output_arg,
                self._definition.node_def)

            # Finally, we decide the function name to use.  If not specified,
            # make up something which is almost certainly unique (but deterministic).
            if not self._func_name:
                self._func_name = "_".join([base_func_name, self._hash_str])
            self._definition.signature.name = self._func_name
            if self._func.__doc__:
                self._definition.signature.description = self._func.__doc__

            self._op_def = self._definition.signature
        else:  # C API is enabled
            output_names = ([compat.as_bytes(x) for x in self._out_names]
                            if self._out_names else [])
            description = self._func.__doc__ or None
            # pylint: disable=protected-access
            c_func = c_api.TF_GraphToFunction_wrapper(
                temp_graph._c_graph,
                base_func_name,
                self._func_name is None,  # append_hash_to_fn_name
                None,  # opers
                [t._as_tf_output() for t in temp_graph.inputs],
                [t._as_tf_output() for t in temp_graph.outputs],
                output_names,
                None,  # opts
                description)
            self._c_func = c_api_util.ScopedTFFunction(c_func)
            # pylint: enable=protected-access
            self._set_c_attrs(kwargs_attr)

            # Set cached fields: _op_def and _func_name (if not already set)
            self._op_def = self.definition.signature
            if self._func_name:
                assert self._func_name == self._op_def.name
            else:
                self._func_name = compat.as_str(self._op_def.name)

        self._stateful_ops = [(op.name, op.type)
                              for op in temp_graph.get_operations()
                              if op.op_def.is_stateful]
Ejemplo n.º 8
0
    def _create_definition_if_needed_impl(self):
        """This is not what you want, see _create_definition_if_needed."""
        if self._definition is not None or self._c_func is not None:
            return

        # Create the func_def object.
        temp_graph = _FuncGraph(capture_by_value=self._capture_by_value)
        with temp_graph.as_default():
            # List of placeholders for the function_def.
            inputs = []
            for (argname, argtype) in self._args:
                argholder = array_ops.placeholder(argtype, name=argname)
                inputs.append(argholder)
            # Call func and gather the output tensors.
            with vs.variable_scope("", custom_getter=temp_graph.getvar):
                outputs = self._func(*inputs)

            # There is no way of distinguishing between a function not returning
            # anything and a function returning None in Python.
            # We need to allow the former and ideally want to forbid the latter as
            # it is most likely user error.
            # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
            # allow users to explicitly mark the function as not returning anything.
            # For now, we allow a single None return and interpret it as a function
            # with no output.
            if outputs is None:
                outputs = []
            else:
                # If func only returned one value, make it a tuple.
                if not isinstance(outputs, (list, tuple)):
                    outputs = (outputs, )
                if any([_ is None for _ in outputs]):
                    raise ValueError("Function can not return None.")
            # Ensures each output is a Tensor in the function graph.
            outputs = [ops.convert_to_tensor(t) for t in outputs]
            outputs = [
                temp_graph.capture(t) if t.graph is not temp_graph else t
                for t in outputs
            ]
        self._extra_inputs = temp_graph.extra_inputs
        inputs.extend(temp_graph.extra_args)
        # pylint: disable=protected-access
        self._sub_functions = temp_graph._functions
        # pylint: enable=protected-access

        # Extra kwargs are treated as attrs on the function def.
        if self._func_name:
            base_func_name = self._func_name
        else:
            base_func_name = _get_func_name(self._func)
            if self._grad_func:
                base_func_name += ("_%s" % self._grad_func.name)
        kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
                                             **self._extra_kwargs)

        if not temp_graph._c_graph:  # pylint: disable=protected-access
            # Build the FunctionDef
            self._definition = graph_to_function_def.graph_to_function_def(
                temp_graph,
                temp_graph.get_operations(),
                inputs,
                outputs,
                out_names=self._out_names)

            for k in kwargs_attr:
                self._definition.attr[k].CopyFrom(kwargs_attr[k])

            # Hash the definition and its dependencies.
            self._hash_str = self._create_hash_str(
                self._definition.signature.input_arg,
                self._definition.signature.output_arg,
                self._definition.node_def)

            # Finally, we decide the function name to use.  If not specified,
            # make up something which is almost certainly unique (but deterministic).
            if not self._func_name:
                self._func_name = "_".join([base_func_name, self._hash_str])
            self._definition.signature.name = self._func_name
            if self._func.__doc__:
                self._definition.signature.description = self._func.__doc__

            self._op_def = self._definition.signature
        else:  # C API is enabled
            output_names = ([compat.as_bytes(x) for x in self._out_names]
                            if self._out_names else [])
            description = self._func.__doc__ or None
            # pylint: disable=protected-access
            c_func = c_api.TF_GraphToFunction_wrapper(
                temp_graph._c_graph,
                base_func_name,
                self._func_name is None,  # append_hash_to_fn_name
                None,  # opers
                [t._as_tf_output() for t in inputs],
                [t._as_tf_output() for t in outputs],
                output_names,
                None,  # opts
                description)
            self._c_func = c_api_util.ScopedTFFunction(c_func)
            # pylint: enable=protected-access
            self._set_c_attrs(kwargs_attr)

            # Set cached fields: _op_def and _func_name (if not already set)
            self._op_def = self.definition.signature
            if self._func_name:
                assert self._func_name == self._op_def.name
            else:
                self._func_name = compat.as_str(self._op_def.name)