Ejemplo n.º 1
0
def test_compute_relative_mask():
    relative_mask = compute_relative_mask(
        Extent.zeros(ndims=2),
        common.HorizontalMask(
            i=common.HorizontalInterval.compute_domain(start_offset=-1,
                                                       end_offset=1),
            j=common.HorizontalInterval.full(),
        ),
    )

    assert relative_mask.i == common.HorizontalInterval.compute_domain()
    assert relative_mask.j == common.HorizontalInterval.compute_domain()

    relative_mask = compute_relative_mask(
        Extent.zeros(ndims=2),
        common.HorizontalMask(
            i=common.HorizontalInterval.at_endpt(
                level=common.LevelMarker.START, start_offset=-2, end_offset=3),
            j=common.HorizontalInterval.full(),
        ),
    )

    assert relative_mask.i == common.HorizontalInterval.at_endpt(
        level=common.LevelMarker.START, start_offset=0, end_offset=3)
    assert relative_mask.j == common.HorizontalInterval.compute_domain()
Ejemplo n.º 2
0
    def apply(self, transform_data: TransformData):
        seq_axis = transform_data.definition_ir.domain.index(
            transform_data.definition_ir.domain.sequential_axis
        )
        access_extents = {}
        for name in transform_data.symbols:
            access_extents[name] = Extent.zeros()

        blocks = transform_data.blocks
        for block in reversed(blocks):
            for ij_block in reversed(block.ij_blocks):
                ij_block.compute_extent = Extent.zeros()
                for name in ij_block.outputs:
                    ij_block.compute_extent |= access_extents[name]
                for int_block in ij_block.interval_blocks:
                    for name, extent in int_block.inputs.items():
                        accumulated_extent = ij_block.compute_extent + extent
                        access_extents[name] |= accumulated_extent

        # Exclude sequential axis
        for name, extent in access_extents.items():
            adjusted = list(extent)
            adjusted[seq_axis] = (0, 0)
            access_extents[name] = Extent(adjusted)

        transform_data.implementation_ir.fields_extents = {
            name: Extent(extent) for name, extent in access_extents.items()
        }

        return transform_data
Ejemplo n.º 3
0
    def _make_stage(self, ij_block):
        # Apply blocks and decls
        apply_blocks = []
        decls = []
        for int_block in ij_block.interval_blocks:
            # Make apply block
            stmts = []
            local_symbols = {}
            for stmt_info in int_block.stmts:
                if isinstance(stmt_info.stmt, gt_ir.Decl):
                    decl = stmt_info.stmt
                    if decl.name in self.data.symbols:
                        decls.append(stmt_info.stmt)
                    else:
                        assert isinstance(decl, gt_ir.VarDecl)
                        local_symbols[decl.name] = decl
                else:
                    stmts.append(stmt_info.stmt)

            apply_block = gt_ir.ApplyBlock(
                interval=self._make_axis_interval(int_block.interval),
                local_symbols=local_symbols,
                body=gt_ir.BlockStmt(stmts=stmts),
            )
            apply_blocks.append(apply_block)

        # Accessors
        accessors = []
        remaining_outputs = set(ij_block.outputs)
        for name, extent in ij_block.inputs.items():
            if name in remaining_outputs:
                read_write = True
                remaining_outputs.remove(name)
                extent |= Extent.zeros()
            else:
                read_write = False
            accessors.append(self._make_accessor(name, extent, read_write))
        zero_extent = Extent.zeros(self.data.ndims)
        for name in remaining_outputs:
            accessors.append(self._make_accessor(name, zero_extent, True))

        stage = gt_ir.Stage(
            name="stage__{}".format(ij_block.id),
            accessors=accessors,
            apply_blocks=apply_blocks,
            compute_extent=ij_block.compute_extent,
        )

        return stage
