Ejemplo n.º 1
0
class ONNXTypeConstraint:
    """ Python representation of an ONNX type constraint. """

    type_str = Property(dtype=str, desc="The type parameter string")
    types = ListProperty(
        element_type=typeclass,
        desc=
        "The possible types. Note that only tensor types are currently supported."
    )

    def __repr__(self):
        return self.type_str
Ejemplo n.º 2
0
class MyListObject(object):
    list_prop = ListProperty(element_type=int)

    def __init__(self, p):
        super().__init__()
        self.list_prop = p

    def to_json(self):
        return all_properties_to_json(self)

    @staticmethod
    def from_json(json_obj, context=None):
        ret = MyListObject([])
        set_properties_from_json(ret, json_obj, context=context)
        return ret
Ejemplo n.º 3
0
class Map(object):
    """ A Map is a two-node representation of parametric graphs, containing
        an integer set by which the contents (nodes dominated by an entry
        node and post-dominated by an exit node) are replicated.

        Maps contain a `schedule` property, which specifies how the scope
        should be scheduled (execution order). Code generators can use the
        schedule property to generate appropriate code, e.g., GPU kernels.
    """

    # List of (editable) properties
    label = Property(dtype=str, desc="Label of the map")
    params = ListProperty(element_type=str, desc="Mapped parameters")
    range = RangeProperty(desc="Ranges of map parameters",
                          default=sbs.Range([]))
    schedule = Property(dtype=dtypes.ScheduleType,
                        desc="Map schedule",
                        choices=dtypes.ScheduleType,
                        from_string=lambda x: dtypes.ScheduleType[x],
                        default=dtypes.ScheduleType.Default)
    unroll = Property(dtype=bool, desc="Map unrolling")
    collapse = Property(dtype=int,
                        default=1,
                        desc="How many dimensions to"
                        " collapse into the parallel range")
    debuginfo = DebugInfoProperty()
    is_collapsed = Property(dtype=bool,
                            desc="Show this node/scope/state as collapsed",
                            default=False)

    instrument = Property(choices=dtypes.InstrumentationType,
                          desc="Measure execution statistics with given method",
                          default=dtypes.InstrumentationType.No_Instrumentation)

    def __init__(self,
                 label,
                 params,
                 ndrange,
                 schedule=dtypes.ScheduleType.Default,
                 unroll=False,
                 collapse=1,
                 fence_instrumentation=False,
                 debuginfo=None):
        super(Map, self).__init__()

        # Assign properties
        self.label = label
        self.schedule = schedule
        self.unroll = unroll
        self.collapse = 1
        self.params = params
        self.range = ndrange
        self.debuginfo = debuginfo
        self._fence_instrumentation = fence_instrumentation

    def __str__(self):
        return self.label + "[" + ", ".join([
            "{}={}".format(i, r) for i, r in zip(
                self._params, [sbs.Range.dim_to_string(d) for d in self._range])
        ]) + "]"

    def validate(self, sdfg, state, node):
        if not dtypes.validate_name(self.label):
            raise NameError('Invalid map name "%s"' % self.label)

    def get_param_num(self):
        """ Returns the number of map dimension parameters/symbols. """
        return len(self.params)
Ejemplo n.º 4
0
class Stream(Data):
    """ Stream (or stream array) data descriptor. """

    # Properties
    offset = ListProperty(element_type=symbolic.pystr_to_symbolic)
    buffer_size = SymbolicProperty(desc="Size of internal buffer.", default=0)

    def __init__(self,
                 dtype,
                 buffer_size,
                 shape=None,
                 transient=False,
                 storage=dtypes.StorageType.Default,
                 location=None,
                 offset=None,
                 lifetime=dtypes.AllocationLifetime.Scope,
                 debuginfo=None):

        if shape is None:
            shape = (1, )

        self.buffer_size = buffer_size

        if offset is not None:
            if len(offset) != len(shape):
                raise TypeError('Offset must be the same size as shape')
            self.offset = cp.copy(offset)
        else:
            self.offset = [0] * len(shape)

        super(Stream, self).__init__(dtype, shape, transient, storage, location,
                                     lifetime, debuginfo)

    def to_json(self):
        attrs = serialize.all_properties_to_json(self)

        retdict = {"type": type(self).__name__, "attributes": attrs}

        return retdict

    @classmethod
    def from_json(cls, json_obj, context=None):
        # Create dummy object
        ret = cls(dtypes.int8, 1)
        serialize.set_properties_from_json(ret, json_obj, context=context)

        # Check validity now
        ret.validate()
        return ret

    def __repr__(self):
        return '%s (dtype=%s, shape=%s)' % (type(self).__name__, self.dtype,
                                            self.shape)

    @property
    def total_size(self):
        return _prod(self.shape)

    @property
    def strides(self):
        return [_prod(self.shape[i + 1:]) for i in range(len(self.shape))]

    def clone(self):
        return type(self)(self.dtype, self.buffer_size, self.shape,
                          self.transient, self.storage, self.location,
                          self.offset, self.lifetime, self.debuginfo)

    # Checks for equivalent shape and type
    def is_equivalent(self, other):
        if not isinstance(other, type(self)):
            return False

        # Test type
        if self.dtype != other.dtype:
            return False

        # Test dimensionality
        if len(self.shape) != len(other.shape):
            return False

        # Test shape
        for dim, otherdim in zip(self.shape, other.shape):
            if dim != otherdim:
                return False
        return True

    def as_arg(self, with_types=True, for_call=False, name=None):
        if not with_types or for_call: return name
        if self.storage in [
                dtypes.StorageType.GPU_Global, dtypes.StorageType.GPU_Shared
        ]:
            return 'dace::GPUStream<%s, %s> %s' % (str(
                self.dtype.ctype), 'true' if sp.log(
                    self.buffer_size, 2).is_Integer else 'false', name)

        return 'dace::Stream<%s> %s' % (str(self.dtype.ctype), name)

    def sizes(self):
        return [
            d.name if isinstance(d, symbolic.symbol) else str(d)
            for d in self.shape
        ]

    def size_string(self):
        return (" * ".join(
            [cppunparse.pyexpr2cpp(symbolic.symstr(s)) for s in self.shape]))

    def is_stream_array(self):
        return _prod(self.shape) != 1

    def covers_range(self, rng):
        if len(rng) != len(self.shape):
            return False

        for s, (rb, re, rs) in zip(self.shape, rng):
            # Shape has to be positive
            if isinstance(s, sp.Basic):
                olds = s
                if 'positive' in s.assumptions0:
                    s = sp.Symbol(str(s), **s.assumptions0)
                else:
                    s = sp.Symbol(str(s), positive=True, **s.assumptions0)
                if isinstance(rb, sp.Basic):
                    rb = rb.subs({olds: s})
                if isinstance(re, sp.Basic):
                    re = re.subs({olds: s})
                if isinstance(rs, sp.Basic):
                    rs = rs.subs({olds: s})

            try:
                if rb < 0:  # Negative offset
                    return False
            except TypeError:  # cannot determine truth value of Relational
                pass
                #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (rb > 0),
                #      'If this expression is false, please refine symbol definitions in the program.')
            try:
                if re > s:  # Beyond shape
                    return False
            except TypeError:  # cannot determine truth value of Relational
                pass
                #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (re < s),
                #      'If this expression is false, please refine symbol definitions in the program.')

        return True

    @property
    def free_symbols(self):
        result = super().free_symbols
        if isinstance(self.buffer_size, sp.Expr):
            result |= set(self.buffer_size.free_symbols)
        for o in self.offset:
            if isinstance(o, sp.Expr):
                result |= set(o.free_symbols)

        return result
