def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') target_entity = anno.getanno(node.func, 'live_val') target_fqn = anno.getanno(node.func, 'fqn') if anno.hasanno(node, 'is_constructor'): new_name = self.ctx.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: if anno.hasanno(node.func, 'parent_type'): owner_type = anno.getanno(node.func, 'parent_type') else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.ctx.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: if target_entity is not None: if tf_inspect.ismethod(target_entity): # The renaming process will transform it into a regular function. # TODO(mdan): Is this complete? How does it work with nested members? node.args = [node.func.value] + node.args node.func = templates.replace_as_expression( 'func_name', func_name=new_name) return node
def test_getmethodclass_callables(self): class TestCallable(object): def __call__(self): pass c = TestCallable() self.assertEqual(inspect_utils.getmethodclass(c), TestCallable)
def test_getmethodclass_callables(self): class TestCallable(object): def __call__(self): pass c = TestCallable() self.assertEqual(inspect_utils.getmethodclass(c), TestCallable)
def test_getmethodclass_locals(self): def local_function(): pass class LocalClass(object): def member_function(self): pass @decorator def decorated_member(self): pass @function_decorator() def fn_decorated_member(self): pass @wrapping_decorator() def wrap_decorated_member(self): pass self.assertEqual( inspect_utils.getmethodclass(local_function), None) self.assertEqual( inspect_utils.getmethodclass(LocalClass.member_function), LocalClass) self.assertEqual( inspect_utils.getmethodclass(LocalClass.decorated_member), LocalClass) self.assertEqual( inspect_utils.getmethodclass(LocalClass.fn_decorated_member), LocalClass) self.assertEqual( inspect_utils.getmethodclass(LocalClass.wrap_decorated_member), LocalClass) test_obj = LocalClass() self.assertEqual( inspect_utils.getmethodclass(test_obj.member_function), test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.decorated_member), test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.fn_decorated_member), test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.wrap_decorated_member), test_obj)
def test_getmethodclass_locals(self): def local_function(): pass class LocalClass(object): def member_function(self): pass @decorator def decorated_member(self): pass @function_decorator() def fn_decorated_member(self): pass @wrapping_decorator() def wrap_decorated_member(self): pass self.assertEqual( inspect_utils.getmethodclass(local_function), None) self.assertEqual( inspect_utils.getmethodclass(LocalClass.member_function), LocalClass) self.assertEqual( inspect_utils.getmethodclass(LocalClass.decorated_member), LocalClass) self.assertEqual( inspect_utils.getmethodclass(LocalClass.fn_decorated_member), LocalClass) self.assertEqual( inspect_utils.getmethodclass(LocalClass.wrap_decorated_member), LocalClass) test_obj = LocalClass() self.assertEqual( inspect_utils.getmethodclass(test_obj.member_function), LocalClass) self.assertEqual( inspect_utils.getmethodclass(test_obj.decorated_member), LocalClass) self.assertEqual( inspect_utils.getmethodclass(test_obj.fn_decorated_member), LocalClass) self.assertEqual( inspect_utils.getmethodclass(test_obj.wrap_decorated_member), LocalClass)
def test_getmethodclass_weakref_mechanism(self): test_obj = TestClass() def test_fn(self): return self bound_method = types.MethodType( test_fn, function.TfMethodTarget(weakref.ref(test_obj), test_obj.member_function)) self.assertEqual(inspect_utils.getmethodclass(bound_method), TestClass)
def test_getmethodclass_weakref_mechanism(self): test_obj = TestClass() def test_fn(self): return self bound_method = types.MethodType( test_fn, function.TfMethodTarget( weakref.ref(test_obj), test_obj.member_function)) self.assertEqual(inspect_utils.getmethodclass(bound_method), TestClass)
def test_getmethodclass_weakref_mechanism(self): test_obj = TestClass() class WeakrefWrapper(object): def __init__(self): self.ag_self_weakref__ = weakref.ref(test_obj) def test_fn(self): return self bound_method = types.MethodType(test_fn, WeakrefWrapper()) self.assertEqual(inspect_utils.getmethodclass(bound_method), TestClass)
def test_getmethodclass_weakref_mechanism(self): test_obj = TestClass() class WeakrefWrapper(object): def __init__(self): self.ag_self_weakref__ = weakref.ref(test_obj) def test_fn(self): return self bound_method = types.MethodType(test_fn, WeakrefWrapper()) self.assertEqual(inspect_utils.getmethodclass(bound_method), test_obj)
def converted_call(f, owner, options, *args, **kwargs): """Compiles a function call inline. For internal use only.""" if owner is not None: if not isinstance(f, str): raise ValueError( 'When owner is specified, the function name must be specified as' ' a string: {}'.format(f)) # Special case when the owner is a 'super' object. In that case lookups of # dynamic attributes won't work. See # inspect_utils.SuperWrapperForDynamicAttrs. if isinstance(owner, super): owner = inspect_utils.SuperWrapperForDynamicAttrs(owner) f = getattr(owner, f) # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. if not options.force_conversion and conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) unknown_arg_value = object() # Sentinel for arguments of unknown value if inspect_utils.isbuiltin(f): return py_builtins.overload_of(f)(*args, **kwargs) if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions target_entity = f arg_map_target = f f_class = inspect_utils.getmethodclass(f) if f_class is not None: # If this is a method call, it may or may not include self. # # Example when self is included: # converted_call(to_graph(foo.bar), foo) # # Example when self is not included: # super(...).foo(args) # if owner is not None and (not args or args[0] is not owner): effective_args = (owner,) + args else: effective_args = args partial_types = (f_class,) else: effective_args = args partial_types = () elif tf_inspect.isclass(f): # Constructors target_entity = f arg_map_target = f.__init__ effective_args = args partial_types = () elif hasattr(f, '__call__') and hasattr(f, '__class__'): # Callable objects target_entity = f.__call__ arg_map_target = f.__call__ effective_args = (f,) + args partial_types = (f.__class__,) else: NotImplementedError('unknown callable type "%s"' % type(f)) arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) arg_types = {} for name, arg in arg_values.items(): if arg is unknown_arg_value: continue arg_class = arg.__class__ arg_types[name] = (arg_class.__name__, arg_class) # When called from within a decorator, this is the only indication that # the function is a method - it appears that the decorator is applied # before the method is bound. if not partial_types: if 'self' in arg_values: if tf_inspect.isclass(arg_values['self'].__class__): partial_types = (arg_values['self'].__class__,) elif 'cls' in arg_values: if tf_inspect.isclass(arg_values['cls']): partial_types = (arg_values['cls'],) converted_f = to_graph( target_entity, recursive=options.recursive, verbose=options.verbose, arg_values=arg_values, arg_types=arg_types, partial_types=partial_types, strip_decorators=options.strip_decorators) return converted_f(*effective_args, **kwargs)
def test_getmethodclass(self): self.assertEqual( inspect_utils.getmethodclass(free_function), None) self.assertEqual( inspect_utils.getmethodclass(free_factory()), None) self.assertEqual( inspect_utils.getmethodclass(TestClass.member_function), TestClass) self.assertEqual( inspect_utils.getmethodclass(TestClass.decorated_member), TestClass) self.assertEqual( inspect_utils.getmethodclass(TestClass.fn_decorated_member), TestClass) self.assertEqual( inspect_utils.getmethodclass(TestClass.wrap_decorated_member), TestClass) self.assertEqual( inspect_utils.getmethodclass(TestClass.static_method), TestClass) self.assertEqual( inspect_utils.getmethodclass(TestClass.class_method), TestClass) test_obj = TestClass() self.assertEqual( inspect_utils.getmethodclass(test_obj.member_function), test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.decorated_member), test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.fn_decorated_member), test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.wrap_decorated_member), test_obj) self.assertEqual( inspect_utils.getmethodclass(test_obj.static_method), TestClass) self.assertEqual( inspect_utils.getmethodclass(test_obj.class_method), TestClass)
def test_getmethodclass_no_bool_conversion(self): tensor = constant_op.constant([1]) self.assertEqual(inspect_utils.getmethodclass(tensor.get_shape), type(tensor))
def is_whitelisted_for_graph(o): """Checks whether an entity is whitelisted for use in graph mode. Examples of whitelisted entities include all members of the tensorflow package. Args: o: A Python entity. Returns: Boolean """ # TODO(b/120224672): Fix this. if isinstance(o, functools.partial): # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since # functools.partial objects do not have a __module__ attribute. m = functools else: m = tf_inspect.getmodule(o) if hasattr(m, '__name__'): # Builtins typically have unnamed modules. for prefix, in config.DEFAULT_UNCOMPILED_MODULES: if m.__name__.startswith(prefix): logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix) return True # Temporary -- whitelist tensorboard modules. # TODO(b/122731813): Remove. if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__: logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o) return True if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'): logging.log(2, 'Whitelisted: %s: already converted', o) return True if tf_inspect.isgeneratorfunction(o): logging.warn( 'Entity {} appears to be a generator function. It will not be converted' ' by AutoGraph.'.format(o), 1) logging.log(2, 'Whitelisted: %s: generator functions are not converted', o) return True if hasattr(o, '__call__'): # Callable objects: whitelisted if their __call__ method is. # The type check avoids infinite recursion around the __call__ method # of function objects. if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__): # pylint: disable=unidiomatic-typecheck logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o) return True owner_class = None if tf_inspect.ismethod(o): # Methods of whitelisted classes are also whitelisted, even if they are # bound via user subclasses. # # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also # whitelisted. # # class Custom(tf.Foo): # pass # # baz = Custom() # # For the example above, if `Custom` did overload `bar`, then it would no # longer be whitelisted. owner_class = inspect_utils.getmethodclass(o) if owner_class is not None: if issubclass(owner_class, unittest.TestCase): logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o) return True owner_class = inspect_utils.getdefiningclass(o, owner_class) if is_whitelisted_for_graph(owner_class): logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o, owner_class) return True if inspect_utils.isnamedtuple(o): # Due to the way they're constructed, namedtuple types cannot be converted # because they don't expose source code. But we assume they are safe for # graph mode since they are just containers. if tf_inspect.isclass(o) and len(o.__bases__) > 1: logging.warn( 'Entity {} looks like a namedtuple subclass. Its constructor will' ' not be converted by AutoGraph, but if it has any custom methods,' ' those will be.'.format(o), 1) logging.log(2, 'Whitelisted: %s: named tuple', o) return True logging.log(2, 'Not whitelisted: %s: default rule', o) return False
def converted_call(f, owner, options, *args, **kwargs): """Compiles a function call inline. For internal use only.""" logging.vlog(logging.DEBUG, 'Converted call: %s; owner: %s', f, owner) if owner is not None: if not isinstance(f, str): raise ValueError( 'When owner is specified, the function name must be specified as' ' a string: {}'.format(f)) # Special case when the owner is a 'super' object. In that case lookups of # dynamic attributes won't work. See # inspect_utils.SuperWrapperForDynamicAttrs. if isinstance(owner, super): owner = inspect_utils.SuperWrapperForDynamicAttrs(owner) f = getattr(owner, f) if inspect_utils.isbuiltin(f): return py_builtins.overload_of(f)(*args, **kwargs) # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. if not options.force_conversion and conversion.is_whitelisted_for_graph(f): # Args typically include `self`, as required by the conversion process. # When conversion is skipped, `self` is not necessary, because the # original bound method is being executed. This code removes it. if tf_inspect.ismethod(f) and args: f_class = inspect_utils.getmethodclass(f) if args[0] is f_class: args = args[1:] return f(*args, **kwargs) # internal_convert_user_code is for example turned off when issuing a dynamic # call conversion from generated code while in nonrecursive mode. In that # case we evidently don't want to recurse, but we still have to convert # things like builtins. if not options.internal_convert_user_code: return f(*args, **kwargs) # Unwrap functools.partial objects # TODO(mdan): Consider sharing unwrapping logic with tf_inspect. while isinstance(f, functools.partial): args = f.args + args new_kwargs = {} if f.keywords is not None: new_kwargs.update(f.keywords) new_kwargs.update(kwargs) kwargs = new_kwargs f = f.func if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions target_entity = f arg_map_target = f f_class = inspect_utils.getmethodclass(f) # TODO(b/119246461): This may be more elegantly handled using __get__? if f_class is not None: # If this is a method call, it may or may not include self. # # Example when self is included: # converted_call(to_graph(foo.bar), foo) # # Example when self is not included: # super(...).foo(args) # if owner is not None and (not args or args[0] is not owner): effective_args = (owner, ) + args else: # When the owner is not specified, use the result of # inspect_utils.getmethodclass. # TODO(b/119246461): Make sure an owner is always specified. if not args or args[0] is not f_class: effective_args = (f_class, ) + args else: effective_args = (f_class, ) + args[1:] partial_types = (f_class, ) else: effective_args = args partial_types = () elif tf_inspect.isclass(f): # Constructors target_entity = f arg_map_target = f.__init__ effective_args = args partial_types = () elif hasattr(f, '__call__') and hasattr(f, '__class__'): # Callable objects target_entity = f.__call__ arg_map_target = f.__call__ effective_args = (f, ) + args partial_types = (f.__class__, ) else: NotImplementedError('unknown callable type "%s"' % type(f)) arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) arg_types = {} for name, arg in arg_values.items(): arg_class = arg.__class__ arg_types[name] = (arg_class.__name__, arg_class) # When called from within a decorator, this is the only indication that # the function is a method - it appears that the decorator is applied # before the method is bound. if not partial_types: if 'self' in arg_values: if tf_inspect.isclass(arg_values['self'].__class__): partial_types = (arg_values['self'].__class__, ) elif 'cls' in arg_values: if tf_inspect.isclass(arg_values['cls']): partial_types = (arg_values['cls'], ) converted_f = to_graph( target_entity, recursive=options.recursive, arg_values=arg_values, arg_types=arg_types, experimental_optional_features=options.optional_features, experimental_strip_decorators=options.strip_decorators, experimental_verbose=options.verbose, experimental_partial_types=partial_types) result = converted_f(*effective_args, **kwargs) # The converted function's closure is simply inserted into the function's # module __dict__. Since modules are permanently cached, that results in # leaking the entire closure. # Normally, it's not safe to delete the module because that may release said # closure as well. However, in the case of converted_call we are certain the # function will not be executed again, so the closure should no longer be # needed so long as the function doesn't return any executable code. # TODO(mdan): Attach the closure properly, using cells. if all(map(_is_not_callable, nest.flatten(result))): del sys.modules[converted_f.__module__] return result
def test_getmethodclass(self): self.assertEqual( inspect_utils.getmethodclass(free_function), None) self.assertEqual( inspect_utils.getmethodclass(free_factory()), None) self.assertEqual( inspect_utils.getmethodclass(TestClass.member_function), TestClass) self.assertEqual( inspect_utils.getmethodclass(TestClass.decorated_member), TestClass) self.assertEqual( inspect_utils.getmethodclass(TestClass.fn_decorated_member), TestClass) self.assertEqual( inspect_utils.getmethodclass(TestClass.wrap_decorated_member), TestClass) self.assertEqual( inspect_utils.getmethodclass(TestClass.static_method), TestClass) self.assertEqual( inspect_utils.getmethodclass(TestClass.class_method), TestClass) test_obj = TestClass() self.assertEqual( inspect_utils.getmethodclass(test_obj.member_function), TestClass) self.assertEqual( inspect_utils.getmethodclass(test_obj.decorated_member), TestClass) self.assertEqual( inspect_utils.getmethodclass(test_obj.fn_decorated_member), TestClass) self.assertEqual( inspect_utils.getmethodclass(test_obj.wrap_decorated_member), TestClass) self.assertEqual( inspect_utils.getmethodclass(test_obj.static_method), TestClass) self.assertEqual( inspect_utils.getmethodclass(test_obj.class_method), TestClass)
def converted_call(f, recursive, verbose, force_conversion, arg_types, *args, **kwargs): """Compiles a function call inline. For internal use only.""" # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. if not force_conversion and conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) unknown_arg_value = object() # Sentinel for arguments of unknown value if inspect_utils.isbuiltin(f): return py_builtins.overload_of(f)(*args, **kwargs) if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions target_entity = f arg_map_target = f effective_args = args f_class = inspect_utils.getmethodclass(f) if f_class is not None: partial_types = (f_class,) else: partial_types = () elif tf_inspect.isclass(f): # Constructors target_entity = f arg_map_target = f.__init__ effective_args = args partial_types = () elif hasattr(f, '__call__') and hasattr(f, '__class__'): # Callable objects target_entity = f.__call__ arg_map_target = f.__call__ effective_args = (f,) + args partial_types = (f.__class__,) else: NotImplementedError('unknown callable type "%s"' % type(f)) arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) for name, arg in arg_values.items(): if arg is unknown_arg_value: continue arg_class = arg.__class__ # If arg_value_hints specifies any name, use that instead. if name not in arg_types: arg_types[name] = (arg_class.__name__, arg_class) # When called from within a decorator, this is the only indication that # the function is a method - it appears that the decorator is applied # before the method is bound. if not partial_types: if 'self' in arg_values: if tf_inspect.isclass(arg_values['self'].__class__): partial_types = (arg_values['self'].__class__,) elif 'cls' in arg_values: if tf_inspect.isclass(arg_values['cls']): partial_types = (arg_values['cls'],) converted_f = to_graph( target_entity, recursive=recursive, verbose=verbose, arg_values=arg_values, arg_types=arg_types, partial_types=partial_types) return converted_f(*effective_args, **kwargs)
def test_getmethodclass_no_bool_conversion(self): tensor = constant_op.constant([1]) self.assertEqual(inspect_utils.getmethodclass(tensor.get_shape), tensor)
def is_whitelisted(o, check_call_override=True, allow_namedtuple_subclass=False): """Checks whether an entity is whitelisted for use in graph mode. Examples of whitelisted entities include all members of the tensorflow package. Args: o: A Python entity. check_call_override: Reserved for internal use. When set to `False`, it disables the rule according to which classes are whitelisted if their __call__ method is whitelisted. allow_namedtuple_subclass: Reserved for internal use. When `True`, namedtuple subclasses are not whitelisted. Returns: Boolean """ # TODO(b/120224672): Fix this. if isinstance(o, functools.partial): # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since # functools.partial objects do not have a __module__ attribute. m = functools else: m = tf_inspect.getmodule(o) # Examples of callables that lack a __module__ property include builtins. if hasattr(m, '__name__'): for rule in config.CONVERSION_RULES: action = rule.get_action(m) if action == config.Action.CONVERT: logging.log(2, 'Not whitelisted: %s: %s', o, rule) return False elif action == config.Action.DO_NOT_CONVERT: logging.log(2, 'Whitelisted: %s: %s', o, rule) return True if tf_inspect.isgeneratorfunction(o): logging.warn( 'Entity %s appears to be a generator function. It will not be converted' ' by AutoGraph.', o) logging.log(2, 'Whitelisted: %s: generator functions are not converted', o) return True if (check_call_override and not tf_inspect.isclass(o) and hasattr(o, '__call__')): # Callable objects: whitelisted if their __call__ method is. # The type check avoids infinite recursion around the __call__ method # of function objects. if (type(o) != type(o.__call__)) and is_whitelisted(o.__call__): # pylint: disable=unidiomatic-typecheck logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o) return True owner_class = None if tf_inspect.ismethod(o): # Methods of whitelisted classes are also whitelisted, even if they are # bound via user subclasses. # # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also # whitelisted. # # class Custom(tf.Foo): # pass # # baz = Custom() # # For the example above, if `Custom` did overload `bar`, then it would no # longer be whitelisted. owner_class = inspect_utils.getmethodclass(o) if owner_class is function.TfMethodTarget: owner_class = o.__self__.target_class if owner_class is not None: if issubclass(owner_class, unittest.TestCase): logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o) return True owner_class = inspect_utils.getdefiningclass(o, owner_class) if is_whitelisted(owner_class, check_call_override=False, allow_namedtuple_subclass=True): logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o, owner_class) return True if inspect_utils.isnamedtuple(o): # Due to the way they're constructed, namedtuple types cannot be converted # because they don't expose source code. But we assume they are safe for # graph mode since they are just containers. if allow_namedtuple_subclass: if not any( inspect_utils.isnamedtuple(base) for base in o.__bases__): logging.log(2, 'Whitelisted: %s: named tuple', o) return True else: logging.log(2, 'Whitelisted: %s: named tuple or subclass', o) return True logging.log(2, 'Not whitelisted: %s: default rule', o) return False
def converted_call(f, owner, options, *args, **kwargs): """Compiles a function call inline. For internal use only.""" if options.verbose >= converter.Verbosity.VERBOSE: logging.info('Converted call: {}; owner: {}'.format(f, owner)) if owner is not None: if not isinstance(f, str): raise ValueError( 'When owner is specified, the function name must be specified as' ' a string: {}'.format(f)) # Special case when the owner is a 'super' object. In that case lookups of # dynamic attributes won't work. See # inspect_utils.SuperWrapperForDynamicAttrs. if isinstance(owner, super): owner = inspect_utils.SuperWrapperForDynamicAttrs(owner) f = getattr(owner, f) if inspect_utils.isbuiltin(f): return py_builtins.overload_of(f)(*args, **kwargs) # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. if not options.force_conversion and conversion.is_whitelisted_for_graph(f): # Args typically include `self`, as required by the conversion process. # When conversion is skipped, `self` is not necessary, because the # original bound method is being executed. This code removes it. if tf_inspect.ismethod(f) and args: f_class = inspect_utils.getmethodclass(f) if args[0] is f_class: args = args[1:] return f(*args, **kwargs) # internal_convert_user_code is for example turned off when issuing a dynamic # call conversion from generated code while in nonrecursive mode. In that # case we evidently don't want to recurse, but we still have to convert # things like builtins. if not options.internal_convert_user_code: return f(*args, **kwargs) # Unwrap functools.partial objects # TODO(allenl, mdan): Consider sharing unwrapping logic with tf_inspect. while isinstance(f, functools.partial): args = f.args + args new_kwargs = {} if f.keywords is not None: new_kwargs.update(f.keywords) new_kwargs.update(kwargs) kwargs = new_kwargs f = f.func if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions target_entity = f arg_map_target = f f_class = inspect_utils.getmethodclass(f) # TODO(b/119246461): This may be more elegantly handled using __get__? if f_class is not None: # If this is a method call, it may or may not include self. # # Example when self is included: # converted_call(to_graph(foo.bar), foo) # # Example when self is not included: # super(...).foo(args) # if owner is not None and (not args or args[0] is not owner): effective_args = (owner,) + args else: # When the owner is not specified, use the result of # inspect_utils.getmethodclass. # TODO(b/119246461): Make sure an owner is always specified. if not args or args[0] is not f_class: effective_args = (f_class,) + args else: effective_args = (f_class,) + args[1:] partial_types = (f_class,) else: effective_args = args partial_types = () elif tf_inspect.isclass(f): # Constructors target_entity = f arg_map_target = f.__init__ effective_args = args partial_types = () elif hasattr(f, '__call__') and hasattr(f, '__class__'): # Callable objects target_entity = f.__call__ arg_map_target = f.__call__ effective_args = (f,) + args partial_types = (f.__class__,) else: NotImplementedError('unknown callable type "%s"' % type(f)) arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) arg_types = {} for name, arg in arg_values.items(): arg_class = arg.__class__ arg_types[name] = (arg_class.__name__, arg_class) # When called from within a decorator, this is the only indication that # the function is a method - it appears that the decorator is applied # before the method is bound. if not partial_types: if 'self' in arg_values: if tf_inspect.isclass(arg_values['self'].__class__): partial_types = (arg_values['self'].__class__,) elif 'cls' in arg_values: if tf_inspect.isclass(arg_values['cls']): partial_types = (arg_values['cls'],) converted_f = to_graph( target_entity, recursive=options.recursive, verbose=options.verbose, arg_values=arg_values, arg_types=arg_types, partial_types=partial_types, strip_decorators=options.strip_decorators, optional_features=options.optional_features) result = converted_f(*effective_args, **kwargs) # The converted function's closure is simply inserted into the function's # module __dict__. Since modules are permanently cached, that results in # leaking the entire closure. # Normally, it's not safe to delete the module because that may release said # closure as well. However, in the case of converted_call we are certain the # function will not be executed again, so the closure should no longer be # needed so long as the function doesn't return any executable code. # TODO(mdan): Attach the closure properly, using cells. if all(map(_is_not_callable, nest.flatten(result))): del sys.modules[converted_f.__module__] return result
def is_whitelisted_for_graph(o): """Check whether an entity is whitelisted for use in graph mode. Examples of whitelisted entities include all members of the tensorflow package. Args: o: A Python entity. Returns: Boolean """ # TODO(b/120224672): Fix this. if isinstance(o, functools.partial): # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since # functools.partial objects do not have a __module__ attribute. m = functools else: m = tf_inspect.getmodule(o) if not hasattr(m, '__name__'): # Note: typically it's builtins that fall in this category. Builtins will # be handled by specific code that follows this screening layer. logging.log(2, '%s is NOT whitelisted: unknown module name', o) return False for prefix, in config.DEFAULT_UNCOMPILED_MODULES: if m.__name__.startswith(prefix): logging.log(2, '%s is whitelisted: name starts with "%s"', o, prefix) return True if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'): logging.log(2, '%s is whitelisted: already converted', o) return True if (not inspect_utils.isweakrefself(o) and not tf_inspect.isclass(o) and hasattr(o, '__call__') and hasattr(o, '__class__')): # Callable objects: whitelisted if their __call__ method is. call_whitelisted = is_whitelisted_for_graph(o.__call__) if call_whitelisted: logging.log(2, '%s is whitelisted: object __call__ whitelisted', o) return call_whitelisted if tf_inspect.ismethod(o): # Methods of whitelisted classes are also whitelisted, even if they are # bound via user subclasses. # # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also # whitelisted. # # class Custom(tf.Foo): # pass # # baz = Custom() # # For the example above, if `Custom` did overload `bar`, then it would no # longer be whitelisted. owner_class = inspect_utils.getmethodclass(o) if owner_class is not None: owner_class = inspect_utils.getdefiningclass(o, owner_class) if is_whitelisted_for_graph(owner_class): logging.log(2, '%s is whitelisted: owner is whitelisted %s', o, owner_class) return True if inspect_utils.isnamedtuple(o): # Due to the way they're constructed, namedtuple types cannot be converted # because they don't expose source code. But we assume they are safe for # graph mode since they are just containers. if tf_inspect.isclass(o) and len(o.__bases__) > 1: logging.warn_first_n( 'Entity {} looks like a namedtuple subclass. If it has any custom' ' methods, they will not be converted by AutoGraph.'.format(o), 1) logging.log(2, '%s is whitelisted: named tuple', o) return True logging.log(2, '%s is NOT whitelisted', o) return False
def is_whitelisted_for_graph(o, check_call_override=True): """Checks whether an entity is whitelisted for use in graph mode. Examples of whitelisted entities include all members of the tensorflow package. Args: o: A Python entity. check_call_override: Reserved for internal use. When set to `False`, it disables the rule according to which classes are whitelisted if their __call__ method is whitelisted. Returns: Boolean """ # TODO(b/120224672): Fix this. if isinstance(o, functools.partial): # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since # functools.partial objects do not have a __module__ attribute. m = functools else: m = tf_inspect.getmodule(o) if hasattr(m, '__name__'): # Builtins typically have unnamed modules. for prefix, in config.DEFAULT_UNCOMPILED_MODULES: if m.__name__.startswith(prefix + '.') or m.__name__ == prefix: logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix) return True if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'): logging.log(2, 'Whitelisted: %s: already converted', o) return True if tf_inspect.isgeneratorfunction(o): logging.warn( 'Entity {} appears to be a generator function. It will not be converted' ' by AutoGraph.'.format(o), 1) logging.log(2, 'Whitelisted: %s: generator functions are not converted', o) return True if check_call_override and hasattr(o, '__call__'): # Callable objects: whitelisted if their __call__ method is. # The type check avoids infinite recursion around the __call__ method # of function objects. if (type(o) != type(o.__call__)) and is_whitelisted_for_graph( o.__call__): # pylint: disable=unidiomatic-typecheck logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o) return True owner_class = None if tf_inspect.ismethod(o): # Methods of whitelisted classes are also whitelisted, even if they are # bound via user subclasses. # # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also # whitelisted. # # class Custom(tf.Foo): # pass # # baz = Custom() # # For the example above, if `Custom` did overload `bar`, then it would no # longer be whitelisted. owner_class = inspect_utils.getmethodclass(o) if owner_class is not None: if issubclass(owner_class, unittest.TestCase): logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o) return True owner_class = inspect_utils.getdefiningclass(o, owner_class) is_call_override = (o.__name__ == '__call__') if is_whitelisted_for_graph( owner_class, check_call_override=not is_call_override): logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o, owner_class) return True if inspect_utils.isnamedtuple(o): # Due to the way they're constructed, namedtuple types cannot be converted # because they don't expose source code. But we assume they are safe for # graph mode since they are just containers. if tf_inspect.isclass(o) and len(o.__bases__) > 1: logging.warn( 'Entity {} looks like a namedtuple subclass. Its constructor will' ' not be converted by AutoGraph, but if it has any custom methods,' ' those will be.'.format(o), 1) logging.log(2, 'Whitelisted: %s: named tuple', o) return True logging.log(2, 'Not whitelisted: %s: default rule', o) return False
def is_whitelisted_for_graph(o): """Check whether an entity is whitelisted for use in graph mode. Examples of whitelisted entities include all members of the tensorflow package. Args: o: A Python entity. Returns: Boolean """ # TODO(b/120224672): Fix this. if isinstance(o, functools.partial): # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since # functools.partial objects do not have a __module__ attribute. m = functools else: m = tf_inspect.getmodule(o) if not hasattr(m, '__name__'): # Note: typically it's builtins that fall in this category. Builtins will # be handled by specific code that follows this screening layer. logging.log(2, '%s is NOT whitelisted: unknown module name', o) return False for prefix, in config.DEFAULT_UNCOMPILED_MODULES: if m.__name__.startswith(prefix): logging.log(2, '%s is whitelisted: name starts with "%s"', o, prefix) return True if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'): logging.log(2, '%s is whitelisted: already converted', o) return True if (not inspect_utils.isweakrefself(o) and not tf_inspect.isclass(o) and hasattr(o, '__call__') and hasattr(o, '__class__')): # Callable objects: whitelisted if their __call__ method is. call_whitelisted = is_whitelisted_for_graph(o.__call__) if call_whitelisted: logging.log(2, '%s is whitelisted: object __call__ whitelisted', o) return call_whitelisted if tf_inspect.ismethod(o): # Methods of whitelisted classes are also whitelisted, even if they are # bound via user subclasses. # # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also # whitelisted. # # class Custom(tf.Foo): # pass # # baz = Custom() # # For the example above, if `Custom` did overload `bar`, then it would no # longer be whitelisted. owner_class = inspect_utils.getmethodclass(o) if owner_class is not None: owner_class = inspect_utils.getdefiningclass(o, owner_class) if is_whitelisted_for_graph(owner_class): logging.log(2, '%s is whitelisted: owner is whitelisted %s', o, owner_class) return True if inspect_utils.isnamedtuple(o): # Due to the way they're constructed, namedtuple types cannot be converted # because they don't expose source code. But we assume they are safe for # graph mode since they are just containers. if tf_inspect.isclass(o) and len(o.__bases__) > 1: logging.warn_first_n( 'Entity {} looks like a namedtuple subclass. If it has any custom' ' methods, they will not be converted by AutoGraph.'.format(o), 1) logging.log(2, '%s is whitelisted: named tuple', o) return True logging.log(2, '%s is NOT whitelisted', o) return False
def converted_call(f, options, *args, **kwargs): """Compiles a function call inline. For internal use only.""" # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. if not options.force_conversion and conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) unknown_arg_value = object() # Sentinel for arguments of unknown value if inspect_utils.isbuiltin(f): return py_builtins.overload_of(f)(*args, **kwargs) if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions target_entity = f arg_map_target = f effective_args = args f_class = inspect_utils.getmethodclass(f) if f_class is not None: partial_types = (f_class, ) else: partial_types = () elif tf_inspect.isclass(f): # Constructors target_entity = f arg_map_target = f.__init__ effective_args = args partial_types = () elif hasattr(f, '__call__') and hasattr(f, '__class__'): # Callable objects target_entity = f.__call__ arg_map_target = f.__call__ effective_args = (f, ) + args partial_types = (f.__class__, ) else: NotImplementedError('unknown callable type "%s"' % type(f)) arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) for name, arg in arg_values.items(): if arg is unknown_arg_value: continue arg_class = arg.__class__ # If arg_value_hints specifies any name, use that instead. if name not in options.arg_types: options.arg_types[name] = (arg_class.__name__, arg_class) # When called from within a decorator, this is the only indication that # the function is a method - it appears that the decorator is applied # before the method is bound. if not partial_types: if 'self' in arg_values: if tf_inspect.isclass(arg_values['self'].__class__): partial_types = (arg_values['self'].__class__, ) elif 'cls' in arg_values: if tf_inspect.isclass(arg_values['cls']): partial_types = (arg_values['cls'], ) converted_f = to_graph(target_entity, recursive=options.recursive, verbose=options.verbose, arg_values=arg_values, arg_types=options.arg_types, partial_types=partial_types, strip_decorators=options.strip_decorators) return converted_f(*effective_args, **kwargs)
def converted_call(f, owner, options, *args, **kwargs): """Compiles a function call inline. For internal use only.""" if options.verbose: logging.info('Converted call: {}; owner: {}'.format(f, owner)) if owner is not None: if not isinstance(f, str): raise ValueError( 'When owner is specified, the function name must be specified as' ' a string: {}'.format(f)) # Special case when the owner is a 'super' object. In that case lookups of # dynamic attributes won't work. See # inspect_utils.SuperWrapperForDynamicAttrs. if isinstance(owner, super): owner = inspect_utils.SuperWrapperForDynamicAttrs(owner) f = getattr(owner, f) # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. if not options.force_conversion and conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) if inspect_utils.isbuiltin(f): return py_builtins.overload_of(f)(*args, **kwargs) # internal_convert_user_code is for example turned off when issuing a dynamic # call conversion from generated code while in nonrecursive mode. In that # case we evidently don't want to recurse, but we still have to convert # things like builtins. if not options.internal_convert_user_code: return f(*args, **kwargs) if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions target_entity = f arg_map_target = f f_class = inspect_utils.getmethodclass(f) if f_class is not None: # If this is a method call, it may or may not include self. # # Example when self is included: # converted_call(to_graph(foo.bar), foo) # # Example when self is not included: # super(...).foo(args) # if owner is not None and (not args or args[0] is not owner): effective_args = (owner, ) + args else: effective_args = args partial_types = (f_class, ) else: effective_args = args partial_types = () elif tf_inspect.isclass(f): # Constructors target_entity = f arg_map_target = f.__init__ effective_args = args partial_types = () elif hasattr(f, '__call__') and hasattr(f, '__class__'): # Callable objects target_entity = f.__call__ arg_map_target = f.__call__ effective_args = (f, ) + args partial_types = (f.__class__, ) else: NotImplementedError('unknown callable type "%s"' % type(f)) arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) arg_types = {} for name, arg in arg_values.items(): arg_class = arg.__class__ arg_types[name] = (arg_class.__name__, arg_class) # When called from within a decorator, this is the only indication that # the function is a method - it appears that the decorator is applied # before the method is bound. if not partial_types: if 'self' in arg_values: if tf_inspect.isclass(arg_values['self'].__class__): partial_types = (arg_values['self'].__class__, ) elif 'cls' in arg_values: if tf_inspect.isclass(arg_values['cls']): partial_types = (arg_values['cls'], ) converted_f = to_graph(target_entity, recursive=options.recursive, verbose=options.verbose, arg_values=arg_values, arg_types=arg_types, partial_types=partial_types, strip_decorators=options.strip_decorators, optional_features=options.optional_features) result = converted_f(*effective_args, **kwargs) # When converting a function, we write a tmp file and import it as a module. # This leaks the module's closure. Once we've executed the converted_f module # and there is no more code left to be executed, we can clean up the module. # TODO(mdan): Look into workarounds that don't suffer from refcount leaks. # Possibly attach the closure as a regular closure cell, instead of relying on # module globals. # If there are callables in the result, they will fail to find their closure # when called, so only delete module if all returned types are not callable. flat_results = nest.flatten(result) if all(map(_is_not_callable, flat_results)): del sys.modules[converted_f.__module__] return result