def test_computed_col_def_updated(): df = DataFrame() df.jets['ptgev'] = lambda j: j.pt / 1000 d1 = df.jets.ptgev expr_1, context_1 = render(d1) assert isinstance(expr_1, ast.Call) assert isinstance(expr_1.func, ast_Callable) assert len(expr_1.args) == 1 expr_2, _ = render_callable(expr_1.func, context_1, expr_1.func.dataframe) # type: ignore # Run the render again. df.jets['ptgev'] = lambda j: j.pt / 1001 d1 = df.jets.ptgev expr_1, context_1 = render(d1) assert isinstance(expr_1, ast.Call) assert isinstance(expr_1.func, ast_Callable) assert len(expr_1.args) == 1 expr_2, _ = render_callable(expr_1.func, context_1, expr_1.func.dataframe) # type: ignore
def test_callable_wrong_number_args(): d = DataFrame() d1 = d.apply(lambda b: b) expr, ctx = render(d1) assert isinstance(expr, ast.Call) arg1 = expr.args[0] # type: ast.AST assert isinstance(arg1, ast_Callable) with pytest.raises(Exception): render_callable(arg1, ctx, d, d)
def test_callable_returns_const(): d = DataFrame() d1 = d.apply(lambda b: 20) expr, ctx = render(d1) assert isinstance(expr, ast.Call) arg1 = expr.args[0] # type: ast.AST assert isinstance(arg1, ast_Callable) expr1, new_ctx = render_callable(arg1, ctx, d) assert isinstance(expr1, ast.Num) assert expr1.n == 20
def test_callable_simple_call(): d = DataFrame() d1 = d.apply(lambda b: b) expr, ctx = render(d1) assert isinstance(expr, ast.Call) arg1 = expr.args[0] # type: ast.AST assert isinstance(arg1, ast_Callable) expr1, new_ctx = render_callable(arg1, ctx, d) assert isinstance(expr1, ast_DataFrame) assert expr1.dataframe is d
def test_render_callable_twice_for_same_results(): df = DataFrame() eles = df.Electrons() mc_part = df.TruthParticles() mc_ele = mc_part[mc_part.pdgId == 11] good_mc_ele = mc_ele[mc_ele.ptgev > 20] ele_mcs = eles.map(lambda reco_e: good_mc_ele) expr, context = render(ele_mcs) class find_callable(ast.NodeVisitor): @classmethod def findit(cls, a: ast.AST) -> Optional[ast_Callable]: f = find_callable() f.visit(a) return f._callable def __init__(self): ast.NodeVisitor.__init__(self) self._callable: Optional[ast_Callable] = None def visit_ast_Callable(self, a: ast_Callable): self._callable = a callable = find_callable.findit(expr) assert callable is not None c_expr1, c_context1 = render_callable(callable, context, callable.dataframe) c_expr2, c_context2 = render_callable(callable, context, callable.dataframe) assert ast.dump(c_expr1) == ast.dump(c_expr2) assert len(c_context1._seen_datasources) == len( c_context2._seen_datasources) assert len(c_context1._resolved) == len(c_context2._resolved)
def test_callable_captures_column(): d = DataFrame() d1 = d.jets.apply(lambda b: d.met > 20.0) expr, ctx = render(d1) assert isinstance(expr, ast.Call) assert isinstance(expr.func, ast.Attribute) root_of_call = expr.func.value assert isinstance(root_of_call, ast.Attribute) arg1 = expr.args[0] # type: ast.AST assert isinstance(arg1, ast_Callable) expr1, _ = render_callable(arg1, ctx, d.jets) assert isinstance(expr1, ast.Compare)
def test_callable_returns_matched_ast(): d = DataFrame() d1 = d.jets.apply(lambda b: b) expr, ctx = render(d1) assert isinstance(expr, ast.Call) assert isinstance(expr.func, ast.Attribute) root_of_call = expr.func.value assert isinstance(root_of_call, ast.Attribute) arg1 = expr.args[0] # type: ast.AST assert isinstance(arg1, ast_Callable) expr1, new_ctx = render_callable(arg1, ctx, d.jets) assert root_of_call is expr1
def test_callable_function(): def test_func(b): return b d = DataFrame() d1 = d.apply(test_func) expr, ctx = render(d1) assert isinstance(expr, ast.Call) arg1 = expr.args[0] # type: ast.AST assert isinstance(arg1, ast_Callable) expr1, new_ctx = render_callable(arg1, ctx, d) assert isinstance(expr1, ast_DataFrame) assert expr1.dataframe is d
def test_callable_context(): d = DataFrame() d1 = d.jets.apply(lambda b: b) expr, ctx = render(d1) assert isinstance(expr, ast.Call) arg1 = expr.args[0] # type: ast.AST assert isinstance(arg1, ast_Callable) expr1, _ = render_callable(arg1, ctx, arg1.dataframe) assert isinstance(expr.func, ast.Attribute) root_of_call = expr.func.value assert isinstance(root_of_call, ast.Attribute) assert root_of_call is expr1
def visit_Call(self, a: ast.Call): if not isinstance(a.func, ast_Callable): return self.generic_visit(a) assert len(a.args) == 1 # arg = self.visit(a.args[0]) expr, new_context = render_callable( cast(ast_Callable, a.func), self._context, a.func.dataframe) # type: ignore old_context = self._context try: self._context = new_context return self.visit(expr) finally: self._context = old_context
def test_lambda_for_computed_col(): df = DataFrame() df.jets['ptgev'] = lambda j: j.pt / 1000 d1 = df.jets.ptgev expr_1, context_1 = render(d1) assert isinstance(expr_1, ast.Call) assert isinstance(expr_1.func, ast_Callable) assert len(expr_1.args) == 1 a = expr_1.args[0] assert isinstance(a, ast.Attribute) expr_2, _ = render_callable(expr_1.func, context_1, expr_1.func.dataframe) # type: ignore assert isinstance(expr_2, ast.BinOp) assert isinstance(expr_2.left, ast.Attribute) assert expr_2.left.value is a
def test_callable_context_no_update(): d = DataFrame() d1 = d.apply(lambda b: b.jets.pt) expr, ctx = render(d1) assert isinstance(expr, ast.Call) arg1 = expr.args[0] # type: ast.AST assert isinstance(arg1, ast_Callable) seen_ds = len(ctx._seen_datasources) resolved = len(ctx._resolved) expr1, new_ctx = render_callable(arg1, ctx, d) assert len(new_ctx._seen_datasources) != len(ctx._seen_datasources) \ or len(new_ctx._resolved) != len(ctx._resolved) assert seen_ds == len(ctx._seen_datasources) \ or resolved == len(ctx._resolved)
def call_mapseq(self, node: ast.Call, value: ast_awkward) -> ast.AST: assert len(node.args) == 1, 'mapseq takes only one argument' c_func = node.args[0] if not isinstance(c_func, ast_Callable): return node # Generate the expression using a dataframe ast to get generic behavior # Replace it with the awkward array in the end a = DataFrame() df_expr, new_context = render_callable(cast(ast_Callable, c_func), self._context, a) expr = _replace_dataframe(df_expr, a, value) # Just run it through this processor to continue the evaluation. old_context, self._context = self._context, new_context try: e = self.visit(expr) return e finally: self._context = old_context
def _render_callable(a: ast.AST, callable: ast_Callable, context: render_context, tracker: _statement_tracker) -> term_info: # And the thing we want to call we can now render. expr, new_context = render_callable(callable, context, callable.dataframe) # In that expr there may be captured variables, or references to things that # are not in `value`. If that is the case, that means we need to add a monad to fetch # them from earlier in the process. root_expr = _find_root_expr(expr, tracker.sequence._ast) if root_expr is tracker.sequence._ast: # Just continuing on with the sequence already in place. assert _is_list(tracker.sequence.result_type) # or isinstance(tracker.sequence, statement_unwrap_list) # if _is_list(tracker.sequence.result_type): # s, t = _render_expression( # statement_unwrap_list(tracker.sequence._ast, tracker.sequence.result_type), # expr, new_context, tracker) # else: # s, t = _render_expression(tracker.sequence, expr, new_context, tracker) seq = tracker.sequence.unwrap_if_possible() s, t = _render_expression(seq, expr, new_context, tracker) assert t.term == 'main_sequence' if _is_list(tracker.sequence.result_type): s = [smt.wrap() for smt in s] if len(s) > 0: tracker.statements += s tracker.sequence = s[-1] return t elif root_expr is not None: monad_index = tracker.carry_monad_forward(root_expr) monad_ref = _monad_manager.new_monad_ref() # Create a pointer to the base monad - which is an object with tracker.substitute_ast( root_expr, _ast_VarRef( term_info(f'{monad_ref}[{monad_index}]', object, [monad_ref]))): # The var we are going to loop over is a pointer to the sequence. seq_as_object = tracker.sequence.unwrap_if_possible() select_var = new_term(seq_as_object.result_type) select_var_rep_ast = _ast_VarRef(select_var) with tracker.substitute_ast(tracker.sequence._ast, select_var_rep_ast): trm = _resolve_expr_inline(seq_as_object, expr, new_context, tracker) result_type = _type_replace(tracker.sequence.result_type, select_var.type, trm.type) st = statement_select(a, tracker.sequence.result_type, result_type, select_var, trm) if trm.has_monads(): st.prev_statement_is_monad() st.set_monad_ref(monad_ref) tracker.statements.append(st) tracker.sequence = st return term_info('main_sequence', st.result_type) else: # If root_expr is none, then whatever it is is a constant. So just select it. _render_expresion_as_transform(tracker, context, expr) return term_info('main_sequence', tracker.sequence.result_type)
def find_callable_and_render( expr: ast.AST, ctx: render_context) -> Tuple[ast.AST, render_context]: ff = find_callable() ff.visit(expr) assert ff.callable is not None return render_callable(ff.callable, ctx, ff.callable.dataframe)