def wrapper(*args, **kwargs): return autograph.converted_call( original_func, None, autograph.ConversionOptions( verbose=True, recursive=True, strip_decorators=(function.defun, def_function.function), optional_features=(), ), *args, **kwargs)
def wrapper(*args, **kwargs): return autograph.converted_call( original_func, None, autograph.ConversionOptions( verbose=autograph.Verbosity.BRIEF, recursive=True, strip_decorators=(def_function.function, ), optional_features=(), ), *args, **kwargs)
def wrapper(*args, **kwargs): # Note: functions annotated with @tf.function should always be # converted even though they would meet autograph's whitelisting # criteria. # If this assumption is ever broken, converted_call will need to # handle the possibility of original_func still being a shim, e.g. # bound to WeakrefSelf. return autograph.converted_call( original_func, None, autograph.ConversionOptions( recursive=True, optional_features=autograph_options, force_conversion=True, ), args, kwargs)
def wrapper(*args, **kwargs): """Calls a converted version of original_func.""" # TODO(mdan): Push this block higher in tf.function's call stack. try: return autograph.converted_call( original_func, None, autograph.ConversionOptions( recursive=True, optional_features=autograph_options, force_conversion=True, ), args, kwargs) except Exception as e: # pylint:disable=broad-except if hasattr(e, "ag_error_metadata"): raise e.ag_error_metadata.to_exception(type(e)) else: raise
def func_graph_from_py_func(name, python_func, args, kwargs, signature=None, func_graph=None, experimental_autograph=False, add_control_dependencies=True, arg_names=None, op_return_value=None): """Returns a `FuncGraph` generated from `python_func`. Args: name: an identifier for the function. python_func: the Python function to trace. args: the positional args with which the Python function should be called; ignored if a signature is provided. kwargs: the keyword args with which the Python function should be called; ignored if a signature is provided. signature: a possibly nested sequence of `TensorSpecs` specifying the shapes and dtypes of the arguments. When a signature is provided, `args` and `kwargs` are ignored, and `python_func` is traced with Tensors conforming to `signature`. If `None`, the shapes and dtypes are inferred from the inputs. func_graph: Optional. An instance of FuncGraph. If provided, we will use this graph else a new one is built and returned. experimental_autograph: whether to use autograph to compile `python_func`. See https://www.tensorflow.org/guide/autograph for more information. add_control_dependencies: If True, automatically adds control dependencies to ensure program order matches execution order and stateful ops always execute. arg_names: Optional list of argument names, used to give input placeholders recognizable names. op_return_value: Optional. A Tensor. If set and `python_func` returns Operations, those return values will be replaced with this value. If not set, returning an Operation triggers an error. Returns: A FuncGraph. Raises: TypeError: If any of `python_func`'s return values is neither `None` nor a `Tensor`. """ if op_return_value is not None: assert isinstance(op_return_value, ops.Tensor), op_return_value if func_graph is None: func_graph = FuncGraph(name) assert isinstance(func_graph, FuncGraph) if add_control_dependencies: control_manager = AutomaticControlDependencies else: control_manager = ops.NullContextmanager with func_graph.as_default(), control_manager() as a: current_scope = variable_scope.get_variable_scope() default_use_recource = current_scope.use_resource current_scope.set_use_resource(True) if signature is not None: args = signature kwargs = {} func_args = _get_defun_inputs_from_args(args, arg_names) func_kwargs = _get_defun_inputs_from_kwargs(kwargs) # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. # Variables to help check whether mutation happens in calling the function # Copy the recursive list, tuple and map structure, but not base objects func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args)) func_kwargs_before = nest.pack_sequence_as(func_kwargs, nest.flatten(func_kwargs)) def convert(x): """Converts a function output to a Tensor.""" if x is None: return None if op_return_value is not None and isinstance(x, ops.Operation): # TODO(b/79881896): we currently can't capture external control deps, so # this won't work if x needs to be captured (i.e. if python_func returns # captured Operations). with ops.control_dependencies([x]): x = array_ops.identity(op_return_value) else: try: x = ops.convert_to_tensor_or_indexed_slices(x) except (ValueError, TypeError): raise TypeError( "To be compatible with tf.contrib.eager.defun, Python functions " "must return zero or more Tensors; in compilation of %s, found " "return value of type %s, which is not a Tensor." % (str(python_func), type(x))) if add_control_dependencies: x = a.mark_as_return(x) return x this_tape = tape.push_new_tape() try: if experimental_autograph: func_outputs = autograph.converted_call( python_func, None, autograph.ConversionOptions( verbose=True, recursive=True, strip_decorators=(function.defun, ), optional_features=(), ), *func_args, **func_kwargs) else: func_outputs = python_func(*func_args, **func_kwargs) # invariant: `func_outputs` contains only Tensors and `None`s. func_outputs = nest.map_structure(convert, func_outputs) check_mutation(func_args_before, func_args) check_mutation(func_kwargs_before, func_kwargs) finally: tape.pop_tape(this_tape) current_scope.set_use_resource(default_use_recource) # Variables in `func_args`, `func_kwargs` should be explicit inputs # to the function, not captured inputs. tape_variables = this_tape.watched_variables() arg_variables = set() inputs = [] for arg in nest.flatten(func_args) + nest.flatten(func_kwargs): if isinstance(arg, resource_variable_ops.ResourceVariable): try: resource_placeholder = func_graph.captures.pop(arg.handle) arg_variables.add(arg) except KeyError: # This case occurs if a Variable among the inputs is not actually # used by the function; we still add an explicit input for it # because the user should presumably pass the Variable as an input # to the corresponding graph function. resource_placeholder = _create_substitute_placeholder( arg.handle) inputs.append(resource_placeholder) elif isinstance(arg, ops.Tensor): inputs.append(arg) variables = [v for v in tape_variables if v not in arg_variables] func_graph.inputs = inputs + list(func_graph.captures.values()) func_graph.structured_outputs = func_outputs # Returning a closed-over tensor does not trigger convert_to_tensor. func_graph.outputs.extend( func_graph.capture(x) for x in flatten(func_graph.structured_outputs) if x is not None) func_graph.variables = variables # Register any other functions defined in the graph. with ops.init_scope(): if context.executing_eagerly(): for f in func_graph._functions.values(): # pylint: disable=protected-access # TODO(ashankar): What about the gradient registry? context.add_function(f._c_func.func) # pylint: disable=protected-access return func_graph