def as_oir(self): loc = SourceLocation( self.debuginfo.start_line or 1, self.debuginfo.start_column or 1, self.debuginfo.filename or "<unknown>", ) sections = [] for interval, sdfg in self.sections: horizontal_executions = [] for state in sdfg.topological_sort(sdfg.start_state): for node in (n for n in nx.topological_sort(state.nx) if isinstance(n, HorizontalExecutionLibraryNode)): horizontal_executions.append( OIRFieldRenamer(get_node_name_mapping( state, node)).visit(node.as_oir())) sections.append( VerticalLoopSection( interval=interval, horizontal_executions=horizontal_executions, loc=loc)) return VerticalLoop(sections=sections, loop_order=self.loop_order, caches=self.caches, loc=loc)
def validate(self, parent_sdfg: dace.SDFG, parent_state: dace.SDFGState, *args, **kwargs): get_node_name_mapping(parent_state, self) for _, sdfg in self.sections: sdfg.validate() is_correct_node_types = all( isinstance(n, (dace.SDFGState, dace.nodes.AccessNode, HorizontalExecutionLibraryNode)) for n, _ in sdfg.all_nodes_recursive()) is_correct_data_and_dtype = all( isinstance(array, dace.data.Array) and typestr_to_data_type( dace_dtype_to_typestr(array.dtype)) != DataType.INVALID for array in sdfg.arrays.values()) if not is_correct_node_types or not is_correct_data_and_dtype: raise ValueError("Tried to convert incompatible SDFG to OIR.") super().validate(parent_sdfg, parent_state, *args, **kwargs)
def convert(sdfg: dace.SDFG) -> oir.Stencil: validate_oir_sdfg(sdfg) params, decls = sdfg_arrays_to_oir_decls(sdfg) vertical_loops = [] for state in sdfg.topological_sort(sdfg.start_state): for node in (n for n in nx.topological_sort(state.nx) if isinstance(n, VerticalLoopLibraryNode)): new_node = OIRFieldRenamer(get_node_name_mapping( state, node)).visit(node.as_oir()) vertical_loops.append(new_node) return oir.Stencil(name=sdfg.name, params=params, declarations=decls, vertical_loops=vertical_loops)
def as_oir(self): sections = [] for interval, sdfg in self.sections: horizontal_executions = [] for state in sdfg.topological_sort(sdfg.start_state): for node in (n for n in nx.topological_sort(state.nx) if isinstance(n, HorizontalExecutionLibraryNode)): horizontal_executions.append( OIRFieldRenamer(get_node_name_mapping( state, node)).visit(node.as_oir())) sections.append( VerticalLoopSection( interval=interval, horizontal_executions=horizontal_executions)) return VerticalLoop( sections=sections, loop_order=self.loop_order, caches=self.caches, )
def validate(self, parent_sdfg: dace.SDFG, parent_state: dace.SDFGState, *args, **kwargs): get_node_name_mapping(parent_state, self)