def _disallow_inside_tf_function(method_name: str) -> None: if tf.inside_function(): error_msg = ( "Detected a call to `Model.{method_name}` inside a `tf.function`. " "`Model.{method_name} is a high-level endpoint that manages its own " "`tf.function`. Please move the call to `Model.{method_name}` outside " "of all enclosing `tf.function`s. Note that you can call a `Model` " "directly on `Tensor`s inside a `tf.function` like: `model(x)`." ).format(method_name=method_name) raise RuntimeError(error_msg)
def _disallow_inside_tf_function(method_name): """Disallow calling a method inside a `tf.function`.""" if tf.inside_function(): error_msg = ( 'Detected a call to `PreprocessingLayer.{method_name}` inside a ' '`tf.function`. `PreprocessingLayer.{method_name} is a high-level ' 'endpoint that manages its own `tf.function`. Please move the call ' 'to `PreprocessingLayer.{method_name}` outside of all enclosing ' '`tf.function`s. Note that you can call a `PreprocessingLayer` ' 'directly on `Tensor`s inside a `tf.function` like: `layer(x)`, ' 'or update its state like: `layer.update_state(x)`.').format( method_name=method_name) raise RuntimeError(error_msg)
def is_in_tf_function(): """Returns if inside of a tf.function.""" # Check if running in V1 graph mode. if not tf.compat.v1.executing_eagerly_outside_functions(): return False if not tf.inside_function(): return False # Check if inside Keras FuncGraph. if is_in_keras_graph(): return False # Check for a v1 `wrap_function` FuncGraph. graph = tf.compat.v1.get_default_graph() if (getattr(graph, 'name', False) and graph.name.startswith('wrapped_function')): return False return True