Esempio n. 1
0
 def testRunFunctionsEagerly(self):
   try:
     original_setting = def_function.functions_run_eagerly()
     def_function.run_functions_eagerly(True)
     x = constant_op.constant(1.)
     with forwardprop.ForwardAccumulator(x, 2.) as acc:
       y = x * 3.
     self.assertAllClose(6., acc.jvp(y))
   finally:
     def_function.run_functions_eagerly(original_setting)
Esempio n. 2
0
def pfor(loop_fn,
         iters,
         fallback_to_while_loop=True,
         parallel_iterations=None):
    """Equivalent to running `loop_fn` `iters` times and stacking the outputs.

  `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters`
  times, with input from 0 to `iters - 1`, and stacking corresponding output of
  each iteration. However the implementation does not use a `tf.while_loop`.
  Instead it adds new operations to the graph that collectively compute the same
  value as what running `loop_fn` in a loop would compute.


  This is an experimental feature and currently has a lot of limitations:
    - There should be no data dependency between the different iterations. For
      example, a future iteration should not depend on a value or side-effect of
      a previous iteration.
    - Stateful kernels may mostly not be supported since these often imply a
      data dependency or ordering of the iterations. We do support a limited set
      of such stateful kernels though (like RandomFoo, Variable operations like
      reads, etc).
    - Conversion works only on a limited set of kernels for which a converter
      has been registered.
    - `loop_fn` has limited support for control flow operations. `tf.cond` in
      particular is not supported.
    - `loop_fn` should return nested structure of Tensors or Operations. However
      if an Operation is returned, it should have zero outputs.
    - The shape and dtype of `loop_fn` outputs should not depend on the input
      to loop_fn.

  Args:
    loop_fn: A function that takes an int32 scalar tf.Tensor object representing
      the iteration number, and optionally a keyword argument `pfor_config` set
      to a PForConfig object. It returns a possibly nested structure of Tensor
      or Operation objects. Note that if setting `parallel_iterations` argument
      to something other than None, `loop_fn` may be called more than once
      during graph construction. So it may need to avoid mutating global state.
    iters: Number of iterations for which to run `loop_fn`.
    fallback_to_while_loop: If true, on failing to vectorize an operation, pfor
      fallbacks to using a `tf.while_loop` to dispatch the iterations.
    parallel_iterations: A knob to control how many iterations are vectorized
      and dispatched in parallel. The default value of None corresponds to
      vectorizing all the iterations.  If `parallel_iterations` is smaller than
      `iters`, then chunks of at most that many iterations are dispatched in
      sequence. This knob can be used to control the total memory usage.

  Returns:
    Returns a nested structure of stacked tensor objects with the same nested
    structure as the output of `loop_fn`.
  Raises:
    ValueError: If parallel_iterations is not None and not an integer > 1.
  """
    def f():
        return _pfor_impl(loop_fn,
                          iters,
                          fallback_to_while_loop=fallback_to_while_loop,
                          parallel_iterations=parallel_iterations)

    # Note that we wrap into a tf.function if in eager execution mode or under
    # XLA compilation. The latter is so that we don't compile operations like
    # tf.placeholder that are created by the loop body.
    functions_run_eagerly = None
    if context.executing_eagerly() or _is_under_xla_context():
        functions_run_eagerly = def_function.functions_run_eagerly()
        if functions_run_eagerly:
            logging.warning(
                "It looks like tf.function behavior was disabled, perhaps using "
                "tf.config.run_functions_eagerly. Vectorization "
                "primitives (e.g. tf.vectorized_map) require tf.function to work. "
                "These primitives will override the disable.")
            def_function.run_functions_eagerly(False)
        f = def_function.function(f)
    outputs = f()
    if functions_run_eagerly is not None:
        def_function.run_functions_eagerly(functions_run_eagerly)
    return outputs
    def __init__(self,
                 func,
                 transformation_name,
                 dataset=None,
                 input_classes=None,
                 input_shapes=None,
                 input_types=None,
                 input_structure=None,
                 add_to_graph=True,
                 use_legacy_function=False,
                 defun_kwargs=None):
        """Creates a new `StructuredFunctionWrapper` for the given function.

    Args:
      func: A function from a (nested) structure to another (nested) structure.
      transformation_name: Human-readable name of the transformation in which
        this function is being instantiated, for error messages.
      dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this
        dataset will be assumed as the structure for `func` arguments; otherwise
        `input_classes`, `input_shapes`, and `input_types` must be defined.
      input_classes: (Optional.) A (nested) structure of `type`. If given, this
        argument defines the Python types for `func` arguments.
      input_shapes: (Optional.) A (nested) structure of `tf.TensorShape`. If
        given, this argument defines the shapes and structure for `func`
        arguments.
      input_types: (Optional.) A (nested) structure of `tf.DType`. If given,
        this argument defines the element types and structure for `func`
        arguments.
      input_structure: (Optional.) A `Structure` object. If given, this argument
        defines the element types and structure for `func` arguments.
      add_to_graph: (Optional.) If `True`, the function will be added to the
        default graph, if it exists.
      use_legacy_function: (Optional.) A boolean that determines whether the
        function be created using `tensorflow.python.eager.function.defun`
        (default behavior) or `tensorflow.python.framework.function.Defun`
        (legacy behavior).
      defun_kwargs: (Optional.) A dictionary mapping string argument names to
        values. If supplied, will be passed to `function` as keyword arguments.

    Raises:
      ValueError: If an invalid combination of `dataset`, `input_classes`,
        `input_shapes`, and `input_types` is passed.
    """
        # pylint: disable=protected-access
        if input_structure is None:
            if dataset is None:
                if input_classes is None or input_shapes is None or input_types is None:
                    raise ValueError(
                        "Either `dataset`, `input_structure` or all of "
                        "`input_classes`, `input_shapes`, and `input_types` "
                        "must be specified.")
                self._input_structure = structure.convert_legacy_structure(
                    input_types, input_shapes, input_classes)
            else:
                if not (input_classes is None and input_shapes is None
                        and input_types is None):
                    raise ValueError(
                        "Either `dataset`, `input_structure` or all of "
                        "`input_classes`, `input_shapes`, and `input_types` "
                        "must be specified.")
                self._input_structure = dataset.element_spec
        else:
            if not (dataset is None and input_classes is None
                    and input_shapes is None and input_types is None):
                raise ValueError(
                    "Either `dataset`, `input_structure`, or all of "
                    "`input_classes`, `input_shapes`, and `input_types` "
                    "must be specified.")
            self._input_structure = input_structure

        self._func = func

        if defun_kwargs is None:
            defun_kwargs = {}

        readable_transformation_name = transformation_name.replace(
            ".", "_")[:-2] if len(transformation_name) > 2 else ""

        func_name = "_".join(
            [readable_transformation_name,
             function_utils.get_func_name(func)])
        # Sanitize function name to remove symbols that interfere with graph
        # construction.
        for symbol in ["<", ">", "\\", "'", " "]:
            func_name = func_name.replace(symbol, "")

        ag_ctx = autograph_ctx.control_status_ctx()

        def wrapper_helper(*args):
            """Wrapper for passing nested structures to and from tf.data functions."""
            nested_args = structure.from_compatible_tensor_list(
                self._input_structure, args)
            if not _should_unpack(nested_args):
                nested_args = (nested_args, )
            ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
            if _should_pack(ret):
                ret = tuple(ret)

            try:
                self._output_structure = structure.type_spec_from_value(ret)
            except (ValueError, TypeError):
                six.reraise(
                    TypeError,
                    TypeError(
                        f"Unsupported return value from function passed to "
                        f"{transformation_name}: {ret}."),
                    sys.exc_info()[2])
            return ret

        def trace_legacy_function(defun_kwargs):
            @function.Defun(*structure.get_flat_tensor_types(
                self._input_structure), **defun_kwargs)
            def wrapped_fn(*args):
                ret = wrapper_helper(*args)
                return structure.to_tensor_list(self._output_structure, ret)

            return lambda: wrapped_fn

        def trace_py_function(defun_kwargs):
            # First we trace the function to infer the output structure.
            @eager_function.defun_with_attributes(
                input_signature=structure.get_flat_tensor_specs(
                    self._input_structure),
                autograph=False,
                attributes=defun_kwargs)
            def unused(*args):  # pylint: disable=missing-docstring,unused-variable
                ret = wrapper_helper(*args)
                ret = structure.to_tensor_list(self._output_structure, ret)
                return [ops.convert_to_tensor(t) for t in ret]

            _ = unused.get_concrete_function()

            def py_function_wrapper(*args):
                nested_args = structure.from_compatible_tensor_list(
                    self._input_structure, args)
                if not _should_unpack(nested_args):
                    nested_args = (nested_args, )
                ret = self._func(*nested_args)
                if _should_pack(ret):
                    ret = tuple(ret)
                ret = structure.to_tensor_list(self._output_structure, ret)
                return [ops.convert_to_tensor(t) for t in ret]

            # Next we trace the function wrapped in `eager_py_func` to force eager
            # execution.
            @eager_function.defun_with_attributes(
                input_signature=structure.get_flat_tensor_specs(
                    self._input_structure),
                autograph=False,
                attributes=defun_kwargs)
            def wrapped_fn(*args):  # pylint: disable=missing-docstring
                return script_ops.eager_py_func(
                    py_function_wrapper, args,
                    structure.get_flat_tensor_types(self._output_structure))

            return wrapped_fn.get_concrete_function

        def trace_tf_function(defun_kwargs):
            # Note: wrapper_helper will apply autograph based on context.
            @eager_function.defun_with_attributes(
                input_signature=structure.get_flat_tensor_specs(
                    self._input_structure),
                autograph=False,
                attributes=defun_kwargs)
            def wrapped_fn(*args):  # pylint: disable=missing-docstring
                ret = wrapper_helper(*args)
                ret = structure.to_tensor_list(self._output_structure, ret)
                return [ops.convert_to_tensor(t) for t in ret]

            return wrapped_fn.get_concrete_function

        if use_legacy_function:
            defun_kwargs.update(
                {"func_name": func_name + "_" + str(ops.uid())})
            fn_factory = trace_legacy_function(defun_kwargs)
        else:
            defun_kwargs.update({"func_name": func_name})
            defun_kwargs.update({"_tf_data_function": True})
            if dataset_ops.DEBUG_MODE:
                fn_factory = trace_py_function(defun_kwargs)
            else:
                if def_function.functions_run_eagerly():
                    warnings.warn(
                        "Even though the `tf.config.experimental_run_functions_eagerly` "
                        "option is set, this option does not apply to tf.data functions. "
                        "To force eager execution of tf.data functions, please use "
                        "`tf.data.experimental.enable_debug_mode()`.")
                fn_factory = trace_tf_function(defun_kwargs)

        self._function = fn_factory()
        # There is no graph to add in eager mode.
        add_to_graph &= not context.executing_eagerly()
        # There are some lifetime issues when a legacy function is not added to a
        # out-living graph. It's already deprecated so de-prioritizing the fix.
        add_to_graph |= use_legacy_function
        if add_to_graph:
            self._function.add_to_graph(ops.get_default_graph())

        if not use_legacy_function:
            outer_graph_seed = ops.get_default_graph().seed
            if outer_graph_seed and self._function.graph.seed == outer_graph_seed:
                if self._function.graph._seed_used:
                    warnings.warn(
                        "Seed %s from outer graph might be getting used by function %s, "
                        "if the random op has not been provided any seed. Explicitly set "
                        "the seed in the function if this is not the intended behavior."
                        % (outer_graph_seed, func_name),
                        stacklevel=4)