示例#1
0
def test_compute_relative_mask():
    relative_mask = compute_relative_mask(
        Extent.zeros(ndims=2),
        common.HorizontalMask(
            i=common.HorizontalInterval.compute_domain(start_offset=-1,
                                                       end_offset=1),
            j=common.HorizontalInterval.full(),
        ),
    )

    assert relative_mask.i == common.HorizontalInterval.compute_domain()
    assert relative_mask.j == common.HorizontalInterval.compute_domain()

    relative_mask = compute_relative_mask(
        Extent.zeros(ndims=2),
        common.HorizontalMask(
            i=common.HorizontalInterval.at_endpt(
                level=common.LevelMarker.START, start_offset=-2, end_offset=3),
            j=common.HorizontalInterval.full(),
        ),
    )

    assert relative_mask.i == common.HorizontalInterval.at_endpt(
        level=common.LevelMarker.START, start_offset=0, end_offset=3)
    assert relative_mask.j == common.HorizontalInterval.compute_domain()
示例#2
0
    def apply(self, transform_data: TransformData):
        seq_axis = transform_data.definition_ir.domain.index(
            transform_data.definition_ir.domain.sequential_axis
        )
        access_extents = {}
        for name in transform_data.symbols:
            access_extents[name] = Extent.zeros()

        blocks = transform_data.blocks
        for block in reversed(blocks):
            for ij_block in reversed(block.ij_blocks):
                ij_block.compute_extent = Extent.zeros()
                for name in ij_block.outputs:
                    ij_block.compute_extent |= access_extents[name]
                for int_block in ij_block.interval_blocks:
                    for name, extent in int_block.inputs.items():
                        accumulated_extent = ij_block.compute_extent + extent
                        access_extents[name] |= accumulated_extent

        # Exclude sequential axis
        for name, extent in access_extents.items():
            adjusted = list(extent)
            adjusted[seq_axis] = (0, 0)
            access_extents[name] = Extent(adjusted)

        transform_data.implementation_ir.fields_extents = {
            name: Extent(extent) for name, extent in access_extents.items()
        }

        return transform_data
示例#3
0
    def _make_stage(self, ij_block):
        # Apply blocks and decls
        apply_blocks = []
        decls = []
        for int_block in ij_block.interval_blocks:
            # Make apply block
            stmts = []
            local_symbols = {}
            for stmt_info in int_block.stmts:
                if isinstance(stmt_info.stmt, gt_ir.Decl):
                    decl = stmt_info.stmt
                    if decl.name in self.data.symbols:
                        decls.append(stmt_info.stmt)
                    else:
                        assert isinstance(decl, gt_ir.VarDecl)
                        local_symbols[decl.name] = decl
                else:
                    stmts.append(stmt_info.stmt)

            apply_block = gt_ir.ApplyBlock(
                interval=self._make_axis_interval(int_block.interval),
                local_symbols=local_symbols,
                body=gt_ir.BlockStmt(stmts=stmts),
            )
            apply_blocks.append(apply_block)

        # Accessors
        accessors = []
        remaining_outputs = set(ij_block.outputs)
        for name, extent in ij_block.inputs.items():
            if name in remaining_outputs:
                read_write = True
                remaining_outputs.remove(name)
                extent |= Extent.zeros()
            else:
                read_write = False
            accessors.append(self._make_accessor(name, extent, read_write))
        zero_extent = Extent.zeros(self.data.ndims)
        for name in remaining_outputs:
            accessors.append(self._make_accessor(name, zero_extent, True))

        stage = gt_ir.Stage(
            name="stage__{}".format(ij_block.id),
            accessors=accessors,
            apply_blocks=apply_blocks,
            compute_extent=ij_block.compute_extent,
        )

        return stage
