Ejemplo n.º 1
0
    def visit_Attribute(self, node: ast.Attribute) -> cppast.DeclRefExpr:
        if (
            not isinstance(node.value, ast.Name)
            # TODO: implement proper scope checking
            # or node.value.id not in ("self", "math")
        ):
            self.error(node, "Unrecognized attribute")

        # Math constant
        if node.value.id == "math":
            try:
                constant: str = visitors_util.get_math_constant_str(node.attr)
            except NotImplementedError:
                self.error(node, "Unrecognized math constant")
            return cppast.DeclRefExpr(constant)

        if node.value.id == "self":
            name: str = node.attr
            
            if cppast.DeclRefExpr(name) in self.views:
                return name

            # if name not in self.env:
            #     self.error(node, "Couldn't find variable")

            field = cppast.MemberExpr(cppast.DeclRefExpr("this"), name)
            field.is_pointer = True

            return field

        return cppast.DeclRefExpr(f"{node.value.id}.{node.attr}")
Ejemplo n.º 2
0
    def visit_Call(self, node: ast.Call) -> cppast.CallExpr:
        name: str = visitors_util.get_node_name(node.func)

        if name == "print":
            self.error(
                node.func, "Function not supported, did you mean pykokkos.printf()?"
            )
        elif name in ["fence"]:
            name = "Kokkos::" + name

        function = cppast.DeclRefExpr(name)
        args: List[cppast.Expr] = [self.visit(a) for a in node.args]

        if visitors_util.is_math_function(name) or name in ["printf", "abs", "Kokkos::fence"]:
            return cppast.CallExpr(function, args)

        if function in self.kokkos_functions:
            return cppast.CallExpr(function, args)

        # Call to a dependency's constructor
        if function.declname in visitors_util.allowed_types:
            name = visitors_util.allowed_types[name]
            function = cppast.DeclRefExpr(name)
            return cppast.CallExpr(function, args)

        # Call to a dependency's method
        for key, value in self.dependency_methods.items():
            if function in value:
                object_name: cppast.DeclRefExpr = self.visit(node.func.value)
                return cppast.MemberCallExpr(object_name, function, args)

        self.error(node.func, "Function not supported for translation")
Ejemplo n.º 3
0
def generate_copy_back(members: PyKokkosMembers) -> str:
    """
    Generate the code that copies back the views

    :param members: an object containing the fields and views
    :returns: the source code for instantiating the functor
    """

    copy_back: str = ""

    device_views: Dict[str, str] = get_device_views(members)
    for v, d_v in device_views.items():
        view_type: cppast.ClassType = members.views[cppast.DeclRefExpr(v)]
        # skip subviews
        if view_type is None:
            continue

        # Need to resize views for binsort. Unmanaged views cannot be resized.
        if cppast.DeclRefExpr("Unmanaged") not in view_type.template_params:
            rank = int(re.search(r'\d+', view_type.typename).group())
            resize_args: List[str] = [v]

            for i in range(rank):
                resize_args.append(f"{d_v}.extent({i})")

            copy_back += f"Kokkos::resize("
            copy_back += ",".join(resize_args)
            copy_back += ");"

        copy_back += f"Kokkos::deep_copy({v}, {d_v});"

    return copy_back
Ejemplo n.º 4
0
    def get_classtype_methods(
        self, classtypes: List[PyKokkosEntity]
    ) -> Dict[cppast.DeclRefExpr, List[cppast.DeclRefExpr]]:
        classtype_methods: Dict[cppast.DeclRefExpr,
                                List[cppast.DeclRefExpr]] = {}

        for c in classtypes:
            classdef: ast.ClassDef = c.AST

            classref = cppast.DeclRefExpr(classdef.name)
            classtype_methods[classref] = []

            for node in classdef.body:
                if isinstance(node, ast.FunctionDef):
                    function: cppast.DeclRefExpr

                    # If constructor
                    if node.name == "__init__":
                        function = cppast.DeclRefExpr(classdef.name)
                    else:
                        function = cppast.DeclRefExpr(node.name)

                    classtype_methods[classref].append(function)

        return classtype_methods
