def test_second_dr_returns_filtered():
    df = DataFrame()

    @user_func
    def DeltaR(p1_eta: float) -> float:
        '''
        Calculate the DeltaR between two particles given their `eta` and `phi` locations.
        Implemented on the back end.
        '''
        assert False, 'This should never be called'

    mc_part = df.TruthParticles('TruthParticles')
    eles = df.Electrons('Electrons')

    def dr(e, mc):
        '''Make calculating DR easier as I have a hard-to-use DR calculation function on
        the back end'''
        return DeltaR(e.eta())

    def very_near2(mcs, e):
        'Return all particles in mcs that are DR less than 0.5'
        return mcs[lambda m: dr(e, m) < 0.1]

    eles['near_mcs'] = lambda reco_e: very_near2(mc_part, reco_e)

    eles['hasMC'] = lambda e: e.near_mcs.Count() > 0
    good_eles_with_mc = eles[eles.hasMC]
    good_eles_with_mc['mc'] = lambda e: e.near_mcs.First().ptgev

    d1 = good_eles_with_mc.mc

    expr_1, context_1 = render(d1)

    class render_in_depth(ast.NodeTransformer):
        def __init__(self, context):
            ast.NodeTransformer.__init__(self)
            self._context = context

        def visit_Call(self, a: ast.Call):
            if not isinstance(a.func, ast_Callable):
                return self.generic_visit(a)

            assert len(a.args) == 1
            # arg = self.visit(a.args[0])

            expr, new_context = render_callable(
                cast(ast_Callable,
                     a.func), self._context, a.func.dataframe)  # type: ignore
            old_context = self._context
            try:
                self._context = new_context
                return self.visit(expr)
            finally:
                self._context = old_context

    assert isinstance(expr_1, ast.Call)

    rendered = render_in_depth(context_1).visit(expr_1)
    assert rendered is not None
def test_render_twice_for_same_results():

    df = DataFrame()
    eles = df.Electrons()
    mc_part = df.TruthParticles()
    mc_ele = mc_part[mc_part.pdgId == 11]
    good_mc_ele = mc_ele[mc_ele.ptgev > 20]

    ele_mcs = eles.map(lambda reco_e: good_mc_ele)

    expr1, context1 = render(ele_mcs)
    expr2, context2 = render(ele_mcs)

    assert ast.dump(expr1) == ast.dump(expr2)
    assert len(context1._resolved) == len(context2._resolved)
    assert len(context1._seen_datasources) == len(context2._seen_datasources)
def test_nested_col_access():
    df = DataFrame()

    @user_func
    def DeltaR(p1_eta: float) -> float:
        '''
        Calculate the DeltaR between two particles given their `eta` and `phi` locations.
        Implemented on the back end.
        '''
        assert False, 'This should never be called'

    mc_part = df.TruthParticles('TruthParticles')
    eles = df.Electrons('Electrons')

    def dr(e, mc):
        '''Make calculating DR easier as I have a hard-to-use DR calculation function on the
        back end'''
        return DeltaR(e.eta())

    def very_near2(mcs, e):
        'Return all particles in mcs that are DR less than 0.5'
        return mcs[lambda m: dr(e, m) < 0.1]

    eles['near_mcs'] = lambda reco_e: very_near2(mc_part, reco_e)

    eles['hasMC'] = lambda e: e.near_mcs.Count() > 0
    good_eles_with_mc = eles[eles.hasMC]
    good_eles_with_mc['mc'] = lambda e: e.near_mcs.First().ptgev

    d1 = good_eles_with_mc.mc

    assert d1.filter is None
    assert d1.child_expr is not None
    assert isinstance(d1.child_expr, ast.Call)
    assert isinstance(d1.child_expr.func, ast_Callable)

    assert isinstance(d1.child_expr, ast.Call)
    assert len(d1.child_expr.args) == 1
    assert isinstance(d1.child_expr.args[0], ast_DataFrame)
    p = cast(ast_DataFrame, d1.child_expr.args[0]).dataframe

    assert cast(ast_Callable, d1.child_expr.func).dataframe is p
def test_different_callables_look_different():
    # This is returning a recursive reference sometimes, due to a bug (every ast_Callable
    # looked the same).
    df = DataFrame()

    mc_part = df.TruthParticles('TruthParticles')
    eles = df.Electrons('Electrons')

    # This gives us a list of events, and in each event, good electrons, and then for each good
    # electron, all good MC electrons that are near by
    eles['near_mcs'] = lambda reco_e: mc_part
    eles['hasMC'] = lambda e: e.near_mcs.Count() > 0

    expr, context = render(eles[~eles.hasMC].pt)

    expr2, context_2 = find_callable_and_render(expr, context)

    expr3, _ = find_callable_and_render(expr2, context_2)

    assert ast.dump(expr2) != ast.dump(expr3)
def test_render_callable_twice_for_same_results():

    df = DataFrame()
    eles = df.Electrons()
    mc_part = df.TruthParticles()
    mc_ele = mc_part[mc_part.pdgId == 11]
    good_mc_ele = mc_ele[mc_ele.ptgev > 20]

    ele_mcs = eles.map(lambda reco_e: good_mc_ele)

    expr, context = render(ele_mcs)

    class find_callable(ast.NodeVisitor):
        @classmethod
        def findit(cls, a: ast.AST) -> Optional[ast_Callable]:
            f = find_callable()
            f.visit(a)
            return f._callable

        def __init__(self):
            ast.NodeVisitor.__init__(self)
            self._callable: Optional[ast_Callable] = None

        def visit_ast_Callable(self, a: ast_Callable):
            self._callable = a

    callable = find_callable.findit(expr)
    assert callable is not None

    c_expr1, c_context1 = render_callable(callable, context,
                                          callable.dataframe)
    c_expr2, c_context2 = render_callable(callable, context,
                                          callable.dataframe)

    assert ast.dump(c_expr1) == ast.dump(c_expr2)
    assert len(c_context1._seen_datasources) == len(
        c_context2._seen_datasources)
    assert len(c_context1._resolved) == len(c_context2._resolved)