Beispiel #1
0
def sdfg_arrays_to_oir_decls(
        sdfg: dace.SDFG) -> Tuple[List[oir.Decl], List[oir.Temporary]]:
    params = list()
    decls = list()

    array: dace.data.Data
    for name, array in sdfg.arrays.items():
        dtype = common.typestr_to_data_type(dace_dtype_to_typestr(array.dtype))
        if isinstance(array, dace.data.Array):
            dimensions = array_dimensions(array)
            if not array.transient:
                params.append(
                    oir.FieldDecl(
                        name=name,
                        dtype=dtype,
                        dimensions=dimensions,
                        data_dims=array.shape[sum(dimensions):],
                    ))
            else:
                decls.append(
                    oir.Temporary(
                        name=name,
                        dtype=dtype,
                        dimensions=dimensions,
                        data_dims=array.shape[sum(dimensions):],
                    ))
        else:
            assert isinstance(array, dace.data.Scalar)
            params.append(oir.ScalarDecl(name=name, dtype=dtype))

    reserved_symbols = internal_symbols(sdfg)
    for sym, stype in sdfg.symbols.items():
        if sym not in reserved_symbols:
            params.append(
                oir.ScalarDecl(name=sym,
                               dtype=common.typestr_to_data_type(
                                   stype.as_numpy_dtype().str)))
    return params, decls
Beispiel #2
0
def validate_oir_sdfg(sdfg: dace.SDFG):

    from gtc.dace.nodes import VerticalLoopLibraryNode

    sdfg.validate()
    is_correct_node_types = all(
        isinstance(n, (dace.SDFGState, dace.nodes.AccessNode,
                       VerticalLoopLibraryNode))
        for n, _ in sdfg.all_nodes_recursive())
    is_correct_data_and_dtype = all(
        isinstance(array, dace.data.Array) and typestr_to_data_type(
            dace_dtype_to_typestr(array.dtype)) != DataType.INVALID
        for array in sdfg.arrays.values())
    if not is_correct_node_types or not is_correct_data_and_dtype:
        raise ValueError("Not a valid OIR-level SDFG")
Beispiel #3
0
    def validate(self, parent_sdfg: dace.SDFG, parent_state: dace.SDFGState,
                 *args, **kwargs):

        get_node_name_mapping(parent_state, self)

        for _, sdfg in self.sections:
            sdfg.validate()
            is_correct_node_types = all(
                isinstance(n, (dace.SDFGState, dace.nodes.AccessNode,
                               HorizontalExecutionLibraryNode))
                for n, _ in sdfg.all_nodes_recursive())
            is_correct_data_and_dtype = all(
                isinstance(array, dace.data.Array) and typestr_to_data_type(
                    dace_dtype_to_typestr(array.dtype)) != DataType.INVALID
                for array in sdfg.arrays.values())
            if not is_correct_node_types or not is_correct_data_and_dtype:
                raise ValueError("Tried to convert incompatible SDFG to OIR.")

        super().validate(parent_sdfg, parent_state, *args, **kwargs)