Ejemplo n.º 5
0
    def generate_subview(self, node: ast.Assign, view_name: str) -> cppast.DeclStmt:
        subview_args: List[cppast.Expr] = [cppast.DeclRefExpr(view_name)]

        slice_node = node.value
        for dim in slice_node.slice.dims:
            if isinstance(dim, ast.Index):
                subview_args.append(self.visit(dim))
            else:
                if dim.lower is None and dim.upper is None: 
                    subview_args.append(cppast.DeclRefExpr("Kokkos::ALL"))
                elif dim.lower is not None and dim.upper is not None:
                    make_pair = cppast.CallExpr("std::make_pair",
                            [self.visit(dim.lower), self.visit(dim.upper)])
                    subview_args.append(make_pair)
                else:
                    self.error(
                            slice_node, "Partial slice not supported, use [n:m] or [:]")

        if len(node.targets) > 1:
            self.error(node, "Multiple declarations of subview not supported")

        auto = cppast.ClassType("auto")
        target = node.targets[0]
        target_ref = cppast.DeclRefExpr(target.id)
        if target_ref in self.views:
            self.error(
                node, "Redeclaration of existing subview")
        else:
            self.views[target_ref] = None
            self.subviews[target_ref.declname] = view_name

        call = cppast.CallExpr("Kokkos::subview", subview_args)
        decl = cppast.DeclStmt(cppast.VarDecl(auto, self.visit(target), call))

        return decl 
Ejemplo n.º 6
0
def get_kernel_params(
    members: PyKokkosMembers,
    is_hierarchical: bool,
    is_workload: bool,
    real: Optional[str]
) -> Dict[str, str]:
    """
    Get the parameters of the kernel. The parameters include the fields, the views,
    the views holding the reduction results, and the view holding the timer results.
    Also add parameters for the parameters of the execution policy.

    :param members: an object containing the fields and views
    :param is_hierarchical: does the workunit use hierarchical parallelism
    :param real: the precision for which to generate a binding
    :returns: a dict mapping from argument name to type
    """

    s = cppast.Serializer()
    params: Dict[str, str] = {}
    for n, t in members.fields.items():
        params[n.declname] = s.serialize(t)

    for n, t in members.views.items():
        # skip subviews
        if t is None:
            continue
        layout: str = f"{Keywords.DefaultExecSpace.value}::array_layout"
        params[n.declname] = cpp_view_type(t, space=Keywords.ArgMemSpace.value, layout=layout, real=real)

    if not is_workload:
        params[Keywords.KernelName.value] = "const std::string&"

        if is_hierarchical:
            params[Keywords.LeagueSize.value] = "int"
            params[Keywords.TeamSize.value] = "int"
            params[Keywords.VectorLength.value] = "int"
        else:
            params[Keywords.ThreadsBegin.value] = "int"
            params[Keywords.ThreadsEnd.value] = "int"

    for result in members.reduction_result_queue:
        view_name = f"reduction_result_{result}"
        view_type = cppast.ClassType("View1D")
        view_type.add_template_param(cppast.DeclRefExpr("double"))
        view_type.add_template_param(cppast.DeclRefExpr("HostSpace"))
        params[view_name] = cpp_view_type(view_type, space="Kokkos::HostSpace", layout="Kokkos::LayoutRight")

    for result in members.timer_result_queue:
        view_name = f"timer_result_{result}"
        view_type = cppast.ClassType("View1D")
        view_type.add_template_param(cppast.DeclRefExpr("double"))
        view_type.add_template_param(cppast.DeclRefExpr("HostSpace"))
        params[view_name] = cpp_view_type(view_type, space="Kokkos::HostSpace", layout="Kokkos::LayoutRight")

    return params
Ejemplo n.º 7
0
    def visit_Attribute(self, node: ast.Attribute) -> cppast.DeclRefExpr:
        name: str = visitors_util.get_node_name(node)
        if name in self.work_units:
            return cppast.DeclRefExpr(name)

        if node.value.id == "self":
            if name in self.views:
                return name

            return cppast.DeclRefExpr(name)

        return super().visit_Attribute(node)
Ejemplo n.º 8
0
        def visit_FunctionDef(node: ast.FunctionDef):
            if node.decorator_list:
                node_decorator: str = visitors_util.get_node_name(
                    node.decorator_list[0])

                if decorator.value == node_decorator:
                    functions[cppast.DeclRefExpr(node.name)] = node
