Ejemplo n.º 1
0
def standardize_dtype_dict(dtypes):
    """Standardizes the dtype dict as it can be specified for the stencil test suites.

    In the input dictionary, a selection of possible dtypes or just a single dtype can be specified for a set of fields
    or a single field. This function makes sure that all keys are tuples (by wrapping single field names and single
    dtypes as 1-tuples)
    """
    assert isinstance(dtypes, collections.abc.Mapping)
    assert all((isinstance(k, str) or gt_utils.is_iterable_of(k, str))
               for k in dtypes.keys()), "Invalid key in 'dtypes'."
    assert all((isinstance(k, type) or gt_utils.is_iterable_of(k, type)
                or gt_utils.is_iterable_of(k, np.dtype))
               for k in dtypes.values()), "Invalid dtype in 'dtypes'"

    result = {}
    for key, value in dtypes.items():
        if isinstance(key, str):
            key = (key, )
        else:
            key = (*key, )
        if isinstance(value, type):
            value = (value, )
        else:
            value = (*value, )
        result[key] = value

    for key, value in result.items():
        result[key] = [np.dtype(dt) for dt in value]

    keys = [k for t in result.keys() for k in t]
    if not len(keys) == len(set(keys)):
        raise ValueError("Any field can be in only one group.")
    return result
Ejemplo n.º 2
0
def test_normalize_shape():
    from gt4py.storage.utils import normalize_shape

    assert normalize_shape(None) is None
    assert gt_utils.is_iterable_of(normalize_shape([1, 2, 3]),
                                   iterable_class=tuple,
                                   item_class=int)

    # test that exceptions are raised for invalid inputs.
    try:
        normalize_shape("1")
    except TypeError:
        pass
    else:
        assert False

    try:
        normalize_shape(1)
    except TypeError:
        pass
    else:
        assert False

    try:
        normalize_shape((0, ))
    except ValueError:
        pass
    else:
        assert False
Ejemplo n.º 3
0
def normalize_default_origin(default_origin, mask=None):

    check_mask(mask)

    if default_origin is None:
        return None
    if mask is None:
        mask = (True, ) * len(default_origin)

    if sum(mask) != len(default_origin) and len(mask) != len(default_origin):
        raise ValueError(
            "len(default_origin) must be equal to len(mask) or the number of 'True' entries in mask."
        )

    if not gt_util.is_iterable_of(default_origin, numbers.Integral):
        raise TypeError(
            "default_origin must be a tuple of ints or pairs of ints.")
    if any(o < 0 for o in default_origin):
        raise ValueError("default_origin ({}) contains negative value.".format(
            default_origin))

    new_default_origin = list(default_origin)
    if sum(mask) < len(default_origin):
        new_default_origin = [
            h for i, h in enumerate(new_default_origin) if mask[i]
        ]

    return tuple(new_default_origin)
Ejemplo n.º 4
0
def normalize_shape(shape, mask=None):

    check_mask(mask)

    if shape is None:
        return None
    if mask is None:
        mask = (True, ) * len(shape)

    if sum(mask) != len(shape) and len(mask) != len(shape):
        raise ValueError(
            "len(shape) must be equal to len(mask) or the number of 'True' entries in mask."
        )

    if not gt_util.is_iterable_of(shape, numbers.Integral):
        raise TypeError("shape must be a tuple of ints or pairs of ints.")
    if any(o <= 0 for o in shape):
        raise ValueError(
            "shape ({}) contains non-positive value.".format(shape))

    new_shape = list(shape)
    if sum(mask) < len(shape):
        new_shape = [int(h) for i, h in enumerate(new_shape) if mask[i]]

    return tuple(new_shape)
Ejemplo n.º 5
0
def test_normalize_default_origin():
    from gt4py.storage.utils import normalize_shape

    assert normalize_shape(None) is None
    assert gt_utils.is_iterable_of(
        gt_storage_utils.normalize_default_origin([1, 2, 3]),
        iterable_class=gt_ir.Index,
        item_class=int,
    )

    # test that exceptions are raised for invalid inputs.
    try:
        gt_storage_utils.normalize_default_origin("1")
    except TypeError:
        pass
    else:
        assert False

    try:
        gt_storage_utils.normalize_default_origin(1)
    except TypeError:
        pass
    else:
        assert False

    try:
        gt_storage_utils.normalize_default_origin((-1, ))
    except ValueError:
        pass
    else:
        assert False
