def test_computed_col_def_updated():
    df = DataFrame()
    df.jets['ptgev'] = lambda j: j.pt / 1000
    d1 = df.jets.ptgev

    expr_1, context_1 = render(d1)

    assert isinstance(expr_1, ast.Call)
    assert isinstance(expr_1.func, ast_Callable)
    assert len(expr_1.args) == 1

    expr_2, _ = render_callable(expr_1.func, context_1,
                                expr_1.func.dataframe)  # type: ignore

    # Run the render again.
    df.jets['ptgev'] = lambda j: j.pt / 1001
    d1 = df.jets.ptgev

    expr_1, context_1 = render(d1)

    assert isinstance(expr_1, ast.Call)
    assert isinstance(expr_1.func, ast_Callable)
    assert len(expr_1.args) == 1

    expr_2, _ = render_callable(expr_1.func, context_1,
                                expr_1.func.dataframe)  # type: ignore
def test_callable_wrong_number_args():
    d = DataFrame()
    d1 = d.apply(lambda b: b)
    expr, ctx = render(d1)

    assert isinstance(expr, ast.Call)
    arg1 = expr.args[0]  # type: ast.AST
    assert isinstance(arg1, ast_Callable)

    with pytest.raises(Exception):
        render_callable(arg1, ctx, d, d)
def test_callable_returns_const():
    d = DataFrame()
    d1 = d.apply(lambda b: 20)
    expr, ctx = render(d1)

    assert isinstance(expr, ast.Call)
    arg1 = expr.args[0]  # type: ast.AST
    assert isinstance(arg1, ast_Callable)

    expr1, new_ctx = render_callable(arg1, ctx, d)
    assert isinstance(expr1, ast.Num)
    assert expr1.n == 20
def test_callable_simple_call():
    d = DataFrame()
    d1 = d.apply(lambda b: b)
    expr, ctx = render(d1)

    assert isinstance(expr, ast.Call)
    arg1 = expr.args[0]  # type: ast.AST
    assert isinstance(arg1, ast_Callable)

    expr1, new_ctx = render_callable(arg1, ctx, d)
    assert isinstance(expr1, ast_DataFrame)
    assert expr1.dataframe is d
def test_render_callable_twice_for_same_results():

    df = DataFrame()
    eles = df.Electrons()
    mc_part = df.TruthParticles()
    mc_ele = mc_part[mc_part.pdgId == 11]
    good_mc_ele = mc_ele[mc_ele.ptgev > 20]

    ele_mcs = eles.map(lambda reco_e: good_mc_ele)

    expr, context = render(ele_mcs)

    class find_callable(ast.NodeVisitor):
        @classmethod
        def findit(cls, a: ast.AST) -> Optional[ast_Callable]:
            f = find_callable()
            f.visit(a)
            return f._callable

        def __init__(self):
            ast.NodeVisitor.__init__(self)
            self._callable: Optional[ast_Callable] = None

        def visit_ast_Callable(self, a: ast_Callable):
            self._callable = a

    callable = find_callable.findit(expr)
    assert callable is not None

    c_expr1, c_context1 = render_callable(callable, context,
                                          callable.dataframe)
    c_expr2, c_context2 = render_callable(callable, context,
                                          callable.dataframe)

    assert ast.dump(c_expr1) == ast.dump(c_expr2)
    assert len(c_context1._seen_datasources) == len(
        c_context2._seen_datasources)
    assert len(c_context1._resolved) == len(c_context2._resolved)
def test_callable_captures_column():
    d = DataFrame()
    d1 = d.jets.apply(lambda b: d.met > 20.0)
    expr, ctx = render(d1)

    assert isinstance(expr, ast.Call)
    assert isinstance(expr.func, ast.Attribute)
    root_of_call = expr.func.value
    assert isinstance(root_of_call, ast.Attribute)

    arg1 = expr.args[0]  # type: ast.AST
    assert isinstance(arg1, ast_Callable)
    expr1, _ = render_callable(arg1, ctx, d.jets)

    assert isinstance(expr1, ast.Compare)
def test_callable_returns_matched_ast():
    d = DataFrame()
    d1 = d.jets.apply(lambda b: b)
    expr, ctx = render(d1)

    assert isinstance(expr, ast.Call)
    assert isinstance(expr.func, ast.Attribute)
    root_of_call = expr.func.value
    assert isinstance(root_of_call, ast.Attribute)

    arg1 = expr.args[0]  # type: ast.AST
    assert isinstance(arg1, ast_Callable)
    expr1, new_ctx = render_callable(arg1, ctx, d.jets)

    assert root_of_call is expr1
def test_callable_function():
    def test_func(b):
        return b

    d = DataFrame()
    d1 = d.apply(test_func)
    expr, ctx = render(d1)

    assert isinstance(expr, ast.Call)
    arg1 = expr.args[0]  # type: ast.AST
    assert isinstance(arg1, ast_Callable)

    expr1, new_ctx = render_callable(arg1, ctx, d)
    assert isinstance(expr1, ast_DataFrame)
    assert expr1.dataframe is d
def test_callable_context():
    d = DataFrame()
    d1 = d.jets.apply(lambda b: b)
    expr, ctx = render(d1)

    assert isinstance(expr, ast.Call)
    arg1 = expr.args[0]  # type: ast.AST
    assert isinstance(arg1, ast_Callable)

    expr1, _ = render_callable(arg1, ctx, arg1.dataframe)

    assert isinstance(expr.func, ast.Attribute)
    root_of_call = expr.func.value
    assert isinstance(root_of_call, ast.Attribute)

    assert root_of_call is expr1
        def visit_Call(self, a: ast.Call):
            if not isinstance(a.func, ast_Callable):
                return self.generic_visit(a)

            assert len(a.args) == 1
            # arg = self.visit(a.args[0])

            expr, new_context = render_callable(
                cast(ast_Callable,
                     a.func), self._context, a.func.dataframe)  # type: ignore
            old_context = self._context
            try:
                self._context = new_context
                return self.visit(expr)
            finally:
                self._context = old_context
