Ejemplo n.º 1
0
 def testRaisesWithNonCallableObject(self):
     with self.assertRaises(ValueError):
         function_utils.get_func_name(None)
Ejemplo n.º 2
0
  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    temp_graph = func_graph_from_py_func(
        self._func, self._arg_names, self._arg_types, self._func_name,
        self._capture_by_value, self._caller_device)

    self._extra_inputs = temp_graph.extra_inputs
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    if self._func_name:
      base_func_name = self._func_name
    else:
      base_func_name = function_utils.get_func_name(self._func)
      if self._grad_func:
        base_func_name += ("_%s" % self._grad_func.name)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          temp_graph.inputs,
          temp_graph.outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      c_func = c_api.TF_GraphToFunction_wrapper(
          temp_graph._c_graph,
          base_func_name,
          self._func_name is None,  # append_hash_to_fn_name
          None,  # opers
          [t._as_tf_output() for t in temp_graph.inputs],
          [t._as_tf_output() for t in temp_graph.outputs],
          output_names,
          None,  # opts
          description)
      self._c_func = c_api_util.ScopedTFFunction(c_func)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)

    self._stateful_ops = [(op.name, op.type)
                          for op in temp_graph.get_operations()
                          if op.op_def.is_stateful]
Ejemplo n.º 3
0
 def testWithFunctoolsPartial(self):
     partial = functools.partial(silly_example_function)
     self.assertRegex(function_utils.get_func_name(partial),
                      '<.*functools.partial.*>')