Ejemplo n.º 5
0
class Array(Data):
    """ Array/constant descriptor (dimensions, type and other properties). """

    # Properties
    allow_conflicts = Property(
        dtype=bool,
        default=False,
        desc='If enabled, allows more than one '
        'memlet to write to the same memory location without conflict '
        'resolution.')

    strides = ShapeProperty(
        # element_type=symbolic.pystr_to_symbolic,
        desc='For each dimension, the number of elements to '
        'skip in order to obtain the next element in '
        'that dimension.')

    total_size = SymbolicProperty(
        default=1,
        desc='The total allocated size of the array. Can be used for'
        ' padding.')

    offset = ListProperty(element_type=symbolic.pystr_to_symbolic,
                          desc='Initial offset to translate all indices by.')

    may_alias = Property(dtype=bool,
                         default=False,
                         desc='This pointer may alias with other pointers in '
                         'the same function')

    alignment = Property(dtype=int,
                         default=0,
                         desc='Allocation alignment in bytes (0 uses '
                         'compiler-default)')

    def __init__(self,
                 dtype,
                 shape,
                 transient=False,
                 allow_conflicts=False,
                 storage=dtypes.StorageType.Default,
                 location=None,
                 strides=None,
                 offset=None,
                 may_alias=False,
                 lifetime=dtypes.AllocationLifetime.Scope,
                 alignment=0,
                 debuginfo=None,
                 total_size=None):

        super(Array, self).__init__(dtype, shape, transient, storage, location,
                                    lifetime, debuginfo)

        if shape is None:
            raise IndexError('Shape must not be None')

        self.allow_conflicts = allow_conflicts
        self.may_alias = may_alias
        self.alignment = alignment

        if strides is not None:
            self.strides = cp.copy(strides)
        else:
            self.strides = [_prod(shape[i + 1:]) for i in range(len(shape))]

        self.total_size = total_size or _prod(shape)

        if offset is not None:
            self.offset = cp.copy(offset)
        else:
            self.offset = [0] * len(shape)

        self.validate()

    def __repr__(self):
        return '%s (dtype=%s, shape=%s)' % (type(self).__name__, self.dtype,
                                            self.shape)

    def clone(self):
        return type(self)(self.dtype, self.shape, self.transient,
                          self.allow_conflicts, self.storage, self.location,
                          self.strides, self.offset, self.may_alias,
                          self.lifetime, self.alignment, self.debuginfo,
                          self.total_size)

    def to_json(self):
        attrs = serialize.all_properties_to_json(self)

        # Take care of symbolic expressions
        attrs['strides'] = list(map(str, attrs['strides']))

        retdict = {"type": type(self).__name__, "attributes": attrs}

        return retdict

    @classmethod
    def from_json(cls, json_obj, context=None):
        # Create dummy object
        ret = cls(dtypes.int8, ())
        serialize.set_properties_from_json(ret, json_obj, context=context)
        # TODO: This needs to be reworked (i.e. integrated into the list property)
        ret.strides = list(map(symbolic.pystr_to_symbolic, ret.strides))

        # Check validity now
        ret.validate()
        return ret

    def validate(self):
        super(Array, self).validate()
        if len(self.strides) != len(self.shape):
            raise TypeError('Strides must be the same size as shape')

        if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol,
                                  symbolic.sympy.Basic)) for s in self.strides):
            raise TypeError('Strides must be a list or tuple of integer '
                            'values or symbols')

        if len(self.offset) != len(self.shape):
            raise TypeError('Offset must be the same size as shape')

    def covers_range(self, rng):
        if len(rng) != len(self.shape):
            return False

        for s, (rb, re, rs) in zip(self.shape, rng):
            # Shape has to be positive
            if isinstance(s, sp.Basic):
                olds = s
                if 'positive' in s.assumptions0:
                    s = sp.Symbol(str(s), **s.assumptions0)
                else:
                    s = sp.Symbol(str(s), positive=True, **s.assumptions0)
                if isinstance(rb, sp.Basic):
                    rb = rb.subs({olds: s})
                if isinstance(re, sp.Basic):
                    re = re.subs({olds: s})
                if isinstance(rs, sp.Basic):
                    rs = rs.subs({olds: s})

            try:
                if rb < 0:  # Negative offset
                    return False
            except TypeError:  # cannot determine truth value of Relational
                pass
                #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (rb > 0),
                #      'If this expression is false, please refine symbol definitions in the program.')
            try:
                if re > s:  # Beyond shape
                    return False
            except TypeError:  # cannot determine truth value of Relational
                pass
                #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (re < s),
                #      'If this expression is false, please refine symbol definitions in the program.')

        return True

    # Checks for equivalent shape and type
    def is_equivalent(self, other):
        if not isinstance(other, type(self)):
            return False

        # Test type
        if self.dtype != other.dtype:
            return False

        # Test dimensionality
        if len(self.shape) != len(other.shape):
            return False

        # Test shape
        for dim, otherdim in zip(self.shape, other.shape):
            # Any other case (constant vs. constant), check for equality
            if otherdim != dim:
                return False
        return True

    def as_arg(self, with_types=True, for_call=False, name=None):
        arrname = name

        if not with_types or for_call:
            return arrname
        if self.may_alias:
            return str(self.dtype.ctype) + ' *' + arrname
        return str(self.dtype.ctype) + ' * __restrict__ ' + arrname

    def sizes(self):
        return [
            d.name if isinstance(d, symbolic.symbol) else str(d)
            for d in self.shape
        ]

    @property
    def free_symbols(self):
        result = super().free_symbols
        for s in self.strides:
            if isinstance(s, sp.Expr):
                result |= set(s.free_symbols)
        if isinstance(self.total_size, sp.Expr):
            result |= set(self.total_size.free_symbols)
        for o in self.offset:
            if isinstance(o, sp.Expr):
                result |= set(o.free_symbols)

        return result
