Ejemplo n.º 1
0
    def visit_If(self, node: ast.If) -> cppast.IfStmt:
        condition: cppast.Expr = self.visit(node.test)
        then_body = cppast.CompoundStmt([self.visit(b) for b in node.body])
        else_body = cppast.CompoundStmt([self.visit(b) for b in node.orelse]) if node.orelse else None
        ifstmt = cppast.IfStmt(condition, then_body, else_body)

        return ifstmt
Ejemplo n.º 2
0
    def visit_FunctionDef(
        self, node: ast.FunctionDef
    ) -> Union[cppast.ConstructorDecl, cppast.MethodDecl]:
        name: str = node.name
        return_type: Optional[cppast.ClassType]

        if name == "__init__":
            name = node.parent.name
            return_type = None
        elif self.is_void_function(node):
            return_type = cppast.ClassType("void")
        else:
            return_type = visitors_util.get_type(node.returns, self.pk_import)

        if len(node.args.args) == 0 or node.args.args[0].arg != "self":
            self.error(node, "Static functions are not supported")

        params: List[cppast.ParmVarDecl] = self.visit(node.args)
        body = cppast.CompoundStmt([self.visit(b) for b in node.body])
        attributes: str = "KOKKOS_FUNCTION"

        if return_type is None:
            return cppast.ConstructorDecl(attributes, name, params, body)
        else:
            return cppast.MethodDecl(attributes, return_type, name, params,
                                     body)
Ejemplo n.º 3
0
    def visit_While(self, node: ast.While) -> cppast.WhileStmt:
        if node.orelse:
            self.error(node.orelse, "Else clause not supported for translation")

        condition: cppast.Expr = self.visit(node.test)
        body = cppast.CompoundStmt([self.visit(b) for b in node.body])
        whilestmt = cppast.WhileStmt(condition, body)

        return whilestmt
Ejemplo n.º 4
0
    def visit_For(self, node: ast.For) -> cppast.ForStmt:
        if not isinstance(node.target, ast.Name):
            self.error(node.target, "Must use single loop variable")

        if node.orelse:
            self.error(node.orelse, "Else clause not supported for translation")

        if (
            not isinstance(node.iter, ast.Call)
            or node.iter.func.id != "range"
        ):
            # TODO: support other iterators?
            self.error(
                node.iter, "Only range() iterator is supported for translation")

        index: cppast.DeclRefExpr = self.visit(node.target)
        start: cppast.Expr
        end: cppast.Expr
        step: cppast.Expr = cppast.IntegerLiteral(1)
        op = cppast.BinaryOperatorKind.LT

        args = node.iter.args
        if len(args) == 1:
            start = cppast.IntegerLiteral(0)
            end = self.visit(args[0])

        else:
            start = self.visit(args[0])
            end = self.visit(args[1])

            if len(args) == 3:
                step = self.visit(args[2])

                # Negative step sizes are only handled correctly if they're
                # written with a preceeding minus sign
                if (
                    isinstance(args[2], ast.UnaryOp)
                    and isinstance(args[2].op, ast.USub)
                ):
                    op = cppast.BinaryOperatorKind.GT

        body = cppast.CompoundStmt([self.visit(b) for b in node.body])

        init = cppast.DeclStmt(cppast.VarDecl(
            cppast.PrimitiveType(cppast.BuiltinType.INT), index, start))
        condition = cppast.BinaryOperator(index, end, op)
        increment = cppast.BinaryOperator(
            index, step, cppast.BinaryOperatorKind.AddAssign)
        forstmt = cppast.ForStmt(init, condition, increment, body)

        return forstmt