Ejemplo n.º 9
0
    def get_typeinfo(
        self, node: Union[ast.Call, ast.FunctionDef]
    ) -> Dict[cppast.DeclRefExpr, List[cppast.DeclRefExpr]]:
        """
        Get the view type info from a decorator

        :param node: the decorator call
        :returns: a dictionary mapping from view name to template params
        """

        decorator: ast.Call = None

        for d in node.decorator_list:
            if isinstance(d, ast.Call):
                func = d.func

                if isinstance(func, ast.Attribute):
                    if func.value.id == self.pk_import and func.attr in (
                            "functor", "workload", "workunit"):
                        decorator = d

        type_info: Dict[cppast.DeclRefExpr, List[cppast.DeclRefExpr]] = {}

        if decorator is not None:
            for k in decorator.keywords:
                view = cppast.DeclRefExpr(k.arg)
                type_info[view] = self.visit(k)

        return type_info
Ejemplo n.º 10
0
    def visit_Call(self, node: ast.Call) -> cppast.CallExpr:
        # Copied from workunit_visitor.py
        name: str = visitors_util.get_node_name(node.func)
        args: List[cppast.Expr] = [self.visit(a) for a in node.args]

        function = cppast.DeclRefExpr(f"Kokkos::{name}")

        atomic_fetch_op: re.Pattern = re.compile("atomic_fetch_*")
        is_atomic_fetch_op: bool = atomic_fetch_op.match(name)
        is_atomic_compare_exchange: bool = name == "atomic_compare_exchange"

        if is_atomic_fetch_op or is_atomic_compare_exchange:
            if is_atomic_fetch_op and len(args) != 3:
                self.error(
                    node, "atomic_fetch_op functions take exactly 3 arguments")
            if is_atomic_compare_exchange and len(args) != 4:
                self.error(
                    node, "atomic_compare_exchange takes exactly 4 arguments")

            # convert indices
            args[0] = cppast.CallExpr(args[0], args[1].exprs)
            del args[1]

            # if not isinstance(args[0], cppast.CallExpr):
            #     self.error(
            #         node, "atomic_fetch_op functions only support views")

            # atomic_fetch_* operations need to have an address as
            # their first argument
            args[0] = cppast.UnaryOperator(args[0],
                                           cppast.BinaryOperatorKind.AddrOf)
            return cppast.CallExpr(function, args)

        return super().visit_Call(node)
Ejemplo n.º 11
0
    def get_member_variables(
            self, node: ast.ClassDef) -> Dict[cppast.DeclRefExpr, cppast.Type]:
        member_variables: Dict[cppast.DeclRefExpr, cppast.Type] = {}
        constructor: Optional[ast.FunctionDef] = None

        for function in node.body:
            if isinstance(function, ast.FunctionDef):
                if function.name == "__init__":
                    constructor = function
                    break

        if constructor is None:
            self.error(node, "Missing constructor")

        for b in constructor.body:
            if isinstance(b, ast.AnnAssign):
                if b.target.value.id == "self":
                    declref = cppast.DeclRefExpr(
                        visitors_util.get_node_name(b.target))
                    typename: cppast.Type = visitors_util.get_type(
                        b.annotation, self.pk_import)

                    if typename is None:
                        self.error(b, "Type not supported")

                    serializer = cppast.Serializer()
                    member_variables[declref] = serializer.serialize(typename)

        return member_variables
Ejemplo n.º 12
0
    def is_dependency(self, input_type: cppast.Type) -> bool:
        if isinstance(input_type, cppast.PrimitiveType):
            return False

        classref = cppast.DeclRefExpr(input_type.typename)
        if classref in self.dependency_methods:
            return True

        return False