Ejemplo n.º 4
0
def nodes_extent_calculation(
    nodes: Collection[Union["VerticalLoopLibraryNode",
                            "HorizontalExecutionLibraryNode"]]
) -> Dict[str, Extent]:
    field_extents: Dict[str, Extent] = dict()
    inner_nodes = []
    from gtc.dace.nodes import HorizontalExecutionLibraryNode, VerticalLoopLibraryNode

    for node in nodes:
        if isinstance(node, VerticalLoopLibraryNode):
            for _, section_sdfg in node.sections:
                for he in (ln for ln, _ in section_sdfg.all_nodes_recursive()
                           if isinstance(ln, dace.nodes.LibraryNode)):
                    inner_nodes.append(he)
        else:
            assert isinstance(node, HorizontalExecutionLibraryNode)
            inner_nodes.append(node)
    for node in inner_nodes:
        access_collection = AccessCollector.apply(node.oir_node)
        block_extent = node.extent
        if block_extent is not None:
            for acc in access_collection.ordered_accesses():
                offset_extent = acc.to_extent(block_extent) | Extent.zeros(2)
                field_extents.setdefault(acc.field, offset_extent)
                field_extents[acc.field] |= offset_extent

    return field_extents
Ejemplo n.º 5
0
    def visit_HorizontalBlock(self, node: npir.HorizontalBlock, *,
                              ctx: Context):
        writes = (node.iter_tree().if_isinstance(
            npir.VectorAssign).getattr("left").if_isinstance(
                npir.FieldSlice).getattr("name").to_set())
        extent = functools.reduce(
            lambda ext, name: ext | ctx.field_extents.get(
                name, Extent.zeros()),
            writes,
            Extent.zeros(),
        )
        ctx.block_extents[id(node)] = extent

        for acc in node.iter_tree().if_isinstance(npir.FieldSlice).to_list():
            ctx.field_extents[acc.name] = ctx.field_extents.get(
                acc.name, Extent.zeros()).union(extent + slice_to_extent(acc))
Ejemplo n.º 6
0
class HorizontalBlockFactory(factory.Factory):
    class Meta:
        model = npir.HorizontalBlock

    body = factory.List([factory.SubFactory(VectorAssignFactory)])
    extent: Extent = Extent.zeros(ndims=2)
    declarations: List[npir.LocalScalarDecl] = []
Ejemplo n.º 7
0
 def visit_ParAssignStmt(
     self,
     node: gtir.ParAssignStmt,
     *,
     ctx: StencilContext,
     field_extents: FIELD_EXT_T,
     **kwargs: Any,
 ) -> None:
     left_extent = field_extents.setdefault(node.left.name, Extent.zeros())
     pa_ctx = self.AssignContext(left_extent=left_extent)
     self.visit(
         ctx.assign_conditions.get(id(node), []),
         field_extents=field_extents,
         pa_ctx=pa_ctx,
         **kwargs,
     )
     self.visit(node.right,
                field_extents=field_extents,
                pa_ctx=pa_ctx,
                **kwargs)
     for key, value in pa_ctx.assign_extents.items():
         if key not in field_extents:
             field_extents[key] = value
         else:
             field_extents[key] |= value
Ejemplo n.º 8
0
Archivo: utils.py Proyecto: DropD/gt4py
    def visit_HorizontalExecution(self, node: oir.HorizontalExecution, *,
                                  ctx: Context) -> None:
        results = AccessCollector.apply(node).cartesian_accesses()
        horizontal_extent = functools.reduce(
            lambda ext, name: ext | ctx.field_extents.get(
                name, Extent.zeros(ndims=2)),
            results.write_fields(),
            Extent.zeros(ndims=2),
        )
        ctx.block_extents[id(node)] = horizontal_extent

        for name, accesses in results.read_offsets().items():
            extent = functools.reduce(
                lambda ext, off: ext | Extent.from_offset(off[:2]), accesses,
                Extent.zeros(ndims=2))
            ctx.field_extents[name] = ctx.field_extents.get(
                name, Extent.zeros(ndims=2)).union(horizontal_extent + extent)
