Пример #1
0
    def visit_HorizontalExecution(self, node: oir.HorizontalExecution, *,
                                  ctx: Context) -> None:
        results = AccessCollector.apply(node)
        horizontal_extent = functools.reduce(
            lambda ext, name: ext | ctx.fields.get(name, self.zero_extent),
            results.write_fields(),
            self.zero_extent,
        )
        ctx.blocks[id(node)] = horizontal_extent

        for name, accesses in results.read_offsets().items():
            extent = functools.reduce(
                lambda ext, off: ext | Extent.from_offset(off[:2]),
                accesses,
                Extent.from_offset(accesses.pop()[:2]),
            )
            total_extent = horizontal_extent + extent
            ctx.fields.setdefault(name, total_extent)
            ctx.fields[name] |= total_extent

        for name in results.write_fields():
            ctx.fields.setdefault(name, horizontal_extent)
Пример #2
0
    def to_extent(self, horizontal_extent: Extent) -> Optional[Extent]:
        """
        Convert the access to an extent provided a horizontal extent for the access.

        This returns None if no overlap exists between the horizontal mask and interval.
        """
        offset_as_extent = Extent.from_offset(
            cast(Tuple[int, int, int], self.offset)[:2])
        zeros = Extent.zeros(ndims=2)
        if self.horizontal_mask:
            if dist_from_edge := mask_overlap_with_extent(
                    self.horizontal_mask, horizontal_extent):
                return ((horizontal_extent - dist_from_edge) +
                        offset_as_extent) | zeros
            else:
                return None
Пример #3
0
    def visit_HorizontalExecution(self, node: oir.HorizontalExecution, *,
                                  ctx: Context) -> None:
        results = AccessCollector.apply(node).cartesian_accesses()
        horizontal_extent = functools.reduce(
            lambda ext, name: ext | ctx.field_extents.get(
                name, Extent.zeros(ndims=2)),
            results.write_fields(),
            Extent.zeros(ndims=2),
        )
        ctx.block_extents[id(node)] = horizontal_extent

        for name, accesses in results.read_offsets().items():
            extent = functools.reduce(
                lambda ext, off: ext | Extent.from_offset(off[:2]), accesses,
                Extent.zeros(ndims=2))
            ctx.field_extents[name] = ctx.field_extents.get(
                name, Extent.zeros(ndims=2)).union(horizontal_extent + extent)
Пример #4
0
 def visit_FieldRef(self, node: gt_ir.FieldRef):
     extent = Extent.from_offset([node.offset.get(ax, 0) for ax in self.data.axes_names])
     result = [(node.name, extent)]
     return result