def compile_func(gen: 'Generator', func: Callable, strategy: Strategy, with_hooks: bool = False) -> Callable: """ The compilation basically assigns functionality to each of the operator calls as governed by the semantics (strategy). Memoization is done with the keys as the `func`, the class of the `strategy` and the `with_hooks` argument. Args: gen (Generator): The generator object containing the function to compile func (Callable): The function to compile strategy (Strategy): The strategy governing the behavior of the operators with_hooks (bool): Whether support for hooks is required Returns: The compiled function """ if isinstance(strategy, PartialReplayStrategy): strategy = strategy.backup_strategy if with_hooks: cache = CompilationCache.WITH_HOOKS[strategy.__class__] else: cache = CompilationCache.WITHOUT_HOOKS[strategy.__class__] if func in cache: return cache[func] cache[func] = None source_code, start_lineno = inspect.getsourcelines(func) source_code = ''.join(source_code) f_ast = astutils.parse(textwrap.dedent(source_code)) # This matches up line numbers with original file and is thus super useful for debugging ast.increment_lineno(f_ast, start_lineno - 1) # Remove the ``@generator`` decorator to avoid recursive compilation f_ast.decorator_list = [ d for d in f_ast.decorator_list if (not isinstance(d, ast.Name) or d.id != 'generator') and ( not isinstance(d, ast.Attribute) or d.attr != 'generator') and ( not (isinstance(d, ast.Call) and isinstance(d.func, ast.Name)) or d.func.id != 'generator') ] # Get all the external dependencies of this function. # We rely on a modified closure function adopted from the ``inspect`` library. closure_vars = getclosurevars_recursive(func, f_ast) g = {**closure_vars.nonlocals.copy(), **closure_vars.globals.copy()} known_ops: Set[str] = strategy.get_known_ops() known_methods: Set[str] = strategy.get_known_methods() op_info_constructor = OpInfoConstructor() delayed_compilations: List[Tuple[Generator, str]] = [] ops = {} handlers = {} op_infos = {} op_idx: int = 0 composition_cnt: int = 0 for n in astutils.preorder_traversal(f_ast): if isinstance(n, ast.Call) and isinstance( n.func, ast.Name) and n.func.id in known_ops: # Rename the function call, and assign a new function to be called during execution. # This new function is determined by the semantics (strategy) being used for compilation. # Also determine if there any eligible hooks for this operator call. op_idx += 1 handler_idx = len(handlers) op_info: OpInfo = op_info_constructor.get(n, gen.name, gen.group) n.keywords.append( ast.keyword(arg='model', value=ast.Name(_GEN_MODEL_VAR, ctx=ast.Load()))) n.keywords.append( ast.keyword(arg='op_info', value=ast.Name(f"_op_info_{op_idx}", ctx=ast.Load()))) op_infos[f"_op_info_{op_idx}"] = op_info n.keywords.append( ast.keyword(arg='handler', value=ast.Name(f"_handler_{handler_idx}", ctx=ast.Load()))) handler = strategy.get_op_handler(op_info) handlers[f"_handler_{handler_idx}"] = handler if not with_hooks: n.func = astutils.parse( f"{_GEN_STRATEGY_VAR}.generic_op").value else: n.keywords.append( ast.keyword(arg=_GEN_HOOK_VAR, value=ast.Name(_GEN_HOOK_VAR, ctx=ast.Load()))) n.keywords.append( ast.keyword(arg=_GEN_STRATEGY_VAR, value=ast.Name(_GEN_STRATEGY_VAR, ctx=ast.Load()))) n.func.id = _GEN_HOOK_WRAPPER ops[_GEN_HOOK_WRAPPER] = hook_wrapper ast.fix_missing_locations(n) elif isinstance(n, ast.Call) and isinstance( n.func, ast.Name) and n.func.id in known_methods: # Similar in spirit to the known_ops case, just much less fancy stuff to do. # Only need to get the right handler which we will achieve by simply making this # a method call instead of a regular call. n.func = ast.Attribute(value=ast.Name(_GEN_STRATEGY_VAR, ctx=ast.Load()), attr=n.func.id, ctx=ast.Load()) ast.fix_missing_locations(n) elif isinstance(n, ast.Call): # Try to check if it is a call to a Generator # TODO : Can we be more sophisticated in our static analysis here try: function = eval(astunparse.unparse(n.func), g) if isinstance(function, Generator): call_id = f"{_GEN_COMPOSITION_ID}_{composition_cnt}" composition_cnt += 1 n.func.id = call_id n.keywords.append( ast.keyword(arg=_GEN_EXEC_ENV_VAR, value=ast.Name(_GEN_EXEC_ENV_VAR, ctx=ast.Load()))) n.keywords.append( ast.keyword(arg=_GEN_STRATEGY_VAR, value=ast.Name(_GEN_STRATEGY_VAR, ctx=ast.Load()))) n.keywords.append( ast.keyword(arg=_GEN_MODEL_VAR, value=ast.Name(_GEN_MODEL_VAR, ctx=ast.Load()))) n.keywords.append( ast.keyword(arg=_GEN_HOOK_VAR, value=ast.Name(_GEN_HOOK_VAR, ctx=ast.Load()))) ast.fix_missing_locations(n) # We delay compilation to handle mutually recursive generators delayed_compilations.append((function, call_id)) except: pass # Add the execution environment argument to the function f_ast.args.kwonlyargs.append( ast.arg(arg=_GEN_EXEC_ENV_VAR, annotation=None)) f_ast.args.kw_defaults.append(ast.NameConstant(value=None)) # Add the strategy argument to the function f_ast.args.kwonlyargs.append( ast.arg(arg=_GEN_STRATEGY_VAR, annotation=None)) f_ast.args.kw_defaults.append(ast.NameConstant(value=None)) # Add the strategy argument to the function f_ast.args.kwonlyargs.append(ast.arg(arg=_GEN_MODEL_VAR, annotation=None)) f_ast.args.kw_defaults.append(ast.NameConstant(value=None)) # Add the hook argument to the function f_ast.args.kwonlyargs.append(ast.arg(arg=_GEN_HOOK_VAR, annotation=None)) f_ast.args.kw_defaults.append(ast.NameConstant(value=None)) ast.fix_missing_locations(f_ast) # New name so it doesn't clash with original func_name = f"{_GEN_COMPILED_TARGET_ID}_{len(cache)}" g.update({k: v for k, v in ops.items()}) g.update({k: v for k, v in handlers.items()}) g.update({k: v for k, v in op_infos.items()}) module = ast.Module() module.body = [f_ast] # Passing ``g`` to exec allows us to execute all the new functions # we assigned to every operator call in the previous AST walk exec(compile(module, filename=inspect.getabsfile(func), mode="exec"), g) result = g[func.__name__] if inspect.ismethod(func): result = result.__get__(func.__self__, func.__self__.__class__) # Restore the correct namespace so that tracebacks contain actual function names g[gen.name] = gen g[func_name] = result cache[func] = result # Handle the delayed compilations now that we have populated the cache for gen, call_id in delayed_compilations: compiled_func = compile_func(gen, gen.func, strategy, with_hooks) if gen.caching and isinstance(strategy, DfsStrategy): # Add instructions for using cached result if any g[call_id] = cache_wrapper(compiled_func) else: g[call_id] = compiled_func return result
def getclosurevars_recursive(func, f_ast: Optional[ast.FunctionDef] = None): """ The default getclosurevars doesn't go over nested function defs and list comprehensions. We write a recursive version of the same. The logic is borrowed from this post - https://bugs.python.org/issue34947 Args: f_ast (Optional[ast.FunctionDef]): The AST of the function if available. If not, an attempt will be made to retrieve the AST func (Callable): The function to inspect Returns: An instance of ClosureVars """ f_code = func.__code__ # Nonlocal references are named in co_freevars and resolved # by looking them up in __closure__ by positional index if func.__closure__ is None: nonlocal_vars = {} else: nonlocal_vars = { var: cell.cell_contents for var, cell in zip(f_code.co_freevars, func.__closure__) } annotation_names = [] try: if f_ast is None: f_ast: ast.FunctionDef = astutils.parse( textwrap.dedent(inspect.getsource(func))) for n in ast.walk(f_ast.args): if isinstance(n, ast.arg) and n.annotation is not None: annotation_names.extend(astutils.get_all_names(n.annotation)) if f_ast.returns is not None: annotation_names.extend(astutils.get_all_names(f_ast.returns)) except: pass # Global and builtin references are named in co_names and resolved # by looking them up in __globals__ or __builtins__ global_ns = func.__globals__ builtin_ns = global_ns.get("__builtins__", builtins.__dict__) if ismodule(builtin_ns): builtin_ns = builtin_ns.__dict__ global_vars = {} builtin_vars = {} unbound_names = set() codes = [f_code] while codes: # The logic is recursive but is implemented iteratively code = codes.pop() for name in code.co_names: if name in ("None", "True", "False"): # Because these used to be builtins instead of keywords, they # may still show up as name references. We ignore them. continue try: global_vars[name] = global_ns[name] except KeyError: try: builtin_vars[name] = builtin_ns[name] except KeyError: unbound_names.add(name) for const in code.co_consts: # Add the code to inspect recursively if iscode(const): codes.append(const) for name in annotation_names: try: global_vars[name] = global_ns[name] except KeyError: try: builtin_vars[name] = builtin_ns[name] except KeyError: unbound_names.add(name) return ClosureVars(nonlocal_vars, global_vars, builtin_vars, unbound_names)