Ejemplo n.º 9
0
        def visit(self, transform_data: TransformData) -> List[DomainBlockInfo]:
            for block in transform_data.blocks:
                context = {
                    "zero_extent": Extent.zeros(transform_data.ndims),
                    "id_generator": transform_data.id_generator,
                }
                self.visit_DomainBlockInfo(block, context)

            return self._split_blocks
Ejemplo n.º 10
0
class FieldDeclFactory(factory.Factory):
    class Meta:
        model = npir.FieldDecl

    name = identifier(npir.FieldDecl)
    dimensions = (True, True, True)
    data_dims: Tuple[int] = cast(Tuple[int], tuple())
    extent: Extent = Extent.zeros(ndims=2)
    dtype = common.DataType.FLOAT32
Ejemplo n.º 11
0
 def has_reads_with_offset(self, *,
                           restrict_to: Optional[Set[str]]) -> bool:
     checked_axes = slice(None) if self.k_offset_extends_domain else slice(
         None, -1)
     fields = restrict_to.intersection(self.inputs) if restrict_to else set(
         self.inputs)
     return any(
         self.inputs[name][checked_axes] != Extent.zeros()[checked_axes]
         for name in fields)
Ejemplo n.º 12
0
    def get_field_extents(self, node):
        assert isinstance(node, HorizontalExecutionLibraryNode)
        input_extents = dict()
        output_extents = dict()

        block_extent: Extent = node.extent
        assert block_extent is not None
        collection = self._get_access_collection(node)

        for acc in collection.read_accesses():
            extent = acc.to_extent(block_extent) | Extent.zeros(2)
            input_extents.setdefault(acc.field, extent)
            input_extents[acc.field] |= extent
        for acc in collection.write_accesses():
            extent = acc.to_extent(block_extent) | Extent.zeros(2)
            output_extents.setdefault(acc.field, extent)
            output_extents[acc.field] |= extent

        return input_extents, output_extents
Ejemplo n.º 13
0
 def visit_FieldDecl(self, node: oir.FieldDecl, *,
                     field_extents: Dict[str, Extent],
                     **kwargs: Any) -> npir.FieldDecl:
     extent = field_extents.get(node.name, Extent.zeros(ndims=2))
     return npir.FieldDecl(
         name=node.name,
         dtype=node.dtype,
         dimensions=node.dimensions,
         data_dims=node.data_dims,
         extent=extent,
     )
Ejemplo n.º 14
0
 def visit_Stencil(self, node: gtir.Stencil, **kwargs: Any) -> FIELD_EXT_T:
     field_extents = {
         name: Extent.zeros()
         for name in _iter_field_names(node)
     }
     ctx = self.StencilContext()
     for field_if in node.iter_tree().if_isinstance(gtir.FieldIfStmt):
         self.visit(field_if, ctx=ctx)
     for assign in reversed(_iter_assigns(node).to_list()):
         self.visit(assign, ctx=ctx, field_extents=field_extents)
     return field_extents
Ejemplo n.º 15
0
 def visit_HorizontalBlock(
     self,
     node: npir.HorizontalBlock,
     *,
     block_extents: Dict[int, HorizontalExtent] = None,
     **kwargs: Any,
 ) -> Union[str, Collection[str]]:
     ij_extent: Extent = (block_extents or {}).get(id(node), Extent.zeros())
     boundary = ij_extent.to_boundary()
     lower = (boundary[0][0], boundary[1][0])
     upper = (boundary[0][1], boundary[1][1])
     return self.generic_visit(node, lower=lower, upper=upper, **kwargs)
Ejemplo n.º 16
0
 def visit_FieldAccess(
     self,
     node: gtir.FieldAccess,
     *,
     field_extents: FIELD_EXT_T,
     pa_ctx: AssignContext,
     **kwargs: Any,
 ) -> None:
     pa_ctx.assign_extents.setdefault(
         node.name, field_extents.setdefault(node.name, Extent.zeros()))
     pa_ctx.assign_extents[node.name] |= pa_ctx.left_extent + _ext_from_off(
         node.offset)