示例#4
0
def nodes_extent_calculation(
    nodes: Collection[Union["VerticalLoopLibraryNode",
                            "HorizontalExecutionLibraryNode"]]
) -> Dict[str, Extent]:
    field_extents: Dict[str, Extent] = dict()
    inner_nodes = []
    from gtc.dace.nodes import HorizontalExecutionLibraryNode, VerticalLoopLibraryNode

    for node in nodes:
        if isinstance(node, VerticalLoopLibraryNode):
            for _, section_sdfg in node.sections:
                for he in (ln for ln, _ in section_sdfg.all_nodes_recursive()
                           if isinstance(ln, dace.nodes.LibraryNode)):
                    inner_nodes.append(he)
        else:
            assert isinstance(node, HorizontalExecutionLibraryNode)
            inner_nodes.append(node)
    for node in inner_nodes:
        access_collection = AccessCollector.apply(node.oir_node)
        block_extent = node.extent
        if block_extent is not None:
            for acc in access_collection.ordered_accesses():
                offset_extent = acc.to_extent(block_extent) | Extent.zeros(2)
                field_extents.setdefault(acc.field, offset_extent)
                field_extents[acc.field] |= offset_extent

    return field_extents
示例#5
0
    def make_temporary_field(
        self, name: str, dtype: gt_ir.DataType, extent: gt_definitions.Extent
    ) -> List[str]:
        source_lines = super().make_temporary_field(name, dtype, extent)
        source_lines.extend(self._make_field_origin(name, extent.to_boundary().lower_indices))

        return source_lines
示例#6
0
class HorizontalBlockFactory(factory.Factory):
    class Meta:
        model = npir.HorizontalBlock

    body = factory.List([factory.SubFactory(VectorAssignFactory)])
    extent: Extent = Extent.zeros(ndims=2)
    declarations: List[npir.LocalScalarDecl] = []
示例#7
0
 def visit_ParAssignStmt(
     self,
     node: gtir.ParAssignStmt,
     *,
     ctx: StencilContext,
     field_extents: FIELD_EXT_T,
     **kwargs: Any,
 ) -> None:
     left_extent = field_extents.setdefault(node.left.name, Extent.zeros())
     pa_ctx = self.AssignContext(left_extent=left_extent)
     self.visit(
         ctx.assign_conditions.get(id(node), []),
         field_extents=field_extents,
         pa_ctx=pa_ctx,
         **kwargs,
     )
     self.visit(node.right,
                field_extents=field_extents,
                pa_ctx=pa_ctx,
                **kwargs)
     for key, value in pa_ctx.assign_extents.items():
         if key not in field_extents:
             field_extents[key] = value
         else:
             field_extents[key] |= value
示例#8
0
    def visit_HorizontalBlock(self, node: npir.HorizontalBlock, *,
                              ctx: Context):
        writes = (node.iter_tree().if_isinstance(
            npir.VectorAssign).getattr("left").if_isinstance(
                npir.FieldSlice).getattr("name").to_set())
        extent = functools.reduce(
            lambda ext, name: ext | ctx.field_extents.get(
                name, Extent.zeros()),
            writes,
            Extent.zeros(),
        )
        ctx.block_extents[id(node)] = extent

        for acc in node.iter_tree().if_isinstance(npir.FieldSlice).to_list():
            ctx.field_extents[acc.name] = ctx.field_extents.get(
                acc.name, Extent.zeros()).union(extent + slice_to_extent(acc))
示例#9
0
    def to_extent(self, horizontal_extent: Extent) -> Optional[Extent]:
        """
        Convert the access to an extent provided a horizontal extent for the access.

        This returns None if no overlap exists between the horizontal mask and interval.
        """
        offset_as_extent = Extent.from_offset(
            cast(Tuple[int, int, int], self.offset)[:2])
        zeros = Extent.zeros(ndims=2)
        if self.horizontal_mask:
            if dist_from_edge := mask_overlap_with_extent(
                    self.horizontal_mask, horizontal_extent):
                return ((horizontal_extent - dist_from_edge) +
                        offset_as_extent) | zeros
            else:
                return None
示例#10
0
def _ext_from_off(
        offset: Union[gtir.CartesianOffset, gtir.VariableKOffset]) -> Extent:
    all_offsets = offset.to_dict()
    return Extent((
        (min(all_offsets["i"], 0), max(all_offsets["i"], 0)),
        (min(all_offsets["j"], 0), max(all_offsets["j"], 0)),
        (0, 0),
    ))
