Ejemplo n.º 1
0
    def apply(self, transform_data: TransformData):
        seq_axis = transform_data.definition_ir.domain.index(
            transform_data.definition_ir.domain.sequential_axis
        )
        access_extents = {}
        for name in transform_data.symbols:
            access_extents[name] = Extent.zeros()

        blocks = transform_data.blocks
        for block in reversed(blocks):
            for ij_block in reversed(block.ij_blocks):
                ij_block.compute_extent = Extent.zeros()
                for name in ij_block.outputs:
                    ij_block.compute_extent |= access_extents[name]
                for int_block in ij_block.interval_blocks:
                    for name, extent in int_block.inputs.items():
                        accumulated_extent = ij_block.compute_extent + extent
                        access_extents[name] |= accumulated_extent

        # Exclude sequential axis
        for name, extent in access_extents.items():
            adjusted = list(extent)
            adjusted[seq_axis] = (0, 0)
            access_extents[name] = Extent(adjusted)

        transform_data.implementation_ir.fields_extents = {
            name: Extent(extent) for name, extent in access_extents.items()
        }

        return transform_data
Ejemplo n.º 2
0
def _ext_from_off(
        offset: Union[gtir.CartesianOffset, gtir.VariableKOffset]) -> Extent:
    all_offsets = offset.to_dict()
    return Extent((
        (min(all_offsets["i"], 0), max(all_offsets["i"], 0)),
        (min(all_offsets["j"], 0), max(all_offsets["j"], 0)),
        (0, 0),
    ))
Ejemplo n.º 3
0
def mask_overlap_with_extent(mask: common.HorizontalMask,
                             horizontal_extent: Extent) -> Optional[Extent]:
    """Compute an overlap extent between a mask and horizontal extent."""
    diffs = [
        _overlap_along_axis(ext, interval)
        for ext, interval in zip(horizontal_extent, mask.intervals)
    ]
    return Extent(diffs[0], diffs[1]) if all(d is not None
                                             for d in diffs) else None
Ejemplo n.º 4
0
def test_temp_with_extent_definition() -> None:
    result = npir_gen.NpirGen().visit(
        VectorAssignFactory(temp_init=True, temp_name="a"),
        field_extents={"a": Extent((0, 1), (-2, 3))},
    )
    assert (
        result ==
        "a_ = ShimmedView(np.zeros((_dI_ + 1, _dJ_ + 5, _dK_), dtype=np.int64), [0, 2, 0])"
    )
def test_horiz_exec_extents():
    stencil = StencilFactory(
        vertical_loops__0__sections__0__horizontal_executions=[
            HorizontalExecutionFactory(body__0__left__name="tmp"),
            HorizontalExecutionFactory(
                body__0__right=FieldAccessFactory(name="tmp", offset__i=1)),
        ])
    hexecs = stencil.vertical_loops[0].sections[0].horizontal_executions
    block_extents = compute_horizontal_block_extents(stencil)
    assert block_extents[id(hexecs[0])] == Extent(((0, 1), (0, 0)))
Ejemplo n.º 6
0
    def visit_Stencil(self, node: oir.Stencil) -> "Context":
        ctx = self.Context()
        for vloop in reversed(node.vertical_loops):
            self.visit(vloop, ctx=ctx)

        if self.add_k:
            ctx.fields = {
                name: Extent(*extent, (0, 0))
                for name, extent in ctx.fields.items()
            }

        return ctx