def test_lambda_for_computed_col():
    df = DataFrame()
    df.jets['ptgev'] = lambda j: j.pt / 1000
    d1 = df.jets.ptgev

    expr_1, context_1 = render(d1)

    assert isinstance(expr_1, ast.Call)
    assert isinstance(expr_1.func, ast_Callable)
    assert len(expr_1.args) == 1
    a = expr_1.args[0]
    assert isinstance(a, ast.Attribute)

    expr_2, _ = render_callable(expr_1.func, context_1,
                                expr_1.func.dataframe)  # type: ignore
    assert isinstance(expr_2, ast.BinOp)
    assert isinstance(expr_2.left, ast.Attribute)
    assert expr_2.left.value is a
def test_callable_context_no_update():
    d = DataFrame()
    d1 = d.apply(lambda b: b.jets.pt)
    expr, ctx = render(d1)

    assert isinstance(expr, ast.Call)
    arg1 = expr.args[0]  # type: ast.AST
    assert isinstance(arg1, ast_Callable)

    seen_ds = len(ctx._seen_datasources)
    resolved = len(ctx._resolved)

    expr1, new_ctx = render_callable(arg1, ctx, d)

    assert len(new_ctx._seen_datasources) != len(ctx._seen_datasources) \
        or len(new_ctx._resolved) != len(ctx._resolved)

    assert seen_ds == len(ctx._seen_datasources) \
        or resolved == len(ctx._resolved)
Ejemplo n.º 13
0
    def call_mapseq(self, node: ast.Call, value: ast_awkward) -> ast.AST:
        assert len(node.args) == 1, 'mapseq takes only one argument'
        c_func = node.args[0]
        if not isinstance(c_func, ast_Callable):
            return node

        # Generate the expression using a dataframe ast to get generic behavior
        # Replace it with the awkward array in the end
        a = DataFrame()
        df_expr, new_context = render_callable(cast(ast_Callable, c_func),
                                               self._context, a)
        expr = _replace_dataframe(df_expr, a, value)

        # Just run it through this processor to continue the evaluation.
        old_context, self._context = self._context, new_context
        try:
            e = self.visit(expr)
            return e
        finally:
            self._context = old_context
Ejemplo n.º 14
0
def _render_callable(a: ast.AST, callable: ast_Callable,
                     context: render_context,
                     tracker: _statement_tracker) -> term_info:
    # And the thing we want to call we can now render.
    expr, new_context = render_callable(callable, context, callable.dataframe)

    # In that expr there may be captured variables, or references to things that
    # are not in `value`. If that is the case, that means we need to add a monad to fetch
    # them from earlier in the process.
    root_expr = _find_root_expr(expr, tracker.sequence._ast)
    if root_expr is tracker.sequence._ast:
        # Just continuing on with the sequence already in place.
        assert _is_list(tracker.sequence.result_type)
        # or isinstance(tracker.sequence, statement_unwrap_list)
        # if _is_list(tracker.sequence.result_type):
        #     s, t = _render_expression(
        #         statement_unwrap_list(tracker.sequence._ast, tracker.sequence.result_type),
        #         expr, new_context, tracker)
        # else:
        #     s, t = _render_expression(tracker.sequence, expr, new_context, tracker)
        seq = tracker.sequence.unwrap_if_possible()
        s, t = _render_expression(seq, expr, new_context, tracker)
        assert t.term == 'main_sequence'
        if _is_list(tracker.sequence.result_type):
            s = [smt.wrap() for smt in s]
        if len(s) > 0:
            tracker.statements += s
            tracker.sequence = s[-1]
        return t

    elif root_expr is not None:
        monad_index = tracker.carry_monad_forward(root_expr)
        monad_ref = _monad_manager.new_monad_ref()

        # Create a pointer to the base monad - which is an object
        with tracker.substitute_ast(
                root_expr,
                _ast_VarRef(
                    term_info(f'{monad_ref}[{monad_index}]', object,
                              [monad_ref]))):

            # The var we are going to loop over is a pointer to the sequence.
            seq_as_object = tracker.sequence.unwrap_if_possible()
            select_var = new_term(seq_as_object.result_type)
            select_var_rep_ast = _ast_VarRef(select_var)

            with tracker.substitute_ast(tracker.sequence._ast,
                                        select_var_rep_ast):
                trm = _resolve_expr_inline(seq_as_object, expr, new_context,
                                           tracker)

        result_type = _type_replace(tracker.sequence.result_type,
                                    select_var.type, trm.type)
        st = statement_select(a, tracker.sequence.result_type, result_type,
                              select_var, trm)
        if trm.has_monads():
            st.prev_statement_is_monad()
            st.set_monad_ref(monad_ref)

        tracker.statements.append(st)
        tracker.sequence = st
        return term_info('main_sequence', st.result_type)

    else:
        # If root_expr is none, then whatever it is is a constant. So just select it.
        _render_expresion_as_transform(tracker, context, expr)
        return term_info('main_sequence', tracker.sequence.result_type)
def find_callable_and_render(
        expr: ast.AST, ctx: render_context) -> Tuple[ast.AST, render_context]:
    ff = find_callable()
    ff.visit(expr)
    assert ff.callable is not None
    return render_callable(ff.callable, ctx, ff.callable.dataframe)