示例#11
0
文件: utils.py 项目: DropD/gt4py
    def visit_HorizontalExecution(self, node: oir.HorizontalExecution, *,
                                  ctx: Context) -> None:
        results = AccessCollector.apply(node).cartesian_accesses()
        horizontal_extent = functools.reduce(
            lambda ext, name: ext | ctx.field_extents.get(
                name, Extent.zeros(ndims=2)),
            results.write_fields(),
            Extent.zeros(ndims=2),
        )
        ctx.block_extents[id(node)] = horizontal_extent

        for name, accesses in results.read_offsets().items():
            extent = functools.reduce(
                lambda ext, off: ext | Extent.from_offset(off[:2]), accesses,
                Extent.zeros(ndims=2))
            ctx.field_extents[name] = ctx.field_extents.get(
                name, Extent.zeros(ndims=2)).union(horizontal_extent + extent)
示例#12
0
 def visit_Stencil(self, node: gtir.Stencil, *, mask_inwards: bool,
                   **kwargs: Any) -> FIELD_EXT_T:
     field_extents: FIELD_EXT_T = {}
     ctx = self.StencilContext()
     for field_if in node.iter_tree().if_isinstance(gtir.FieldIfStmt):
         self.visit(field_if, ctx=ctx)
     for assign in reversed(_iter_assigns(node).to_list()):
         self.visit(assign, ctx=ctx, field_extents=field_extents)
     for name in _iter_field_names(node):
         # ensure we have an extent for all fields. note that we do not initialize to zero in the beginning as this
         #  breaks inward pointing extends (i.e. negative boundaries).
         field_extents.setdefault(name, Extent.zeros())
         if mask_inwards:
             # set inward pointing extents to zero
             field_extents[name] = Extent(*((min(0, e[0]), max(0, e[1]))
                                            for e in field_extents[name]))
     return field_extents
示例#13
0
    def make_temporary_field(self, name: str, dtype: gt_ir.DataType,
                             axes: list, extent: gt_definitions.Extent):
        source_lines = super().make_temporary_field(name, dtype, axes, extent)
        source_lines.extend(
            self._make_field_accessor(name,
                                      extent.to_boundary().lower_indices))

        return source_lines
示例#14
0
def test_temp_with_extent_definition() -> None:
    result = npir_gen.NpirGen().visit(
        VectorAssignFactory(temp_init=True, temp_name="a"),
        field_extents={"a": Extent((0, 1), (-2, 3))},
    )
    assert (
        result ==
        "a_ = ShimmedView(np.zeros((_dI_ + 1, _dJ_ + 5, _dK_), dtype=np.int64), [0, 2, 0])"
    )
示例#15
0
class FieldDeclFactory(factory.Factory):
    class Meta:
        model = npir.FieldDecl

    name = identifier(npir.FieldDecl)
    dimensions = (True, True, True)
    data_dims: Tuple[int] = cast(Tuple[int], tuple())
    extent: Extent = Extent.zeros(ndims=2)
    dtype = common.DataType.FLOAT32
示例#16
0
def mask_overlap_with_extent(mask: common.HorizontalMask,
                             horizontal_extent: Extent) -> Optional[Extent]:
    """Compute an overlap extent between a mask and horizontal extent."""
    diffs = [
        _overlap_along_axis(ext, interval)
        for ext, interval in zip(horizontal_extent, mask.intervals)
    ]
    return Extent(diffs[0], diffs[1]) if all(d is not None
                                             for d in diffs) else None
示例#17
0
 def has_reads_with_offset(self, *,
                           restrict_to: Optional[Set[str]]) -> bool:
     checked_axes = slice(None) if self.k_offset_extends_domain else slice(
         None, -1)
     fields = restrict_to.intersection(self.inputs) if restrict_to else set(
         self.inputs)
     return any(
         self.inputs[name][checked_axes] != Extent.zeros()[checked_axes]
         for name in fields)
示例#18
0
        def visit(self, transform_data: TransformData) -> List[DomainBlockInfo]:
            for block in transform_data.blocks:
                context = {
                    "zero_extent": Extent.zeros(transform_data.ndims),
                    "id_generator": transform_data.id_generator,
                }
                self.visit_DomainBlockInfo(block, context)

            return self._split_blocks
