def _create_new_tf_function(func_graph): """Converts func_graph to a TF_Function and adds it to the current graph. Args: func_graph: function.FuncGraph Returns: The name of the new TF_Function. """ func = function._EagerDefinedFunction( func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {}) func.add_to_graph(func_graph.outer_graph) return func_graph.name
def create_new_tf_function(func_graph): """Converts func_graph to a TF_Function and adds it to the current graph. Args: func_graph: FuncGraph Returns: The name of the new TF_Function. """ func = function._EagerDefinedFunction( # pylint: disable=protected-access func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {}) func.add_to_graph(func_graph.outer_graph) return func_graph.name
def __call__(self, *args, **kwds): """Calls the graph function.""" if self._created_variables: # In this case we have created variables on the first call, so we run the # defunned version which is guaranteed to never create variables. return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable elif self._stateful_fn is not None: # In this case we have not created variables on the first call. So we can # run the first trace but we should fail if variables are created. results = self._stateful_fn(*args, **kwds) if self._created_variables: raise ValueError("Creating variables on a non-first call to a function" " decorated with tf.function.") return results # This is the first call of __call__, so we have to initialize. self._initialize(args, kwds) if self._lifted_all_initializers and self._lifted_placeholders: with ops.init_scope(): handles, placeholders = zip(*self._lifted_placeholders) if context.executing_eagerly(): lifted_fn = function_lib._EagerDefinedFunction( # pylint: disable=protected-access "initializer" + str(ops.uid()), self._lifted_initializer_graph, placeholders, [], {}) with tape.stop_recording(): lifted_fn.call(context.context(), list(handles)) return self._stateless_fn(*args, **kwds) canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds) if not self._created_variables: # If we did not create any variables the trace we have is good enough. return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access def fn_with_cond(*inner_args, **inner_kwds): """Conditionally runs initialization if it's needed.""" condition = True for wr in self._created_variables: variable = wr() if variable is None: raise ValueError( "A tf.Variable created inside your tf.function has been" " garbage-collected. Your code needs to keep Python references" " to variables created inside `tf.function`s.\n" "\n" "A common way to raise this error is to create and return a" " variable only referenced inside your function:\n" "\n" "@tf.function\n" "def f():\n" " v = tf.Variable(1.0)\n" " return v\n" "\n" "v = f() # Crashes with this error message!\n" "\n" "The reason this crashes is that @tf.function annotated" " function returns a **`tf.Tensor`** with the **value** of the" " variable when the function is called rather than the" " variable instance itself. As such there is no code holding a" " reference to the `v` created inside the function and Python" " garbage collects it.\n" "\n" "The simplest way to fix this issue is to create variables" " outside the function and capture them:\n" "\n" "v = tf.Variable(1.0)\n" "\n" "@tf.function\n" "def f():\n" " return v\n" "\n" "f() # <tf.Tensor: ... numpy=1.>\n" "v.assign_add(1.)\n" "f() # <tf.Tensor: ... numpy=2.>") condition = math_ops.logical_and( condition, resource_variable_ops.var_is_initialized_op( variable.handle)) # We want to call stateless_fn if possible because it avoids recomputing # potentially expensive initializers. return control_flow_ops.cond( condition, lambda: self._stateless_fn(*inner_args, **inner_kwds), functools.partial(self._concrete_stateful_fn._filtered_call, # pylint: disable=protected-access inner_args, inner_kwds)) return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)