def _flush_stmts( cls, loop_order: common.LoopOrder, section: oir.VerticalLoopSection, flushing_fields: Dict[str, str], symtable: Dict[str, Any], ) -> List[oir.AssignStmt]: """Generate flush statements for the given loop section. Args: loop_order: forward or backward order. section: loop section to split. flushing_fields: mapping from field names to cache names. Returns: A list of flush statements. """ write_fields = AccessCollector.apply(section).write_fields() flush_stmts = [] for field, cache in flushing_fields.items(): if field in write_fields: flush_stmts.append( oir.AssignStmt( left=oir.FieldAccess( name=field, dtype=symtable[field].dtype, offset=common.CartesianOffset.zero(), ), right=oir.FieldAccess( name=cache, dtype=symtable[field].dtype, offset=common.CartesianOffset.zero(), ), )) return flush_stmts
def visit_ParAssignStmt( self, node: gtir.ParAssignStmt, *, mask: oir.Expr = None, **kwargs: Any ) -> oir.AssignStmt: stmt = oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right)) if mask is not None: # Wrap inside MaskStmt stmt = oir.MaskStmt(body=[stmt], mask=mask, loc=node.loc) return stmt
def visit_FieldIfStmt(self, node: gtir.FieldIfStmt, *, mask: oir.Expr = None, ctx: Context, **kwargs: Any) -> List[oir.Stmt]: mask_field_decl = oir.Temporary(name=f"mask_{id(node)}", dtype=DataType.BOOL, dimensions=(True, True, True)) ctx.temp_fields.append(mask_field_decl) stmts = [ oir.AssignStmt( left=oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=DataType.BOOL, loc=node.loc, ), right=self.visit(node.cond), ) ] current_mask = oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=mask_field_decl.dtype, loc=node.loc, ) combined_mask = current_mask if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc) stmts.extend( self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) if node.false_branch: combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask) if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc) stmts.extend( self.visit(node.false_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) return stmts
def visit_ParAssignStmt( self, node: gtir.ParAssignStmt, *, mask: oir.Expr = None, ctx: Context, **kwargs: Any ) -> None: body = [oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right))] if mask is not None: body = [oir.MaskStmt(body=body, mask=mask)] ctx.add_horizontal_execution( oir.HorizontalExecution( body=body, declarations=[], ), )
def test_assign_stmt_to_vector_assign(parallel_k): assign_stmt = oir.AssignStmt( left=oir.FieldAccess( name="a", offset=common.CartesianOffset.zero(), dtype=common.DataType.FLOAT64 ), right=oir.FieldAccess( name="b", offset=common.CartesianOffset(i=-1, j=22, k=0), dtype=common.DataType.FLOAT64 ), ) ctx = OirToNpir.ComputationContext() v_assign = OirToNpir().visit(assign_stmt, ctx=ctx, parallel_k=parallel_k, mask=None) assert isinstance(v_assign, npir.VectorAssign) assert v_assign.left.k_offset.parallel is parallel_k assert v_assign.right.k_offset.parallel is parallel_k
def _create_mask(ctx: "GTIRToOIR.Context", name: str, cond: oir.Expr) -> oir.Temporary: mask_field_decl = oir.Temporary(name=name, dtype=DataType.BOOL, dimensions=(True, True, True)) ctx.add_decl(mask_field_decl) fill_mask_field = oir.HorizontalExecution( body=[ oir.AssignStmt( left=oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=mask_field_decl.dtype, ), right=cond, ) ], declarations=[], ) ctx.add_horizontal_execution(fill_mask_field) return mask_field_decl
def _fill_stmts( cls, loop_order: common.LoopOrder, section: oir.VerticalLoopSection, filling_fields: Dict[str, str], first_unfilled: Dict[str, int], symtable: Dict[str, Any], ) -> Tuple[List[oir.AssignStmt], Dict[str, int]]: """Generate fill statements for the given loop section. Args: loop_order: forward or backward order. section: loop section to split. filling_fields: mapping from field names to cache names. first_unfilled: direction-normalized offset of the first unfilled cache entry for each field. Returns: A list of fill statements and an updated `first_unfilled` map. """ fill_limits = cls._fill_limits(loop_order, section) fill_stmts = [] for field, cache in filling_fields.items(): lmin, lmax = fill_limits.get(field, (0, 0)) lmin = max(lmin, first_unfilled.get(field, lmin)) for offset in range(lmin, lmax + 1): k_offset = common.CartesianOffset( i=0, j=0, k=offset if loop_order == common.LoopOrder.FORWARD else -offset, ) fill_stmts.append( oir.AssignStmt( left=oir.FieldAccess( name=cache, dtype=symtable[field].dtype, offset=k_offset ), right=oir.FieldAccess( name=field, dtype=symtable[field].dtype, offset=k_offset ), ) ) first_unfilled[field] = lmax return fill_stmts, first_unfilled
def test_temp_assign(parallel_k): assign_stmt = oir.AssignStmt( left=oir.FieldAccess( name="a", offset=common.CartesianOffset.zero(), dtype=common.DataType.FLOAT64, ), right=oir.FieldAccess(name="b", offset=common.CartesianOffset(i=-1, j=22, k=0), dtype=common.DataType.FLOAT64), ) ctx = OirToNpir.ComputationContext() _ = OirToNpir().visit( assign_stmt, ctx=ctx, parallel_k=parallel_k, mask=None, symtable={"a": TemporaryFactory(name="a")}, ) assert len(ctx.temp_defs) == 1 assert isinstance(ctx.temp_defs["a"].left, npir.VectorTemp) assert isinstance(ctx.temp_defs["a"].right, npir.EmptyTemp)