def get_overload_annotations(mod, jit_ignored_properties): # original function => [(mangled overload name, overload function)] overloads = {} for name in dir(type(mod)): if name in jit_ignored_properties: continue item = getattr(mod, name, None) if not callable(item): continue # builtin functions like repr() in python 2 do not have __module__ defined if hasattr(item, "__module__") and item.__module__ is not None: method_overloads = _jit_internal._get_overloaded_methods(item, mod.__class__) if method_overloads is None: continue if item.__func__ in method_overloads: raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message( 'method', item.__func__)) names = [name + "__" + str(i) for i in range(len(method_overloads))] overloads[item] = list(zip(names, method_overloads)) return overloads
def get_overload_annotations(mod): # original function => [(mangled overload name, overload function)] overloads = {} for name in dir(type(mod)): item = getattr(mod, name, None) if not callable(item): continue # builtin functions like repr() in python 2 do not have __module__ defined if hasattr(item, "__module__") and item.__module__ is not None: method_overloads = _jit_internal._get_overloaded_methods(item, mod.__class__) if method_overloads is None: continue names = [name + "__" + str(i) for i in range(len(method_overloads))] overloads[item] = list(zip(names, method_overloads)) return overloads
def recursive_script(mod, exclude_methods=()): """ Makes a ScriptModule from an nn.Module. If `_methods` is provided, these methods are treated as @script_methods. If not, it defaults to `('forward',)`. Methods accessed in forward are scripted on demand. """ if isinstance(mod, torch.jit.ScriptModule): return mod if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): # Create constant versions for the iterable modules return create_constant_iterable_module(mod) if not hasattr(mod, '_parameters'): raise RuntimeError("'{}' has not been initialized, did you forget to call 'super()'?" .format(type(mod).__name__)) methods = () if hasattr(mod, 'forward'): if mod.forward.__func__ == torch.nn.Module.forward: raise RuntimeError("No forward method was defined on {}".format(mod)) if not _jit_internal.is_ignored_fn(mod.forward): methods = ('forward',) exported = [] overloads = [] for name in dir(mod): item = getattr(mod, name) if callable(item): if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT: exported.append(name) # builtin functions like repr() in python 2 do not have __module__ defined if hasattr(item, "__module__") and item.__module__ is not None: method_overloads = _jit_internal._get_overloaded_methods(item, mod.__class__) if method_overloads is not None: overloads.append((item, method_overloads)) methods = methods + tuple(exported) methods = tuple(name for name in methods if name not in exclude_methods) overload_name_mappings = dict(getattr(mod, "__overloads__", {})) overload_stubs = [] for orig_fn, overload_fns in overloads: orig_ast = torch.jit.get_jit_def(orig_fn, self_name="ScriptModule") names = list(map(lambda i: orig_ast.name().name + "__" + str(i), range(len(overload_fns)))) overload_name_mappings[orig_ast.name().name] = names for overload_fn, name in zip(overload_fns, names): torch.jit._check_no_signature(overload_fn) over_ast = torch.jit.get_jit_def(overload_fn, self_name="ScriptModule") new_ast = torch._C._replace_overloaded_method_decl(over_ast.decl(), orig_ast, name) _rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn) overload_stubs.append(torch.jit.ScriptMethodStub(_rcb, new_ast, overload_fn)) mod.__overloads__ = overload_name_mappings # we shouldn't directly compile overloaded methods, just its overloads def ignore_overloaded(method_name): return method_name not in overload_name_mappings def make_stub(method): func = get_function_from_type(type(mod), method) return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func)) filtered_methods = filter(ignore_overloaded, methods) stubs = list(map(make_stub, filtered_methods)) return copy_to_script_module(mod, overload_stubs + stubs)