Ejemplo n.º 7
0
def test_stencil_extents_simple():
    testee = StencilFactory(
        vertical_loops__0__sections__0__horizontal_executions=[
            HorizontalExecutionFactory(body=[
                AssignStmtFactory(
                    left__name="tmp", right__name="input", right__offset__i=1)
            ]),
            HorizontalExecutionFactory(body=[
                AssignStmtFactory(
                    left__name="output", right__name="tmp", right__offset__i=1)
            ]),
        ],
        declarations=[TemporaryFactory(name="tmp")],
    )

    field_extents, block_extents = compute_extents(testee)

    assert field_extents["input"] == Extent((1, 2), (0, 0))
    assert field_extents["output"] == Extent((0, 0), (0, 0))

    hexecs = testee.vertical_loops[0].sections[0].horizontal_executions
    assert block_extents[id(hexecs[0])] == Extent((0, 1), (0, 0))
    assert block_extents[id(hexecs[1])] == Extent((0, 0), (0, 0))
Ejemplo n.º 8
0
def test_stencil_extents_region(mask, offset, access_extent):
    testee = StencilFactory(
        vertical_loops__0__sections__0__horizontal_executions=[
            HorizontalExecutionFactory(body=[
                AssignStmtFactory(left__name="tmp", right__name="input")
            ]),
            HorizontalExecutionFactory(body=[
                HorizontalRestrictionFactory(
                    mask=mask,
                    body=[
                        AssignStmtFactory(left__name="tmp",
                                          right__name="input",
                                          right__offset__i=offset)
                    ],
                ),
            ]),
            HorizontalExecutionFactory(body=[
                AssignStmtFactory(
                    left__name="output", right__name="tmp", right__offset__i=1)
            ]),
        ],
        declarations=[TemporaryFactory(name="tmp")],
    )

    block_extents = compute_horizontal_block_extents(testee)
    hexecs = testee.vertical_loops[0].sections[0].horizontal_executions
    mask_read_accesses = AccessCollector.apply(hexecs[1].body[0])
    input_access = next(
        iter(acc for acc in mask_read_accesses.ordered_accesses()
             if acc.field == "input"))

    block_extent = ((0, 1), (0, 0))
    assert block_extents[id(hexecs[1])] == block_extent
    if access_extent is not None:
        assert input_access.to_extent(Extent(block_extent)) == access_extent
    else:
        assert input_access.to_extent(Extent(block_extent)) is None
Ejemplo n.º 9
0
 def visit_Stencil(self, node: gtir.Stencil, *, mask_inwards: bool,
                   **kwargs: Any) -> FIELD_EXT_T:
     field_extents: FIELD_EXT_T = {}
     ctx = self.StencilContext()
     for field_if in node.iter_tree().if_isinstance(gtir.FieldIfStmt):
         self.visit(field_if, ctx=ctx)
     for assign in reversed(_iter_assigns(node).to_list()):
         self.visit(assign, ctx=ctx, field_extents=field_extents)
     for name in _iter_field_names(node):
         # ensure we have an extent for all fields. note that we do not initialize to zero in the beginning as this
         #  breaks inward pointing extends (i.e. negative boundaries).
         field_extents.setdefault(name, Extent.zeros())
         if mask_inwards:
             # set inward pointing extents to zero
             field_extents[name] = Extent(*((min(0, e[0]), max(0, e[1]))
                                            for e in field_extents[name]))
     return field_extents
Ejemplo n.º 10
0
        def _merge_extents(self, refs: list):
            result = {}
            params = set()

            # Merge offsets for same symbol
            for name, extent in refs:
                if extent is None:
                    assert name in params or name not in result
                    params |= {name}
                    result.setdefault(name, Extent((0, 0), (0, 0), (0, 0)))
                else:
                    assert name not in params
                    if name in result:
                        result[name] |= extent
                    else:
                        result[name] = extent

            return result
Ejemplo n.º 11
0
def slice_to_extent(acc: npir.FieldSlice) -> Extent:
    return Extent((
        [acc.i_offset.offset.value] * 2 if acc.i_offset else [0, 0],
        [acc.j_offset.offset.value] * 2 if acc.j_offset else [0, 0],
        [0, 0],
    ))
