Ejemplo n.º 1
0
 def visit_KCache(self, node: oir.KCache, *, pruneable: Set[str],
                  **kwargs: Any) -> oir.KCache:
     if node.name in pruneable:
         return oir.KCache(name=node.name,
                           fill=node.fill,
                           flush=False,
                           loc=node.loc)
     return self.generic_visit(node, **kwargs)
Ejemplo n.º 2
0
    def visit_VerticalLoop(self, node: oir.VerticalLoop,
                           **kwargs: Any) -> oir.VerticalLoop:
        # k-caches are restricted to loops with a single horizontal region as all regions without
        # horizontal offsets should be merged before anyway and this restriction allows for easier
        # conversion of fill and flush caches to local caches later
        if node.loop_order == common.LoopOrder.PARALLEL or any(
                len(section.horizontal_executions) != 1
                for section in node.sections):
            return self.generic_visit(node, **kwargs)

        all_accesses = AccessCollector.apply(node)
        fields_with_variable_reads = {
            field
            for field, offsets in all_accesses.offsets().items()
            if any(off[2] is None for off in offsets)
        }

        def accessed_more_than_once(offsets: Set[Any]) -> bool:
            return len(offsets) > 1

        def already_cached(field: str) -> bool:
            return field in {c.name for c in node.caches}

        # TODO(fthaler): k-caches with non-zero ij offsets?
        def has_horizontal_offset(offsets: Set[Tuple[int, int, int]]) -> bool:
            return any(offset[:2] != (0, 0) for offset in offsets)

        def offsets_within_limits(offsets: Set[Tuple[int, int, int]]) -> bool:
            return all(
                abs(offset[2]) <= self.max_cacheable_offset
                for offset in offsets)

        def has_variable_offset_reads(field: str) -> bool:
            return field in fields_with_variable_reads

        accesses = all_accesses.cartesian_accesses().offsets()
        cacheable = {
            field
            for field, offsets in accesses.items()
            if not already_cached(field) and
            not has_variable_offset_reads(field) and accessed_more_than_once(
                offsets) and not has_horizontal_offset(offsets)
            and offsets_within_limits(offsets)
        }
        caches = self.visit(node.caches, **kwargs) + [
            oir.KCache(name=field, fill=True, flush=True)
            for field in cacheable
        ]
        return oir.VerticalLoop(
            loop_order=node.loop_order,
            sections=node.sections,
            caches=caches,
            loc=node.loc,
        )
Ejemplo n.º 3
0
    def visit_VerticalLoop(
        self,
        node: oir.VerticalLoop,
        *,
        new_tmps: List[oir.Temporary],
        symtable: Dict[str, Any],
        new_symbol_name: Callable[[str], str],
        **kwargs: Any,
    ) -> oir.VerticalLoop:
        filling_fields: Dict[str, str] = {
            c.name: new_symbol_name(c.name)
            for c in node.caches
            if isinstance(c, oir.KCache) and c.fill
        }
        flushing_fields: Dict[str, str] = {
            c.name: filling_fields[c.name] if c.name in filling_fields else new_symbol_name(c.name)
            for c in node.caches
            if isinstance(c, oir.KCache) and c.flush
        }

        filling_or_flushing_fields = dict(
            set(filling_fields.items()) | set(flushing_fields.items())
        )

        if not filling_or_flushing_fields:
            return node

        # new temporaries used for caches, declarations are later added to stencil
        for field_name, tmp_name in filling_or_flushing_fields.items():
            new_tmps.append(
                oir.Temporary(
                    name=tmp_name, dtype=symtable[field_name].dtype, dimensions=(True, True, True)
                )
            )

        if filling_fields:
            # split sections where more than one fill operations are required at the entry level
            first_unfilled: Dict[str, int] = dict()
            split_sections: List[oir.VerticalLoopSection] = []
            for section in node.sections:
                split_section, previous_fills = self._split_section_with_multiple_fills(
                    node.loop_order, section, filling_fields, first_unfilled, new_symbol_name
                )
                split_sections += split_section
        else:
            split_sections = node.sections

        # generate cache fill and flush statements
        first_unfilled = dict()
        sections = []
        for section in split_sections:
            fills, first_unfilled = self._fill_stmts(
                node.loop_order, section, filling_fields, first_unfilled, symtable
            )
            flushes = self._flush_stmts(node.loop_order, section, flushing_fields, symtable)
            sections.append(
                self.visit(
                    section,
                    fills=fills,
                    flushes=flushes,
                    name_map=filling_or_flushing_fields,
                    symtable=symtable,
                    **kwargs,
                )
            )

        # replace cache declarations
        caches = [c for c in node.caches if c.name not in filling_or_flushing_fields] + [
            oir.KCache(name=f, fill=False, flush=False) for f in filling_or_flushing_fields.values()
        ]

        return oir.VerticalLoop(loop_order=node.loop_order, sections=sections, caches=caches)