def _disallow_inside_tf_function(method_name): """Disallow calling a method inside a `tf.function`.""" if tf.inside_function(): error_msg = ( "Detected a call to `PreprocessingLayer.{method_name}` inside a " "`tf.function`. `PreprocessingLayer.{method_name} is a high-level " "endpoint that manages its own `tf.function`. Please move the call " "to `PreprocessingLayer.{method_name}` outside of all enclosing " "`tf.function`s. Note that you can call a `PreprocessingLayer` " "directly on `Tensor`s inside a `tf.function` like: `layer(x)`, " "or update its state like: `layer.update_state(x)`." ).format(method_name=method_name) raise RuntimeError(error_msg)
def is_in_tf_function(): """Returns if inside of a tf.function.""" # Check if running in V1 graph mode. if not tf.compat.v1.executing_eagerly_outside_functions(): return False if not tf.inside_function(): return False # Check if inside Keras FuncGraph. if is_in_keras_graph(): return False # Check for a v1 `wrap_function` FuncGraph. graph = tf.compat.v1.get_default_graph() if (getattr(graph, 'name', False) and graph.name.startswith('wrapped_function')): return False return True