Ejemplo n.º 12
0
def test_full_computation_valid(tmp_path) -> None:
    result = npir_gen.NpirGen.apply(
        npir.Computation(
            params=["f1", "f2", "f3", "s1"],
            field_params=["f1", "f2", "f3"],
            field_decls=[
                FieldDeclFactory(name="f1"),
                FieldDeclFactory(name="f2"),
                FieldDeclFactory(name="f3"),
            ],
            vertical_passes=[
                VerticalPassFactory(
                    temp_defs=[],
                    body=[
                        npir.HorizontalBlock(body=[
                            VectorAssignFactory(
                                left=FieldSliceFactory(name="f1",
                                                       parallel_k=True),
                                right=npir.VectorArithmetic(
                                    op=common.ArithmeticOperator.MUL,
                                    left=FieldSliceFactory(name="f2",
                                                           parallel_k=True,
                                                           offsets=(-2, -2,
                                                                    0)),
                                    right=FieldSliceFactory(name="f3",
                                                            parallel_k=True,
                                                            offsets=(0, 3, 1)),
                                ),
                            ),
                        ], ),
                    ],
                ),
                VerticalPassFactory(
                    lower=common.AxisBound.from_start(offset=1),
                    upper=common.AxisBound.from_end(offset=-3),
                    direction=common.LoopOrder.BACKWARD,
                    temp_defs=[],
                    body=[
                        npir.HorizontalBlock(body=[
                            VectorAssignFactory(
                                left__name="f2",
                                right=npir.VectorArithmetic(
                                    op=common.ArithmeticOperator.ADD,
                                    left=FieldSliceFactory(name="f2",
                                                           parallel_k=False),
                                    right=FieldSliceFactory(name="f2",
                                                            parallel_k=False,
                                                            offsets=(0, 0, 1)),
                                ),
                            ),
                        ], )
                    ],
                ),
            ],
        ),
        field_extents={
            "f1": Extent([(0, 0), (0, 0)]),
            "f2": Extent([(-2, 0), (-2, 0)]),
            "f3": Extent([(0, 0), (0, 3)]),
        },
    )
    print(result)
    mod_path = tmp_path / "npir_gen_1.py"
    mod_path.write_text(result)

    sys.path.append(str(tmp_path))
    import npir_gen_1 as mod

    f1 = np.zeros((10, 10, 10))
    f2 = np.ones_like(f1) * 3
    f3 = np.ones_like(f1) * 2
    s1 = 5
    mod.run(
        f1=f1,
        f2=f2,
        f3=f3,
        s1=s1,
        _domain_=(8, 5, 9),
        _origin_={
            "f1": (2, 2, 0),
            "f2": (2, 2, 0),
            "f3": (2, 2, 0)
        },
    )
    assert (f1[2:, 2:-3, 0:-1] == 6).all()
    assert (f1[0:2, :, :] == 0).all()
    assert (f1[:, 0:2, :] == 0).all()
    assert (f1[:, -3:, :] == 0).all()
    assert (f1[:, :, -1:] == 0).all()

    exp_f2 = np.ones((10)) * 3
    # Remember that reversed ranges still include the first (higher) argument and exclude the
    # second. Thus range(-4, 0, -1) contains the same indices as range(1, -3).
    exp_f2[-4:0:-1] = np.cumsum(exp_f2[1:-3])
    assert (f2[3, 3, :] == exp_f2[:]).all()
Ejemplo n.º 13
0
def _ext_from_off(offset: gtir.CartesianOffset) -> Extent:
    return Extent(((min(offset.i, 0), max(offset.i,
                                          0)), (min(offset.j,
                                                    0), max(offset.j,
                                                            0)), (0, 0)))
Ejemplo n.º 14
0
def _ext_from_off(
        offset: Union[gtir.CartesianOffset, gtir.VariableKOffset]) -> Extent:
    if isinstance(offset, gtir.VariableKOffset):
        return Extent(((0, 0), (0, 0), (0, 0)))
    return Extent(((offset.i, offset.i), (offset.j, offset.j), (0, 0)))