Ejemplo n.º 17
0
    def get_field_extents(self, node):
        assert isinstance(node, VerticalLoopLibraryNode)
        input_extents = dict()
        output_extents = dict()
        for _, sdfg in node.sections:
            for n in sdfg.states()[0].nodes():

                if not isinstance(n, HorizontalExecutionLibraryNode):
                    continue
                block_extent = n.extent
                assert block_extent is not None
                collection = self._get_access_collection(n)
                for acc in collection.read_accesses():
                    extent = acc.to_extent(block_extent) | Extent.zeros(2)
                    input_extents.setdefault(acc.field, extent)
                    input_extents[acc.field] |= extent
                for acc in collection.write_accesses():
                    extent = acc.to_extent(block_extent) | Extent.zeros(2)
                    output_extents.setdefault(acc.field, extent)
                    output_extents[acc.field] |= extent
        return input_extents, output_extents
Ejemplo n.º 18
0
 def visit_Temporary(self, node: oir.Temporary, *,
                     field_extents: Dict[str, Extent],
                     **kwargs: Any) -> npir.TemporaryDecl:
     temp_extent = field_extents[node.name] | Extent.zeros(ndims=2)
     offset = [-ext[0] for ext in temp_extent]
     assert all(off >= 0 for off in offset)
     padding = [ext[1] - ext[0] for ext in temp_extent]
     return npir.TemporaryDecl(
         name=node.name,
         dtype=node.dtype,
         data_dims=node.data_dims,
         offset=offset,
         padding=padding,
     )
Ejemplo n.º 19
0
    def to_extent(self, horizontal_extent: Extent) -> Optional[Extent]:
        """
        Convert the access to an extent provided a horizontal extent for the access.

        This returns None if no overlap exists between the horizontal mask and interval.
        """
        offset_as_extent = Extent.from_offset(
            cast(Tuple[int, int, int], self.offset)[:2])
        zeros = Extent.zeros(ndims=2)
        if self.horizontal_mask:
            if dist_from_edge := mask_overlap_with_extent(
                    self.horizontal_mask, horizontal_extent):
                return ((horizontal_extent - dist_from_edge) +
                        offset_as_extent) | zeros
            else:
                return None
Ejemplo n.º 20
0
 def visit_Stencil(self, node: gtir.Stencil, *, mask_inwards: bool,
                   **kwargs: Any) -> FIELD_EXT_T:
     field_extents: FIELD_EXT_T = {}
     ctx = self.StencilContext()
     for field_if in node.iter_tree().if_isinstance(gtir.FieldIfStmt):
         self.visit(field_if, ctx=ctx)
     for assign in reversed(_iter_assigns(node).to_list()):
         self.visit(assign, ctx=ctx, field_extents=field_extents)
     for name in _iter_field_names(node):
         # ensure we have an extent for all fields. note that we do not initialize to zero in the beginning as this
         #  breaks inward pointing extends (i.e. negative boundaries).
         field_extents.setdefault(name, Extent.zeros())
         if mask_inwards:
             # set inward pointing extents to zero
             field_extents[name] = Extent(*((min(0, e[0]), max(0, e[1]))
                                            for e in field_extents[name]))
     return field_extents
Ejemplo n.º 21
0
    def visit_HorizontalBlock(self, node: npir.HorizontalBlock,
                              **kwargs) -> Union[str, Collection[str]]:
        lower, upper = [0, 0], [0, 0]

        if extents := kwargs.get("field_extents"):
            fields = set(node.iter_tree().if_isinstance(
                npir.FieldSlice).getattr("name"))
            for field in fields:
                # The extent of masks has not yet been collected but is always zero.
                extents.setdefault(field, Extent.zeros())
            lower[0] = min(extents[field].to_boundary()[0][0]
                           for field in fields)
            lower[1] = min(extents[field].to_boundary()[1][0]
                           for field in fields)
            upper[0] = min(extents[field].to_boundary()[0][1]
                           for field in fields)
            upper[1] = min(extents[field].to_boundary()[1][1]
                           for field in fields)