Ejemplo n.º 6
0
class Reduce(dace.sdfg.nodes.LibraryNode):
    """ An SDFG node that reduces an N-dimensional array to an
        (N-k)-dimensional array, with a list of axes to reduce and
        a reduction binary function. """

    # Global properties
    implementations = {
        'pure': ExpandReducePure,
        'OpenMP': ExpandReduceOpenMP,
        'CUDA (device)': ExpandReduceCUDADevice,
        'CUDA (block)': ExpandReduceCUDABlock,
        'CUDA (block allreduce)': ExpandReduceCUDABlockAll
        # 'CUDA (warp)': ExpandReduceCUDAWarp,
        # 'CUDA (warp allreduce)': ExpandReduceCUDAWarpAll
    }

    default_implementation = 'pure'

    # Properties
    axes = ListProperty(element_type=int, allow_none=True)
    wcr = LambdaProperty(default='lambda a, b: a')
    identity = Property(allow_none=True)

    def __init__(self,
                 wcr='lambda a, b: a',
                 axes=None,
                 identity=None,
                 schedule=dtypes.ScheduleType.Default,
                 debuginfo=None,
                 **kwargs):
        super().__init__(name='Reduce', **kwargs)
        self.wcr = wcr
        self.axes = axes
        self.identity = identity
        self.debuginfo = debuginfo
        self.schedule = schedule

    @staticmethod
    def from_json(json_obj, context=None):
        ret = Reduce("lambda a, b: a", None)
        dace.serialize.set_properties_from_json(ret, json_obj, context=context)
        return ret

    def __str__(self):
        # Autodetect reduction type
        redtype = detect_reduction_type(self.wcr)
        if redtype == dtypes.ReductionType.Custom:
            wcrstr = unparse(ast.parse(self.wcr).body[0].value.body)
        else:
            wcrstr = str(redtype)
            wcrstr = wcrstr[wcrstr.find('.') + 1:]  # Skip "ReductionType."

        return 'Reduce ({op}), Axes: {axes}'.format(
            axes=('all' if self.axes is None else str(self.axes)), op=wcrstr)

    def __label__(self, sdfg, state):
        return str(self).replace(' Axes', '\nAxes')

    def validate(self, sdfg, state):
        if len(state.in_edges(self)) != 1:
            raise ValueError('Reduce node must have one input')
        if len(state.out_edges(self)) != 1:
            raise ValueError('Reduce node must have one output')
Ejemplo n.º 7
0
class Tasklet(CodeNode):
    """ A node that contains a tasklet: a functional computation procedure
        that can only access external data specified using connectors.

        Tasklets may be implemented in Python, C++, or any supported
        language by the code generator.
    """

    code = CodeProperty(desc="Tasklet code", default=CodeBlock(""))
    state_fields = ListProperty(
        element_type=str, desc="Fields that are added to the global state")
    code_global = CodeProperty(
        desc="Global scope code needed for tasklet execution",
        default=CodeBlock("", dtypes.Language.CPP))
    code_init = CodeProperty(
        desc="Extra code that is called on DaCe runtime initialization",
        default=CodeBlock("", dtypes.Language.CPP))
    code_exit = CodeProperty(
        desc="Extra code that is called on DaCe runtime cleanup",
        default=CodeBlock("", dtypes.Language.CPP))
    debuginfo = DebugInfoProperty()

    instrument = EnumProperty(
        dtype=dtypes.InstrumentationType,
        desc="Measure execution statistics with given method",
        default=dtypes.InstrumentationType.No_Instrumentation)

    def __init__(self,
                 label,
                 inputs=None,
                 outputs=None,
                 code="",
                 language=dtypes.Language.Python,
                 state_fields=None,
                 code_global="",
                 code_init="",
                 code_exit="",
                 location=None,
                 debuginfo=None):
        super(Tasklet, self).__init__(label, location, inputs, outputs)

        self.code = CodeBlock(code, language)

        self.state_fields = state_fields or []
        self.code_global = CodeBlock(code_global, dtypes.Language.CPP)
        self.code_init = CodeBlock(code_init, dtypes.Language.CPP)
        self.code_exit = CodeBlock(code_exit, dtypes.Language.CPP)
        self.debuginfo = debuginfo

    @property
    def language(self):
        return self.code.language

    @staticmethod
    def from_json(json_obj, context=None):
        ret = Tasklet("dummylabel")
        dace.serialize.set_properties_from_json(ret, json_obj, context=context)
        return ret

    @property
    def name(self):
        return self._label

    def validate(self, sdfg, state):
        if not dtypes.validate_name(self.label):
            raise NameError('Invalid tasklet name "%s"' % self.label)
        for in_conn in self.in_connectors:
            if not dtypes.validate_name(in_conn):
                raise NameError('Invalid input connector "%s"' % in_conn)
        for out_conn in self.out_connectors:
            if not dtypes.validate_name(out_conn):
                raise NameError('Invalid output connector "%s"' % out_conn)

    @property
    def free_symbols(self) -> Set[str]:
        return self.code.get_free_symbols(self.in_connectors.keys()
                                          | self.out_connectors.keys())

    def infer_connector_types(self, sdfg, state):
        # If a MLIR tasklet, simply read out the types (it's explicit)
        if self.code.language == dtypes.Language.MLIR:
            # Inline import because mlir.utils depends on pyMLIR which may not be installed
            # Doesn't cause crashes due to missing pyMLIR if a MLIR tasklet is not present
            from dace.codegen.targets.mlir import utils

            mlir_ast = utils.get_ast(self.code.code)
            mlir_is_generic = utils.is_generic(mlir_ast)
            mlir_entry_func = utils.get_entry_func(mlir_ast, mlir_is_generic)

            mlir_result_type = utils.get_entry_result_type(
                mlir_entry_func, mlir_is_generic)
            mlir_out_name = next(iter(self.out_connectors.keys()))

            if self.out_connectors[
                    mlir_out_name] is None or self.out_connectors[
                        mlir_out_name].ctype == "void":
                self.out_connectors[mlir_out_name] = utils.get_dace_type(
                    mlir_result_type)
            elif self.out_connectors[mlir_out_name] != utils.get_dace_type(
                    mlir_result_type):
                warnings.warn(
                    "Type mismatch between MLIR tasklet out connector and MLIR code"
                )

            for mlir_arg in utils.get_entry_args(mlir_entry_func,
                                                 mlir_is_generic):
                if self.in_connectors[
                        mlir_arg[0]] is None or self.in_connectors[
                            mlir_arg[0]].ctype == "void":
                    self.in_connectors[mlir_arg[0]] = utils.get_dace_type(
                        mlir_arg[1])
                elif self.in_connectors[mlir_arg[0]] != utils.get_dace_type(
                        mlir_arg[1]):
                    warnings.warn(
                        "Type mismatch between MLIR tasklet in connector and MLIR code"
                    )

            return

        # If a Python tasklet, use type inference to figure out all None output
        # connectors
        if all(cval.type is not None for cval in self.out_connectors.values()):
            return
        if self.code.language != dtypes.Language.Python:
            return

        if any(cval.type is None for cval in self.in_connectors.values()):
            raise TypeError('Cannot infer output connectors of tasklet "%s", '
                            'not all input connectors have types' % str(self))

        # Avoid import loop
        from dace.codegen.tools.type_inference import infer_types

        # Get symbols defined at beginning of node, and infer all types in
        # tasklet
        syms = state.symbols_defined_at(self)
        syms.update(self.in_connectors)
        new_syms = infer_types(self.code.code, syms)
        for cname, oconn in self.out_connectors.items():
            if oconn.type is None:
                if cname not in new_syms:
                    raise TypeError('Cannot infer type of tasklet %s output '
                                    '"%s", please specify manually.' %
                                    (self.label, cname))
                self.out_connectors[cname] = new_syms[cname]

    def __str__(self):
        if not self.label:
            return "--Empty--"
        else:
            return self.label