Ejemplo n.º 5
0
    def visit_FunctionDef(
            self, node: ast.FunctionDef
    ) -> Union[str, Tuple[str, cppast.MethodDecl]]:
        if self.is_nested_call(node):
            params: List[cppast.ParmVarDecl] = [
                a for a in self.visit(node.args)
            ]
            body = cppast.CompoundStmt([self.visit(b) for b in node.body])

            workunit = cppast.LambdaExpr("[&]", params, body)
            self.nested_work_units[node.name] = workunit

            return ""

        else:
            operation: Optional[str] = self.get_operation_type(node)
            if operation is None:
                self.error(node.args, "Incorrect types in workunit definition")

            tag_type = cppast.ClassType(f"const {node.name}")
            tag_type.is_reference = True
            tag = cppast.ParmVarDecl(tag_type, cppast.DeclRefExpr(""))

            params: List[cppast.ParmVarDecl] = [tag]
            params.extend(self.visit(node.args))

            body = cppast.CompoundStmt([self.visit(b) for b in node.body])
            attributes: str = "KOKKOS_FUNCTION"
            decltype = cppast.ClassType("void")
            declname: str = "operator()"

            method = cppast.MethodDecl(attributes, decltype, declname, params,
                                       body)
            method.is_const = True

            return (operation, method)
Ejemplo n.º 6
0
    def visit_AnnAssign(self, node: ast.AnnAssign) -> cppast.Stmt:
        if isinstance(node.value, ast.Call):
            decltype: cppast.Type = visitors_util.get_type(
                node.annotation, self.pk_import)
            if decltype is None:
                self.error(node, "Type not supported")
            declname: cppast.DeclRefExpr = self.visit(node.target)
            function_name: str = visitors_util.get_node_name(node.value.func)

            # Call to a TeamMember method
            if function_name in dir(TeamMember):
                vardecl = cppast.VarDecl(decltype, declname,
                                         self.visit(node.value))
                return cppast.DeclStmt(vardecl)

            # Nested parallelism
            if function_name in ("parallel_reduce", "parallel_scan"):
                args: List[cppast.Expr] = [
                    self.visit(a) for a in node.value.args
                ]

                initial_value: cppast.Expr
                if len(args) == 3:
                    initial_value = args[2]
                else:
                    initial_value = cppast.IntegerLiteral(0)

                vardecl = cppast.VarDecl(decltype, declname, initial_value)
                declstmt = cppast.DeclStmt(vardecl)

                work_unit: str = args[1].declname
                function = cppast.DeclRefExpr(f"Kokkos::{function_name}")

                call: cppast.CallExpr
                if work_unit in self.nested_work_units:
                    call = cppast.CallExpr(
                        function,
                        [args[0], self.nested_work_units[work_unit], declname])
                else:
                    call = cppast.CallExpr(
                        function, [args[0], f"pk_id_{work_unit}", declname])

                callstmt = cppast.CallStmt(call)

                return cppast.CompoundStmt([declstmt, callstmt])

        return super().visit_AnnAssign(node)
Ejemplo n.º 7
0
def generate_constructor(
    name: str,
    fields: Dict[cppast.DeclRefExpr, cppast.PrimitiveType],
    views: Dict[cppast.DeclRefExpr, cppast.ClassType]
) -> cppast.ConstructorDecl:
    """
    Generate the functor constructor

    :param name: the functor class name
    :param fields: a dict mapping from field name to type
    :param views: a dict mapping from view name to type
    :returns: the cppast representation of the constructor
    """

    params: List[cppast.ParmVarDecl] = []
    assignments: List[cppast.AssignOperator] = []

    for n, t in fields.items():
        params.append(cppast.ParmVarDecl(t, n))

    for n, t in views.items():
        # skip subviews
        if t is None:
            continue
        view_type: str = get_view_type(t)
        params.append(cppast.ParmVarDecl(view_type, n))

    # Kokkos fails to compile a functor if there are no parameters in its constructor
    if len(params) == 0:
        decl = cppast.DeclRefExpr("pk_field")
        type = cppast.PrimitiveType(cppast.BuiltinType.INT)
        params.append(cppast.ParmVarDecl(type, decl))

    assignments.extend(generate_assignments(fields))
    # skip subviews
    assignments.extend(generate_assignments({v: views[v] for v in views if views[v]}))

    body = cppast.CompoundStmt(assignments)

    return cppast.ConstructorDecl("", name, params, body)