Ejemplo n.º 4
0
 def testWithLambda(self):
     anon_fn = lambda x: x
     self.assertEqual('<lambda>', function_utils.get_func_name(anon_fn))
    def __init__(self,
                 func,
                 transformation_name,
                 dataset=None,
                 input_classes=None,
                 input_shapes=None,
                 input_types=None,
                 input_structure=None,
                 add_to_graph=True,
                 use_legacy_function=False,
                 defun_kwargs=None):
        """Creates a new `StructuredFunctionWrapper` for the given function.

    Args:
      func: A function from a (nested) structure to another (nested) structure.
      transformation_name: Human-readable name of the transformation in which
        this function is being instantiated, for error messages.
      dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this
        dataset will be assumed as the structure for `func` arguments; otherwise
        `input_classes`, `input_shapes`, and `input_types` must be defined.
      input_classes: (Optional.) A (nested) structure of `type`. If given, this
        argument defines the Python types for `func` arguments.
      input_shapes: (Optional.) A (nested) structure of `tf.TensorShape`. If
        given, this argument defines the shapes and structure for `func`
        arguments.
      input_types: (Optional.) A (nested) structure of `tf.DType`. If given,
        this argument defines the element types and structure for `func`
        arguments.
      input_structure: (Optional.) A `Structure` object. If given, this argument
        defines the element types and structure for `func` arguments.
      add_to_graph: (Optional.) If `True`, the function will be added to the
        default graph, if it exists.
      use_legacy_function: (Optional.) A boolean that determines whether the
        function be created using `tensorflow.python.eager.function.defun`
        (default behavior) or `tensorflow.python.framework.function.Defun`
        (legacy behavior).
      defun_kwargs: (Optional.) A dictionary mapping string argument names to
        values. If supplied, will be passed to `function` as keyword arguments.

    Raises:
      ValueError: If an invalid combination of `dataset`, `input_classes`,
        `input_shapes`, and `input_types` is passed.
    """
        # pylint: disable=protected-access
        if input_structure is None:
            if dataset is None:
                if input_classes is None or input_shapes is None or input_types is None:
                    raise ValueError(
                        "Either `dataset`, `input_structure` or all of "
                        "`input_classes`, `input_shapes`, and `input_types` "
                        "must be specified.")
                self._input_structure = structure.convert_legacy_structure(
                    input_types, input_shapes, input_classes)
            else:
                if not (input_classes is None and input_shapes is None
                        and input_types is None):
                    raise ValueError(
                        "Either `dataset`, `input_structure` or all of "
                        "`input_classes`, `input_shapes`, and `input_types` "
                        "must be specified.")
                self._input_structure = dataset.element_spec
        else:
            if not (dataset is None and input_classes is None
                    and input_shapes is None and input_types is None):
                raise ValueError(
                    "Either `dataset`, `input_structure`, or all of "
                    "`input_classes`, `input_shapes`, and `input_types` "
                    "must be specified.")
            self._input_structure = input_structure

        self._func = func

        if defun_kwargs is None:
            defun_kwargs = {}

        readable_transformation_name = transformation_name.replace(
            ".", "_")[:-2] if len(transformation_name) > 2 else ""

        func_name = "_".join(
            [readable_transformation_name,
             function_utils.get_func_name(func)])
        # Sanitize function name to remove symbols that interfere with graph
        # construction.
        for symbol in ["<", ">", "\\", "'", " "]:
            func_name = func_name.replace(symbol, "")

        ag_ctx = autograph_ctx.control_status_ctx()

        def wrapper_helper(*args):
            """Wrapper for passing nested structures to and from tf.data functions."""
            nested_args = structure.from_compatible_tensor_list(
                self._input_structure, args)
            if not _should_unpack(nested_args):
                nested_args = (nested_args, )
            ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
            if _should_pack(ret):
                ret = tuple(ret)

            try:
                self._output_structure = structure.type_spec_from_value(ret)
            except (ValueError, TypeError):
                six.reraise(
                    TypeError,
                    TypeError(
                        f"Unsupported return value from function passed to "
                        f"{transformation_name}: {ret}."),
                    sys.exc_info()[2])
            return ret

        def trace_legacy_function(defun_kwargs):
            @function.Defun(*structure.get_flat_tensor_types(
                self._input_structure), **defun_kwargs)
            def wrapped_fn(*args):
                ret = wrapper_helper(*args)
                return structure.to_tensor_list(self._output_structure, ret)

            return lambda: wrapped_fn

        def trace_py_function(defun_kwargs):
            # First we trace the function to infer the output structure.
            @eager_function.defun_with_attributes(
                input_signature=structure.get_flat_tensor_specs(
                    self._input_structure),
                autograph=False,
                attributes=defun_kwargs)
            def unused(*args):  # pylint: disable=missing-docstring,unused-variable
                ret = wrapper_helper(*args)
                ret = structure.to_tensor_list(self._output_structure, ret)
                return [ops.convert_to_tensor(t) for t in ret]

            _ = unused.get_concrete_function()

            def py_function_wrapper(*args):
                nested_args = structure.from_compatible_tensor_list(
                    self._input_structure, args)
                if not _should_unpack(nested_args):
                    nested_args = (nested_args, )
                ret = self._func(*nested_args)
                if _should_pack(ret):
                    ret = tuple(ret)
                ret = structure.to_tensor_list(self._output_structure, ret)
                return [ops.convert_to_tensor(t) for t in ret]

            # Next we trace the function wrapped in `eager_py_func` to force eager
            # execution.
            @eager_function.defun_with_attributes(
                input_signature=structure.get_flat_tensor_specs(
                    self._input_structure),
                autograph=False,
                attributes=defun_kwargs)
            def wrapped_fn(*args):  # pylint: disable=missing-docstring
                return script_ops.eager_py_func(
                    py_function_wrapper, args,
                    structure.get_flat_tensor_types(self._output_structure))

            return wrapped_fn.get_concrete_function

        def trace_tf_function(defun_kwargs):
            # Note: wrapper_helper will apply autograph based on context.
            @eager_function.defun_with_attributes(
                input_signature=structure.get_flat_tensor_specs(
                    self._input_structure),
                autograph=False,
                attributes=defun_kwargs)
            def wrapped_fn(*args):  # pylint: disable=missing-docstring
                ret = wrapper_helper(*args)
                ret = structure.to_tensor_list(self._output_structure, ret)
                return [ops.convert_to_tensor(t) for t in ret]

            return wrapped_fn.get_concrete_function

        if use_legacy_function:
            defun_kwargs.update(
                {"func_name": func_name + "_" + str(ops.uid())})
            fn_factory = trace_legacy_function(defun_kwargs)
        else:
            defun_kwargs.update({"func_name": func_name})
            defun_kwargs.update({"_tf_data_function": True})
            if dataset_ops.DEBUG_MODE:
                fn_factory = trace_py_function(defun_kwargs)
            else:
                if def_function.functions_run_eagerly():
                    warnings.warn(
                        "Even though the `tf.config.experimental_run_functions_eagerly` "
                        "option is set, this option does not apply to tf.data functions. "
                        "To force eager execution of tf.data functions, please use "
                        "`tf.data.experimental.enable_debug_mode()`.")
                fn_factory = trace_tf_function(defun_kwargs)

        self._function = fn_factory()
        # There is no graph to add in eager mode.
        add_to_graph &= not context.executing_eagerly()
        # There are some lifetime issues when a legacy function is not added to a
        # out-living graph. It's already deprecated so de-prioritizing the fix.
        add_to_graph |= use_legacy_function
        if add_to_graph:
            self._function.add_to_graph(ops.get_default_graph())

        if not use_legacy_function:
            outer_graph_seed = ops.get_default_graph().seed
            if outer_graph_seed and self._function.graph.seed == outer_graph_seed:
                if self._function.graph._seed_used:
                    warnings.warn(
                        "Seed %s from outer graph might be getting used by function %s, "
                        "if the random op has not been provided any seed. Explicitly set "
                        "the seed in the function if this is not the intended behavior."
                        % (outer_graph_seed, func_name),
                        stacklevel=4)