Ejemplo n.º 8
0
Archivo: data.py Proyecto: tbennun/dace
class Stream(Data):
    """ Stream (or stream array) data descriptor. """

    # Properties
    strides = ListProperty(element_type=symbolic.pystr_to_symbolic)
    offset = ListProperty(element_type=symbolic.pystr_to_symbolic)
    buffer_size = SymbolicProperty(desc="Size of internal buffer.")
    veclen = Property(dtype=int,
                      desc="Vector length. Memlets must adhere to this.")

    def __init__(self,
                 dtype,
                 veclen,
                 buffer_size,
                 shape=None,
                 transient=False,
                 storage=dace.dtypes.StorageType.Default,
                 location='',
                 strides=None,
                 offset=None,
                 toplevel=False,
                 debuginfo=None):

        if shape is None:
            shape = (1, )

        self.veclen = veclen
        self.buffer_size = buffer_size

        if strides is not None:
            if len(strides) != len(shape):
                raise TypeError('Strides must be the same size as shape')
            self.strides = cp.copy(strides)
        else:
            self.strides = cp.copy(list(shape))

        if offset is not None:
            if len(offset) != len(shape):
                raise TypeError('Offset must be the same size as shape')
            self.offset = cp.copy(offset)
        else:
            self.offset = [0] * len(shape)

        super(Stream, self).__init__(dtype, shape, transient, storage,
                                     location, toplevel, debuginfo)

    def to_json(self):
        attrs = dace.serialize.all_properties_to_json(self)

        # Take care of symbolic expressions
        attrs['strides'] = list(map(str, attrs['strides']))

        retdict = {"type": type(self).__name__, "attributes": attrs}

        return retdict

    @staticmethod
    def from_json(json_obj, context=None):
        if json_obj['type'] != "Stream":
            raise TypeError("Invalid data type")

        # Create dummy object
        ret = Stream(dace.dtypes.int8, 1, 1)
        dace.serialize.set_properties_from_json(ret, json_obj, context=context)
        # TODO: FIXME:
        # Since the strides are a list-property (normal Property()),
        # loading from/to string (and, consequently, from/to json)
        # leads to validation errors (contains Strings/Integers, not sympy symbols).
        # To fix this, it needs a custom class
        # For now, this is a workaround:
        ret.strides = list(map(symbolic.pystr_to_symbolic, ret.strides))

        # Check validity now
        ret.validate()
        return ret

    def __repr__(self):
        return 'Stream (dtype=%s, shape=%s)' % (self.dtype, self.shape)

    def clone(self):
        return Stream(self.dtype, self.veclen, self.buffer_size, self.shape,
                      self.transient, self.storage, self.location,
                      self.strides, self.offset, self.toplevel, self.debuginfo)

    # Checks for equivalent shape and type
    def is_equivalent(self, other):
        if not isinstance(other, Stream):
            return False

        # Test type
        if self.dtype != other.dtype:
            return False

        # Test dimensionality
        if len(self.shape) != len(other.shape):
            return False

        # Test shape
        for dim, otherdim in zip(self.shape, other.shape):
            # If both are symbols, ensure equality
            if symbolic.issymbolic(dim) and symbolic.issymbolic(otherdim):
                if dim != otherdim:
                    return False

            # If one is a symbol and the other is a constant
            # make sure they are equivalent
            elif symbolic.issymbolic(otherdim):
                if symbolic.eval(otherdim) != dim:
                    return False
            elif symbolic.issymbolic(dim):
                if symbolic.eval(dim) != otherdim:
                    return False
            else:
                # Any other case (constant vs. constant), check for equality
                if otherdim != dim:
                    return False
        return True

    def signature(self, with_types=True, for_call=False, name=None):
        if not with_types or for_call: return name
        if self.storage in [
                dace.dtypes.StorageType.GPU_Global,
                dace.dtypes.StorageType.GPU_Shared,
                dace.dtypes.StorageType.GPU_Stack
        ]:
            return 'dace::GPUStream<%s, %s> %s' % (str(
                self.dtype.ctype), 'true' if sp.log(
                    self.buffer_size, 2).is_Integer else 'false', name)

        return 'dace::Stream<%s> %s' % (str(self.dtype.ctype), name)

    def sizes(self):
        return [
            d.name if isinstance(d, symbolic.symbol) else str(d)
            for d in self.shape
        ]

    def size_string(self):
        return (" * ".join([
            cppunparse.pyexpr2cpp(dace.symbolic.symstr(s))
            for s in self.strides
        ]))

    def is_stream_array(self):
        return functools.reduce(lambda a, b: a * b, self.strides) != 1

    def covers_range(self, rng):
        if len(rng) != len(self.shape):
            return False

        for s, (rb, re, rs) in zip(self.shape, rng):
            # Shape has to be positive
            if isinstance(s, sympy.Basic):
                olds = s
                if 'positive' in s.assumptions0:
                    s = sympy.Symbol(str(s), **s.assumptions0)
                else:
                    s = sympy.Symbol(str(s), positive=True, **s.assumptions0)
                if isinstance(rb, sympy.Basic):
                    rb = rb.subs({olds: s})
                if isinstance(re, sympy.Basic):
                    re = re.subs({olds: s})
                if isinstance(rs, sympy.Basic):
                    rs = rs.subs({olds: s})

            try:
                if rb < 0:  # Negative offset
                    return False
            except TypeError:  # cannot determine truth value of Relational
                pass
                #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (rb > 0),
                #      'If this expression is false, please refine symbol definitions in the program.')
            try:
                if re > s:  # Beyond shape
                    return False
            except TypeError:  # cannot determine truth value of Relational
                pass
                #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (re < s),
                #      'If this expression is false, please refine symbol definitions in the program.')

        return True
