def test_compute_relative_mask(): relative_mask = compute_relative_mask( Extent.zeros(ndims=2), common.HorizontalMask( i=common.HorizontalInterval.compute_domain(start_offset=-1, end_offset=1), j=common.HorizontalInterval.full(), ), ) assert relative_mask.i == common.HorizontalInterval.compute_domain() assert relative_mask.j == common.HorizontalInterval.compute_domain() relative_mask = compute_relative_mask( Extent.zeros(ndims=2), common.HorizontalMask( i=common.HorizontalInterval.at_endpt( level=common.LevelMarker.START, start_offset=-2, end_offset=3), j=common.HorizontalInterval.full(), ), ) assert relative_mask.i == common.HorizontalInterval.at_endpt( level=common.LevelMarker.START, start_offset=0, end_offset=3) assert relative_mask.j == common.HorizontalInterval.compute_domain()
def apply(self, transform_data: TransformData): seq_axis = transform_data.definition_ir.domain.index( transform_data.definition_ir.domain.sequential_axis ) access_extents = {} for name in transform_data.symbols: access_extents[name] = Extent.zeros() blocks = transform_data.blocks for block in reversed(blocks): for ij_block in reversed(block.ij_blocks): ij_block.compute_extent = Extent.zeros() for name in ij_block.outputs: ij_block.compute_extent |= access_extents[name] for int_block in ij_block.interval_blocks: for name, extent in int_block.inputs.items(): accumulated_extent = ij_block.compute_extent + extent access_extents[name] |= accumulated_extent # Exclude sequential axis for name, extent in access_extents.items(): adjusted = list(extent) adjusted[seq_axis] = (0, 0) access_extents[name] = Extent(adjusted) transform_data.implementation_ir.fields_extents = { name: Extent(extent) for name, extent in access_extents.items() } return transform_data
def _make_stage(self, ij_block): # Apply blocks and decls apply_blocks = [] decls = [] for int_block in ij_block.interval_blocks: # Make apply block stmts = [] local_symbols = {} for stmt_info in int_block.stmts: if isinstance(stmt_info.stmt, gt_ir.Decl): decl = stmt_info.stmt if decl.name in self.data.symbols: decls.append(stmt_info.stmt) else: assert isinstance(decl, gt_ir.VarDecl) local_symbols[decl.name] = decl else: stmts.append(stmt_info.stmt) apply_block = gt_ir.ApplyBlock( interval=self._make_axis_interval(int_block.interval), local_symbols=local_symbols, body=gt_ir.BlockStmt(stmts=stmts), ) apply_blocks.append(apply_block) # Accessors accessors = [] remaining_outputs = set(ij_block.outputs) for name, extent in ij_block.inputs.items(): if name in remaining_outputs: read_write = True remaining_outputs.remove(name) extent |= Extent.zeros() else: read_write = False accessors.append(self._make_accessor(name, extent, read_write)) zero_extent = Extent.zeros(self.data.ndims) for name in remaining_outputs: accessors.append(self._make_accessor(name, zero_extent, True)) stage = gt_ir.Stage( name="stage__{}".format(ij_block.id), accessors=accessors, apply_blocks=apply_blocks, compute_extent=ij_block.compute_extent, ) return stage
def nodes_extent_calculation( nodes: Collection[Union["VerticalLoopLibraryNode", "HorizontalExecutionLibraryNode"]] ) -> Dict[str, Extent]: field_extents: Dict[str, Extent] = dict() inner_nodes = [] from gtc.dace.nodes import HorizontalExecutionLibraryNode, VerticalLoopLibraryNode for node in nodes: if isinstance(node, VerticalLoopLibraryNode): for _, section_sdfg in node.sections: for he in (ln for ln, _ in section_sdfg.all_nodes_recursive() if isinstance(ln, dace.nodes.LibraryNode)): inner_nodes.append(he) else: assert isinstance(node, HorizontalExecutionLibraryNode) inner_nodes.append(node) for node in inner_nodes: access_collection = AccessCollector.apply(node.oir_node) block_extent = node.extent if block_extent is not None: for acc in access_collection.ordered_accesses(): offset_extent = acc.to_extent(block_extent) | Extent.zeros(2) field_extents.setdefault(acc.field, offset_extent) field_extents[acc.field] |= offset_extent return field_extents
def visit_HorizontalBlock(self, node: npir.HorizontalBlock, *, ctx: Context): writes = (node.iter_tree().if_isinstance( npir.VectorAssign).getattr("left").if_isinstance( npir.FieldSlice).getattr("name").to_set()) extent = functools.reduce( lambda ext, name: ext | ctx.field_extents.get( name, Extent.zeros()), writes, Extent.zeros(), ) ctx.block_extents[id(node)] = extent for acc in node.iter_tree().if_isinstance(npir.FieldSlice).to_list(): ctx.field_extents[acc.name] = ctx.field_extents.get( acc.name, Extent.zeros()).union(extent + slice_to_extent(acc))
class HorizontalBlockFactory(factory.Factory): class Meta: model = npir.HorizontalBlock body = factory.List([factory.SubFactory(VectorAssignFactory)]) extent: Extent = Extent.zeros(ndims=2) declarations: List[npir.LocalScalarDecl] = []
def visit_ParAssignStmt( self, node: gtir.ParAssignStmt, *, ctx: StencilContext, field_extents: FIELD_EXT_T, **kwargs: Any, ) -> None: left_extent = field_extents.setdefault(node.left.name, Extent.zeros()) pa_ctx = self.AssignContext(left_extent=left_extent) self.visit( ctx.assign_conditions.get(id(node), []), field_extents=field_extents, pa_ctx=pa_ctx, **kwargs, ) self.visit(node.right, field_extents=field_extents, pa_ctx=pa_ctx, **kwargs) for key, value in pa_ctx.assign_extents.items(): if key not in field_extents: field_extents[key] = value else: field_extents[key] |= value
def visit_HorizontalExecution(self, node: oir.HorizontalExecution, *, ctx: Context) -> None: results = AccessCollector.apply(node).cartesian_accesses() horizontal_extent = functools.reduce( lambda ext, name: ext | ctx.field_extents.get( name, Extent.zeros(ndims=2)), results.write_fields(), Extent.zeros(ndims=2), ) ctx.block_extents[id(node)] = horizontal_extent for name, accesses in results.read_offsets().items(): extent = functools.reduce( lambda ext, off: ext | Extent.from_offset(off[:2]), accesses, Extent.zeros(ndims=2)) ctx.field_extents[name] = ctx.field_extents.get( name, Extent.zeros(ndims=2)).union(horizontal_extent + extent)
def visit(self, transform_data: TransformData) -> List[DomainBlockInfo]: for block in transform_data.blocks: context = { "zero_extent": Extent.zeros(transform_data.ndims), "id_generator": transform_data.id_generator, } self.visit_DomainBlockInfo(block, context) return self._split_blocks
class FieldDeclFactory(factory.Factory): class Meta: model = npir.FieldDecl name = identifier(npir.FieldDecl) dimensions = (True, True, True) data_dims: Tuple[int] = cast(Tuple[int], tuple()) extent: Extent = Extent.zeros(ndims=2) dtype = common.DataType.FLOAT32
def has_reads_with_offset(self, *, restrict_to: Optional[Set[str]]) -> bool: checked_axes = slice(None) if self.k_offset_extends_domain else slice( None, -1) fields = restrict_to.intersection(self.inputs) if restrict_to else set( self.inputs) return any( self.inputs[name][checked_axes] != Extent.zeros()[checked_axes] for name in fields)
def get_field_extents(self, node): assert isinstance(node, HorizontalExecutionLibraryNode) input_extents = dict() output_extents = dict() block_extent: Extent = node.extent assert block_extent is not None collection = self._get_access_collection(node) for acc in collection.read_accesses(): extent = acc.to_extent(block_extent) | Extent.zeros(2) input_extents.setdefault(acc.field, extent) input_extents[acc.field] |= extent for acc in collection.write_accesses(): extent = acc.to_extent(block_extent) | Extent.zeros(2) output_extents.setdefault(acc.field, extent) output_extents[acc.field] |= extent return input_extents, output_extents
def visit_FieldDecl(self, node: oir.FieldDecl, *, field_extents: Dict[str, Extent], **kwargs: Any) -> npir.FieldDecl: extent = field_extents.get(node.name, Extent.zeros(ndims=2)) return npir.FieldDecl( name=node.name, dtype=node.dtype, dimensions=node.dimensions, data_dims=node.data_dims, extent=extent, )
def visit_Stencil(self, node: gtir.Stencil, **kwargs: Any) -> FIELD_EXT_T: field_extents = { name: Extent.zeros() for name in _iter_field_names(node) } ctx = self.StencilContext() for field_if in node.iter_tree().if_isinstance(gtir.FieldIfStmt): self.visit(field_if, ctx=ctx) for assign in reversed(_iter_assigns(node).to_list()): self.visit(assign, ctx=ctx, field_extents=field_extents) return field_extents
def visit_HorizontalBlock( self, node: npir.HorizontalBlock, *, block_extents: Dict[int, HorizontalExtent] = None, **kwargs: Any, ) -> Union[str, Collection[str]]: ij_extent: Extent = (block_extents or {}).get(id(node), Extent.zeros()) boundary = ij_extent.to_boundary() lower = (boundary[0][0], boundary[1][0]) upper = (boundary[0][1], boundary[1][1]) return self.generic_visit(node, lower=lower, upper=upper, **kwargs)
def visit_FieldAccess( self, node: gtir.FieldAccess, *, field_extents: FIELD_EXT_T, pa_ctx: AssignContext, **kwargs: Any, ) -> None: pa_ctx.assign_extents.setdefault( node.name, field_extents.setdefault(node.name, Extent.zeros())) pa_ctx.assign_extents[node.name] |= pa_ctx.left_extent + _ext_from_off( node.offset)
def get_field_extents(self, node): assert isinstance(node, VerticalLoopLibraryNode) input_extents = dict() output_extents = dict() for _, sdfg in node.sections: for n in sdfg.states()[0].nodes(): if not isinstance(n, HorizontalExecutionLibraryNode): continue block_extent = n.extent assert block_extent is not None collection = self._get_access_collection(n) for acc in collection.read_accesses(): extent = acc.to_extent(block_extent) | Extent.zeros(2) input_extents.setdefault(acc.field, extent) input_extents[acc.field] |= extent for acc in collection.write_accesses(): extent = acc.to_extent(block_extent) | Extent.zeros(2) output_extents.setdefault(acc.field, extent) output_extents[acc.field] |= extent return input_extents, output_extents
def visit_Temporary(self, node: oir.Temporary, *, field_extents: Dict[str, Extent], **kwargs: Any) -> npir.TemporaryDecl: temp_extent = field_extents[node.name] | Extent.zeros(ndims=2) offset = [-ext[0] for ext in temp_extent] assert all(off >= 0 for off in offset) padding = [ext[1] - ext[0] for ext in temp_extent] return npir.TemporaryDecl( name=node.name, dtype=node.dtype, data_dims=node.data_dims, offset=offset, padding=padding, )
def to_extent(self, horizontal_extent: Extent) -> Optional[Extent]: """ Convert the access to an extent provided a horizontal extent for the access. This returns None if no overlap exists between the horizontal mask and interval. """ offset_as_extent = Extent.from_offset( cast(Tuple[int, int, int], self.offset)[:2]) zeros = Extent.zeros(ndims=2) if self.horizontal_mask: if dist_from_edge := mask_overlap_with_extent( self.horizontal_mask, horizontal_extent): return ((horizontal_extent - dist_from_edge) + offset_as_extent) | zeros else: return None
def visit_Stencil(self, node: gtir.Stencil, *, mask_inwards: bool, **kwargs: Any) -> FIELD_EXT_T: field_extents: FIELD_EXT_T = {} ctx = self.StencilContext() for field_if in node.iter_tree().if_isinstance(gtir.FieldIfStmt): self.visit(field_if, ctx=ctx) for assign in reversed(_iter_assigns(node).to_list()): self.visit(assign, ctx=ctx, field_extents=field_extents) for name in _iter_field_names(node): # ensure we have an extent for all fields. note that we do not initialize to zero in the beginning as this # breaks inward pointing extends (i.e. negative boundaries). field_extents.setdefault(name, Extent.zeros()) if mask_inwards: # set inward pointing extents to zero field_extents[name] = Extent(*((min(0, e[0]), max(0, e[1])) for e in field_extents[name])) return field_extents
def visit_HorizontalBlock(self, node: npir.HorizontalBlock, **kwargs) -> Union[str, Collection[str]]: lower, upper = [0, 0], [0, 0] if extents := kwargs.get("field_extents"): fields = set(node.iter_tree().if_isinstance( npir.FieldSlice).getattr("name")) for field in fields: # The extent of masks has not yet been collected but is always zero. extents.setdefault(field, Extent.zeros()) lower[0] = min(extents[field].to_boundary()[0][0] for field in fields) lower[1] = min(extents[field].to_boundary()[1][0] for field in fields) upper[0] = min(extents[field].to_boundary()[0][1] for field in fields) upper[1] = min(extents[field].to_boundary()[1][1] for field in fields)
def visit_HorizontalExecution( self, node: oir.HorizontalExecution, *, block_extents: Optional[Dict[int, Extent]] = None, **kwargs: Any, ) -> npir.HorizontalBlock: if block_extents: extent = block_extents[id(node)] else: extent = Extent.zeros(ndims=2) stmts = utils.flatten_list( self.visit(node.body, extent=extent, **kwargs)) return npir.HorizontalBlock(body=stmts, extent=extent, declarations=self.visit( node.declarations, **kwargs))
def apply(self, transform_data: TransformData): zero_extent = Extent.zeros(transform_data.ndims) blocks = [] for block in transform_data.blocks: if block.iteration_order == gt_ir.IterationOrder.PARALLEL: # Put every statement in a single stage for ij_block in block.ij_blocks: for interval_block in ij_block.interval_blocks: for stmt_info in interval_block.stmts: interval = interval_block.interval new_interval_block = IntervalBlockInfo( transform_data.id_generator.new, interval, [stmt_info], stmt_info.inputs, stmt_info.outputs, ) new_ij_block = IJBlockInfo( transform_data.id_generator.new, {interval}, [new_interval_block], {**new_interval_block.inputs}, set(new_interval_block.outputs), compute_extent=zero_extent, ) new_block = DomainBlockInfo( transform_data.id_generator.new, block.iteration_order, set(new_ij_block.intervals), [new_ij_block], {**new_ij_block.inputs}, set(new_ij_block.outputs), ) blocks.append(new_block) else: blocks.append(block) transform_data.blocks = blocks return transform_data
def test_mask_stmt_to_assigns() -> None: mask_stmt = MaskStmtFactory(body=[AssignStmtFactory()]) assign_stmts = OirToNpir().visit(mask_stmt, extent=Extent.zeros(ndims=2)) assert isinstance(assign_stmts[0].right.cond, npir.FieldSlice) assert len(assign_stmts) == 1
def accumulate_extents(extents: Sequence[Extent]) -> Extent: full_extent = Extent.zeros() for extent in extents: full_extent |= extent return full_extent
def test_mask_propagation() -> None: mask_stmt = MaskStmtFactory() assign_stmts = OirToNpir().visit(mask_stmt, extent=Extent.zeros(ndims=2)) assert assign_stmts[0].right.cond == OirToNpir().visit(mask_stmt.mask)
def __init__(self, add_k: bool = False): self.add_k = add_k self.zero_extent = Extent.zeros(ndims=2)
def __init__(self, transform_data: TransformData, computation_intervals: list): self.data = transform_data self.computation_intervals = computation_intervals self.current_block_info = None self.zero_extent = Extent.zeros(transform_data.ndims)
class FieldAccessor(Accessor): symbol = attribute(of=str) intent = attribute(of=AccessKind) extent = attribute(of=Extent, default=Extent.zeros())
def has_extended_domain(self) -> bool: return self.full_extent != Extent.zeros()