Ejemplo n.º 13
0
    def visit_Subscript(self, node: ast.Subscript) -> Union[cppast.ArraySubscriptExpr, cppast.CallExpr]:
        current_node: ast.Subscript = node
        slices: List = []
        dim: int = 0

        while isinstance(current_node, ast.Subscript):
            index = current_node.slice

            if sys.version_info.minor <= 8:
                # In Python >= 3.9, ast.Index is deprecated
                # (see # https://docs.python.org/3/whatsnew/3.9.html)
                # Instead of ast.Index, value will be used directly

                if not isinstance(index, ast.Index):
                    self.error(
                        current_node, "Slices not supported, use simple indices")

            slices.insert(0, index)
            current_node = current_node.value
            dim += 1

        name: str = visitors_util.get_node_name(current_node)
        ref = cppast.DeclRefExpr(name)

        if ref not in self.views and name not in self.lists:
            self.error(current_node, "Unknown view or list")

        dim_map: List = [cppast.ClassType("View1D"),
                         cppast.ClassType("View2D"),
                         cppast.ClassType("View3D"),
                         cppast.ClassType("View4D"),
                         cppast.ClassType("View5D"),
                         cppast.ClassType("View6D"),
                         cppast.ClassType("View7D"),
                         cppast.ClassType("View8D")]

        if name in self.lists:
            indices: List[cppast.Expr] = [self.visit(s) for s in slices]
            subscript = cppast.ArraySubscriptExpr(ref, indices)

            return subscript

        if (
            ref in self.views
            and (
                self.views[ref] is None  # For views added in @pk.main
                or self.views[ref].typename == dim_map[dim - 1].typename
            )
        ):
            args: List[cppast.Expr] = [self.visit(s) for s in slices]
            subscript = cppast.CallExpr(ref, args)

            return subscript

        self.error(node, f"'{name}' is not a View{dim}D")
Ejemplo n.º 14
0
    def visit_Return(self, node: ast.Return) -> cppast.ReturnStmt:
        parent_function: FunctionDef = self.get_parent_function(node)
        if parent_function is None:
            self.error(node, "Cannot return outside of function")

        if node.value:
            if cppast.DeclRefExpr(parent_function.name) in self.kokkos_functions:
                return cppast.ReturnStmt(self.visit(node.value))
            else:
                self.error(
                    node.value, "Cannot return value from translated function")

        return cppast.ReturnStmt()
Ejemplo n.º 15
0
    def visit_AnnAssign(self, node: ast.AnnAssign) -> cppast.Stmt:
        if isinstance(node.value, ast.Call):
            decltype: cppast.Type = visitors_util.get_type(
                node.annotation, self.pk_import)
            if decltype is None:
                self.error(node, "Type not supported")
            declname: cppast.DeclRefExpr = self.visit(node.target)
            function_name: str = visitors_util.get_node_name(node.value.func)

            # Call to a TeamMember method
            if function_name in dir(TeamMember):
                vardecl = cppast.VarDecl(decltype, declname,
                                         self.visit(node.value))
                return cppast.DeclStmt(vardecl)

            # Nested parallelism
            if function_name in ("parallel_reduce", "parallel_scan"):
                args: List[cppast.Expr] = [
                    self.visit(a) for a in node.value.args
                ]

                initial_value: cppast.Expr
                if len(args) == 3:
                    initial_value = args[2]
                else:
                    initial_value = cppast.IntegerLiteral(0)

                vardecl = cppast.VarDecl(decltype, declname, initial_value)
                declstmt = cppast.DeclStmt(vardecl)

                work_unit: str = args[1].declname
                function = cppast.DeclRefExpr(f"Kokkos::{function_name}")

                call: cppast.CallExpr
                if work_unit in self.nested_work_units:
                    call = cppast.CallExpr(
                        function,
                        [args[0], self.nested_work_units[work_unit], declname])
                else:
                    call = cppast.CallExpr(
                        function, [args[0], f"pk_id_{work_unit}", declname])

                callstmt = cppast.CallStmt(call)

                return cppast.CompoundStmt([declstmt, callstmt])

        return super().visit_AnnAssign(node)
Ejemplo n.º 16
0
    def visit_arg(self, node: ast.arg) -> Union[cppast.ParmVarDecl, str]:
        if node.arg == "self":
            return ""

        if node.annotation is None:
            self.error(node, "Missing type annotation")

        decltype: cppast.Type = visitors_util.get_type(node.annotation, self.pk_import)
        if decltype is None:
            self.error(node, "Type not supported")

        if self.is_dependency(decltype):
            decltype.is_reference = True

        declname = cppast.DeclRefExpr(node.arg)
        arg = cppast.ParmVarDecl(decltype, declname)

        return arg
Ejemplo n.º 17
0
def generate_assignments(members: Dict[cppast.DeclRefExpr, cppast.Type]) -> List[cppast.AssignOperator]:
    """
    Generate the assignments in the constructor

    :param members: the members being assigned
    :returns: the list of assignments
    """

    assignments: List[cppast.AssignOperator] = []

    for n, t in members.items():
        op = cppast.BinaryOperatorKind.Assign
        field = cppast.MemberExpr(cppast.DeclRefExpr("this"), n.declname)
        field.is_pointer = True
        assign = cppast.AssignOperator([field], n, op)

        assignments.append(assign)

    return assignments
