def resolve_generator(self, lambda_body: ast.AST, generators: List[ast.comprehension], node: ast.AST) -> ast.AST: """Translate a list comprehension or a generator to Select statements. `[j.pt() for j in jets] -> jets.Select(lambda j: j.pt())` Args: lambda_body (ast.AST): The target of the lambda expression generators (List[ast.comprehension]): The list of generators node (ast.AST): The original AST node Returns: ast.AST: The reformed ast (untouched if no expression detected) """ a = node for c in reversed(generators): target = c.target if not isinstance(target, ast.Name): raise ValueError( f"Comprehension variable must be a name, but found {target}" f" - {unparse_ast(node)}.") if c.is_async: raise ValueError( f"Comprehension can't be async - {unparse_ast(node)}.") source_collection = c.iter # Turn the if clauses into Where statements for a_if in c.ifs: where_function = lambda_build(target.id, a_if) source_collection = ast.Call( func=ast.Attribute(attr="Where", value=source_collection, ctx=ast.Load()), args=[where_function], keywords=[], ) lambda_function = lambda_build(target.id, lambda_body) a = ast.Call( func=ast.Attribute(attr="Select", value=source_collection, ctx=ast.Load()), args=[lambda_function], keywords=[], ) # In case we have chained comprehensions lambda_body = a return a
def visit_SelectMany_of_SelectMany(self, parent: ast.Call, selection: ast.Lambda): ''' Transformation #1: seq.SelectMany(x: f(x)).SelectMany(y: f(y)) => SelectMany(SelectMany(seq, x: f(x)), y: f(y)) is turned into: seq.SelectMany(x: f(x).SelectMany(y: f(y))) => SelectMany(seq, x: SelectMany(f(x), y: f(y))) ''' _, args = unpack_Call(parent) assert (args is not None) and len(args) == 2 seq = args[0] func_f = args[1] assert isinstance(func_f, ast.Lambda) func_g = selection captured_arg = func_f.args.args[0].arg captured_body = func_f.body new_select = function_call( 'SelectMany', [cast(ast.AST, captured_body), cast(ast.AST, func_g)]) new_select_lambda = lambda_build(captured_arg, new_select) new_selectmany = function_call( 'SelectMany', [seq, cast(ast.AST, new_select_lambda)]) return new_selectmany
def select_method_call_on_first(self, node: ast.Call): """Turn First(seq).method(args) into First(Select(seq, s: s.method(args))) """ # Extract the call info assert isinstance(node.func, ast.Attribute) method_name = node.func.attr method_args = node.args method_keywords = ( node.keywords if hasattr(node, "keywords") else node.kwargs # type: ignore ) assert isinstance(node.func.value, ast.Call) seq = node.func.value.args[0] # Now rebuild the call a = arg_name() call_args = { "func": ast.Attribute(value=ast.Name(a, ast.Load()), attr=method_name), "args": method_args, } if hasattr(node, "keywords"): call_args["keywords"] = method_keywords else: call_args["kwargs"] = method_keywords seq_a_call = ast.Call(**call_args) select = make_Select(seq, lambda_build(a, seq_a_call)) return self.visit(function_call("First", [cast(ast.AST, select)]))
def convolute(ast_g: ast.Lambda, ast_f: ast.Lambda): "Return an AST that represents g(f(args))" # Combine the lambdas into a single call by calling g with f as an argument l_g = make_args_unique(lambda_unwrap(ast_g)) l_f = make_args_unique(lambda_unwrap(ast_f)) x = arg_name() f_arg = ast.Name(x, ast.Load()) call_g = ast.Call(l_g, [ast.Call(l_f, [f_arg], [])], []) call_g_lambda = lambda_build(x, call_g) # Build a new call to nest the functions return call_g_lambda
def visit_Attribute_Of_First(self, first: ast.AST, attr: str): """ Convert a seq.First().attr ==> seq.Select(l: l.attr).First() Other work will do the conversion as needed. """ # Build the select that starts from the source and does the slice. a = arg_name() select = make_Select( first, lambda_build( a, ast.Attribute(value=ast.Name(a, ast.Load()), attr=attr))) return self.visit(function_call("First", [cast(ast.AST, select)]))
def visit_Subscript_Of_First(self, first: ast.AST, s): ''' Convert a seq.First()[0] ==> seq.Select(l: l[0]).First() Other work will do the conversion as needed. ''' # Build the select that starts from the source and does the slice. a = arg_name() select = make_Select( first, lambda_build(a, ast.Subscript(ast.Name(a, ast.Load()), s, ast.Load()))) return self.visit(function_call('First', [cast(ast.AST, select)]))
def visit_Where_of_Where(self, parent, filter): ''' seq.Where(x: f(x)).Where(x: g(x)) => Where(Where(seq, x: f(x)), y: g(y)) is turned into seq.Where(x: f(x) and g(y)) => Where(seq, x: f(x) and g(y)) ''' func_f = parent.filter func_g = filter arg = arg_name() return self.visit( Where( parent.source, lambda_build( arg, ast.BoolOp( ast.And(), [lambda_call(arg, func_f), lambda_call(arg, func_g)]))))
def visit_Where_of_Where(self, parent: ast.Call, filter: ast.Lambda): ''' seq.Where(x: f(x)).Where(x: g(x)) => Where(Where(seq, x: f(x)), y: g(y)) is turned into seq.Where(x: f(x) and g(y)) => Where(seq, x: f(x) and g(y)) ''' # Unpack arguments and f and g functions _, args = unpack_Call(parent) source = args[0] func_f = args[1] assert isinstance(func_f, ast.Lambda) func_g = filter arg = arg_name() convolution = lambda_build( arg, ast.BoolOp(ast.And(), [lambda_call(arg, func_f), lambda_call(arg, func_g)])) # type: ast.AST return self.visit(function_call('Where', [source, convolution]))
def test_lambda_build_proper(): "Make sure we are building the ast right for the version of python we are in" expr = ast.parse("x+1").body[0].value # type: ignore ln = lambda_build("x", expr) assert ast.dump(ast.parse("lambda x: x+1").body[0].value) == ast.dump( ln) # type: ignore
def test_lambda_build_list_arg(): expr = ast.parse("x+1") ln = lambda_build(["x"], expr) assert isinstance(ln, ast.Lambda)
def test_lambda_build_single_arg(): expr = ast.parse("x+1") ln = lambda_build("x", expr) assert isinstance(ln, ast.Lambda)