def sdfg_arrays_to_oir_decls( sdfg: dace.SDFG) -> Tuple[List[oir.Decl], List[oir.Temporary]]: params = list() decls = list() array: dace.data.Data for name, array in sdfg.arrays.items(): dtype = common.typestr_to_data_type(dace_dtype_to_typestr(array.dtype)) if isinstance(array, dace.data.Array): dimensions = array_dimensions(array) if not array.transient: params.append( oir.FieldDecl( name=name, dtype=dtype, dimensions=dimensions, data_dims=array.shape[sum(dimensions):], )) else: decls.append( oir.Temporary( name=name, dtype=dtype, dimensions=dimensions, data_dims=array.shape[sum(dimensions):], )) else: assert isinstance(array, dace.data.Scalar) params.append(oir.ScalarDecl(name=name, dtype=dtype)) reserved_symbols = internal_symbols(sdfg) for sym, stype in sdfg.symbols.items(): if sym not in reserved_symbols: params.append( oir.ScalarDecl(name=sym, dtype=common.typestr_to_data_type( stype.as_numpy_dtype().str))) return params, decls
def validate_oir_sdfg(sdfg: dace.SDFG): from gtc.dace.nodes import VerticalLoopLibraryNode sdfg.validate() is_correct_node_types = all( isinstance(n, (dace.SDFGState, dace.nodes.AccessNode, VerticalLoopLibraryNode)) for n, _ in sdfg.all_nodes_recursive()) is_correct_data_and_dtype = all( isinstance(array, dace.data.Array) and typestr_to_data_type( dace_dtype_to_typestr(array.dtype)) != DataType.INVALID for array in sdfg.arrays.values()) if not is_correct_node_types or not is_correct_data_and_dtype: raise ValueError("Not a valid OIR-level SDFG")
def validate(self, parent_sdfg: dace.SDFG, parent_state: dace.SDFGState, *args, **kwargs): get_node_name_mapping(parent_state, self) for _, sdfg in self.sections: sdfg.validate() is_correct_node_types = all( isinstance(n, (dace.SDFGState, dace.nodes.AccessNode, HorizontalExecutionLibraryNode)) for n, _ in sdfg.all_nodes_recursive()) is_correct_data_and_dtype = all( isinstance(array, dace.data.Array) and typestr_to_data_type( dace_dtype_to_typestr(array.dtype)) != DataType.INVALID for array in sdfg.arrays.values()) if not is_correct_node_types or not is_correct_data_and_dtype: raise ValueError("Tried to convert incompatible SDFG to OIR.") super().validate(parent_sdfg, parent_state, *args, **kwargs)