Ejemplo n.º 6
0
 def testWithCallableClass(self):
     callable_instance = SillyCallableClass()
     self.assertRegex(function_utils.get_func_name(callable_instance),
                      '<.*SillyCallableClass.*>')
 def testRaisesWithNonCallableObject(self):
   with self.assertRaises(ValueError):
     function_utils.get_func_name(None)
Ejemplo n.º 8
0
  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    # Copy variable collections (by reference) from the parent graph such that
    # name based variable sharing (e.g. via tf.make_template) works between the
    # func graph and parent graph.
    variable_keys = []
    variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS)  # pylint: disable=protected-access
    variable_keys.append(vs._VARSTORE_KEY)  # pylint: disable=protected-access

    parent_graph = ops.get_default_graph()
    collections_ref = {
        key: parent_graph.get_collection_ref(key) for key in variable_keys}

    temp_graph = func_graph_from_py_func(
        self._func,
        self._arg_names,
        self._arg_types,
        self._func_name,
        self._capture_by_value,
        self._caller_device,
        collections_ref=collections_ref,
        allowlisted_stateful_ops=self._allowlisted_stateful_ops,
        capture_resource_var_by_value=self._capture_resource_var_by_value)

    self._extra_inputs = temp_graph.extra_inputs
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    if self._func_name:
      base_func_name = self._func_name
    else:
      base_func_name = function_utils.get_func_name(self._func)
      if self._grad_func:
        base_func_name += ("_%s" % self._grad_func.name)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          temp_graph.inputs,
          temp_graph.outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      c_func = c_api.TF_GraphToFunction_wrapper(
          temp_graph._c_graph,
          base_func_name,
          self._func_name is None,  # append_hash_to_fn_name
          None,  # opers
          [t._as_tf_output() for t in temp_graph.inputs],
          [t._as_tf_output() for t in temp_graph.outputs],
          output_names,
          [], # control_outputs
          [], # control_output_names
          None,  # opts
          description)
      self._c_func = c_api_util.ScopedTFFunction(c_func)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)

    self._stateful_ops = [(op.name, op.type)
                          for op in temp_graph.get_operations()
                          if op._is_stateful]  # pylint: disable=protected-access
 def testWithFunctoolsPartial(self):
   partial = functools.partial(silly_example_function)
   self.assertRegexpMatches(
       function_utils.get_func_name(partial),
       '<.*functools.partial.*>')
 def testWithLambda(self):
   anon_fn = lambda x: x
   self.assertEqual('<lambda>', function_utils.get_func_name(anon_fn))
 def testWithCallableClass(self):
   callable_instance = SillyCallableClass()
   self.assertRegexpMatches(
       function_utils.get_func_name(callable_instance),
       '<.*SillyCallableClass.*>')
 def testWithClassMethod(self):
   self.assertEqual(
       'GetFuncNameTest.testWithClassMethod',
       function_utils.get_func_name(self.testWithClassMethod))
 def testWithSimpleFunction(self):
   self.assertEqual(
       'silly_example_function',
       function_utils.get_func_name(silly_example_function))