示例#19
0
    def get_field_extents(self, node):
        assert isinstance(node, HorizontalExecutionLibraryNode)
        input_extents = dict()
        output_extents = dict()

        block_extent: Extent = node.extent
        assert block_extent is not None
        collection = self._get_access_collection(node)

        for acc in collection.read_accesses():
            extent = acc.to_extent(block_extent) | Extent.zeros(2)
            input_extents.setdefault(acc.field, extent)
            input_extents[acc.field] |= extent
        for acc in collection.write_accesses():
            extent = acc.to_extent(block_extent) | Extent.zeros(2)
            output_extents.setdefault(acc.field, extent)
            output_extents[acc.field] |= extent

        return input_extents, output_extents
def test_horiz_exec_extents():
    stencil = StencilFactory(
        vertical_loops__0__sections__0__horizontal_executions=[
            HorizontalExecutionFactory(body__0__left__name="tmp"),
            HorizontalExecutionFactory(
                body__0__right=FieldAccessFactory(name="tmp", offset__i=1)),
        ])
    hexecs = stencil.vertical_loops[0].sections[0].horizontal_executions
    block_extents = compute_horizontal_block_extents(stencil)
    assert block_extents[id(hexecs[0])] == Extent(((0, 1), (0, 0)))
示例#21
0
 def visit_Stencil(self, node: gtir.Stencil, **kwargs: Any) -> FIELD_EXT_T:
     field_extents = {
         name: Extent.zeros()
         for name in _iter_field_names(node)
     }
     ctx = self.StencilContext()
     for field_if in node.iter_tree().if_isinstance(gtir.FieldIfStmt):
         self.visit(field_if, ctx=ctx)
     for assign in reversed(_iter_assigns(node).to_list()):
         self.visit(assign, ctx=ctx, field_extents=field_extents)
     return field_extents
示例#22
0
 def visit_FieldDecl(self, node: oir.FieldDecl, *,
                     field_extents: Dict[str, Extent],
                     **kwargs: Any) -> npir.FieldDecl:
     extent = field_extents.get(node.name, Extent.zeros(ndims=2))
     return npir.FieldDecl(
         name=node.name,
         dtype=node.dtype,
         dimensions=node.dimensions,
         data_dims=node.data_dims,
         extent=extent,
     )
示例#23
0
文件: utils.py 项目: fthaler/gt4py
    def visit_Stencil(self, node: oir.Stencil) -> "Context":
        ctx = self.Context()
        for vloop in reversed(node.vertical_loops):
            self.visit(vloop, ctx=ctx)

        if self.add_k:
            ctx.fields = {
                name: Extent(*extent, (0, 0))
                for name, extent in ctx.fields.items()
            }

        return ctx
示例#24
0
 def visit_FieldAccess(
     self,
     node: gtir.FieldAccess,
     *,
     field_extents: FIELD_EXT_T,
     pa_ctx: AssignContext,
     **kwargs: Any,
 ) -> None:
     pa_ctx.assign_extents.setdefault(
         node.name, field_extents.setdefault(node.name, Extent.zeros()))
     pa_ctx.assign_extents[node.name] |= pa_ctx.left_extent + _ext_from_off(
         node.offset)
示例#25
0
 def visit_HorizontalBlock(
     self,
     node: npir.HorizontalBlock,
     *,
     block_extents: Dict[int, HorizontalExtent] = None,
     **kwargs: Any,
 ) -> Union[str, Collection[str]]:
     ij_extent: Extent = (block_extents or {}).get(id(node), Extent.zeros())
     boundary = ij_extent.to_boundary()
     lower = (boundary[0][0], boundary[1][0])
     upper = (boundary[0][1], boundary[1][1])
     return self.generic_visit(node, lower=lower, upper=upper, **kwargs)
示例#26
0
    def get_field_extents(self, node):
        assert isinstance(node, VerticalLoopLibraryNode)
        input_extents = dict()
        output_extents = dict()
        for _, sdfg in node.sections:
            for n in sdfg.states()[0].nodes():

                if not isinstance(n, HorizontalExecutionLibraryNode):
                    continue
                block_extent = n.extent
                assert block_extent is not None
                collection = self._get_access_collection(n)
                for acc in collection.read_accesses():
                    extent = acc.to_extent(block_extent) | Extent.zeros(2)
                    input_extents.setdefault(acc.field, extent)
                    input_extents[acc.field] |= extent
                for acc in collection.write_accesses():
                    extent = acc.to_extent(block_extent) | Extent.zeros(2)
                    output_extents.setdefault(acc.field, extent)
                    output_extents[acc.field] |= extent
        return input_extents, output_extents