Ejemplo n.º 9
0
Archivo: data.py Proyecto: tbennun/dace
class Array(Data):
    """ Array/constant descriptor (dimensions, type and other properties). """

    # Properties
    allow_conflicts = Property(dtype=bool)
    # TODO: Should we use a Code property here?
    materialize_func = Property(dtype=str,
                                allow_none=True,
                                setter=set_materialize_func)
    access_order = ListProperty(element_type=int)
    strides = ListProperty(element_type=symbolic.pystr_to_symbolic)
    offset = ListProperty(element_type=symbolic.pystr_to_symbolic)
    may_alias = Property(dtype=bool,
                         default=False,
                         desc='This pointer may alias with other pointers in '
                         'the same function')

    def __init__(self,
                 dtype,
                 shape,
                 materialize_func=None,
                 transient=False,
                 allow_conflicts=False,
                 storage=dace.dtypes.StorageType.Default,
                 location='',
                 access_order=None,
                 strides=None,
                 offset=None,
                 may_alias=False,
                 toplevel=False,
                 debuginfo=None):

        super(Array, self).__init__(dtype, shape, transient, storage, location,
                                    toplevel, debuginfo)

        if shape is None:
            raise IndexError('Shape must not be None')

        self.allow_conflicts = allow_conflicts
        self.materialize_func = materialize_func
        self.may_alias = may_alias

        if access_order is not None:
            self.access_order = cp.copy(access_order)
        else:
            self.access_order = tuple(i for i in range(len(shape)))

        if strides is not None:
            self.strides = cp.copy(strides)
        else:
            self.strides = cp.copy(list(shape))

        if offset is not None:
            self.offset = cp.copy(offset)
        else:
            self.offset = [0] * len(shape)

        self.validate()

    def __repr__(self):
        return 'Array (dtype=%s, shape=%s)' % (self.dtype, self.shape)

    def clone(self):
        return Array(self.dtype, self.shape, self.materialize_func,
                     self.transient, self.allow_conflicts, self.storage,
                     self.location, self.access_order, self.strides,
                     self.offset, self.may_alias, self.toplevel,
                     self.debuginfo)

    def to_json(self):
        attrs = dace.serialize.all_properties_to_json(self)

        # Take care of symbolic expressions
        attrs['strides'] = list(map(str, attrs['strides']))

        retdict = {"type": type(self).__name__, "attributes": attrs}

        return retdict

    @staticmethod
    def from_json(json_obj, context=None):
        if json_obj['type'] != "Array":
            raise TypeError("Invalid data type")

        # Create dummy object
        ret = Array(dace.dtypes.int8, ())
        dace.serialize.set_properties_from_json(ret, json_obj, context=context)
        # TODO: This needs to be reworked (i.e. integrated into the list property)
        ret.strides = list(map(symbolic.pystr_to_symbolic, ret.strides))

        # Check validity now
        ret.validate()
        return ret

    def validate(self):
        super(Array, self).validate()
        if len(self.access_order) != len(self.shape):
            raise TypeError('Access order must be the same size as shape')

        if len(self.strides) != len(self.shape):
            raise TypeError('Strides must be the same size as shape')

        if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol,
                                  symbolic.sympy.Basic))
               for s in self.strides):
            raise TypeError('Strides must be a list or tuple of integer '
                            'values or symbols')

        if len(self.offset) != len(self.shape):
            raise TypeError('Offset must be the same size as shape')

    def covers_range(self, rng):
        if len(rng) != len(self.shape):
            return False

        for s, (rb, re, rs) in zip(self.shape, rng):
            # Shape has to be positive
            if isinstance(s, sympy.Basic):
                olds = s
                if 'positive' in s.assumptions0:
                    s = sympy.Symbol(str(s), **s.assumptions0)
                else:
                    s = sympy.Symbol(str(s), positive=True, **s.assumptions0)
                if isinstance(rb, sympy.Basic):
                    rb = rb.subs({olds: s})
                if isinstance(re, sympy.Basic):
                    re = re.subs({olds: s})
                if isinstance(rs, sympy.Basic):
                    rs = rs.subs({olds: s})

            try:
                if rb < 0:  # Negative offset
                    return False
            except TypeError:  # cannot determine truth value of Relational
                pass
                #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (rb > 0),
                #      'If this expression is false, please refine symbol definitions in the program.')
            try:
                if re > s:  # Beyond shape
                    return False
            except TypeError:  # cannot determine truth value of Relational
                pass
                #print('WARNING: Cannot evaluate relational expression %s, assuming true.' % (re < s),
                #      'If this expression is false, please refine symbol definitions in the program.')

        return True

    # Checks for equivalent shape and type
    def is_equivalent(self, other):
        if not isinstance(other, Array):
            return False

        # Test type
        if self.dtype != other.dtype:
            return False

        # Test dimensionality
        if len(self.shape) != len(other.shape):
            return False

        # Test shape
        for dim, otherdim in zip(self.shape, other.shape):
            # If both are symbols, ensure equality
            if symbolic.issymbolic(dim) and symbolic.issymbolic(otherdim):
                if dim != otherdim:
                    return False

            # If one is a symbol and the other is a constant
            # make sure they are equivalent
            elif symbolic.issymbolic(otherdim):
                if symbolic.eval(otherdim) != dim:
                    return False
            elif symbolic.issymbolic(dim):
                if symbolic.eval(dim) != otherdim:
                    return False
            else:
                # Any other case (constant vs. constant), check for equality
                if otherdim != dim:
                    return False
        return True

    def signature(self, with_types=True, for_call=False, name=None):
        arrname = name
        if self.materialize_func is not None:
            if for_call:
                return 'nullptr'
            if not with_types:
                return arrname
            arrname = '/* ' + arrname + ' (immaterial) */'

        if not with_types or for_call:
            return arrname
        if self.may_alias:
            return str(self.dtype.ctype) + ' *' + arrname
        return str(self.dtype.ctype) + ' * __restrict__ ' + arrname

    def sizes(self):
        return [
            d.name if isinstance(d, symbolic.symbol) else str(d)
            for d in self.shape
        ]