Ejemplo n.º 14
0
def func_graph_from_py_func(func, arg_names, arg_types, name=None,
                            capture_by_value=False, device=None,
                            colocation_stack=None, container=None,
                            collections_ref=None, arg_shapes=None):
  """Returns a _FuncGraph generated from `func`.

  Args:
    func: A Python callable which constructs a TF function body. The arguments
      must correspond to `arg_types`. Returns a value or list/tuple of values.
      No returned value can be None.
    arg_names: A sequence of strings for the function argument names.
    arg_types: A sequence of the function's argument types.
    name: The function name. If None, the name is derived from `func`.
    capture_by_value: boolean. If True, captured values will be copied into the
      function body.
    device: device name or function.
    colocation_stack: A colocation stack (list) the _FuncGraph should use.
    container: A container name the _FuncGraph should start with.
    collections_ref: A reference to a collections dict the _FuncGraph should
      use internally.
    arg_shapes: A sequence of the function's argument shapes.

  Returns:
    A _FuncGraph.

  Raises:
    ValueError: if func returns None.
  """
  if not name:
    name = function_utils.get_func_name(func)
  func_graph = _FuncGraph(name, capture_by_value)

  with func_graph.as_default(), ops.device(device):
    # pylint: disable=protected-access
    if collections_ref is not None:
      func_graph._collections = collections_ref
    if container is not None:
      func_graph._container = container
    if colocation_stack is not None:
      func_graph._colocation_stack = colocation_stack
    # pylint: enable=protected-access

    if arg_shapes is None:
      arg_shapes = [None] * len(arg_types)

    # Create placeholders for the function arguments.
    for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes):
      argholder = array_ops.placeholder(argtype, shape=argshape, name=argname)
      func_graph.inputs.append(argholder)
    # Call func and gather the output tensors.
    with vs.variable_scope("", custom_getter=func_graph.getvar):
      outputs = func(*func_graph.inputs)

    # There is no way of distinguishing between a function not returning
    # anything and a function returning None in Python.
    # We need to allow the former and ideally want to forbid the latter as
    # it is most likely user error.
    # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
    # allow users to explicitly mark the function as not returning anything.
    # For now, we allow a single None return and interpret it as a function
    # with no output.
    if outputs is None:
      outputs = []
    else:
      # If func only returned one value, make it a tuple.
      if not isinstance(outputs, (list, tuple)):
        outputs = (outputs,)
      if any([_ is None for _ in outputs]):
        raise ValueError("Function %s can not return None." % name)
    # Ensures each output is a Tensor in the function graph.
    outputs = [ops.convert_to_tensor(t) for t in outputs]
    outputs = [func_graph.capture(t) if t.graph is not func_graph else t
               for t in outputs]
    func_graph.outputs = outputs
  return func_graph
Ejemplo n.º 15
0
 def testWithSimpleFunction(self):
     self.assertEqual('silly_example_function',
                      function_utils.get_func_name(silly_example_function))
Ejemplo n.º 16
0
    def _create_definition_if_needed_impl(self):
        """This is not what you want, see _create_definition_if_needed."""
        if self._definition is not None or self._c_func is not None:
            return

        temp_graph = func_graph_from_py_func(
            self._func,
            self._arg_names,
            self._arg_types,
            self._func_name,
            self._capture_by_value,
            self._caller_device,
            whitelisted_stateful_ops=self._whitelisted_stateful_ops)

        self._extra_inputs = temp_graph.extra_inputs
        # pylint: disable=protected-access
        self._sub_functions = temp_graph._functions
        # pylint: enable=protected-access

        # Extra kwargs are treated as attrs on the function def.
        if self._func_name:
            base_func_name = self._func_name
        else:
            base_func_name = function_utils.get_func_name(self._func)
            if self._grad_func:
                base_func_name += ("_%s" % self._grad_func.name)
        kwargs_attr = _parse_kwargs_as_attrs(base_func_name,
                                             **self._extra_kwargs)

        if not temp_graph._c_graph:  # pylint: disable=protected-access
            # Build the FunctionDef
            self._definition = graph_to_function_def.graph_to_function_def(
                temp_graph,
                temp_graph.get_operations(),
                temp_graph.inputs,
                temp_graph.outputs,
                out_names=self._out_names)

            for k in kwargs_attr:
                self._definition.attr[k].CopyFrom(kwargs_attr[k])

            # Hash the definition and its dependencies.
            self._hash_str = self._create_hash_str(
                self._definition.signature.input_arg,
                self._definition.signature.output_arg,
                self._definition.node_def)

            # Finally, we decide the function name to use.  If not specified,
            # make up something which is almost certainly unique (but deterministic).
            if not self._func_name:
                self._func_name = "_".join([base_func_name, self._hash_str])
            self._definition.signature.name = self._func_name
            if self._func.__doc__:
                self._definition.signature.description = self._func.__doc__

            self._op_def = self._definition.signature
        else:  # C API is enabled
            output_names = ([compat.as_bytes(x) for x in self._out_names]
                            if self._out_names else [])
            description = self._func.__doc__ or None
            # pylint: disable=protected-access
            c_func = c_api.TF_GraphToFunction_wrapper(
                temp_graph._c_graph,
                base_func_name,
                self._func_name is None,  # append_hash_to_fn_name
                None,  # opers
                [t._as_tf_output() for t in temp_graph.inputs],
                [t._as_tf_output() for t in temp_graph.outputs],
                output_names,
                None,  # opts
                description)
            self._c_func = c_api_util.ScopedTFFunction(c_func)
            # pylint: enable=protected-access
            self._set_c_attrs(kwargs_attr)

            # Set cached fields: _op_def and _func_name (if not already set)
            self._op_def = self.definition.signature
            if self._func_name:
                assert self._func_name == self._op_def.name
            else:
                self._func_name = compat.as_str(self._op_def.name)

        self._stateful_ops = [(op.name, op.type)
                              for op in temp_graph.get_operations()
                              if op.op_def.is_stateful]
