コード例 #1
0
ファイル: generators.py プロジェクト: merrymercy/atlas
def compile_func(gen: 'Generator',
                 func: Callable,
                 strategy: Strategy,
                 with_hooks: bool = False) -> Callable:
    """
    The compilation basically assigns functionality to each of the operator calls as
    governed by the semantics (strategy). Memoization is done with the keys as the `func`,
    the class of the `strategy` and the `with_hooks` argument.

    Args:
        gen (Generator): The generator object containing the function to compile
        func (Callable): The function to compile
        strategy (Strategy): The strategy governing the behavior of the operators
        with_hooks (bool): Whether support for hooks is required

    Returns:
        The compiled function

    """

    if isinstance(strategy, PartialReplayStrategy):
        strategy = strategy.backup_strategy

    if with_hooks:
        cache = CompilationCache.WITH_HOOKS[strategy.__class__]
    else:
        cache = CompilationCache.WITHOUT_HOOKS[strategy.__class__]

    if func in cache:
        return cache[func]

    cache[func] = None

    source_code, start_lineno = inspect.getsourcelines(func)
    source_code = ''.join(source_code)
    f_ast = astutils.parse(textwrap.dedent(source_code))

    # This matches up line numbers with original file and is thus super useful for debugging
    ast.increment_lineno(f_ast, start_lineno - 1)

    #  Remove the ``@generator`` decorator to avoid recursive compilation
    f_ast.decorator_list = [
        d for d in f_ast.decorator_list
        if (not isinstance(d, ast.Name) or d.id != 'generator') and (
            not isinstance(d, ast.Attribute) or d.attr != 'generator') and (
                not (isinstance(d, ast.Call) and isinstance(d.func, ast.Name))
                or d.func.id != 'generator')
    ]

    #  Get all the external dependencies of this function.
    #  We rely on a modified closure function adopted from the ``inspect`` library.
    closure_vars = getclosurevars_recursive(func, f_ast)
    g = {**closure_vars.nonlocals.copy(), **closure_vars.globals.copy()}
    known_ops: Set[str] = strategy.get_known_ops()
    known_methods: Set[str] = strategy.get_known_methods()
    op_info_constructor = OpInfoConstructor()
    delayed_compilations: List[Tuple[Generator, str]] = []

    ops = {}
    handlers = {}
    op_infos = {}
    op_idx: int = 0
    composition_cnt: int = 0
    for n in astutils.preorder_traversal(f_ast):
        if isinstance(n, ast.Call) and isinstance(
                n.func, ast.Name) and n.func.id in known_ops:
            #  Rename the function call, and assign a new function to be called during execution.
            #  This new function is determined by the semantics (strategy) being used for compilation.
            #  Also determine if there any eligible hooks for this operator call.
            op_idx += 1
            handler_idx = len(handlers)
            op_info: OpInfo = op_info_constructor.get(n, gen.name, gen.group)

            n.keywords.append(
                ast.keyword(arg='model',
                            value=ast.Name(_GEN_MODEL_VAR, ctx=ast.Load())))

            n.keywords.append(
                ast.keyword(arg='op_info',
                            value=ast.Name(f"_op_info_{op_idx}",
                                           ctx=ast.Load())))
            op_infos[f"_op_info_{op_idx}"] = op_info

            n.keywords.append(
                ast.keyword(arg='handler',
                            value=ast.Name(f"_handler_{handler_idx}",
                                           ctx=ast.Load())))
            handler = strategy.get_op_handler(op_info)
            handlers[f"_handler_{handler_idx}"] = handler

            if not with_hooks:
                n.func = astutils.parse(
                    f"{_GEN_STRATEGY_VAR}.generic_op").value
            else:
                n.keywords.append(
                    ast.keyword(arg=_GEN_HOOK_VAR,
                                value=ast.Name(_GEN_HOOK_VAR, ctx=ast.Load())))
                n.keywords.append(
                    ast.keyword(arg=_GEN_STRATEGY_VAR,
                                value=ast.Name(_GEN_STRATEGY_VAR,
                                               ctx=ast.Load())))

                n.func.id = _GEN_HOOK_WRAPPER
                ops[_GEN_HOOK_WRAPPER] = hook_wrapper

            ast.fix_missing_locations(n)

        elif isinstance(n, ast.Call) and isinstance(
                n.func, ast.Name) and n.func.id in known_methods:
            #  Similar in spirit to the known_ops case, just much less fancy stuff to do.
            #  Only need to get the right handler which we will achieve by simply making this
            #  a method call instead of a regular call.
            n.func = ast.Attribute(value=ast.Name(_GEN_STRATEGY_VAR,
                                                  ctx=ast.Load()),
                                   attr=n.func.id,
                                   ctx=ast.Load())
            ast.fix_missing_locations(n)

        elif isinstance(n, ast.Call):
            #  Try to check if it is a call to a Generator
            #  TODO : Can we be more sophisticated in our static analysis here
            try:
                function = eval(astunparse.unparse(n.func), g)
                if isinstance(function, Generator):
                    call_id = f"{_GEN_COMPOSITION_ID}_{composition_cnt}"
                    composition_cnt += 1
                    n.func.id = call_id
                    n.keywords.append(
                        ast.keyword(arg=_GEN_EXEC_ENV_VAR,
                                    value=ast.Name(_GEN_EXEC_ENV_VAR,
                                                   ctx=ast.Load())))
                    n.keywords.append(
                        ast.keyword(arg=_GEN_STRATEGY_VAR,
                                    value=ast.Name(_GEN_STRATEGY_VAR,
                                                   ctx=ast.Load())))
                    n.keywords.append(
                        ast.keyword(arg=_GEN_MODEL_VAR,
                                    value=ast.Name(_GEN_MODEL_VAR,
                                                   ctx=ast.Load())))
                    n.keywords.append(
                        ast.keyword(arg=_GEN_HOOK_VAR,
                                    value=ast.Name(_GEN_HOOK_VAR,
                                                   ctx=ast.Load())))
                    ast.fix_missing_locations(n)

                    #  We delay compilation to handle mutually recursive generators
                    delayed_compilations.append((function, call_id))

            except:
                pass

    #  Add the execution environment argument to the function
    f_ast.args.kwonlyargs.append(
        ast.arg(arg=_GEN_EXEC_ENV_VAR, annotation=None))
    f_ast.args.kw_defaults.append(ast.NameConstant(value=None))

    #  Add the strategy argument to the function
    f_ast.args.kwonlyargs.append(
        ast.arg(arg=_GEN_STRATEGY_VAR, annotation=None))
    f_ast.args.kw_defaults.append(ast.NameConstant(value=None))

    #  Add the strategy argument to the function
    f_ast.args.kwonlyargs.append(ast.arg(arg=_GEN_MODEL_VAR, annotation=None))
    f_ast.args.kw_defaults.append(ast.NameConstant(value=None))

    #  Add the hook argument to the function
    f_ast.args.kwonlyargs.append(ast.arg(arg=_GEN_HOOK_VAR, annotation=None))
    f_ast.args.kw_defaults.append(ast.NameConstant(value=None))
    ast.fix_missing_locations(f_ast)

    #  New name so it doesn't clash with original
    func_name = f"{_GEN_COMPILED_TARGET_ID}_{len(cache)}"

    g.update({k: v for k, v in ops.items()})
    g.update({k: v for k, v in handlers.items()})
    g.update({k: v for k, v in op_infos.items()})

    module = ast.Module()
    module.body = [f_ast]

    #  Passing ``g`` to exec allows us to execute all the new functions
    #  we assigned to every operator call in the previous AST walk
    exec(compile(module, filename=inspect.getabsfile(func), mode="exec"), g)
    result = g[func.__name__]

    if inspect.ismethod(func):
        result = result.__get__(func.__self__, func.__self__.__class__)

    #  Restore the correct namespace so that tracebacks contain actual function names
    g[gen.name] = gen
    g[func_name] = result

    cache[func] = result

    #  Handle the delayed compilations now that we have populated the cache
    for gen, call_id in delayed_compilations:
        compiled_func = compile_func(gen, gen.func, strategy, with_hooks)
        if gen.caching and isinstance(strategy, DfsStrategy):
            #  Add instructions for using cached result if any
            g[call_id] = cache_wrapper(compiled_func)

        else:
            g[call_id] = compiled_func

    return result