Ejemplo n.º 10
0
class Reduce(Node):
    """ An SDFG node that reduces an N-dimensional array to an
        (N-k)-dimensional array, with a list of axes to reduce and
        a reduction binary function. """

    # Properties
    axes = ListProperty(element_type=int, allow_none=True)
    wcr = LambdaProperty()
    identity = Property(dtype=object, allow_none=True)
    schedule = Property(dtype=dtypes.ScheduleType,
                        desc="Reduction execution policy",
                        choices=dtypes.ScheduleType,
                        from_string=lambda x: dtypes.ScheduleType[x])
    debuginfo = DebugInfoProperty()

    instrument = Property(
        choices=dtypes.InstrumentationType,
        desc="Measure execution statistics with given method",
        default=dtypes.InstrumentationType.No_Instrumentation)

    def __init__(self,
                 wcr,
                 axes,
                 wcr_identity=None,
                 schedule=dtypes.ScheduleType.Default,
                 debuginfo=None):
        super(Reduce, self).__init__()
        self.wcr = wcr  # type: ast._Lambda
        self.axes = axes
        self.identity = wcr_identity
        self.schedule = schedule
        self.debuginfo = debuginfo

    def draw_node(self, sdfg, state):
        return dot.draw_node(sdfg, state, self, shape="invtriangle")

    @staticmethod
    def from_json(json_obj, context=None):
        ret = Reduce("(lambda a, b: (a + b))", None)
        dace.serialize.set_properties_from_json(ret, json_obj, context=context)
        return ret

    def __str__(self):
        # Autodetect reduction type
        redtype = detect_reduction_type(self.wcr)
        if redtype == dtypes.ReductionType.Custom:
            wcrstr = unparse(ast.parse(self.wcr).body[0].value.body)
        else:
            wcrstr = str(redtype)
            wcrstr = wcrstr[wcrstr.find('.') + 1:]  # Skip "ReductionType."

        return 'Op: {op}, Axes: {axes}'.format(
            axes=('all' if self.axes is None else str(self.axes)), op=wcrstr)

    def __label__(self, sdfg, state):
        # Autodetect reduction type
        redtype = detect_reduction_type(self.wcr)
        if redtype == dtypes.ReductionType.Custom:
            wcrstr = unparse(ast.parse(self.wcr).body[0].value.body)
        else:
            wcrstr = str(redtype)
            wcrstr = wcrstr[wcrstr.find('.') + 1:]  # Skip "ReductionType."

        return 'Op: {op}\nAxes: {axes}'.format(
            axes=('all' if self.axes is None else str(self.axes)), op=wcrstr)
Ejemplo n.º 11
0
class ONNXSchema:
    """Python representation of an ONNX schema"""

    name = Property(dtype=str, desc="The operator name")
    domain = Property(dtype=str, desc="The operator domain")
    doc = Property(dtype=str, desc="The operator's docstring")
    since_version = Property(dtype=int, desc="The version of the operator")
    attributes = DictProperty(key_type=str,
                              value_type=ONNXAttribute,
                              desc="The operator attributes")
    type_constraints = DictProperty(
        key_type=str,
        value_type=ONNXTypeConstraint,
        desc="The type constraints for inputs and outputs")
    inputs = ListProperty(element_type=ONNXParameter,
                          desc="The operator input parameter descriptors")
    outputs = ListProperty(element_type=ONNXParameter,
                           desc="The operator output parameter descriptors")

    def __repr__(self):
        return self.domain + "." + self.name

    def validate(self):
        # check all parameters with a type str have a entry in the type constraints
        for param in chain(self.inputs, self.outputs):
            if param.type_str not in self.type_constraints:
                # some operators put a type descriptor here. for those, we will try to insert a new type constraint
                cons_name = param.name + "_constraint"
                if cons_name in self.type_constraints:
                    raise ValueError(
                        "Attempted to insert new type constraint, but the name already existed. Please open an issue."
                    )
                parsed_typeclass = onnx_type_str_to_typeclass(param.type_str)

                if parsed_typeclass is None:
                    print("Could not parse typeStr '{}' for parameter '{}'".
                          format(param.type_str, param.name))

                cons = ONNXTypeConstraint(
                    cons_name,
                    [parsed_typeclass] if parsed_typeclass is not None else [])
                self.type_constraints[cons_name] = cons
                param.type_str = cons_name

        # check for required parameters with no supported type
        for param in chain(self.inputs, self.outputs):
            if ((param.param_type == ONNXParameterType.Single
                 or param.param_type == ONNXParameterType.Variadic)
                    and len(self.type_constraints[param.type_str].types) == 0):
                raise NotImplementedError(
                    "None of the types for parameter '{}' are supported".
                    format(param.name))

        # check that all variadic parameter names do not contain "__"
        for param in chain(self.inputs, self.outputs):
            if param.param_type == ONNXParameterType.Variadic and "__" in param.name:
                raise ValueError(
                    "Unsupported parameter name '{}': variadic parameter names must not contain '__'"
                    .format(param.name))

        # check that all inputs and outputs have unique names
        seen = set()
        for param in self.inputs:
            if param.name in seen:
                raise ValueError(
                    "Got duplicate input parameter name '{}'".format(
                        param.name))
            seen.add(param.name)

        seen = set()
        for param in self.outputs:
            if param.name in seen:
                raise ValueError(
                    "Got duplicate output parameter name '{}'".format(
                        param.name))
            seen.add(param.name)