Ejemplo n.º 17
0
 def testWithClassMethod(self):
     self.assertEqual(
         'GetFuncNameTest.testWithClassMethod',
         function_utils.get_func_name(self.testWithClassMethod))
Ejemplo n.º 18
0
def func_graph_from_py_func(func,
                            arg_names,
                            arg_types,
                            name=None,
                            capture_by_value=False,
                            device=None,
                            colocation_stack=None,
                            container=None,
                            collections_ref=None,
                            arg_shapes=None,
                            allowlisted_stateful_ops=None,
                            capture_resource_var_by_value=True):
  """Returns a _FuncGraph generated from `func`.

  Args:
    func: A Python callable which constructs a TF function body. The arguments
      must correspond to `arg_types`. Returns a value or list/tuple of values.
      No returned value can be None.
    arg_names: A sequence of strings for the function argument names.
    arg_types: A sequence of the function's argument types.
    name: The function name. If None, the name is derived from `func`.
    capture_by_value: boolean. If True, captured values will be copied into the
      function body.
    device: device name or function.
    colocation_stack: A colocation stack (list) the _FuncGraph should use.
    container: A container name the _FuncGraph should start with.
    collections_ref: A reference to a collections dict the _FuncGraph should
      use internally.
    arg_shapes: A sequence of the function's argument shapes.
    allowlisted_stateful_ops: A set of ops that if stateful we ignore and
      re-create.
    capture_resource_var_by_value: Boolean (defaults to True). If False,
      captured resource variable returns the handle instead of value.

  Returns:
    A _FuncGraph.

  Raises:
    ValueError: if func returns None.
  """
  if not name:
    name = function_utils.get_func_name(func)
  func_graph = _FuncGraph(name, capture_by_value, allowlisted_stateful_ops,
                          capture_resource_var_by_value)

  with func_graph.as_default(), ops.device(device):
    # pylint: disable=protected-access
    if collections_ref is not None:
      func_graph._collections = collections_ref
    if container is not None:
      func_graph._container = container
    if colocation_stack is not None:
      func_graph._colocation_stack = colocation_stack
    # pylint: enable=protected-access

    if arg_shapes is None:
      arg_shapes = [None] * len(arg_types)

    # Create placeholders for the function arguments.
    for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes):
      argholder = array_ops.placeholder(argtype, shape=argshape, name=argname)
      func_graph.inputs.append(argholder)
    # Call func and gather the output tensors.
    with vs.variable_scope("", custom_getter=func_graph.getvar):
      outputs = func(*func_graph.inputs)

    # There is no way of distinguishing between a function not returning
    # anything and a function returning None in Python.
    # We need to allow the former and ideally want to forbid the latter as
    # it is most likely user error.
    # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to
    # allow users to explicitly mark the function as not returning anything.
    # For now, we allow a single None return and interpret it as a function
    # with no output.
    if outputs is None:
      outputs = []
    else:
      # If func only returned one value, make it a tuple.
      if not isinstance(outputs, (list, tuple)):
        outputs = (outputs,)
      if any(_ is None for _ in outputs):
        raise ValueError(f"Function {name} can not return None.")
    # Ensures each output is a Tensor in the function graph.
    outputs = [ops.convert_to_tensor(t) for t in outputs]
    outputs = [func_graph.capture(t) if t.graph is not func_graph else t
               for t in outputs]
    func_graph.outputs = outputs
  return func_graph
