def visit_If(self, node: ast.If) -> cppast.IfStmt: condition: cppast.Expr = self.visit(node.test) then_body = cppast.CompoundStmt([self.visit(b) for b in node.body]) else_body = cppast.CompoundStmt([self.visit(b) for b in node.orelse]) if node.orelse else None ifstmt = cppast.IfStmt(condition, then_body, else_body) return ifstmt
def visit_FunctionDef( self, node: ast.FunctionDef ) -> Union[cppast.ConstructorDecl, cppast.MethodDecl]: name: str = node.name return_type: Optional[cppast.ClassType] if name == "__init__": name = node.parent.name return_type = None elif self.is_void_function(node): return_type = cppast.ClassType("void") else: return_type = visitors_util.get_type(node.returns, self.pk_import) if len(node.args.args) == 0 or node.args.args[0].arg != "self": self.error(node, "Static functions are not supported") params: List[cppast.ParmVarDecl] = self.visit(node.args) body = cppast.CompoundStmt([self.visit(b) for b in node.body]) attributes: str = "KOKKOS_FUNCTION" if return_type is None: return cppast.ConstructorDecl(attributes, name, params, body) else: return cppast.MethodDecl(attributes, return_type, name, params, body)
def visit_While(self, node: ast.While) -> cppast.WhileStmt: if node.orelse: self.error(node.orelse, "Else clause not supported for translation") condition: cppast.Expr = self.visit(node.test) body = cppast.CompoundStmt([self.visit(b) for b in node.body]) whilestmt = cppast.WhileStmt(condition, body) return whilestmt
def visit_For(self, node: ast.For) -> cppast.ForStmt: if not isinstance(node.target, ast.Name): self.error(node.target, "Must use single loop variable") if node.orelse: self.error(node.orelse, "Else clause not supported for translation") if ( not isinstance(node.iter, ast.Call) or node.iter.func.id != "range" ): # TODO: support other iterators? self.error( node.iter, "Only range() iterator is supported for translation") index: cppast.DeclRefExpr = self.visit(node.target) start: cppast.Expr end: cppast.Expr step: cppast.Expr = cppast.IntegerLiteral(1) op = cppast.BinaryOperatorKind.LT args = node.iter.args if len(args) == 1: start = cppast.IntegerLiteral(0) end = self.visit(args[0]) else: start = self.visit(args[0]) end = self.visit(args[1]) if len(args) == 3: step = self.visit(args[2]) # Negative step sizes are only handled correctly if they're # written with a preceeding minus sign if ( isinstance(args[2], ast.UnaryOp) and isinstance(args[2].op, ast.USub) ): op = cppast.BinaryOperatorKind.GT body = cppast.CompoundStmt([self.visit(b) for b in node.body]) init = cppast.DeclStmt(cppast.VarDecl( cppast.PrimitiveType(cppast.BuiltinType.INT), index, start)) condition = cppast.BinaryOperator(index, end, op) increment = cppast.BinaryOperator( index, step, cppast.BinaryOperatorKind.AddAssign) forstmt = cppast.ForStmt(init, condition, increment, body) return forstmt
def visit_FunctionDef( self, node: ast.FunctionDef ) -> Union[str, Tuple[str, cppast.MethodDecl]]: if self.is_nested_call(node): params: List[cppast.ParmVarDecl] = [ a for a in self.visit(node.args) ] body = cppast.CompoundStmt([self.visit(b) for b in node.body]) workunit = cppast.LambdaExpr("[&]", params, body) self.nested_work_units[node.name] = workunit return "" else: operation: Optional[str] = self.get_operation_type(node) if operation is None: self.error(node.args, "Incorrect types in workunit definition") tag_type = cppast.ClassType(f"const {node.name}") tag_type.is_reference = True tag = cppast.ParmVarDecl(tag_type, cppast.DeclRefExpr("")) params: List[cppast.ParmVarDecl] = [tag] params.extend(self.visit(node.args)) body = cppast.CompoundStmt([self.visit(b) for b in node.body]) attributes: str = "KOKKOS_FUNCTION" decltype = cppast.ClassType("void") declname: str = "operator()" method = cppast.MethodDecl(attributes, decltype, declname, params, body) method.is_const = True return (operation, method)
def visit_AnnAssign(self, node: ast.AnnAssign) -> cppast.Stmt: if isinstance(node.value, ast.Call): decltype: cppast.Type = visitors_util.get_type( node.annotation, self.pk_import) if decltype is None: self.error(node, "Type not supported") declname: cppast.DeclRefExpr = self.visit(node.target) function_name: str = visitors_util.get_node_name(node.value.func) # Call to a TeamMember method if function_name in dir(TeamMember): vardecl = cppast.VarDecl(decltype, declname, self.visit(node.value)) return cppast.DeclStmt(vardecl) # Nested parallelism if function_name in ("parallel_reduce", "parallel_scan"): args: List[cppast.Expr] = [ self.visit(a) for a in node.value.args ] initial_value: cppast.Expr if len(args) == 3: initial_value = args[2] else: initial_value = cppast.IntegerLiteral(0) vardecl = cppast.VarDecl(decltype, declname, initial_value) declstmt = cppast.DeclStmt(vardecl) work_unit: str = args[1].declname function = cppast.DeclRefExpr(f"Kokkos::{function_name}") call: cppast.CallExpr if work_unit in self.nested_work_units: call = cppast.CallExpr( function, [args[0], self.nested_work_units[work_unit], declname]) else: call = cppast.CallExpr( function, [args[0], f"pk_id_{work_unit}", declname]) callstmt = cppast.CallStmt(call) return cppast.CompoundStmt([declstmt, callstmt]) return super().visit_AnnAssign(node)
def generate_constructor( name: str, fields: Dict[cppast.DeclRefExpr, cppast.PrimitiveType], views: Dict[cppast.DeclRefExpr, cppast.ClassType] ) -> cppast.ConstructorDecl: """ Generate the functor constructor :param name: the functor class name :param fields: a dict mapping from field name to type :param views: a dict mapping from view name to type :returns: the cppast representation of the constructor """ params: List[cppast.ParmVarDecl] = [] assignments: List[cppast.AssignOperator] = [] for n, t in fields.items(): params.append(cppast.ParmVarDecl(t, n)) for n, t in views.items(): # skip subviews if t is None: continue view_type: str = get_view_type(t) params.append(cppast.ParmVarDecl(view_type, n)) # Kokkos fails to compile a functor if there are no parameters in its constructor if len(params) == 0: decl = cppast.DeclRefExpr("pk_field") type = cppast.PrimitiveType(cppast.BuiltinType.INT) params.append(cppast.ParmVarDecl(type, decl)) assignments.extend(generate_assignments(fields)) # skip subviews assignments.extend(generate_assignments({v: views[v] for v in views if views[v]})) body = cppast.CompoundStmt(assignments) return cppast.ConstructorDecl("", name, params, body)
def visit_FunctionDef(self, node: ast.FunctionDef) -> cppast.MethodDecl: if not self.is_valid_kokkos_function(node): self.error(node, "Invalid Kokkos function") return_type: cppast.ClassType if self.is_void_function(node): return_type = cppast.ClassType("void") else: return_type = visitors_util.get_type(node.returns, self.pk_import) if return_type is None: self.error(node, "Return type is not supported for translation") params: List[cppast.ParmVarDecl] = self.visit(node.args) name: str = node.name body = cppast.CompoundStmt([self.visit(b) for b in node.body]) attributes: str = "KOKKOS_FUNCTION" method = cppast.MethodDecl(attributes, return_type, name, params, body) method.is_const = True return method
def visit_Assign(self, node: ast.Assign) -> cppast.Stmt: target = node.targets[0] if isinstance(node.value, ast.Call): name: str = visitors_util.get_node_name(node.value.func) # Create Timer object if name == "Timer": decltype = cppast.ClassType("Kokkos::Timer") declname = cppast.DeclRefExpr("timer") return cppast.DeclStmt(cppast.VarDecl(decltype, declname, None)) # Call Timer.seconds() if name == "seconds": target_name: str = visitors_util.get_node_name(target) if target_name not in self.timer_result_queue: self.timer_result_queue.append(target_name) call = cppast.CallStmt(self.visit(node.value)) target_ref = cppast.DeclRefExpr(target_name) target_view_ref = cppast.DeclRefExpr( f"timer_result_{target_name}") subscript = cppast.ArraySubscriptExpr( target_view_ref, [cppast.IntegerLiteral(0)]) assign_op = cppast.BinaryOperatorKind.Assign # Holds the result of the reduction temporarily temp_ref = cppast.DeclRefExpr("pk_acc") target_assign = cppast.AssignOperator([target_ref], temp_ref, assign_op) view_assign = cppast.AssignOperator([subscript], target_ref, assign_op) return cppast.CompoundStmt([call, target_assign, view_assign]) if name in ("BinSort", "BinOp1D", "BinOp3D"): args: List = node.value.args # if not isinstance(args[0], ast.Attribute): # self.error(node.value, "First argument has to be a view") view = cppast.DeclRefExpr(visitors_util.get_node_name(args[0])) if view not in self.views: self.error(args[0], "Undefined view") view_type: cppast.ClassType = self.views[view] is_subview: bool = view_type is None if is_subview: parent_view = cppast.DeclRefExpr( self.subviews[view.declname]) view_type = self.views[parent_view] view_type_str: str = visitors_util.cpp_view_type(view_type) if name != "BinSort": dimension: int = 1 if name == "BinOp1D" else 3 cpp_type = cppast.DeclRefExpr( BinOp.get_type(dimension, view_type_str)) # Do not translate the first argument (view) constructor = cppast.CallExpr( cpp_type, [self.visit(a) for a in args[1:]]) else: bin_op_type: str = f"decltype({visitors_util.get_node_name(args[1])})" cpp_type = cppast.DeclRefExpr( BinSort.get_type(view_type_str, bin_op_type)) binsort_args: List[cppast.DeclRefExpr] = [ self.visit(a) for a in args ] constructor = cppast.CallExpr(cpp_type, binsort_args) cpp_target: cppast.DeclRefExpr = self.visit(target) auto_type = cppast.ClassType("auto") return cppast.DeclStmt( cppast.VarDecl(auto_type, cpp_target, constructor)) if name in ("get_bin_count", "get_bin_offsets", "get_permute_vector"): if not isinstance(target, ast.Attribute) or target.value.id != "self": self.error( node, "Views defined in pk.main must be an instance variable" ) cpp_target: str = visitors_util.get_node_name(target) cpp_device_target = f"pk_d_{cpp_target}" cpp_target_ref = cppast.DeclRefExpr(cpp_device_target) sorter: cppast.DeclRefExpr = self.visit(node.value.func.value) initial_target_ref = cppast.DeclRefExpr( f"_pk_{cpp_target_ref.declname}") function = cppast.MemberCallExpr(sorter, cppast.DeclRefExpr(name), []) # Add to the dict of declarations made in pk.main if name == "get_permute_vector": # This occurs when a workload is executed multiple times # Initially the view has not been defined in the workload, # so it needs to be classified as a pkmain_view. if cpp_target in self.views: self.views[cpp_target_ref].add_template_param( cppast.PrimitiveType(cppast.BuiltinType.INT)) return cppast.AssignOperator( [cpp_target_ref], function, cppast.BinaryOperatorKind.Assign) # return f"{cpp_target} = {sorter}.{name}();" self.pkmain_views[cpp_target_ref] = cppast.ClassType( "View1D") else: self.pkmain_views[cpp_target_ref] = None auto_type = cppast.ClassType("auto") decl = cppast.DeclStmt( cppast.VarDecl(auto_type, initial_target_ref, function)) # resize the workload's vector to match the generated vector resize_call = cppast.CallStmt( cppast.CallExpr(cppast.DeclRefExpr("Kokkos::resize"), [ cpp_target_ref, cppast.MemberCallExpr(initial_target_ref, cppast.DeclRefExpr("extent"), [cppast.IntegerLiteral(0)]) ])) copy_call = cppast.CallStmt( cppast.CallExpr(cppast.DeclRefExpr("Kokkos::deep_copy"), [cpp_target_ref, initial_target_ref])) # Assign to the functor after resizing functor = cppast.DeclRefExpr("pk_f") functor_access = cppast.MemberExpr(functor, cpp_target) functor_assign = cppast.AssignOperator( [functor_access], cpp_target_ref, cppast.BinaryOperatorKind.Assign) return cppast.CompoundStmt( [decl, resize_call, copy_call, functor_assign]) # Assign result of parallel_reduce if type(target) not in {ast.Name, ast.Subscript } and target.value.id == "self": target_name: str = visitors_util.get_node_name(target) if target_name not in self.reduction_result_queue: self.reduction_result_queue.append(target_name) call = cppast.CallStmt(self.visit(node.value)) target_ref = cppast.DeclRefExpr(target_name) target_view_ref = cppast.DeclRefExpr( f"reduction_result_{target_name}") subscript = cppast.ArraySubscriptExpr(target_view_ref, [cppast.IntegerLiteral(0)]) assign_op = cppast.BinaryOperatorKind.Assign # Holds the result of the reduction temporarily temp_ref = cppast.DeclRefExpr("pk_acc") target_assign = cppast.AssignOperator([target_ref], temp_ref, assign_op) view_assign = cppast.AssignOperator([subscript], target_ref, assign_op) return cppast.CompoundStmt([call, target_assign, view_assign]) return super().visit_Assign(node)
def visit_Call(self, node: ast.Call) -> Union[cppast.Expr, cppast.Stmt]: name: str = visitors_util.get_node_name(node.func) args: List[cppast.Expr] = [self.visit(a) for a in node.args] # Add pk_d_ before each view name to match mirror view names s = cppast.Serializer() for i in range(len(args)): if args[i] in self.views: if self.views[args[i]] is not None: view: str = s.serialize(args[i]) args[i] = cppast.DeclRefExpr(f"pk_d_{view}") # Nested parallelism if name == "TeamPolicy": function = cppast.DeclRefExpr(f"Kokkos::{name}") if len(args) == 2: args.append(cppast.IntegerLiteral(1)) policy = cppast.ConstructExpr(function, args) return policy elif name in ["RangePolicy", "MDRangePolicy"]: rank = len(node.args[0].elts) if rank == 0: self.error(node.value, "RangePolicy dimension must be greater than 0") if rank != len(node.args[1].elts): self.error(node.value, "RangePolicy dimension mismatch") iter_outer = Iterate.Default iter_inner = Iterate.Default for keyword in node.keywords: if keyword.arg == "rank": explicit_rank = keyword.value.args[0].value if explicit_rank != rank: self.error(node.value, "RangePolicy dimension mismatch") iter_outer = getattr(Iterate, keyword.value.args[1].attr) iter_inner = getattr(Iterate, keyword.value.args[2].attr) function = cppast.DeclRefExpr( f"Kokkos::{name}<Kokkos::Rank<{rank},{iter_outer.value},{iter_inner.value}>>" ) policy = cppast.ConstructExpr( cppast.DeclRefExpr(f"Kokkos::{name}"), args) if name == "MDRangePolicy": policy.add_template_param( cppast.DeclRefExpr( f"Kokkos::Rank<{rank},{iter_outer.value},{iter_inner.value}>" )) return policy if name == "seconds": fence = cppast.CallStmt( cppast.CallExpr(cppast.DeclRefExpr("Kokkos::fence"), [])) temp_decl = cppast.DeclRefExpr("pk_acc") seconds = cppast.MemberCallExpr(cppast.DeclRefExpr("timer"), cppast.DeclRefExpr("seconds"), []) result = cppast.AssignOperator([temp_decl], seconds, cppast.BinaryOperatorKind.Assign) return cppast.CompoundStmt([fence, result]) function = cppast.DeclRefExpr(f"Kokkos::{name}") if name == "parallel_for": arg_start: int = 0 # Accounts for the optional kernel name kernel_name: Optional[cppast.StringLiteral] = None if isinstance(args[0], cppast.StringLiteral): kernel_name = args[0] arg_start = 1 policy: cppast.ConstructExpr = args[arg_start] # Replace the number of threads with a RangePolicy if not isinstance(policy, cppast.ConstructExpr): begin = cppast.IntegerLiteral(0) end = args[arg_start] policy = cppast.ConstructExpr( cppast.DeclRefExpr("Kokkos::RangePolicy"), [begin, end]) space = cppast.DeclRefExpr(Keywords.DefaultExecSpace.value) policy.add_template_param(space) if isinstance(node.args[arg_start + 1], ast.Lambda): decl: str = "KOKKOS_LAMBDA (" tid = cppast.DeclRefExpr(node.args[arg_start + 1].args.args[0].arg) # if target exists if len(args) == arg_start + 3: target = cppast.ArraySubscriptExpr(args[arg_start + 2], [tid]) args[arg_start + 1] = cppast.AssignOperator( [target], args[arg_start + 1], cppast.BinaryOperatorKind.Assign) serializer = cppast.Serializer() decl += f"int {tid.declname}) {{" decl += serializer.serialize(args[arg_start + 1]) + ";}\n" call_args: List[cppast.Expr] = [policy, decl] if kernel_name is not None: call_args.insert(0, kernel_name) return cppast.CallExpr(function, call_args) else: work_unit: str = args[arg_start + 1].declname policy.add_template_param( cppast.DeclRefExpr(f"{self.functor}::{work_unit}")) call_args: List[cppast.Expr] = [ policy, cppast.DeclRefExpr("pk_f") ] if kernel_name is not None: call_args.insert(0, kernel_name) return cppast.CallExpr(function, call_args) if name in ("parallel_reduce", "parallel_scan"): arg_start: int = 0 # Accounts for the optional kernel name kernel_name: Optional[cppast.StringLiteral] = None if isinstance(args[0], cppast.StringLiteral): kernel_name = args[0] arg_start = 1 initial_value: cppast.Expr if len(args) == arg_start + 3: initial_value = args[arg_start + 2] else: initial_value = cppast.IntegerLiteral(0) acc_decl = cppast.DeclRefExpr("pk_acc") init_var = cppast.BinaryOperator(acc_decl, initial_value, cppast.BinaryOperatorKind.Assign) policy: cppast.ConstructExpr = args[arg_start] # Replace the number of threads with a RangePolicy if not isinstance(policy, cppast.ConstructExpr): begin = cppast.IntegerLiteral(0) end = args[arg_start] policy = cppast.ConstructExpr( cppast.DeclRefExpr("Kokkos::RangePolicy"), [begin, end]) space = cppast.DeclRefExpr(Keywords.DefaultExecSpace.value) policy.add_template_param(space) if isinstance(node.args[arg_start + 1], ast.Lambda): decl: str = "KOKKOS_LAMBDA (" tid = cppast.DeclRefExpr(node.args[arg_start + 1].args.args[0].arg) acc = cppast.DeclRefExpr(node.args[arg_start + 1].args.args[1].arg) # assign to accumulator args[arg_start + 1] = cppast.AssignOperator( [acc], args[arg_start + 1], cppast.BinaryOperatorKind.Assign) serializer = cppast.Serializer() decl += f"int {tid.declname}, double& {acc.declname}) {{" decl += serializer.serialize(args[arg_start + 1]) + ";}\n" call_args: List[cppast.Expr] = [policy, decl, acc_decl] if kernel_name is not None: call_args.insert(0, kernel_name) call = cppast.CallExpr(function, call_args) else: work_unit: str = args[arg_start + 1].declname policy.add_template_param( cppast.DeclRefExpr(f"{self.functor}::{work_unit}")) call_args: List[cppast.Expr] = [ policy, cppast.DeclRefExpr("pk_f"), acc_decl ] if kernel_name is not None: call_args.insert(0, kernel_name) return cppast.CallExpr(function, call_args) return cppast.BinaryOperator(init_var, call, cppast.BinaryOperatorKind.Comma) if name in dir(BinSort): sorter: str = visitors_util.get_node_name(node.func.value) sorter_ref = cppast.DeclRefExpr(sorter) function = cppast.DeclRefExpr(name) return cppast.MemberCallExpr(sorter_ref, function, args) return super().visit_Call(node)