Example #1
0
    def visit_VerticalLoop(
        self, node: oir.VerticalLoop, *, local_tmps: Set[str], **kwargs: Any
    ) -> oir.VerticalLoop:
        if node.loop_order != common.LoopOrder.PARALLEL or not local_tmps:
            return node

        def already_cached(field: str) -> bool:
            return any(c.name == field for c in node.caches)

        def has_vertical_offset(offsets: Set[Tuple[int, int, int]]) -> bool:
            return any(offset[2] != 0 for offset in offsets)

        accesses = AccessCollector.apply(node).cartesian_accesses().offsets()
        cacheable = {
            field
            for field, offsets in accesses.items()
            if field in local_tmps
            and not already_cached(field)
            and not has_vertical_offset(offsets)
        }
        caches = self.visit(node.caches, **kwargs) + [
            oir.IJCache(name=field) for field in cacheable
        ]
        return oir.VerticalLoop(
            sections=node.sections,
            loop_order=node.loop_order,
            caches=caches,
        )
Example #2
0
 def visit_VerticalLoop(self, node: oir.VerticalLoop) -> Any:
     sections = self.visit(node.sections)
     if not sections:
         return NOTHING
     return oir.VerticalLoop(loop_order=node.loop_order,
                             sections=sections,
                             caches=node.caches)
Example #3
0
    def visit_VerticalLoop(self, node: gtir.VerticalLoop, *,
                           ctx: Context) -> oir.VerticalLoop:
        horiz_execs: List[oir.HorizontalExecution] = []
        for stmt in node.body:
            ctx.reset_local_scalars()
            ret = self.visit(stmt, ctx=ctx)
            stmts = utils.flatten_list(
                [ret] if isinstance(ret, oir.Stmt) else ret)
            horiz_execs.append(
                oir.HorizontalExecution(body=stmts,
                                        declarations=ctx.local_scalars))

        ctx.temp_fields += [
            oir.Temporary(name=temp.name,
                          dtype=temp.dtype,
                          dimensions=temp.dimensions)
            for temp in node.temporaries
        ]

        return oir.VerticalLoop(
            loop_order=node.loop_order,
            sections=[
                oir.VerticalLoopSection(
                    interval=self.visit(node.interval),
                    horizontal_executions=horiz_execs,
                    loc=node.loc,
                )
            ],
        )
Example #4
0
 def _merge(a: oir.VerticalLoop, b: oir.VerticalLoop) -> oir.VerticalLoop:
     sections = a.sections + b.sections
     if a.caches or b.caches:
         warnings.warn("AdjacentLoopMerging pass removed previously declared caches")
     return oir.VerticalLoop(
         loop_order=a.loop_order,
         sections=sections,
         caches=[],
     )
Example #5
0
 def build(self):
     return oir.VerticalLoop(
         interval=oir.Interval(
             start=self._start,
             end=self._end,
         ),
         horizontal_executions=self._horizontal_executions,
         loop_order=self._loop_order,
         declarations=self._declarations,
     )
 def visit_VerticalLoop(self, node: oir.VerticalLoop,
                        **kwargs: Any) -> oir.VerticalLoop:
     if node.loop_order != common.LoopOrder.PARALLEL:
         return node
     sections = self.visit(node.sections, **kwargs)
     accessed = AccessCollector.apply(sections).fields()
     return oir.VerticalLoop(
         loop_order=node.loop_order,
         sections=sections,
         caches=[c for c in node.caches if c.name in accessed],
     )
Example #7
0
    def visit_VerticalLoop(self, node: oir.VerticalLoop,
                           **kwargs: Any) -> oir.VerticalLoop:
        # k-caches are restricted to loops with a single horizontal region as all regions without
        # horizontal offsets should be merged before anyway and this restriction allows for easier
        # conversion of fill and flush caches to local caches later
        if node.loop_order == common.LoopOrder.PARALLEL or any(
                len(section.horizontal_executions) != 1
                for section in node.sections):
            return self.generic_visit(node, **kwargs)

        all_accesses = AccessCollector.apply(node)
        fields_with_variable_reads = {
            field
            for field, offsets in all_accesses.offsets().items()
            if any(off[2] is None for off in offsets)
        }

        def accessed_more_than_once(offsets: Set[Any]) -> bool:
            return len(offsets) > 1

        def already_cached(field: str) -> bool:
            return field in {c.name for c in node.caches}

        # TODO(fthaler): k-caches with non-zero ij offsets?
        def has_horizontal_offset(offsets: Set[Tuple[int, int, int]]) -> bool:
            return any(offset[:2] != (0, 0) for offset in offsets)

        def offsets_within_limits(offsets: Set[Tuple[int, int, int]]) -> bool:
            return all(
                abs(offset[2]) <= self.max_cacheable_offset
                for offset in offsets)

        def has_variable_offset_reads(field: str) -> bool:
            return field in fields_with_variable_reads

        accesses = all_accesses.cartesian_accesses().offsets()
        cacheable = {
            field
            for field, offsets in accesses.items()
            if not already_cached(field) and
            not has_variable_offset_reads(field) and accessed_more_than_once(
                offsets) and not has_horizontal_offset(offsets)
            and offsets_within_limits(offsets)
        }
        caches = self.visit(node.caches, **kwargs) + [
            oir.KCache(name=field, fill=True, flush=True)
            for field in cacheable
        ]
        return oir.VerticalLoop(
            loop_order=node.loop_order,
            sections=node.sections,
            caches=caches,
            loc=node.loc,
        )
