示例#1
0
    def visit_Stage(self, node: gt_ir.Stage) -> Dict[str, Any]:
        # Initialize symbols for the generation of references in this stage
        self.stage_symbols = {}
        args = []
        fields_with_variable_offset = set()
        for field_ref in gt_ir.iter_nodes_of_type(node, gt_ir.FieldRef):
            if isinstance(field_ref.offset.get(self.domain.sequential_axis.name, None), gt_ir.Expr):
                fields_with_variable_offset.add(field_ref.name)
        for accessor in node.accessors:
            self.stage_symbols[accessor.symbol] = accessor
            arg = {"name": accessor.symbol, "access_type": "in", "extent": None}
            if isinstance(accessor, gt_ir.FieldAccessor):
                arg["access_type"] = (
                    "in" if accessor.intent == gt_definitions.AccessKind.READ else "inout"
                )
                if accessor.symbol not in fields_with_variable_offset:
                    arg["extent"] = gt_utils.flatten(accessor.extent)
                else:
                    # If the field has a variable offset, then we assert the maximum vertical extents.
                    # 1000 is just a guess, but should be larger than any reasonable number of vertical levels.
                    arg["extent"] = gt_utils.flatten(accessor.extent[:-1]) + [-1000, 1000]
            args.append(arg)

        parallel_axes_names = [axis.name for axis in self.domain.parallel_axes]
        has_horizontal_region = False
        for pos_node in gt_ir.iter_nodes_of_type(node, gt_ir.AxisPosition):
            if pos_node.axis in parallel_axes_names:
                has_horizontal_region = True

        if has_horizontal_region:
            args.extend(
                [
                    {"name": f"domain_size_{name}", "access_type": "in", "extent": None}
                    for name in parallel_axes_names
                ]
            )

        # Create regions and computations
        regions = []
        for apply_block in node.apply_blocks:
            interval_definition, body_sources = self.visit(apply_block)
            regions.append(
                {
                    "interval_start": interval_definition[0],
                    "interval_end": interval_definition[1],
                    "body": body_sources,
                }
            )
        functor_content = {"args": args, "regions": regions}

        return functor_content
示例#2
0
    def visit_Stage(self, node: gt_ir.Stage):
        # Initialize symbols for the generation of references in this stage
        # self.stage_symbols = dict(node.local_symbols)
        self.stage_symbols = {}
        args = []
        for accessor in node.accessors:
            self.stage_symbols[accessor.symbol] = accessor
            arg = {"name": accessor.symbol, "access_type": "in", "extent": None}
            if isinstance(accessor, gt_ir.FieldAccessor):
                arg["access_type"] = (
                    "in" if accessor.intent == gt_ir.AccessIntent.READ_ONLY else "inout"
                )
                arg["extent"] = gt_utils.flatten(accessor.extent)
            args.append(arg)

        # Create regions and computations
        regions = []
        for apply_block in node.apply_blocks:
            interval_definition, body_sources = self.visit(apply_block)
            regions.append(
                {
                    "interval_start": interval_definition[0],
                    "interval_end": interval_definition[1],
                    "body": body_sources,
                }
            )
        functor_content = {"args": args, "regions": regions}

        return functor_content
示例#3
0
    def visit_Stage(self, node: gt_ir.Stage) -> Dict[str, Any]:
        # Initialize symbols for the generation of references in this stage
        self.stage_symbols = {}
        args = []
        fields_with_variable_offset = set()
        for field_ref in gt_ir.filter_nodes_dfs(node, gt_ir.FieldRef):
            if isinstance(
                    field_ref.offset.get(self.domain.sequential_axis.name,
                                         None), gt_ir.Expr):
                fields_with_variable_offset.add(field_ref.name)
        for accessor in node.accessors:
            self.stage_symbols[accessor.symbol] = accessor
            arg = {
                "name": accessor.symbol,
                "access_type": "in",
                "extent": None
            }
            if isinstance(accessor, gt_ir.FieldAccessor):
                arg["access_type"] = ("in" if accessor.intent
                                      == gt_ir.AccessIntent.READ_ONLY else
                                      "inout")
                if accessor.symbol not in fields_with_variable_offset:
                    arg["extent"] = gt_utils.flatten(accessor.extent)
                else:
                    arg["extent"] = gt_utils.flatten(
                        accessor.extent[:-1]) + [-1000, 1000]
            args.append(arg)

        if len(tuple(gt_ir.filter_nodes_dfs(node, gt_ir.AxisIndex))) > 0:
            args.extend([{
                "name": f"domain_size_{name}",
                "access_type": "in",
                "extent": None
            } for name in self.domain.axes_names])

        # Create regions and computations
        regions = []
        for apply_block in node.apply_blocks:
            interval_definition, body_sources = self.visit(apply_block)
            regions.append({
                "interval_start": interval_definition[0],
                "interval_end": interval_definition[1],
                "body": body_sources,
            })
        functor_content = {"args": args, "regions": regions}

        return functor_content