def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context) -> oir.VerticalLoop: horiz_execs: List[oir.HorizontalExecution] = [] for stmt in node.body: ctx.reset_local_scalars() ret = self.visit(stmt, ctx=ctx) stmts = utils.flatten_list( [ret] if isinstance(ret, oir.Stmt) else ret) horiz_execs.append( oir.HorizontalExecution(body=stmts, declarations=ctx.local_scalars)) ctx.temp_fields += [ oir.Temporary(name=temp.name, dtype=temp.dtype, dimensions=temp.dimensions) for temp in node.temporaries ] return oir.VerticalLoop( loop_order=node.loop_order, sections=[ oir.VerticalLoopSection( interval=self.visit(node.interval), horizontal_executions=horiz_execs, loc=node.loc, ) ], )
def visit_FieldIfStmt(self, node: gtir.FieldIfStmt, *, mask: oir.Expr = None, ctx: Context, **kwargs: Any) -> List[oir.Stmt]: mask_field_decl = oir.Temporary(name=f"mask_{id(node)}", dtype=DataType.BOOL, dimensions=(True, True, True)) ctx.temp_fields.append(mask_field_decl) stmts = [ oir.AssignStmt( left=oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=DataType.BOOL, loc=node.loc, ), right=self.visit(node.cond), ) ] current_mask = oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=mask_field_decl.dtype, loc=node.loc, ) combined_mask = current_mask if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc) stmts.extend( self.visit(node.true_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) if node.false_branch: combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=current_mask) if mask: combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc) stmts.extend( self.visit(node.false_branch.body, mask=combined_mask, ctx=ctx, **kwargs)) return stmts
def visit_VerticalLoop(self, node: gtir.VerticalLoop, *, ctx: Context, **kwargs: Any) -> oir.VerticalLoop: ctx.horizontal_executions.clear() self.visit(node.body, ctx=ctx) for temp in node.temporaries: ctx.add_decl(oir.Temporary(name=temp.name, dtype=temp.dtype)) return oir.VerticalLoop( loop_order=node.loop_order, sections=[ oir.VerticalLoopSection( interval=self.visit(node.interval, **kwargs), horizontal_executions=ctx.horizontal_executions, ) ], caches=[], )
def _create_mask(ctx: "GTIRToOIR.Context", name: str, cond: oir.Expr) -> oir.Temporary: mask_field_decl = oir.Temporary(name=name, dtype=DataType.BOOL, dimensions=(True, True, True)) ctx.add_decl(mask_field_decl) fill_mask_field = oir.HorizontalExecution( body=[ oir.AssignStmt( left=oir.FieldAccess( name=mask_field_decl.name, offset=CartesianOffset.zero(), dtype=mask_field_decl.dtype, ), right=cond, ) ], declarations=[], ) ctx.add_horizontal_execution(fill_mask_field) return mask_field_decl
def sdfg_arrays_to_oir_decls( sdfg: dace.SDFG) -> Tuple[List[oir.Decl], List[oir.Temporary]]: params = list() decls = list() array: dace.data.Data for name, array in sdfg.arrays.items(): dtype = common.typestr_to_data_type(dace_dtype_to_typestr(array.dtype)) if isinstance(array, dace.data.Array): dimensions = array_dimensions(array) if not array.transient: params.append( oir.FieldDecl( name=name, dtype=dtype, dimensions=dimensions, data_dims=array.shape[sum(dimensions):], )) else: decls.append( oir.Temporary( name=name, dtype=dtype, dimensions=dimensions, data_dims=array.shape[sum(dimensions):], )) else: assert isinstance(array, dace.data.Scalar) params.append(oir.ScalarDecl(name=name, dtype=dtype)) reserved_symbols = internal_symbols(sdfg) for sym, stype in sdfg.symbols.items(): if sym not in reserved_symbols: params.append( oir.ScalarDecl(name=sym, dtype=common.typestr_to_data_type( stype.as_numpy_dtype().str))) return params, decls
def visit_VerticalLoop( self, node: oir.VerticalLoop, *, new_tmps: List[oir.Temporary], symtable: Dict[str, Any], new_symbol_name: Callable[[str], str], **kwargs: Any, ) -> oir.VerticalLoop: filling_fields: Dict[str, str] = { c.name: new_symbol_name(c.name) for c in node.caches if isinstance(c, oir.KCache) and c.fill } flushing_fields: Dict[str, str] = { c.name: filling_fields[c.name] if c.name in filling_fields else new_symbol_name(c.name) for c in node.caches if isinstance(c, oir.KCache) and c.flush } filling_or_flushing_fields = dict( set(filling_fields.items()) | set(flushing_fields.items()) ) if not filling_or_flushing_fields: return node # new temporaries used for caches, declarations are later added to stencil for field_name, tmp_name in filling_or_flushing_fields.items(): new_tmps.append( oir.Temporary( name=tmp_name, dtype=symtable[field_name].dtype, dimensions=(True, True, True) ) ) if filling_fields: # split sections where more than one fill operations are required at the entry level first_unfilled: Dict[str, int] = dict() split_sections: List[oir.VerticalLoopSection] = [] for section in node.sections: split_section, previous_fills = self._split_section_with_multiple_fills( node.loop_order, section, filling_fields, first_unfilled, new_symbol_name ) split_sections += split_section else: split_sections = node.sections # generate cache fill and flush statements first_unfilled = dict() sections = [] for section in split_sections: fills, first_unfilled = self._fill_stmts( node.loop_order, section, filling_fields, first_unfilled, symtable ) flushes = self._flush_stmts(node.loop_order, section, flushing_fields, symtable) sections.append( self.visit( section, fills=fills, flushes=flushes, name_map=filling_or_flushing_fields, symtable=symtable, **kwargs, ) ) # replace cache declarations caches = [c for c in node.caches if c.name not in filling_or_flushing_fields] + [ oir.KCache(name=f, fill=False, flush=False) for f in filling_or_flushing_fields.values() ] return oir.VerticalLoop(loop_order=node.loop_order, sections=sections, caches=caches)