Ejemplo n.º 12
0
        if attr.type in [
                ONNXAttributeType.Int, ONNXAttributeType.String,
                ONNXAttributeType.Float, ONNXAttributeType.Tensor
        ]:
            attrs[name] = Property(dtype=_ATTR_TYPE_TO_PYTHON_TYPE[attr.type],
                                   desc=attr.description,
                                   allow_none=True,
                                   default=None if attr.default_value is None
                                   else attr.default_value)
        elif attr.type in [
                ONNXAttributeType.Ints, ONNXAttributeType.Strings,
                ONNXAttributeType.Floats
        ]:
            attrs[name] = ListProperty(
                element_type=_ATTR_TYPE_TO_PYTHON_TYPE[attr.type],
                desc=attr.description,
                allow_none=True,
                default=None
                if attr.default_value is None else attr.default_value)
        elif attr.required:
            raise NotImplementedError(
                "Required attribute '{}' has an unsupported type".format(
                    attr.name))

    required_attrs = {
        name
        for name, attr in dace_schema.attributes.items() if attr.required
    }

    def __init__(self, name, *args, location=None, **op_attributes):
        super(ONNXOp, self).__init__(
            name,
Ejemplo n.º 13
0
class SubArray(object):
    """
    Sub-arrays describe subsets of Arrays (see `dace::data::Array`) for purposes of distributed communication. They are
    implemented with [MPI_Type_create_subarray](https://www.mpich.org/static/docs/v3.2/www3/MPI_Type_create_subarray.html).
    Sub-arrays can be also used for collective scatter/gather communication in a process-grid.

    The `shape`, `subshape`, and `dtype` properties correspond to the `array_of_sizes`, `array_of_subsizes`, and
    `oldtype` parameters of `MPI_Type_create_subarray`.

    The following properties are used for collective scatter/gather communication in a process-grid:

    The `pgrid` property is the name of the process-grid where the data will be distributed. The `correspondence`
    property matches the arrays dimensions to the process-grid's dimensions. For example, if one wants to distribute
    a matrix to a 2D process-grid, but tile the matrix rows over the grid's columns, then `correspondence = [1, 0]`.
    """

    name = Property(dtype=str, desc="The type's name.")
    dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses)
    shape = ShapeProperty(default=[], desc="The array's shape.")
    subshape = ShapeProperty(default=[], desc="The sub-array's shape.")
    pgrid = Property(
        dtype=str,
        allow_none=True,
        default=None,
        desc="Name of the process-grid where the data are distributed.")
    correspondence = ListProperty(
        int,
        allow_none=True,
        default=None,
        desc="Correspondence of the array's indices to the process grid's "
        "indices.")

    def __init__(self,
                 name: str,
                 dtype: dtypes.typeclass,
                 shape: ShapeType,
                 subshape: ShapeType,
                 pgrid: str = None,
                 correspondence: Sequence[Integral] = None):
        self.name = name
        self.dtype = dtype
        self.shape = shape
        self.subshape = subshape
        self.pgrid = pgrid
        self.correspondence = correspondence or list(range(len(shape)))
        self._validate()

    def validate(self):
        """ Validate the correctness of this object.
            Raises an exception on error. """
        self._validate()

    # Validation of this class is in a separate function, so that this
    # class can call `_validate()` without calling the subclasses'
    # `validate` function.
    def _validate(self):
        if any(not isinstance(s, (Integral, symbolic.SymExpr, symbolic.symbol,
                                  symbolic.sympy.Basic)) for s in self.shape):
            raise TypeError(
                'Shape must be a list or tuple of integer values or symbols')
        if any(not isinstance(s, (Integral, symbolic.SymExpr, symbolic.symbol,
                                  symbolic.sympy.Basic))
               for s in self.subshape):
            raise TypeError(
                'Sub-shape must be a list or tuple of integer values or symbols'
            )
        if any(not isinstance(i, Integral) for i in self.correspondence):
            raise TypeError(
                'Correspondence must be a list or tuple of integer values')
        if len(self.shape) != len(self.subshape):
            raise ValueError(
                'The dimensionality of the shape and sub-shape must match')
        if len(self.correspondence) != len(self.shape):
            raise ValueError(
                'The dimensionality of the shape and correspondence list must match'
            )
        return True

    def to_json(self):
        attrs = serialize.all_properties_to_json(self)
        retdict = {"type": type(self).__name__, "attributes": attrs}
        return retdict

    @classmethod
    def from_json(cls, json_obj, context=None):
        # Create dummy object
        ret = cls('tmp', dtypes.int8, [], [], 'tmp', [])
        serialize.set_properties_from_json(ret, json_obj, context=context)
        # Check validity now
        ret.validate()
        return ret

    def init_code(self):
        """ Outputs MPI allocation/initialization code for the sub-array MPI datatype ONLY if the process-grid is set.
            It is assumed that the following variables exist in the SDFG program's state:
            - MPI_Datatype {self.name}
            - int* {self.name}_counts
            - int* {self.name}_displs

            These variables are typically added to the program's state through a Tasklet, e.g., the Dummy MPI node (for
            more details, check the DaCe MPI library in `dace/libraries/mpi`).
        """
        from dace.libraries.mpi import utils
        if self.pgrid:
            return f"""
                if (__state->{self.pgrid}_valid) {{
                    int sizes[{len(self.shape)}] = {{{', '.join([str(s) for s in self.shape])}}};
                    int subsizes[{len(self.shape)}] = {{{', '.join([str(s) for s in self.subshape])}}};
                    int corr[{len(self.shape)}] = {{{', '.join([str(i) for i in self.correspondence])}}};

                    int basic_stride = subsizes[{len(self.shape)} - 1];

                    int process_strides[{len(self.shape)}];
                    int block_strides[{len(self.shape)}];
                    int data_strides[{len(self.shape)}];

                    process_strides[{len(self.shape)} - 1] = 1;
                    block_strides[{len(self.shape)} - 1] = subsizes[{len(self.shape)} - 1];
                    data_strides[{len(self.shape)} - 1] = 1;

                    for (auto i = {len(self.shape)} - 2; i >= 0; --i) {{
                        block_strides[i] = block_strides[i+1] * subsizes[i];
                        process_strides[i] = process_strides[i+1] * __state->{self.pgrid}_dims[corr[i+1]];
                        data_strides[i] = block_strides[i] * process_strides[i] / basic_stride;
                    }}

                    MPI_Datatype type;
                    int origin[{len(self.shape)}] = {{{','.join(['0'] * len(self.shape))}}};
                    MPI_Type_create_subarray({len(self.shape)}, sizes, subsizes, origin, MPI_ORDER_C, {utils.MPI_DDT(self.dtype.base_type)}, &type);
                    MPI_Type_create_resized(type, 0, basic_stride*sizeof({self.dtype.ctype}), &__state->{self.name});
                    MPI_Type_commit(&__state->{self.name});

                    __state->{self.name}_counts = new int[__state->{self.pgrid}_size];
                    __state->{self.name}_displs = new int[__state->{self.pgrid}_size];
                    int block_id[{len(self.shape)}] = {{0}};
                    int displ = 0;
                    for (auto i = 0; i < __state->{self.pgrid}_size; ++i) {{
                        __state->{self.name}_counts[i] = 1;
                        __state->{self.name}_displs[i] = displ;
                        int idx = {len(self.shape)} - 1;
                        while (idx >= 0 && block_id[idx] + 1 >= __state->{self.pgrid}_dims[corr[idx]]) {{
                            block_id[idx] = 0;
                            displ -= data_strides[idx] * (__state->{self.pgrid}_dims[corr[idx]] - 1);
                            idx--;
                        }}
                        if (idx >= 0) {{ 
                            block_id[idx] += 1;
                            displ += data_strides[idx];
                        }} else {{
                            assert(i == __state->{self.pgrid}_size - 1);
                        }}
                    }}
                }}
            """
        else:
            return ""

    def exit_code(self):
        """ Outputs MPI deallocation code for the sub-array MPI datatype ONLY if the process-grid is set. """
        if self.pgrid:
            return f"""
                if (__state->{self.pgrid}_valid) {{
                    delete[] __state->{self.name}_counts;
                    delete[] __state->{self.name}_displs;
                    MPI_Type_free(&__state->{self.name});
                }}
            """
        else:
            return ""
