def is_whitelisted(o, check_call_override=True, allow_namedtuple_subclass=False): """Checks whether an entity is whitelisted for use in graph mode. Examples of whitelisted entities include all members of the tensorflow package. Args: o: A Python entity. check_call_override: Reserved for internal use. When set to `False`, it disables the rule according to which classes are whitelisted if their __call__ method is whitelisted. allow_namedtuple_subclass: Reserved for internal use. When `True`, namedtuple subclasses are not whitelisted. Returns: Boolean """ # TODO(b/120224672): Fix this. if isinstance(o, functools.partial): # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since # functools.partial objects do not have a __module__ attribute. m = functools else: m = tf_inspect.getmodule(o) # Examples of callables that lack a __module__ property include builtins. if hasattr(m, '__name__'): for rule in config.CONVERSION_RULES: action = rule.get_action(m) if action == config.Action.CONVERT: logging.log(2, 'Not whitelisted: %s: %s', o, rule) return False elif action == config.Action.DO_NOT_CONVERT: logging.log(2, 'Whitelisted: %s: %s', o, rule) return True if tf_inspect.isgeneratorfunction(o): logging.warn( 'Entity %s appears to be a generator function. It will not be converted' ' by AutoGraph.', o) logging.log(2, 'Whitelisted: %s: generator functions are not converted', o) return True if (check_call_override and not tf_inspect.isclass(o) and hasattr(o, '__call__')): # Callable objects: whitelisted if their __call__ method is. # The type check avoids infinite recursion around the __call__ method # of function objects. if (type(o) != type(o.__call__)) and is_whitelisted(o.__call__): # pylint: disable=unidiomatic-typecheck logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o) return True owner_class = None if tf_inspect.ismethod(o): # Methods of whitelisted classes are also whitelisted, even if they are # bound via user subclasses. # # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also # whitelisted. # # class Custom(tf.Foo): # pass # # baz = Custom() # # For the example above, if `Custom` did overload `bar`, then it would no # longer be whitelisted. owner_class = inspect_utils.getmethodclass(o) if owner_class is function.TfMethodTarget: owner_class = o.__self__.target_class if owner_class is not None: if issubclass(owner_class, unittest.TestCase): logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o) return True owner_class = inspect_utils.getdefiningclass(o, owner_class) if is_whitelisted(owner_class, check_call_override=False, allow_namedtuple_subclass=True): logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o, owner_class) return True if inspect_utils.isnamedtuple(o): # Due to the way they're constructed, namedtuple types cannot be converted # because they don't expose source code. But we assume they are safe for # graph mode since they are just containers. if allow_namedtuple_subclass: if not any( inspect_utils.isnamedtuple(base) for base in o.__bases__): logging.log(2, 'Whitelisted: %s: named tuple', o) return True else: logging.log(2, 'Whitelisted: %s: named tuple or subclass', o) return True logging.log(2, 'Not whitelisted: %s: default rule', o) return False
def is_whitelisted_for_graph(o, check_call_override=True): """Checks whether an entity is whitelisted for use in graph mode. Examples of whitelisted entities include all members of the tensorflow package. Args: o: A Python entity. check_call_override: Reserved for internal use. When set to `False`, it disables the rule according to which classes are whitelisted if their __call__ method is whitelisted. Returns: Boolean """ # TODO(b/120224672): Fix this. if isinstance(o, functools.partial): # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since # functools.partial objects do not have a __module__ attribute. m = functools else: m = tf_inspect.getmodule(o) if hasattr(m, '__name__'): # Builtins typically have unnamed modules. for prefix, in config.DEFAULT_UNCOMPILED_MODULES: if m.__name__.startswith(prefix + '.') or m.__name__ == prefix: logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix) return True if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'): logging.log(2, 'Whitelisted: %s: already converted', o) return True if tf_inspect.isgeneratorfunction(o): logging.warn( 'Entity {} appears to be a generator function. It will not be converted' ' by AutoGraph.'.format(o), 1) logging.log(2, 'Whitelisted: %s: generator functions are not converted', o) return True if check_call_override and hasattr(o, '__call__'): # Callable objects: whitelisted if their __call__ method is. # The type check avoids infinite recursion around the __call__ method # of function objects. if (type(o) != type(o.__call__)) and is_whitelisted_for_graph( o.__call__): # pylint: disable=unidiomatic-typecheck logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o) return True owner_class = None if tf_inspect.ismethod(o): # Methods of whitelisted classes are also whitelisted, even if they are # bound via user subclasses. # # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also # whitelisted. # # class Custom(tf.Foo): # pass # # baz = Custom() # # For the example above, if `Custom` did overload `bar`, then it would no # longer be whitelisted. owner_class = inspect_utils.getmethodclass(o) if owner_class is not None: if issubclass(owner_class, unittest.TestCase): logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o) return True owner_class = inspect_utils.getdefiningclass(o, owner_class) is_call_override = (o.__name__ == '__call__') if is_whitelisted_for_graph( owner_class, check_call_override=not is_call_override): logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o, owner_class) return True if inspect_utils.isnamedtuple(o): # Due to the way they're constructed, namedtuple types cannot be converted # because they don't expose source code. But we assume they are safe for # graph mode since they are just containers. if tf_inspect.isclass(o) and len(o.__bases__) > 1: logging.warn( 'Entity {} looks like a namedtuple subclass. Its constructor will' ' not be converted by AutoGraph, but if it has any custom methods,' ' those will be.'.format(o), 1) logging.log(2, 'Whitelisted: %s: named tuple', o) return True logging.log(2, 'Not whitelisted: %s: default rule', o) return False
def is_whitelisted_for_graph(o): """Checks whether an entity is whitelisted for use in graph mode. Examples of whitelisted entities include all members of the tensorflow package. Args: o: A Python entity. Returns: Boolean """ # TODO(b/120224672): Fix this. if isinstance(o, functools.partial): # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since # functools.partial objects do not have a __module__ attribute. m = functools else: m = tf_inspect.getmodule(o) if hasattr(m, '__name__'): # Builtins typically have unnamed modules. for prefix, in config.DEFAULT_UNCOMPILED_MODULES: if m.__name__.startswith(prefix): logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix) return True # Temporary -- whitelist tensorboard modules. # TODO(b/122731813): Remove. if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__: logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o) return True if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'): logging.log(2, 'Whitelisted: %s: already converted', o) return True if tf_inspect.isgeneratorfunction(o): logging.warn( 'Entity {} appears to be a generator function. It will not be converted' ' by AutoGraph.'.format(o), 1) logging.log(2, 'Whitelisted: %s: generator functions are not converted', o) return True if hasattr(o, '__call__'): # Callable objects: whitelisted if their __call__ method is. # The type check avoids infinite recursion around the __call__ method # of function objects. if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__): # pylint: disable=unidiomatic-typecheck logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o) return True owner_class = None if tf_inspect.ismethod(o): # Methods of whitelisted classes are also whitelisted, even if they are # bound via user subclasses. # # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also # whitelisted. # # class Custom(tf.Foo): # pass # # baz = Custom() # # For the example above, if `Custom` did overload `bar`, then it would no # longer be whitelisted. owner_class = inspect_utils.getmethodclass(o) if owner_class is not None: if issubclass(owner_class, unittest.TestCase): logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o) return True owner_class = inspect_utils.getdefiningclass(o, owner_class) if is_whitelisted_for_graph(owner_class): logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o, owner_class) return True if inspect_utils.isnamedtuple(o): # Due to the way they're constructed, namedtuple types cannot be converted # because they don't expose source code. But we assume they are safe for # graph mode since they are just containers. if tf_inspect.isclass(o) and len(o.__bases__) > 1: logging.warn( 'Entity {} looks like a namedtuple subclass. Its constructor will' ' not be converted by AutoGraph, but if it has any custom methods,' ' those will be.'.format(o), 1) logging.log(2, 'Whitelisted: %s: named tuple', o) return True logging.log(2, 'Not whitelisted: %s: default rule', o) return False