def apply(self, transform_data: TransformData): seq_axis = transform_data.definition_ir.domain.index( transform_data.definition_ir.domain.sequential_axis ) access_extents = {} for name in transform_data.symbols: access_extents[name] = Extent.zeros() blocks = transform_data.blocks for block in reversed(blocks): for ij_block in reversed(block.ij_blocks): ij_block.compute_extent = Extent.zeros() for name in ij_block.outputs: ij_block.compute_extent |= access_extents[name] for int_block in ij_block.interval_blocks: for name, extent in int_block.inputs.items(): accumulated_extent = ij_block.compute_extent + extent access_extents[name] |= accumulated_extent # Exclude sequential axis for name, extent in access_extents.items(): adjusted = list(extent) adjusted[seq_axis] = (0, 0) access_extents[name] = Extent(adjusted) transform_data.implementation_ir.fields_extents = { name: Extent(extent) for name, extent in access_extents.items() } return transform_data
def _ext_from_off( offset: Union[gtir.CartesianOffset, gtir.VariableKOffset]) -> Extent: all_offsets = offset.to_dict() return Extent(( (min(all_offsets["i"], 0), max(all_offsets["i"], 0)), (min(all_offsets["j"], 0), max(all_offsets["j"], 0)), (0, 0), ))
def mask_overlap_with_extent(mask: common.HorizontalMask, horizontal_extent: Extent) -> Optional[Extent]: """Compute an overlap extent between a mask and horizontal extent.""" diffs = [ _overlap_along_axis(ext, interval) for ext, interval in zip(horizontal_extent, mask.intervals) ] return Extent(diffs[0], diffs[1]) if all(d is not None for d in diffs) else None
def test_temp_with_extent_definition() -> None: result = npir_gen.NpirGen().visit( VectorAssignFactory(temp_init=True, temp_name="a"), field_extents={"a": Extent((0, 1), (-2, 3))}, ) assert ( result == "a_ = ShimmedView(np.zeros((_dI_ + 1, _dJ_ + 5, _dK_), dtype=np.int64), [0, 2, 0])" )
def test_horiz_exec_extents(): stencil = StencilFactory( vertical_loops__0__sections__0__horizontal_executions=[ HorizontalExecutionFactory(body__0__left__name="tmp"), HorizontalExecutionFactory( body__0__right=FieldAccessFactory(name="tmp", offset__i=1)), ]) hexecs = stencil.vertical_loops[0].sections[0].horizontal_executions block_extents = compute_horizontal_block_extents(stencil) assert block_extents[id(hexecs[0])] == Extent(((0, 1), (0, 0)))
def visit_Stencil(self, node: oir.Stencil) -> "Context": ctx = self.Context() for vloop in reversed(node.vertical_loops): self.visit(vloop, ctx=ctx) if self.add_k: ctx.fields = { name: Extent(*extent, (0, 0)) for name, extent in ctx.fields.items() } return ctx
def test_stencil_extents_simple(): testee = StencilFactory( vertical_loops__0__sections__0__horizontal_executions=[ HorizontalExecutionFactory(body=[ AssignStmtFactory( left__name="tmp", right__name="input", right__offset__i=1) ]), HorizontalExecutionFactory(body=[ AssignStmtFactory( left__name="output", right__name="tmp", right__offset__i=1) ]), ], declarations=[TemporaryFactory(name="tmp")], ) field_extents, block_extents = compute_extents(testee) assert field_extents["input"] == Extent((1, 2), (0, 0)) assert field_extents["output"] == Extent((0, 0), (0, 0)) hexecs = testee.vertical_loops[0].sections[0].horizontal_executions assert block_extents[id(hexecs[0])] == Extent((0, 1), (0, 0)) assert block_extents[id(hexecs[1])] == Extent((0, 0), (0, 0))
def test_stencil_extents_region(mask, offset, access_extent): testee = StencilFactory( vertical_loops__0__sections__0__horizontal_executions=[ HorizontalExecutionFactory(body=[ AssignStmtFactory(left__name="tmp", right__name="input") ]), HorizontalExecutionFactory(body=[ HorizontalRestrictionFactory( mask=mask, body=[ AssignStmtFactory(left__name="tmp", right__name="input", right__offset__i=offset) ], ), ]), HorizontalExecutionFactory(body=[ AssignStmtFactory( left__name="output", right__name="tmp", right__offset__i=1) ]), ], declarations=[TemporaryFactory(name="tmp")], ) block_extents = compute_horizontal_block_extents(testee) hexecs = testee.vertical_loops[0].sections[0].horizontal_executions mask_read_accesses = AccessCollector.apply(hexecs[1].body[0]) input_access = next( iter(acc for acc in mask_read_accesses.ordered_accesses() if acc.field == "input")) block_extent = ((0, 1), (0, 0)) assert block_extents[id(hexecs[1])] == block_extent if access_extent is not None: assert input_access.to_extent(Extent(block_extent)) == access_extent else: assert input_access.to_extent(Extent(block_extent)) is None
def visit_Stencil(self, node: gtir.Stencil, *, mask_inwards: bool, **kwargs: Any) -> FIELD_EXT_T: field_extents: FIELD_EXT_T = {} ctx = self.StencilContext() for field_if in node.iter_tree().if_isinstance(gtir.FieldIfStmt): self.visit(field_if, ctx=ctx) for assign in reversed(_iter_assigns(node).to_list()): self.visit(assign, ctx=ctx, field_extents=field_extents) for name in _iter_field_names(node): # ensure we have an extent for all fields. note that we do not initialize to zero in the beginning as this # breaks inward pointing extends (i.e. negative boundaries). field_extents.setdefault(name, Extent.zeros()) if mask_inwards: # set inward pointing extents to zero field_extents[name] = Extent(*((min(0, e[0]), max(0, e[1])) for e in field_extents[name])) return field_extents
def _merge_extents(self, refs: list): result = {} params = set() # Merge offsets for same symbol for name, extent in refs: if extent is None: assert name in params or name not in result params |= {name} result.setdefault(name, Extent((0, 0), (0, 0), (0, 0))) else: assert name not in params if name in result: result[name] |= extent else: result[name] = extent return result
def slice_to_extent(acc: npir.FieldSlice) -> Extent: return Extent(( [acc.i_offset.offset.value] * 2 if acc.i_offset else [0, 0], [acc.j_offset.offset.value] * 2 if acc.j_offset else [0, 0], [0, 0], ))
def test_full_computation_valid(tmp_path) -> None: result = npir_gen.NpirGen.apply( npir.Computation( params=["f1", "f2", "f3", "s1"], field_params=["f1", "f2", "f3"], field_decls=[ FieldDeclFactory(name="f1"), FieldDeclFactory(name="f2"), FieldDeclFactory(name="f3"), ], vertical_passes=[ VerticalPassFactory( temp_defs=[], body=[ npir.HorizontalBlock(body=[ VectorAssignFactory( left=FieldSliceFactory(name="f1", parallel_k=True), right=npir.VectorArithmetic( op=common.ArithmeticOperator.MUL, left=FieldSliceFactory(name="f2", parallel_k=True, offsets=(-2, -2, 0)), right=FieldSliceFactory(name="f3", parallel_k=True, offsets=(0, 3, 1)), ), ), ], ), ], ), VerticalPassFactory( lower=common.AxisBound.from_start(offset=1), upper=common.AxisBound.from_end(offset=-3), direction=common.LoopOrder.BACKWARD, temp_defs=[], body=[ npir.HorizontalBlock(body=[ VectorAssignFactory( left__name="f2", right=npir.VectorArithmetic( op=common.ArithmeticOperator.ADD, left=FieldSliceFactory(name="f2", parallel_k=False), right=FieldSliceFactory(name="f2", parallel_k=False, offsets=(0, 0, 1)), ), ), ], ) ], ), ], ), field_extents={ "f1": Extent([(0, 0), (0, 0)]), "f2": Extent([(-2, 0), (-2, 0)]), "f3": Extent([(0, 0), (0, 3)]), }, ) print(result) mod_path = tmp_path / "npir_gen_1.py" mod_path.write_text(result) sys.path.append(str(tmp_path)) import npir_gen_1 as mod f1 = np.zeros((10, 10, 10)) f2 = np.ones_like(f1) * 3 f3 = np.ones_like(f1) * 2 s1 = 5 mod.run( f1=f1, f2=f2, f3=f3, s1=s1, _domain_=(8, 5, 9), _origin_={ "f1": (2, 2, 0), "f2": (2, 2, 0), "f3": (2, 2, 0) }, ) assert (f1[2:, 2:-3, 0:-1] == 6).all() assert (f1[0:2, :, :] == 0).all() assert (f1[:, 0:2, :] == 0).all() assert (f1[:, -3:, :] == 0).all() assert (f1[:, :, -1:] == 0).all() exp_f2 = np.ones((10)) * 3 # Remember that reversed ranges still include the first (higher) argument and exclude the # second. Thus range(-4, 0, -1) contains the same indices as range(1, -3). exp_f2[-4:0:-1] = np.cumsum(exp_f2[1:-3]) assert (f2[3, 3, :] == exp_f2[:]).all()
def _ext_from_off(offset: gtir.CartesianOffset) -> Extent: return Extent(((min(offset.i, 0), max(offset.i, 0)), (min(offset.j, 0), max(offset.j, 0)), (0, 0)))
def _ext_from_off( offset: Union[gtir.CartesianOffset, gtir.VariableKOffset]) -> Extent: if isinstance(offset, gtir.VariableKOffset): return Extent(((0, 0), (0, 0), (0, 0))) return Extent(((offset.i, offset.i), (offset.j, offset.j), (0, 0)))