Ejemplo n.º 8
0
    def visit_FunctionDef(self, node: ast.FunctionDef) -> cppast.MethodDecl:
        if not self.is_valid_kokkos_function(node):
            self.error(node, "Invalid Kokkos function")

        return_type: cppast.ClassType
        if self.is_void_function(node):
            return_type = cppast.ClassType("void")
        else:
            return_type = visitors_util.get_type(node.returns, self.pk_import)

        if return_type is None:
            self.error(node, "Return type is not supported for translation")

        params: List[cppast.ParmVarDecl] = self.visit(node.args)

        name: str = node.name
        body = cppast.CompoundStmt([self.visit(b) for b in node.body])
        attributes: str = "KOKKOS_FUNCTION"

        method = cppast.MethodDecl(attributes, return_type, name, params, body)
        method.is_const = True

        return method
Ejemplo n.º 9
0
    def visit_Assign(self, node: ast.Assign) -> cppast.Stmt:
        target = node.targets[0]

        if isinstance(node.value, ast.Call):
            name: str = visitors_util.get_node_name(node.value.func)

            # Create Timer object
            if name == "Timer":
                decltype = cppast.ClassType("Kokkos::Timer")
                declname = cppast.DeclRefExpr("timer")
                return cppast.DeclStmt(cppast.VarDecl(decltype, declname,
                                                      None))

            # Call Timer.seconds()
            if name == "seconds":
                target_name: str = visitors_util.get_node_name(target)
                if target_name not in self.timer_result_queue:
                    self.timer_result_queue.append(target_name)

                call = cppast.CallStmt(self.visit(node.value))
                target_ref = cppast.DeclRefExpr(target_name)
                target_view_ref = cppast.DeclRefExpr(
                    f"timer_result_{target_name}")
                subscript = cppast.ArraySubscriptExpr(
                    target_view_ref, [cppast.IntegerLiteral(0)])
                assign_op = cppast.BinaryOperatorKind.Assign

                # Holds the result of the reduction temporarily
                temp_ref = cppast.DeclRefExpr("pk_acc")
                target_assign = cppast.AssignOperator([target_ref], temp_ref,
                                                      assign_op)
                view_assign = cppast.AssignOperator([subscript], target_ref,
                                                    assign_op)

                return cppast.CompoundStmt([call, target_assign, view_assign])

            if name in ("BinSort", "BinOp1D", "BinOp3D"):
                args: List = node.value.args
                # if not isinstance(args[0], ast.Attribute):
                #     self.error(node.value, "First argument has to be a view")

                view = cppast.DeclRefExpr(visitors_util.get_node_name(args[0]))
                if view not in self.views:
                    self.error(args[0], "Undefined view")

                view_type: cppast.ClassType = self.views[view]
                is_subview: bool = view_type is None
                if is_subview:
                    parent_view = cppast.DeclRefExpr(
                        self.subviews[view.declname])
                    view_type = self.views[parent_view]

                view_type_str: str = visitors_util.cpp_view_type(view_type)

                if name != "BinSort":
                    dimension: int = 1 if name == "BinOp1D" else 3
                    cpp_type = cppast.DeclRefExpr(
                        BinOp.get_type(dimension, view_type_str))

                    # Do not translate the first argument (view)
                    constructor = cppast.CallExpr(
                        cpp_type, [self.visit(a) for a in args[1:]])

                else:
                    bin_op_type: str = f"decltype({visitors_util.get_node_name(args[1])})"
                    cpp_type = cppast.DeclRefExpr(
                        BinSort.get_type(view_type_str, bin_op_type))

                    binsort_args: List[cppast.DeclRefExpr] = [
                        self.visit(a) for a in args
                    ]
                    constructor = cppast.CallExpr(cpp_type, binsort_args)

                cpp_target: cppast.DeclRefExpr = self.visit(target)
                auto_type = cppast.ClassType("auto")

                return cppast.DeclStmt(
                    cppast.VarDecl(auto_type, cpp_target, constructor))

            if name in ("get_bin_count", "get_bin_offsets",
                        "get_permute_vector"):
                if not isinstance(target,
                                  ast.Attribute) or target.value.id != "self":
                    self.error(
                        node,
                        "Views defined in pk.main must be an instance variable"
                    )

                cpp_target: str = visitors_util.get_node_name(target)
                cpp_device_target = f"pk_d_{cpp_target}"
                cpp_target_ref = cppast.DeclRefExpr(cpp_device_target)
                sorter: cppast.DeclRefExpr = self.visit(node.value.func.value)

                initial_target_ref = cppast.DeclRefExpr(
                    f"_pk_{cpp_target_ref.declname}")

                function = cppast.MemberCallExpr(sorter,
                                                 cppast.DeclRefExpr(name), [])

                # Add to the dict of declarations made in pk.main
                if name == "get_permute_vector":
                    # This occurs when a workload is executed multiple times
                    # Initially the view has not been defined in the workload,
                    # so it needs to be classified as a pkmain_view.
                    if cpp_target in self.views:
                        self.views[cpp_target_ref].add_template_param(
                            cppast.PrimitiveType(cppast.BuiltinType.INT))

                        return cppast.AssignOperator(
                            [cpp_target_ref], function,
                            cppast.BinaryOperatorKind.Assign)
                        # return f"{cpp_target} = {sorter}.{name}();"

                    self.pkmain_views[cpp_target_ref] = cppast.ClassType(
                        "View1D")
                else:
                    self.pkmain_views[cpp_target_ref] = None

                auto_type = cppast.ClassType("auto")
                decl = cppast.DeclStmt(
                    cppast.VarDecl(auto_type, initial_target_ref, function))

                # resize the workload's vector to match the generated vector
                resize_call = cppast.CallStmt(
                    cppast.CallExpr(cppast.DeclRefExpr("Kokkos::resize"), [
                        cpp_target_ref,
                        cppast.MemberCallExpr(initial_target_ref,
                                              cppast.DeclRefExpr("extent"),
                                              [cppast.IntegerLiteral(0)])
                    ]))

                copy_call = cppast.CallStmt(
                    cppast.CallExpr(cppast.DeclRefExpr("Kokkos::deep_copy"),
                                    [cpp_target_ref, initial_target_ref]))

                # Assign to the functor after resizing
                functor = cppast.DeclRefExpr("pk_f")
                functor_access = cppast.MemberExpr(functor, cpp_target)
                functor_assign = cppast.AssignOperator(
                    [functor_access], cpp_target_ref,
                    cppast.BinaryOperatorKind.Assign)

                return cppast.CompoundStmt(
                    [decl, resize_call, copy_call, functor_assign])

        # Assign result of parallel_reduce
        if type(target) not in {ast.Name, ast.Subscript
                                } and target.value.id == "self":
            target_name: str = visitors_util.get_node_name(target)
            if target_name not in self.reduction_result_queue:
                self.reduction_result_queue.append(target_name)

            call = cppast.CallStmt(self.visit(node.value))
            target_ref = cppast.DeclRefExpr(target_name)
            target_view_ref = cppast.DeclRefExpr(
                f"reduction_result_{target_name}")
            subscript = cppast.ArraySubscriptExpr(target_view_ref,
                                                  [cppast.IntegerLiteral(0)])
            assign_op = cppast.BinaryOperatorKind.Assign

            # Holds the result of the reduction temporarily
            temp_ref = cppast.DeclRefExpr("pk_acc")
            target_assign = cppast.AssignOperator([target_ref], temp_ref,
                                                  assign_op)
            view_assign = cppast.AssignOperator([subscript], target_ref,
                                                assign_op)

            return cppast.CompoundStmt([call, target_assign, view_assign])

        return super().visit_Assign(node)