Ejemplo n.º 18
0
    def visit_arg(self, node: ast.arg) -> cppast.ParmVarDecl:
        if node.annotation is None:
            self.error(node, "Missing type annotation")

        decltype: cppast.Type = visitors_util.get_type(node.annotation,
                                                       self.pk_import)
        if decltype is None:
            self.error(node, "Type not supported")

        # If argument is pk.TeamMember (hierarchical parallelism)
        is_hierachical: bool = isinstance(node.annotation, ast.Attribute)

        if is_hierachical:
            decltype.typename = f"const {decltype.typename}"
            decltype.is_reference = True

        declname = cppast.DeclRefExpr(node.arg)
        arg = cppast.ParmVarDecl(decltype, declname)

        return arg
Ejemplo n.º 19
0
    def get_trait(self, node: ast.Call) -> Optional[cppast.DeclRefExpr]:
        if not hasattr(node, "keywords"):
            return None

        args: List[ast.keyword] = node.keywords
        for a in args:
            if a.arg == "trait":
                if not isinstance(a.value, ast.Attribute):
                    self.error(
                        node,
                        "Trait argument should be of the form pk.Trait.Atomic..."
                    )

                trait: str = visitors_util.get_node_name(a.value)

                if trait in Trait.__members__:
                    return cppast.DeclRefExpr(trait)
                else:
                    self.error(node, "Unrecognized trait")

        return None
Ejemplo n.º 20
0
def generate_constructor(
    name: str,
    fields: Dict[cppast.DeclRefExpr, cppast.PrimitiveType],
    views: Dict[cppast.DeclRefExpr, cppast.ClassType]
) -> cppast.ConstructorDecl:
    """
    Generate the functor constructor

    :param name: the functor class name
    :param fields: a dict mapping from field name to type
    :param views: a dict mapping from view name to type
    :returns: the cppast representation of the constructor
    """

    params: List[cppast.ParmVarDecl] = []
    assignments: List[cppast.AssignOperator] = []

    for n, t in fields.items():
        params.append(cppast.ParmVarDecl(t, n))

    for n, t in views.items():
        # skip subviews
        if t is None:
            continue
        view_type: str = get_view_type(t)
        params.append(cppast.ParmVarDecl(view_type, n))

    # Kokkos fails to compile a functor if there are no parameters in its constructor
    if len(params) == 0:
        decl = cppast.DeclRefExpr("pk_field")
        type = cppast.PrimitiveType(cppast.BuiltinType.INT)
        params.append(cppast.ParmVarDecl(type, decl))

    assignments.extend(generate_assignments(fields))
    # skip subviews
    assignments.extend(generate_assignments({v: views[v] for v in views if views[v]}))

    body = cppast.CompoundStmt(assignments)

    return cppast.ConstructorDecl("", name, params, body)
Ejemplo n.º 21
0
    def visit_Assign(self, node: ast.Assign) -> Union[cppast.AssignOperator, cppast.DeclStmt]:
        for target in node.targets:
            if (
                # TODO: check if target.value.id is in scope
                (isinstance(target, ast.Attribute) and target.value.id == "self")
                and type(target) not in {ast.Name, ast.Subscript}
            ):
                self.error(
                    target, "Only local variables and views supported for assignment",
                )

        # handle subview
        if isinstance(node.value, ast.Subscript):
            if (sys.version_info.minor <= 8 and not isinstance(node.value.slice, ast.Index)) or (
                sys.version_info.minor > 8 and isinstance(node.value.slice, ast.Tuple)):

                view = node.value.value
                if isinstance(view, ast.Attribute) and view.value.id == "self":
                # reference view through self
                    attr = node.value.value
                    view_name = view.attr
                elif isinstance(view, ast.Name):
                # reference views through params (standalone)
                    view_name = view.id
                else:
                    self.error(view, "View not recognized")

                if cppast.DeclRefExpr(view_name) in self.views:
                    return self.generate_subview(node, view_name)
                else:
                    self.error(node, "Can only take subview of views")

        targets: List[cppast.DeclRefExpr] = [
            self.visit(t) for t in node.targets]
        value: cppast.Expr = self.visit(node.value)
        op: cppast.BinaryOperatorKind = cppast.BinaryOperatorKind.Assign
        assign = cppast.AssignOperator(targets, value, op)

        return assign