Ejemplo n.º 22
0
    def visit_HorizontalExecution(
        self,
        node: oir.HorizontalExecution,
        *,
        block_extents: Optional[Dict[int, Extent]] = None,
        **kwargs: Any,
    ) -> npir.HorizontalBlock:
        if block_extents:
            extent = block_extents[id(node)]
        else:
            extent = Extent.zeros(ndims=2)

        stmts = utils.flatten_list(
            self.visit(node.body, extent=extent, **kwargs))
        return npir.HorizontalBlock(body=stmts,
                                    extent=extent,
                                    declarations=self.visit(
                                        node.declarations, **kwargs))
Ejemplo n.º 23
0
    def apply(self, transform_data: TransformData):
        zero_extent = Extent.zeros(transform_data.ndims)
        blocks = []
        for block in transform_data.blocks:
            if block.iteration_order == gt_ir.IterationOrder.PARALLEL:
                # Put every statement in a single stage
                for ij_block in block.ij_blocks:
                    for interval_block in ij_block.interval_blocks:
                        for stmt_info in interval_block.stmts:
                            interval = interval_block.interval
                            new_interval_block = IntervalBlockInfo(
                                transform_data.id_generator.new,
                                interval,
                                [stmt_info],
                                stmt_info.inputs,
                                stmt_info.outputs,
                            )
                            new_ij_block = IJBlockInfo(
                                transform_data.id_generator.new,
                                {interval},
                                [new_interval_block],
                                {**new_interval_block.inputs},
                                set(new_interval_block.outputs),
                                compute_extent=zero_extent,
                            )
                            new_block = DomainBlockInfo(
                                transform_data.id_generator.new,
                                block.iteration_order,
                                set(new_ij_block.intervals),
                                [new_ij_block],
                                {**new_ij_block.inputs},
                                set(new_ij_block.outputs),
                            )
                            blocks.append(new_block)
            else:
                blocks.append(block)

        transform_data.blocks = blocks

        return transform_data
Ejemplo n.º 24
0
def test_mask_stmt_to_assigns() -> None:
    mask_stmt = MaskStmtFactory(body=[AssignStmtFactory()])
    assign_stmts = OirToNpir().visit(mask_stmt, extent=Extent.zeros(ndims=2))
    assert isinstance(assign_stmts[0].right.cond, npir.FieldSlice)
    assert len(assign_stmts) == 1
Ejemplo n.º 25
0
 def accumulate_extents(extents: Sequence[Extent]) -> Extent:
     full_extent = Extent.zeros()
     for extent in extents:
         full_extent |= extent
     return full_extent
Ejemplo n.º 26
0
def test_mask_propagation() -> None:
    mask_stmt = MaskStmtFactory()
    assign_stmts = OirToNpir().visit(mask_stmt, extent=Extent.zeros(ndims=2))
    assert assign_stmts[0].right.cond == OirToNpir().visit(mask_stmt.mask)
Ejemplo n.º 27
0
 def __init__(self, add_k: bool = False):
     self.add_k = add_k
     self.zero_extent = Extent.zeros(ndims=2)
Ejemplo n.º 28
0
 def __init__(self, transform_data: TransformData, computation_intervals: list):
     self.data = transform_data
     self.computation_intervals = computation_intervals
     self.current_block_info = None
     self.zero_extent = Extent.zeros(transform_data.ndims)
Ejemplo n.º 29
0
class FieldAccessor(Accessor):
    symbol = attribute(of=str)
    intent = attribute(of=AccessKind)
    extent = attribute(of=Extent, default=Extent.zeros())
Ejemplo n.º 30
0
 def has_extended_domain(self) -> bool:
     return self.full_extent != Extent.zeros()