def testRunFunctionsEagerly(self): try: original_setting = def_function.functions_run_eagerly() def_function.run_functions_eagerly(True) x = constant_op.constant(1.) with forwardprop.ForwardAccumulator(x, 2.) as acc: y = x * 3. self.assertAllClose(6., acc.jvp(y)) finally: def_function.run_functions_eagerly(original_setting)
def pfor(loop_fn, iters, fallback_to_while_loop=True, parallel_iterations=None): """Equivalent to running `loop_fn` `iters` times and stacking the outputs. `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters` times, with input from 0 to `iters - 1`, and stacking corresponding output of each iteration. However the implementation does not use a `tf.while_loop`. Instead it adds new operations to the graph that collectively compute the same value as what running `loop_fn` in a loop would compute. This is an experimental feature and currently has a lot of limitations: - There should be no data dependency between the different iterations. For example, a future iteration should not depend on a value or side-effect of a previous iteration. - Stateful kernels may mostly not be supported since these often imply a data dependency or ordering of the iterations. We do support a limited set of such stateful kernels though (like RandomFoo, Variable operations like reads, etc). - Conversion works only on a limited set of kernels for which a converter has been registered. - `loop_fn` has limited support for control flow operations. `tf.cond` in particular is not supported. - `loop_fn` should return nested structure of Tensors or Operations. However if an Operation is returned, it should have zero outputs. - The shape and dtype of `loop_fn` outputs should not depend on the input to loop_fn. Args: loop_fn: A function that takes an int32 scalar tf.Tensor object representing the iteration number, and optionally a keyword argument `pfor_config` set to a PForConfig object. It returns a possibly nested structure of Tensor or Operation objects. Note that if setting `parallel_iterations` argument to something other than None, `loop_fn` may be called more than once during graph construction. So it may need to avoid mutating global state. iters: Number of iterations for which to run `loop_fn`. fallback_to_while_loop: If true, on failing to vectorize an operation, pfor fallbacks to using a `tf.while_loop` to dispatch the iterations. parallel_iterations: A knob to control how many iterations are vectorized and dispatched in parallel. The default value of None corresponds to vectorizing all the iterations. If `parallel_iterations` is smaller than `iters`, then chunks of at most that many iterations are dispatched in sequence. This knob can be used to control the total memory usage. Returns: Returns a nested structure of stacked tensor objects with the same nested structure as the output of `loop_fn`. Raises: ValueError: If parallel_iterations is not None and not an integer > 1. """ def f(): return _pfor_impl(loop_fn, iters, fallback_to_while_loop=fallback_to_while_loop, parallel_iterations=parallel_iterations) # Note that we wrap into a tf.function if in eager execution mode or under # XLA compilation. The latter is so that we don't compile operations like # tf.placeholder that are created by the loop body. functions_run_eagerly = None if context.executing_eagerly() or _is_under_xla_context(): functions_run_eagerly = def_function.functions_run_eagerly() if functions_run_eagerly: logging.warning( "It looks like tf.function behavior was disabled, perhaps using " "tf.config.run_functions_eagerly. Vectorization " "primitives (e.g. tf.vectorized_map) require tf.function to work. " "These primitives will override the disable.") def_function.run_functions_eagerly(False) f = def_function.function(f) outputs = f() if functions_run_eagerly is not None: def_function.run_functions_eagerly(functions_run_eagerly) return outputs
def __init__(self, func, transformation_name, dataset=None, input_classes=None, input_shapes=None, input_types=None, input_structure=None, add_to_graph=True, use_legacy_function=False, defun_kwargs=None): """Creates a new `StructuredFunctionWrapper` for the given function. Args: func: A function from a (nested) structure to another (nested) structure. transformation_name: Human-readable name of the transformation in which this function is being instantiated, for error messages. dataset: (Optional.) A `tf.data.Dataset`. If given, the structure of this dataset will be assumed as the structure for `func` arguments; otherwise `input_classes`, `input_shapes`, and `input_types` must be defined. input_classes: (Optional.) A (nested) structure of `type`. If given, this argument defines the Python types for `func` arguments. input_shapes: (Optional.) A (nested) structure of `tf.TensorShape`. If given, this argument defines the shapes and structure for `func` arguments. input_types: (Optional.) A (nested) structure of `tf.DType`. If given, this argument defines the element types and structure for `func` arguments. input_structure: (Optional.) A `Structure` object. If given, this argument defines the element types and structure for `func` arguments. add_to_graph: (Optional.) If `True`, the function will be added to the default graph, if it exists. use_legacy_function: (Optional.) A boolean that determines whether the function be created using `tensorflow.python.eager.function.defun` (default behavior) or `tensorflow.python.framework.function.Defun` (legacy behavior). defun_kwargs: (Optional.) A dictionary mapping string argument names to values. If supplied, will be passed to `function` as keyword arguments. Raises: ValueError: If an invalid combination of `dataset`, `input_classes`, `input_shapes`, and `input_types` is passed. """ # pylint: disable=protected-access if input_structure is None: if dataset is None: if input_classes is None or input_shapes is None or input_types is None: raise ValueError( "Either `dataset`, `input_structure` or all of " "`input_classes`, `input_shapes`, and `input_types` " "must be specified.") self._input_structure = structure.convert_legacy_structure( input_types, input_shapes, input_classes) else: if not (input_classes is None and input_shapes is None and input_types is None): raise ValueError( "Either `dataset`, `input_structure` or all of " "`input_classes`, `input_shapes`, and `input_types` " "must be specified.") self._input_structure = dataset.element_spec else: if not (dataset is None and input_classes is None and input_shapes is None and input_types is None): raise ValueError( "Either `dataset`, `input_structure`, or all of " "`input_classes`, `input_shapes`, and `input_types` " "must be specified.") self._input_structure = input_structure self._func = func if defun_kwargs is None: defun_kwargs = {} readable_transformation_name = transformation_name.replace( ".", "_")[:-2] if len(transformation_name) > 2 else "" func_name = "_".join( [readable_transformation_name, function_utils.get_func_name(func)]) # Sanitize function name to remove symbols that interfere with graph # construction. for symbol in ["<", ">", "\\", "'", " "]: func_name = func_name.replace(symbol, "") ag_ctx = autograph_ctx.control_status_ctx() def wrapper_helper(*args): """Wrapper for passing nested structures to and from tf.data functions.""" nested_args = structure.from_compatible_tensor_list( self._input_structure, args) if not _should_unpack(nested_args): nested_args = (nested_args, ) ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args) if _should_pack(ret): ret = tuple(ret) try: self._output_structure = structure.type_spec_from_value(ret) except (ValueError, TypeError): six.reraise( TypeError, TypeError( f"Unsupported return value from function passed to " f"{transformation_name}: {ret}."), sys.exc_info()[2]) return ret def trace_legacy_function(defun_kwargs): @function.Defun(*structure.get_flat_tensor_types( self._input_structure), **defun_kwargs) def wrapped_fn(*args): ret = wrapper_helper(*args) return structure.to_tensor_list(self._output_structure, ret) return lambda: wrapped_fn def trace_py_function(defun_kwargs): # First we trace the function to infer the output structure. @eager_function.defun_with_attributes( input_signature=structure.get_flat_tensor_specs( self._input_structure), autograph=False, attributes=defun_kwargs) def unused(*args): # pylint: disable=missing-docstring,unused-variable ret = wrapper_helper(*args) ret = structure.to_tensor_list(self._output_structure, ret) return [ops.convert_to_tensor(t) for t in ret] _ = unused.get_concrete_function() def py_function_wrapper(*args): nested_args = structure.from_compatible_tensor_list( self._input_structure, args) if not _should_unpack(nested_args): nested_args = (nested_args, ) ret = self._func(*nested_args) if _should_pack(ret): ret = tuple(ret) ret = structure.to_tensor_list(self._output_structure, ret) return [ops.convert_to_tensor(t) for t in ret] # Next we trace the function wrapped in `eager_py_func` to force eager # execution. @eager_function.defun_with_attributes( input_signature=structure.get_flat_tensor_specs( self._input_structure), autograph=False, attributes=defun_kwargs) def wrapped_fn(*args): # pylint: disable=missing-docstring return script_ops.eager_py_func( py_function_wrapper, args, structure.get_flat_tensor_types(self._output_structure)) return wrapped_fn.get_concrete_function def trace_tf_function(defun_kwargs): # Note: wrapper_helper will apply autograph based on context. @eager_function.defun_with_attributes( input_signature=structure.get_flat_tensor_specs( self._input_structure), autograph=False, attributes=defun_kwargs) def wrapped_fn(*args): # pylint: disable=missing-docstring ret = wrapper_helper(*args) ret = structure.to_tensor_list(self._output_structure, ret) return [ops.convert_to_tensor(t) for t in ret] return wrapped_fn.get_concrete_function if use_legacy_function: defun_kwargs.update( {"func_name": func_name + "_" + str(ops.uid())}) fn_factory = trace_legacy_function(defun_kwargs) else: defun_kwargs.update({"func_name": func_name}) defun_kwargs.update({"_tf_data_function": True}) if dataset_ops.DEBUG_MODE: fn_factory = trace_py_function(defun_kwargs) else: if def_function.functions_run_eagerly(): warnings.warn( "Even though the `tf.config.experimental_run_functions_eagerly` " "option is set, this option does not apply to tf.data functions. " "To force eager execution of tf.data functions, please use " "`tf.data.experimental.enable_debug_mode()`.") fn_factory = trace_tf_function(defun_kwargs) self._function = fn_factory() # There is no graph to add in eager mode. add_to_graph &= not context.executing_eagerly() # There are some lifetime issues when a legacy function is not added to a # out-living graph. It's already deprecated so de-prioritizing the fix. add_to_graph |= use_legacy_function if add_to_graph: self._function.add_to_graph(ops.get_default_graph()) if not use_legacy_function: outer_graph_seed = ops.get_default_graph().seed if outer_graph_seed and self._function.graph.seed == outer_graph_seed: if self._function.graph._seed_used: warnings.warn( "Seed %s from outer graph might be getting used by function %s, " "if the random op has not been provided any seed. Explicitly set " "the seed in the function if this is not the intended behavior." % (outer_graph_seed, func_name), stacklevel=4)