Ejemplo n.º 22
0
    def visit_arg(self, node: ast.arg) -> None:
        """
        Visit an individual parameter

        :param node: the arg node
        """

        annotation: Union[ast.Name, ast.Attribute] = node.annotation

        declref = cppast.DeclRefExpr(node.arg)
        decltype: Optional[cppast.Type] = visitors_util.get_type(annotation, self.pk_import)

        if decltype is None:
            self.error(node, "Type is not supported")

        # just checking decltype might be enough
        is_field: bool = isinstance(annotation, ast.Name) or \
                isinstance(decltype, cppast.PrimitiveType)
        if is_field:
            self.fields[declref] = decltype
        else:
            self.views[declref] = decltype
Ejemplo n.º 23
0
    def get_layout(self, node: ast.Call) -> Optional[cppast.DeclRefExpr]:
        if not hasattr(node, "keywords"):
            return None

        args: List[ast.keyword] = node.keywords

        for a in args:
            if a.arg == "layout":
                if not isinstance(a.value, ast.Attribute) and not isinstance(
                        a.value.value, ast.Attribute):
                    self.error(
                        node,
                        "Layout argument should be of the form pk.layout.Layout..."
                    )

                layout: str = visitors_util.get_node_name(a.value)

                if layout in Layout.__members__:
                    return cppast.DeclRefExpr(layout)
                else:
                    self.error(node, "Unrecognized layout")

        return None
Ejemplo n.º 24
0
    def visit_BinOp(self, node: ast.BinOp) -> Union[cppast.BinaryOperator, cppast.CallExpr, cppast.CastExpr]:
        lhs = cppast.ParenExpr(self.visit(node.left))
        rhs = cppast.ParenExpr(self.visit(node.right))

        if isinstance(node.op, ast.Pow):
            return cppast.CallExpr(cppast.DeclRefExpr("pow"), [lhs, rhs])

        op: cppast.BinaryOperatorKind = self.visit(node.op)

        if isinstance(node.op, ast.Div):
            # Cast one of the operands to a double
            lhs = cppast.CastExpr(
                cppast.PrimitiveType(cppast.BuiltinType.DOUBLE), lhs)

        binop = cppast.BinaryOperator(lhs, rhs, op)

        if isinstance(node.op, ast.FloorDiv):
            # Cast the result to an int
            cast = cppast.CastExpr(
                cppast.PrimitiveType(cppast.BuiltinType.INT), binop)
            return cast

        return binop
Ejemplo n.º 25
0
    def get_memory_space(self, node: ast.Call) -> Optional[cppast.DeclRefExpr]:
        if not hasattr(node, "keywords"):
            return None

        args: List[ast.keyword] = node.keywords

        for a in args:
            if a.arg == "space":
                if not isinstance(a.value, ast.Attribute) and not isinstance(
                        a.value.value, ast.Attribute):
                    self.error(
                        node,
                        "MemorySpace argument should be of the form pk.MemorySpace.HostSpace..."
                    )

                space: str = visitors_util.get_node_name(a.value)

                if space in MemorySpace.__members__:
                    return cppast.DeclRefExpr(space)
                else:
                    self.error(node, "Unrecognized memory space")

        return None
Ejemplo n.º 26
0
    def get_entities(self, style: PyKokkosStyles) -> Dict[str, PyKokkosEntity]:
        """
        Get the entities from path that are of a particular style

        :param style: the style of the entity to get
        :returns: a dict mapping the name of each entity to a PyKokkosEntity instance
        """

        entities: Dict[str, PyKokkosEntity] = {}
        check_entity: Callable[[ast.stmt], bool]

        if style is PyKokkosStyles.workload:
            check_entity = self.is_workload
        elif style is PyKokkosStyles.functor:
            check_entity = self.is_functor
        elif style is PyKokkosStyles.workunit:
            check_entity = self.is_workunit
        elif style is PyKokkosStyles.classtype:
            check_entity = self.is_classtype

        for i, node in enumerate(self.tree.body):
            if check_entity(node, self.pk_import):
                start: int = node.lineno - 1

                try:
                    stop: int = self.tree.body[i + 1].lineno - 1
                except IndexError:
                    stop = len(self.lines)

                name: str = node.name
                entity = PyKokkosEntity(style, cppast.DeclRefExpr(name), node,
                                        (self.lines[start:stop], start),
                                        self.path, self.pk_import)
                entities[name] = entity

        return entities