示例#27
0
文件: utils.py 项目: fthaler/gt4py
    def visit_HorizontalExecution(self, node: oir.HorizontalExecution, *,
                                  ctx: Context) -> None:
        results = AccessCollector.apply(node)
        horizontal_extent = functools.reduce(
            lambda ext, name: ext | ctx.fields.get(name, self.zero_extent),
            results.write_fields(),
            self.zero_extent,
        )
        ctx.blocks[id(node)] = horizontal_extent

        for name, accesses in results.read_offsets().items():
            extent = functools.reduce(
                lambda ext, off: ext | Extent.from_offset(off[:2]),
                accesses,
                Extent.from_offset(accesses.pop()[:2]),
            )
            total_extent = horizontal_extent + extent
            ctx.fields.setdefault(name, total_extent)
            ctx.fields[name] |= total_extent

        for name in results.write_fields():
            ctx.fields.setdefault(name, horizontal_extent)
示例#28
0
def test_stencil_extents_simple():
    testee = StencilFactory(
        vertical_loops__0__sections__0__horizontal_executions=[
            HorizontalExecutionFactory(body=[
                AssignStmtFactory(
                    left__name="tmp", right__name="input", right__offset__i=1)
            ]),
            HorizontalExecutionFactory(body=[
                AssignStmtFactory(
                    left__name="output", right__name="tmp", right__offset__i=1)
            ]),
        ],
        declarations=[TemporaryFactory(name="tmp")],
    )

    field_extents, block_extents = compute_extents(testee)

    assert field_extents["input"] == Extent((1, 2), (0, 0))
    assert field_extents["output"] == Extent((0, 0), (0, 0))

    hexecs = testee.vertical_loops[0].sections[0].horizontal_executions
    assert block_extents[id(hexecs[0])] == Extent((0, 1), (0, 0))
    assert block_extents[id(hexecs[1])] == Extent((0, 0), (0, 0))
示例#29
0
 def visit_Temporary(self, node: oir.Temporary, *,
                     field_extents: Dict[str, Extent],
                     **kwargs: Any) -> npir.TemporaryDecl:
     temp_extent = field_extents[node.name] | Extent.zeros(ndims=2)
     offset = [-ext[0] for ext in temp_extent]
     assert all(off >= 0 for off in offset)
     padding = [ext[1] - ext[0] for ext in temp_extent]
     return npir.TemporaryDecl(
         name=node.name,
         dtype=node.dtype,
         data_dims=node.data_dims,
         offset=offset,
         padding=padding,
     )
示例#30
0
def test_stencil_extents_region(mask, offset, access_extent):
    testee = StencilFactory(
        vertical_loops__0__sections__0__horizontal_executions=[
            HorizontalExecutionFactory(body=[
                AssignStmtFactory(left__name="tmp", right__name="input")
            ]),
            HorizontalExecutionFactory(body=[
                HorizontalRestrictionFactory(
                    mask=mask,
                    body=[
                        AssignStmtFactory(left__name="tmp",
                                          right__name="input",
                                          right__offset__i=offset)
                    ],
                ),
            ]),
            HorizontalExecutionFactory(body=[
                AssignStmtFactory(
                    left__name="output", right__name="tmp", right__offset__i=1)
            ]),
        ],
        declarations=[TemporaryFactory(name="tmp")],
    )

    block_extents = compute_horizontal_block_extents(testee)
    hexecs = testee.vertical_loops[0].sections[0].horizontal_executions
    mask_read_accesses = AccessCollector.apply(hexecs[1].body[0])
    input_access = next(
        iter(acc for acc in mask_read_accesses.ordered_accesses()
             if acc.field == "input"))

    block_extent = ((0, 1), (0, 0))
    assert block_extents[id(hexecs[1])] == block_extent
    if access_extent is not None:
        assert input_access.to_extent(Extent(block_extent)) == access_extent
    else:
        assert input_access.to_extent(Extent(block_extent)) is None