コード例 #2
0
def getclosurevars_recursive(func, f_ast: Optional[ast.FunctionDef] = None):
    """
    The default getclosurevars doesn't go over nested function defs and list comprehensions.
    We write a recursive version of the same.
    The logic is borrowed from this post - https://bugs.python.org/issue34947
    Args:
        f_ast (Optional[ast.FunctionDef]): The AST of the function if available.
            If not, an attempt will be made to retrieve the AST
        func (Callable): The function to inspect

    Returns:
        An instance of ClosureVars

    """
    f_code = func.__code__
    # Nonlocal references are named in co_freevars and resolved
    # by looking them up in __closure__ by positional index
    if func.__closure__ is None:
        nonlocal_vars = {}
    else:
        nonlocal_vars = {
            var: cell.cell_contents
            for var, cell in zip(f_code.co_freevars, func.__closure__)
        }

    annotation_names = []
    try:
        if f_ast is None:
            f_ast: ast.FunctionDef = astutils.parse(
                textwrap.dedent(inspect.getsource(func)))

        for n in ast.walk(f_ast.args):
            if isinstance(n, ast.arg) and n.annotation is not None:
                annotation_names.extend(astutils.get_all_names(n.annotation))
        if f_ast.returns is not None:
            annotation_names.extend(astutils.get_all_names(f_ast.returns))

    except:
        pass

    # Global and builtin references are named in co_names and resolved
    # by looking them up in __globals__ or __builtins__
    global_ns = func.__globals__
    builtin_ns = global_ns.get("__builtins__", builtins.__dict__)
    if ismodule(builtin_ns):
        builtin_ns = builtin_ns.__dict__
    global_vars = {}
    builtin_vars = {}
    unbound_names = set()
    codes = [f_code]
    while codes:
        #  The logic is recursive but is implemented iteratively
        code = codes.pop()
        for name in code.co_names:
            if name in ("None", "True", "False"):
                # Because these used to be builtins instead of keywords, they
                # may still show up as name references. We ignore them.
                continue
            try:
                global_vars[name] = global_ns[name]
            except KeyError:
                try:
                    builtin_vars[name] = builtin_ns[name]
                except KeyError:
                    unbound_names.add(name)

        for const in code.co_consts:
            #  Add the code to inspect recursively
            if iscode(const):
                codes.append(const)

    for name in annotation_names:
        try:
            global_vars[name] = global_ns[name]
        except KeyError:
            try:
                builtin_vars[name] = builtin_ns[name]
            except KeyError:
                unbound_names.add(name)

    return ClosureVars(nonlocal_vars, global_vars, builtin_vars, unbound_names)