Example #8
0
 def visit_VerticalLoop(
     self,
     node: oir.VerticalLoop,
     tmps_to_replace: Set[str],
     **kwargs: Any,
 ) -> oir.VerticalLoop:
     return oir.VerticalLoop(
         loop_order=node.loop_order,
         sections=self.visit(node.sections,
                             tmps_to_replace=tmps_to_replace,
                             **kwargs),
         caches=[c for c in node.caches if c.name not in tmps_to_replace],
     )
Example #9
0
    def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context,
                           **kwargs: Any) -> oir.VerticalLoop:
        ctx.horizontal_executions.clear()
        self.visit(node.body, ctx=ctx)

        for temp in node.temporaries:
            ctx.add_decl(oir.Temporary(name=temp.name, dtype=temp.dtype))

        return oir.VerticalLoop(
            loop_order=node.loop_order,
            sections=[
                oir.VerticalLoopSection(
                    interval=self.visit(node.interval, **kwargs),
                    horizontal_executions=ctx.horizontal_executions,
                )
            ],
            caches=[],
        )
Example #10
0
    def visit_VerticalLoop(
        self,
        node: oir.VerticalLoop,
        *,
        new_tmps: List[oir.Temporary],
        symtable: Dict[str, Any],
        new_symbol_name: Callable[[str], str],
        **kwargs: Any,
    ) -> oir.VerticalLoop:
        filling_fields: Dict[str, str] = {
            c.name: new_symbol_name(c.name)
            for c in node.caches
            if isinstance(c, oir.KCache) and c.fill
        }
        flushing_fields: Dict[str, str] = {
            c.name: filling_fields[c.name] if c.name in filling_fields else new_symbol_name(c.name)
            for c in node.caches
            if isinstance(c, oir.KCache) and c.flush
        }

        filling_or_flushing_fields = dict(
            set(filling_fields.items()) | set(flushing_fields.items())
        )

        if not filling_or_flushing_fields:
            return node

        # new temporaries used for caches, declarations are later added to stencil
        for field_name, tmp_name in filling_or_flushing_fields.items():
            new_tmps.append(
                oir.Temporary(
                    name=tmp_name, dtype=symtable[field_name].dtype, dimensions=(True, True, True)
                )
            )

        if filling_fields:
            # split sections where more than one fill operations are required at the entry level
            first_unfilled: Dict[str, int] = dict()
            split_sections: List[oir.VerticalLoopSection] = []
            for section in node.sections:
                split_section, previous_fills = self._split_section_with_multiple_fills(
                    node.loop_order, section, filling_fields, first_unfilled, new_symbol_name
                )
                split_sections += split_section
        else:
            split_sections = node.sections

        # generate cache fill and flush statements
        first_unfilled = dict()
        sections = []
        for section in split_sections:
            fills, first_unfilled = self._fill_stmts(
                node.loop_order, section, filling_fields, first_unfilled, symtable
            )
            flushes = self._flush_stmts(node.loop_order, section, flushing_fields, symtable)
            sections.append(
                self.visit(
                    section,
                    fills=fills,
                    flushes=flushes,
                    name_map=filling_or_flushing_fields,
                    symtable=symtable,
                    **kwargs,
                )
            )

        # replace cache declarations
        caches = [c for c in node.caches if c.name not in filling_or_flushing_fields] + [
            oir.KCache(name=f, fill=False, flush=False) for f in filling_or_flushing_fields.values()
        ]

        return oir.VerticalLoop(loop_order=node.loop_order, sections=sections, caches=caches)