Ejemplo n.º 10
0
    def visit_Call(self, node: ast.Call) -> Union[cppast.Expr, cppast.Stmt]:
        name: str = visitors_util.get_node_name(node.func)
        args: List[cppast.Expr] = [self.visit(a) for a in node.args]

        # Add pk_d_ before each view name to match mirror view names
        s = cppast.Serializer()
        for i in range(len(args)):
            if args[i] in self.views:
                if self.views[args[i]] is not None:
                    view: str = s.serialize(args[i])
                    args[i] = cppast.DeclRefExpr(f"pk_d_{view}")

        # Nested parallelism
        if name == "TeamPolicy":
            function = cppast.DeclRefExpr(f"Kokkos::{name}")
            if len(args) == 2:
                args.append(cppast.IntegerLiteral(1))

            policy = cppast.ConstructExpr(function, args)

            return policy

        elif name in ["RangePolicy", "MDRangePolicy"]:
            rank = len(node.args[0].elts)
            if rank == 0:
                self.error(node.value,
                           "RangePolicy dimension must be greater than 0")
            if rank != len(node.args[1].elts):
                self.error(node.value, "RangePolicy dimension mismatch")

            iter_outer = Iterate.Default
            iter_inner = Iterate.Default
            for keyword in node.keywords:
                if keyword.arg == "rank":
                    explicit_rank = keyword.value.args[0].value
                    if explicit_rank != rank:
                        self.error(node.value,
                                   "RangePolicy dimension mismatch")

                    iter_outer = getattr(Iterate, keyword.value.args[1].attr)
                    iter_inner = getattr(Iterate, keyword.value.args[2].attr)

            function = cppast.DeclRefExpr(
                f"Kokkos::{name}<Kokkos::Rank<{rank},{iter_outer.value},{iter_inner.value}>>"
            )
            policy = cppast.ConstructExpr(
                cppast.DeclRefExpr(f"Kokkos::{name}"), args)
            if name == "MDRangePolicy":
                policy.add_template_param(
                    cppast.DeclRefExpr(
                        f"Kokkos::Rank<{rank},{iter_outer.value},{iter_inner.value}>"
                    ))

            return policy

        if name == "seconds":
            fence = cppast.CallStmt(
                cppast.CallExpr(cppast.DeclRefExpr("Kokkos::fence"), []))
            temp_decl = cppast.DeclRefExpr("pk_acc")
            seconds = cppast.MemberCallExpr(cppast.DeclRefExpr("timer"),
                                            cppast.DeclRefExpr("seconds"), [])
            result = cppast.AssignOperator([temp_decl], seconds,
                                           cppast.BinaryOperatorKind.Assign)

            return cppast.CompoundStmt([fence, result])

        function = cppast.DeclRefExpr(f"Kokkos::{name}")
        if name == "parallel_for":
            arg_start: int = 0  # Accounts for the optional kernel name
            kernel_name: Optional[cppast.StringLiteral] = None
            if isinstance(args[0], cppast.StringLiteral):
                kernel_name = args[0]
                arg_start = 1

            policy: cppast.ConstructExpr = args[arg_start]

            # Replace the number of threads with a RangePolicy
            if not isinstance(policy, cppast.ConstructExpr):
                begin = cppast.IntegerLiteral(0)
                end = args[arg_start]
                policy = cppast.ConstructExpr(
                    cppast.DeclRefExpr("Kokkos::RangePolicy"), [begin, end])

            space = cppast.DeclRefExpr(Keywords.DefaultExecSpace.value)
            policy.add_template_param(space)

            if isinstance(node.args[arg_start + 1], ast.Lambda):
                decl: str = "KOKKOS_LAMBDA ("
                tid = cppast.DeclRefExpr(node.args[arg_start +
                                                   1].args.args[0].arg)

                # if target exists
                if len(args) == arg_start + 3:
                    target = cppast.ArraySubscriptExpr(args[arg_start + 2],
                                                       [tid])
                    args[arg_start + 1] = cppast.AssignOperator(
                        [target], args[arg_start + 1],
                        cppast.BinaryOperatorKind.Assign)

                serializer = cppast.Serializer()
                decl += f"int {tid.declname}) {{"
                decl += serializer.serialize(args[arg_start + 1]) + ";}\n"

                call_args: List[cppast.Expr] = [policy, decl]
                if kernel_name is not None:
                    call_args.insert(0, kernel_name)

                return cppast.CallExpr(function, call_args)

            else:
                work_unit: str = args[arg_start + 1].declname
                policy.add_template_param(
                    cppast.DeclRefExpr(f"{self.functor}::{work_unit}"))

                call_args: List[cppast.Expr] = [
                    policy, cppast.DeclRefExpr("pk_f")
                ]
                if kernel_name is not None:
                    call_args.insert(0, kernel_name)

                return cppast.CallExpr(function, call_args)

        if name in ("parallel_reduce", "parallel_scan"):
            arg_start: int = 0  # Accounts for the optional kernel name
            kernel_name: Optional[cppast.StringLiteral] = None
            if isinstance(args[0], cppast.StringLiteral):
                kernel_name = args[0]
                arg_start = 1

            initial_value: cppast.Expr
            if len(args) == arg_start + 3:
                initial_value = args[arg_start + 2]
            else:
                initial_value = cppast.IntegerLiteral(0)

            acc_decl = cppast.DeclRefExpr("pk_acc")
            init_var = cppast.BinaryOperator(acc_decl, initial_value,
                                             cppast.BinaryOperatorKind.Assign)

            policy: cppast.ConstructExpr = args[arg_start]

            # Replace the number of threads with a RangePolicy
            if not isinstance(policy, cppast.ConstructExpr):
                begin = cppast.IntegerLiteral(0)
                end = args[arg_start]
                policy = cppast.ConstructExpr(
                    cppast.DeclRefExpr("Kokkos::RangePolicy"), [begin, end])

            space = cppast.DeclRefExpr(Keywords.DefaultExecSpace.value)
            policy.add_template_param(space)

            if isinstance(node.args[arg_start + 1], ast.Lambda):
                decl: str = "KOKKOS_LAMBDA ("
                tid = cppast.DeclRefExpr(node.args[arg_start +
                                                   1].args.args[0].arg)
                acc = cppast.DeclRefExpr(node.args[arg_start +
                                                   1].args.args[1].arg)

                # assign to accumulator
                args[arg_start + 1] = cppast.AssignOperator(
                    [acc], args[arg_start + 1],
                    cppast.BinaryOperatorKind.Assign)

                serializer = cppast.Serializer()
                decl += f"int {tid.declname}, double& {acc.declname}) {{"
                decl += serializer.serialize(args[arg_start + 1]) + ";}\n"

                call_args: List[cppast.Expr] = [policy, decl, acc_decl]
                if kernel_name is not None:
                    call_args.insert(0, kernel_name)

                call = cppast.CallExpr(function, call_args)

            else:
                work_unit: str = args[arg_start + 1].declname
                policy.add_template_param(
                    cppast.DeclRefExpr(f"{self.functor}::{work_unit}"))

                call_args: List[cppast.Expr] = [
                    policy, cppast.DeclRefExpr("pk_f"), acc_decl
                ]
                if kernel_name is not None:
                    call_args.insert(0, kernel_name)

                return cppast.CallExpr(function, call_args)

            return cppast.BinaryOperator(init_var, call,
                                         cppast.BinaryOperatorKind.Comma)

        if name in dir(BinSort):
            sorter: str = visitors_util.get_node_name(node.func.value)
            sorter_ref = cppast.DeclRefExpr(sorter)
            function = cppast.DeclRefExpr(name)

            return cppast.MemberCallExpr(sorter_ref, function, args)

        return super().visit_Call(node)