Ejemplo n.º 14
0
class ProcessGrid(object):
    """
    Process-grids implement cartesian topologies similarly to cartesian communicators created with [MPI_Cart_create](https://www.mpich.org/static/docs/latest/www3/MPI_Cart_create.html)
    and [MPI_Cart_sub](https://www.mpich.org/static/docs/v3.2/www3/MPI_Cart_sub.html).

    The boolean property `is_subgrid` provides a switch between "parent" process-grids (equivalent to communicators
    create with `MPI_Cart_create`) and sub-grids (equivalent to communicators created with `MPI_Cart_sub`).
    
    If `is_subgrid` is false, a "parent" process-grid is created. The `shape` property is equivalent to the `dims`
    parameter of `MPI_Cart_create`. The other properties are ignored. All "parent" process-grids spawn out of
    `MPI_COMM_WORLD`, while their `periods` and `reorder` parameters are set to False.

    If `is_subgrid` is true, then the `parent_grid` is partitioned to lower-dimensional cartesian sub-grids (for more
    details, see the documentation of `MPI_Cart_sub`). The `parent_grid` property is equivalent to the `comm` parameter
    of `MPI_Cart_sub`. The `color` property corresponds to the `remain_dims` parameter of `MPI_Cart_sub`, i.e., the i-th
    entry specifies whether the i-th dimension is kep in the sub-grid or is dropped.
    
    The following properties store information used in the redistribution of data:

    The `exact_grid` property is either None or the rank of an MPI process in the `parent_grid`. If set then, out of all
    the sub-grids created, only the one that contains this rank is used for collective communication. The `root`
    property is used to select the root rank for purposed of collective communication (by default 0).
    """

    name = Property(dtype=str, desc="The process-grid's name.")
    is_subgrid = Property(
        dtype=bool,
        default=False,
        desc="If true, spanws sub-grids out of the parent process-grid.")
    shape = ShapeProperty(default=[], desc="The process-grid's shape.")
    parent_grid = Property(
        dtype=str,
        allow_none=True,
        default=None,
        desc="Name of the parent process-grid "
        "(mandatory if `is_subgrid` is true, otherwise ignored).")
    color = ListProperty(
        int,
        allow_none=True,
        default=None,
        desc=
        "The i-th entry specifies whether the i-th dimension is kept in the sub-grid or is "
        "dropped (mandatory if `is_subgrid` is true, otherwise ignored).")
    exact_grid = SymbolicProperty(
        allow_none=True,
        default=None,
        desc=
        "If set then, out of all the sub-grids created, only the one that contains the "
        "rank with id `exact_grid` will be utilized for collective communication "
        "(optional if `is_subgrid` is true, otherwise ignored).")
    root = SymbolicProperty(default=0,
                            desc="The root rank for collective communication.")

    def __init__(self,
                 name: str,
                 is_subgrid: bool,
                 shape: ShapeType = None,
                 parent_grid: str = None,
                 color: Sequence[Union[Integral, bool]] = None,
                 exact_grid: RankType = None,
                 root: RankType = 0):
        self.name = name
        self.is_subgrid = is_subgrid
        if is_subgrid:
            self.parent_grid = parent_grid.name
            self.color = color
            self.exact_grid = exact_grid
            self.shape = [
                parent_grid.shape[i] for i, remain in enumerate(color)
                if remain
            ]
        else:
            self.shape = shape
        self.root = root
        self._validate()

    def validate(self):
        """ Validate the correctness of this object.
            Raises an exception on error. """
        self._validate()

    # Validation of this class is in a separate function, so that this
    # class can call `_validate()` without calling the subclasses'
    # `validate` function.
    def _validate(self):
        if self.is_subgrid:
            if not self.parent_grid or len(self.parent_grid) == 0:
                raise ValueError(
                    'Sub-grid misses its corresponding parent process-grid')
        if any(not isinstance(s, (Integral, symbolic.SymExpr, symbolic.symbol,
                                  symbolic.sympy.Basic)) for s in self.shape):
            raise TypeError(
                'Shape must be a list or tuple of integer values or symbols')
        if self.color and any(c < 0 or c > 1 for c in self.color):
            raise ValueError(
                'Color must have only logical true (1) or false (0) values.')
        return True

    def to_json(self):
        attrs = serialize.all_properties_to_json(self)
        retdict = {"type": type(self).__name__, "attributes": attrs}
        return retdict

    @classmethod
    def from_json(cls, json_obj, context=None):
        # Create dummy object
        ret = cls('tmp', False, [])
        serialize.set_properties_from_json(ret, json_obj, context=context)
        # Check validity now
        ret.validate()
        return ret

    def init_code(self):
        """ Outputs MPI allocation/initialization code for the process-grid.
            It is assumed that the following variables exist in the SDFG program's state:
            - MPI_Comm {self.name}_comm
            - MPI_Group {self.name}_group
            - int {self.name}_rank
            - int {self.name}_size
            - int* {self.name}_dims
            - int* {self.name}_remain
            - int* {self.name}_coords
            - bool {self.name})_valid

            These variables are typically added to the program's state through a Tasklet, e.g., the Dummy MPI node (for
            more details, check the DaCe MPI library in `dace/libraries/mpi`).

        """
        if self.is_subgrid:
            tmp = ""
            for i, s in enumerate(self.shape):
                tmp += f"__state->{self.name}_dims[{i}] = {s};\n"
            tmp += f"""
                __state->{self.name}_valid = false;
                if (__state->{self.parent_grid}_valid) {{
                    int {self.name}_remain[{len(self.color)}] = {{{', '.join(['1' if c else '0' for c in self.color])}}};
                    MPI_Cart_sub(__state->{self.parent_grid}_comm, {self.name}_remain, &__state->{self.name}_comm);
                    MPI_Comm_group(__state->{self.name}_comm, &__state->{self.name}_group);
                    MPI_Comm_rank(__state->{self.name}_comm, &__state->{self.name}_rank);
                    MPI_Comm_size(__state->{self.name}_comm, &__state->{self.name}_size);
                    MPI_Cart_coords(__state->{self.name}_comm, __state->{self.name}_rank, {len(self.shape)}, __state->{self.name}_coords);
            """
            if self.exact_grid is not None:
                tmp += f"""
                    int ranks1[1] = {{{self.exact_grid}}};
                    int ranks2[1];
                    MPI_Group_translate_ranks(__state->{self.parent_grid}_group, 1, ranks1, __state->{self.name}_group, ranks2);
                    __state->{self.name}_valid = (ranks2[0] != MPI_PROC_NULL && ranks2[0] != MPI_UNDEFINED);
                }}
                """
            else:
                tmp += f"""
                    __state->{self.name}_valid = true;
                }}
                """
            return tmp
        else:
            tmp = ""
            for i, s in enumerate(self.shape):
                tmp += f"__state->{self.name}_dims[{i}] = {s};\n"
            tmp += f"""
                int {self.name}_periods[{len(self.shape)}] = {{0}};
                MPI_Cart_create(MPI_COMM_WORLD, {len(self.shape)}, __state->{self.name}_dims, {self.name}_periods, 0, &__state->{self.name}_comm);
                if (__state->{self.name}_comm != MPI_COMM_NULL) {{
                    MPI_Comm_group(__state->{self.name}_comm, &__state->{self.name}_group);
                    MPI_Comm_rank(__state->{self.name}_comm, &__state->{self.name}_rank);
                    MPI_Comm_size(__state->{self.name}_comm, &__state->{self.name}_size);
                    MPI_Cart_coords(__state->{self.name}_comm, __state->{self.name}_rank, {len(self.shape)}, __state->{self.name}_coords);
                    __state->{self.name}_valid = true;
                }} else {{
                    __state->{self.name}_group = MPI_GROUP_NULL;
                    __state->{self.name}_rank = MPI_PROC_NULL;
                    __state->{self.name}_size = 0;
                    __state->{self.name}_valid = false;
                }}
            """
            return tmp

    def exit_code(self):
        """ Outputs MPI deallocation code for the process-grid. """
        return f"""