Ejemplo n.º 6
0
    def _validate_new_args(cls, cls_name, bases, cls_dict):
        missing_members = cls.required_members - cls_dict.keys()
        if len(missing_members) > 0:
            raise TypeError(
                "Missing {missing} required members in '{name}' definition".
                format(missing=missing_members, name=cls_name))
        # Check class dict
        domain_range = cls_dict["domain_range"]
        backends = cls_dict["backends"]

        # Create testing strategies
        assert isinstance(cls_dict["symbols"],
                          collections.abc.Mapping), "Invalid 'symbols' mapping"

        # Check domain and ndims
        assert 1 <= len(domain_range) <= 3 and all(
            len(d) == 2
            for d in domain_range), "Invalid 'domain_range' definition"

        if any(cls_name.endswith(suffix) for suffix in ("1D", "2D", "3D")):
            assert cls_dict["ndims"] == int(
                cls_name[-2:-1]
            ), "Suite name does not match the actual 'ndims'"

        # Check dtypes
        assert isinstance(
            cls_dict["dtypes"],
            (collections.abc.Sequence, collections.abc.Mapping
             )), "'dtypes' must be a sequence or a mapping object"

        # Check backends
        assert gt_utils.is_iterable_of(
            backends, str), "'backends' must be a sequence of strings"
        for b in backends:
            assert b in gt.backend.REGISTRY.names, "backend '{backend}' not supported".format(
                backend=b)

        # Check definition and validation functions
        if not isinstance(cls_dict["definition"], types.FunctionType):
            raise TypeError(
                "The 'definition' attribute must be a stencil definition function"
            )
        if not isinstance(cls_dict["validation"], types.FunctionType):
            raise TypeError(
                "The 'validation' attribute must be a validation function")
Ejemplo n.º 7
0
def check_mask(mask):
    if not gt_util.is_iterable_of(mask, bool) and not mask is None:
        raise TypeError("Mask must be an iterable of booleans.")
Ejemplo n.º 8
0
def normalize_storage_spec(default_origin, shape, dtype, mask):
    """Normalize the fields of the storage spec in a homogeneous representation.

    Returns
    -------

    tuple(default_origin, shape, dtype, mask)
        The output tuple fields verify the following semantics:

            - default_origin: tuple of ints with default origin values for the non-masked dimensions
            - shape: tuple of ints with shape values for the non-masked dimensions
            - dtype: scalar numpy.dtype (non-structured and without subarrays)
            - backend: backend identifier string (gtc:numpy, gtc:gt:cpu_kfirst, gtc:gpu, ...)
            - mask: a tuple of bools (at least 3d)
    """

    if mask is None:
        mask = tuple(True if i < len(shape) else False
                     for i in range(max(len(shape), 3)))
    elif not gt_util.is_iterable_of(mask, bool):
        # User-friendly axes specification (e.g. "IJK" or gtscript.IJK)
        str_kind = "".join(
            str(i) for i in mask) if isinstance(mask,
                                                (tuple, list)) else str(mask)
        axes_set = set(str_kind)
        if axes_set - {"I", "J", "K"}:
            raise ValueError(
                f"Invalid axes names in mask specification: '{mask}'")
        if len(axes_set) != len(str_kind):
            raise ValueError(
                f"Repeated axes names in mask specification: '{mask}'")
        mask = ("I" in axes_set, "J" in axes_set, "K" in axes_set)
    elif len(mask) < 3 or not sum(mask):
        raise ValueError(f"Invalid mask definition: '{mask}'")

    assert len(mask) >= 3

    if shape is not None:
        if not gt_util.is_iterable_of(shape, numbers.Integral):
            raise TypeError("shape must be an iterable of ints.")
        if len(shape) not in (sum(mask), len(mask)):
            raise ValueError(
                f"Mask ({mask}) and shape ({shape}) have non-matching sizes."
                f"len(shape)(={len(shape)}) must be equal to len(mask)(={len(mask)}) "
                f"or the number of 'True' entries in mask '{mask}'.")

        if sum(mask) < len(shape):
            shape = tuple(int(d) for i, d in enumerate(shape) if mask[i])
        else:
            shape = tuple(shape)

        if any(i <= 0 for i in shape):
            raise ValueError(f"shape ({shape}) contains non-positive value.")
    else:
        raise TypeError("shape must be an iterable of ints.")

    if default_origin is not None:
        if not gt_util.is_iterable_of(default_origin, numbers.Integral):
            raise TypeError("default_origin must be an iterable of ints.")
        if len(default_origin) not in (sum(mask), len(mask)):
            raise ValueError(
                f"Mask ({mask}) and default_origin ({default_origin}) have non-matching sizes."
                f"len(default_origin)(={len(default_origin)}) must be equal to len(mask)(={len(mask)}) "
                f"or the number of 'True' entries in mask '{mask}'.")

        if sum(mask) < len(default_origin):
            default_origin = tuple(d for i, d in enumerate(default_origin)
                                   if mask[i])
        else:
            default_origin = tuple(default_origin)

        if any(i < 0 for i in default_origin):
            raise ValueError(
                "default_origin ({}) contains negative value.".format(
                    default_origin))
    else:
        raise TypeError("default_origin must be an iterable of ints.")

    dtype = np.dtype(dtype)
    if dtype.shape:
        # Subarray dtype
        default_origin = (*default_origin, *((0, ) * dtype.ndim))
        shape = (*shape, *(dtype.subdtype[1]))
        mask = (*mask, *((True, ) * dtype.ndim))
        dtype = dtype.subdtype[0]

    return default_origin, shape, dtype, mask