Ejemplo n.º 19
0
  def _create_definition_if_needed_impl(self):
    """This is not what you want, see _create_definition_if_needed."""
    if self._definition is not None or self._c_func is not None:
      return

    # Copy variable collections (by reference) from the parent graph such that
    # name based variable sharing (e.g. via tf.make_template) works between the
    # func graph and parent graph.
    variable_keys = []
    variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS)  # pylint: disable=protected-access
    variable_keys.append(vs._VARSTORE_KEY)  # pylint: disable=protected-access

    collections_ref = {}
    parent_collections_ref = ops.get_default_graph()._collections  # pylint: disable=protected-access
    for key in variable_keys:
      if key not in parent_collections_ref:
        parent_collections_ref[key] = collections_ref[key] = []
      else:
        collections_ref[key] = parent_collections_ref[key]

    temp_graph = func_graph_from_py_func(
        self._func,
        self._arg_names,
        self._arg_types,
        self._func_name,
        self._capture_by_value,
        self._caller_device,
        collections_ref=collections_ref,
        whitelisted_stateful_ops=self._whitelisted_stateful_ops,
        capture_resource_var_by_value=self._capture_resource_var_by_value)

    self._extra_inputs = temp_graph.extra_inputs
    # pylint: disable=protected-access
    self._sub_functions = temp_graph._functions
    # pylint: enable=protected-access

    # Extra kwargs are treated as attrs on the function def.
    if self._func_name:
      base_func_name = self._func_name
    else:
      base_func_name = function_utils.get_func_name(self._func)
      if self._grad_func:
        base_func_name += ("_%s" % self._grad_func.name)
    kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs)

    if not temp_graph._c_graph:  # pylint: disable=protected-access
      # Build the FunctionDef
      self._definition = graph_to_function_def.graph_to_function_def(
          temp_graph,
          temp_graph.get_operations(),
          temp_graph.inputs,
          temp_graph.outputs,
          out_names=self._out_names)

      for k in kwargs_attr:
        self._definition.attr[k].CopyFrom(kwargs_attr[k])

      # Hash the definition and its dependencies.
      self._hash_str = self._create_hash_str(
          self._definition.signature.input_arg,
          self._definition.signature.output_arg, self._definition.node_def)

      # Finally, we decide the function name to use.  If not specified,
      # make up something which is almost certainly unique (but deterministic).
      if not self._func_name:
        self._func_name = "_".join([base_func_name, self._hash_str])
      self._definition.signature.name = self._func_name
      if self._func.__doc__:
        self._definition.signature.description = self._func.__doc__

      self._op_def = self._definition.signature
    else:  # C API is enabled
      output_names = ([compat.as_bytes(x) for x in self._out_names]
                      if self._out_names else [])
      description = self._func.__doc__ or None
      # pylint: disable=protected-access
      c_func = c_api.TF_GraphToFunction_wrapper(
          temp_graph._c_graph,
          base_func_name,
          self._func_name is None,  # append_hash_to_fn_name
          None,  # opers
          [t._as_tf_output() for t in temp_graph.inputs],
          [t._as_tf_output() for t in temp_graph.outputs],
          output_names,
          [], # control_outputs
          [], # control_output_names
          None,  # opts
          description)
      self._c_func = c_api_util.ScopedTFFunction(c_func)
      # pylint: enable=protected-access
      self._set_c_attrs(kwargs_attr)

      # Set cached fields: _op_def and _func_name (if not already set)
      self._op_def = self.definition.signature
      if self._func_name:
        assert self._func_name == self._op_def.name
      else:
        self._func_name = compat.as_str(self._op_def.name)

    self._stateful_ops = [(op.name, op.type)
                          for op in temp_graph.get_operations()
                          if op.op_def.is_stateful]