def test_field_access_to_field_slice(parallel_k): field_access = oir.FieldAccess( name="a", offset=common.CartesianOffset(i=-1, j=2, k=0), dtype=common.DataType.FLOAT64, ) ctx = OirToNpir.ComputationContext() parallel_field_slice = OirToNpir().visit(field_access, ctx=ctx, parallel_k=parallel_k) assert parallel_field_slice.k_offset.parallel is parallel_k assert parallel_field_slice.i_offset.offset.value == -1
def transform_offset( self, offset: Dict[str, Union[int, Expr]], **kwargs: Any ) -> Union[common.CartesianOffset, gtir.VariableKOffset]: k_val = offset.get("K", 0) if isinstance(k_val, numbers.Integral): return common.CartesianOffset(i=offset.get("I", 0), j=offset.get("J", 0), k=k_val) elif isinstance(k_val, Expr): return gtir.VariableKOffset(k=self.visit(k_val, **kwargs)) else: raise TypeError("Unrecognized vertical offset type")
def visit_CartesianOffset( self, node: common.CartesianOffset, *, shift: Optional[Tuple[int, int, int]] = None, **kwargs: Any, ) -> common.CartesianOffset: if shift: di, dj, dk = shift return common.CartesianOffset(i=node.i + di, j=node.j + dj, k=node.k + dk) return self.generic_visit(node, **kwargs)
def test_assign_stmt_to_vector_assign(parallel_k): assign_stmt = oir.AssignStmt( left=oir.FieldAccess( name="a", offset=common.CartesianOffset.zero(), dtype=common.DataType.FLOAT64 ), right=oir.FieldAccess( name="b", offset=common.CartesianOffset(i=-1, j=22, k=0), dtype=common.DataType.FLOAT64 ), ) ctx = OirToNpir.ComputationContext() v_assign = OirToNpir().visit(assign_stmt, ctx=ctx, parallel_k=parallel_k, mask=None) assert isinstance(v_assign, npir.VectorAssign) assert v_assign.left.k_offset.parallel is parallel_k assert v_assign.right.k_offset.parallel is parallel_k
def _fill_stmts( cls, loop_order: common.LoopOrder, section: oir.VerticalLoopSection, filling_fields: Dict[str, str], first_unfilled: Dict[str, int], symtable: Dict[str, Any], ) -> Tuple[List[oir.AssignStmt], Dict[str, int]]: """Generate fill statements for the given loop section. Args: loop_order: forward or backward order. section: loop section to split. filling_fields: mapping from field names to cache names. first_unfilled: direction-normalized offset of the first unfilled cache entry for each field. Returns: A list of fill statements and an updated `first_unfilled` map. """ fill_limits = cls._fill_limits(loop_order, section) fill_stmts = [] for field, cache in filling_fields.items(): lmin, lmax = fill_limits.get(field, (0, 0)) lmin = max(lmin, first_unfilled.get(field, lmin)) for offset in range(lmin, lmax + 1): k_offset = common.CartesianOffset( i=0, j=0, k=offset if loop_order == common.LoopOrder.FORWARD else -offset, ) fill_stmts.append( oir.AssignStmt( left=oir.FieldAccess( name=cache, dtype=symtable[field].dtype, offset=k_offset ), right=oir.FieldAccess( name=field, dtype=symtable[field].dtype, offset=k_offset ), ) ) first_unfilled[field] = lmax return fill_stmts, first_unfilled
def test_temp_assign(parallel_k): assign_stmt = oir.AssignStmt( left=oir.FieldAccess( name="a", offset=common.CartesianOffset.zero(), dtype=common.DataType.FLOAT64, ), right=oir.FieldAccess(name="b", offset=common.CartesianOffset(i=-1, j=22, k=0), dtype=common.DataType.FLOAT64), ) ctx = OirToNpir.ComputationContext() _ = OirToNpir().visit( assign_stmt, ctx=ctx, parallel_k=parallel_k, mask=None, symtable={"a": TemporaryFactory(name="a")}, ) assert len(ctx.temp_defs) == 1 assert isinstance(ctx.temp_defs["a"].left, npir.VectorTemp) assert isinstance(ctx.temp_defs["a"].right, npir.EmptyTemp)