Ejemplo n.º 27
0
    def visit_FunctionDef(
            self, node: ast.FunctionDef
    ) -> Union[str, Tuple[str, cppast.MethodDecl]]:
        if self.is_nested_call(node):
            params: List[cppast.ParmVarDecl] = [
                a for a in self.visit(node.args)
            ]
            body = cppast.CompoundStmt([self.visit(b) for b in node.body])

            workunit = cppast.LambdaExpr("[&]", params, body)
            self.nested_work_units[node.name] = workunit

            return ""

        else:
            operation: Optional[str] = self.get_operation_type(node)
            if operation is None:
                self.error(node.args, "Incorrect types in workunit definition")

            tag_type = cppast.ClassType(f"const {node.name}")
            tag_type.is_reference = True
            tag = cppast.ParmVarDecl(tag_type, cppast.DeclRefExpr(""))

            params: List[cppast.ParmVarDecl] = [tag]
            params.extend(self.visit(node.args))

            body = cppast.CompoundStmt([self.visit(b) for b in node.body])
            attributes: str = "KOKKOS_FUNCTION"
            decltype = cppast.ClassType("void")
            declname: str = "operator()"

            method = cppast.MethodDecl(attributes, decltype, declname, params,
                                       body)
            method.is_const = True

            return (operation, method)
Ejemplo n.º 28
0
    def visit_Attribute(self, node: ast.Attribute) -> cppast.DeclRefExpr:
        if node.value.id == "self":
            name: str = visitors_util.get_node_name(node)
            return cppast.DeclRefExpr(name)

        self.error(node, "Can only define instance variables")
Ejemplo n.º 29
0
    def visit_Call(self, node: ast.Call) -> cppast.CallExpr:
        name: str = visitors_util.get_node_name(node.func)
        args: List[cppast.Expr] = [self.visit(a) for a in node.args]

        # Call to a TeamMember method
        if name in dir(TeamMember):
            team_member: str = visitors_util.get_node_name(node.func.value)
            call = cppast.MemberCallExpr(cppast.DeclRefExpr(team_member),
                                         cppast.DeclRefExpr(name), [])

            return call

        # Call to view.extent()
        if name == "extent":
            if len(args) != 1:
                self.error(node, "the extent method takes exactly 1 argument")

            view: str = visitors_util.get_node_name(node.func.value)
            call = cppast.MemberCallExpr(cppast.DeclRefExpr(view),
                                         cppast.DeclRefExpr(name), [args[0]])

            return call

        function = cppast.DeclRefExpr(f"Kokkos::{name}")
        if name in ("TeamThreadRange", "ThreadVectorRange", "PerTeam",
                    "PerThread"):
            return cppast.CallExpr(function, args)

        if name in ("parallel_for", "single"):
            work_unit: str = args[1].declname
            if work_unit in self.nested_work_units:
                return cppast.CallExpr(
                    function, [args[0], self.nested_work_units[work_unit]])
            else:
                return cppast.CallExpr(function,
                                       [args[0], f"pk_id_{work_unit}"])

        atomic_fetch_op: re.Pattern = re.compile("atomic_fetch_*")
        is_atomic_fetch_op: bool = atomic_fetch_op.match(name)
        is_atomic_compare_exchange: bool = name == "atomic_compare_exchange"

        if is_atomic_fetch_op or is_atomic_compare_exchange:
            if is_atomic_fetch_op and len(args) != 3:
                self.error(
                    node, "atomic_fetch_op functions take exactly 3 arguments")
            if is_atomic_compare_exchange and len(args) != 4:
                self.error(
                    node, "atomic_compare_exchange takes exactly 4 arguments")

            # convert indices
            args[0] = cppast.CallExpr(args[0], args[1].exprs)
            del args[1]

            # if not isinstance(args[0], cppast.CallExpr):
            #     self.error(
            #         node, "atomic_fetch_op functions only support views")

            # atomic_fetch_* operations need to have an address as
            # their first argument
            args[0] = cppast.UnaryOperator(args[0],
                                           cppast.BinaryOperatorKind.AddrOf)
            return cppast.CallExpr(function, args)

        return super().visit_Call(node)
Ejemplo n.º 30
0
 def visit_Name(self, node: ast.Name) -> cppast.DeclRefExpr:
     return cppast.DeclRefExpr(node.id)