def standardize_dtype_dict(dtypes): """Standardizes the dtype dict as it can be specified for the stencil test suites. In the input dictionary, a selection of possible dtypes or just a single dtype can be specified for a set of fields or a single field. This function makes sure that all keys are tuples (by wrapping single field names and single dtypes as 1-tuples) """ assert isinstance(dtypes, collections.abc.Mapping) assert all((isinstance(k, str) or gt_utils.is_iterable_of(k, str)) for k in dtypes.keys()), "Invalid key in 'dtypes'." assert all((isinstance(k, type) or gt_utils.is_iterable_of(k, type) or gt_utils.is_iterable_of(k, np.dtype)) for k in dtypes.values()), "Invalid dtype in 'dtypes'" result = {} for key, value in dtypes.items(): if isinstance(key, str): key = (key, ) else: key = (*key, ) if isinstance(value, type): value = (value, ) else: value = (*value, ) result[key] = value for key, value in result.items(): result[key] = [np.dtype(dt) for dt in value] keys = [k for t in result.keys() for k in t] if not len(keys) == len(set(keys)): raise ValueError("Any field can be in only one group.") return result
def test_normalize_shape(): from gt4py.storage.utils import normalize_shape assert normalize_shape(None) is None assert gt_utils.is_iterable_of(normalize_shape([1, 2, 3]), iterable_class=tuple, item_class=int) # test that exceptions are raised for invalid inputs. try: normalize_shape("1") except TypeError: pass else: assert False try: normalize_shape(1) except TypeError: pass else: assert False try: normalize_shape((0, )) except ValueError: pass else: assert False
def normalize_default_origin(default_origin, mask=None): check_mask(mask) if default_origin is None: return None if mask is None: mask = (True, ) * len(default_origin) if sum(mask) != len(default_origin) and len(mask) != len(default_origin): raise ValueError( "len(default_origin) must be equal to len(mask) or the number of 'True' entries in mask." ) if not gt_util.is_iterable_of(default_origin, numbers.Integral): raise TypeError( "default_origin must be a tuple of ints or pairs of ints.") if any(o < 0 for o in default_origin): raise ValueError("default_origin ({}) contains negative value.".format( default_origin)) new_default_origin = list(default_origin) if sum(mask) < len(default_origin): new_default_origin = [ h for i, h in enumerate(new_default_origin) if mask[i] ] return tuple(new_default_origin)
def normalize_shape(shape, mask=None): check_mask(mask) if shape is None: return None if mask is None: mask = (True, ) * len(shape) if sum(mask) != len(shape) and len(mask) != len(shape): raise ValueError( "len(shape) must be equal to len(mask) or the number of 'True' entries in mask." ) if not gt_util.is_iterable_of(shape, numbers.Integral): raise TypeError("shape must be a tuple of ints or pairs of ints.") if any(o <= 0 for o in shape): raise ValueError( "shape ({}) contains non-positive value.".format(shape)) new_shape = list(shape) if sum(mask) < len(shape): new_shape = [int(h) for i, h in enumerate(new_shape) if mask[i]] return tuple(new_shape)
def test_normalize_default_origin(): from gt4py.storage.utils import normalize_shape assert normalize_shape(None) is None assert gt_utils.is_iterable_of( gt_storage_utils.normalize_default_origin([1, 2, 3]), iterable_class=gt_ir.Index, item_class=int, ) # test that exceptions are raised for invalid inputs. try: gt_storage_utils.normalize_default_origin("1") except TypeError: pass else: assert False try: gt_storage_utils.normalize_default_origin(1) except TypeError: pass else: assert False try: gt_storage_utils.normalize_default_origin((-1, )) except ValueError: pass else: assert False
def _validate_new_args(cls, cls_name, bases, cls_dict): missing_members = cls.required_members - cls_dict.keys() if len(missing_members) > 0: raise TypeError( "Missing {missing} required members in '{name}' definition". format(missing=missing_members, name=cls_name)) # Check class dict domain_range = cls_dict["domain_range"] backends = cls_dict["backends"] # Create testing strategies assert isinstance(cls_dict["symbols"], collections.abc.Mapping), "Invalid 'symbols' mapping" # Check domain and ndims assert 1 <= len(domain_range) <= 3 and all( len(d) == 2 for d in domain_range), "Invalid 'domain_range' definition" if any(cls_name.endswith(suffix) for suffix in ("1D", "2D", "3D")): assert cls_dict["ndims"] == int( cls_name[-2:-1] ), "Suite name does not match the actual 'ndims'" # Check dtypes assert isinstance( cls_dict["dtypes"], (collections.abc.Sequence, collections.abc.Mapping )), "'dtypes' must be a sequence or a mapping object" # Check backends assert gt_utils.is_iterable_of( backends, str), "'backends' must be a sequence of strings" for b in backends: assert b in gt.backend.REGISTRY.names, "backend '{backend}' not supported".format( backend=b) # Check definition and validation functions if not isinstance(cls_dict["definition"], types.FunctionType): raise TypeError( "The 'definition' attribute must be a stencil definition function" ) if not isinstance(cls_dict["validation"], types.FunctionType): raise TypeError( "The 'validation' attribute must be a validation function")
def check_mask(mask): if not gt_util.is_iterable_of(mask, bool) and not mask is None: raise TypeError("Mask must be an iterable of booleans.")
def normalize_storage_spec(default_origin, shape, dtype, mask): """Normalize the fields of the storage spec in a homogeneous representation. Returns ------- tuple(default_origin, shape, dtype, mask) The output tuple fields verify the following semantics: - default_origin: tuple of ints with default origin values for the non-masked dimensions - shape: tuple of ints with shape values for the non-masked dimensions - dtype: scalar numpy.dtype (non-structured and without subarrays) - backend: backend identifier string (gtc:numpy, gtc:gt:cpu_kfirst, gtc:gpu, ...) - mask: a tuple of bools (at least 3d) """ if mask is None: mask = tuple(True if i < len(shape) else False for i in range(max(len(shape), 3))) elif not gt_util.is_iterable_of(mask, bool): # User-friendly axes specification (e.g. "IJK" or gtscript.IJK) str_kind = "".join( str(i) for i in mask) if isinstance(mask, (tuple, list)) else str(mask) axes_set = set(str_kind) if axes_set - {"I", "J", "K"}: raise ValueError( f"Invalid axes names in mask specification: '{mask}'") if len(axes_set) != len(str_kind): raise ValueError( f"Repeated axes names in mask specification: '{mask}'") mask = ("I" in axes_set, "J" in axes_set, "K" in axes_set) elif len(mask) < 3 or not sum(mask): raise ValueError(f"Invalid mask definition: '{mask}'") assert len(mask) >= 3 if shape is not None: if not gt_util.is_iterable_of(shape, numbers.Integral): raise TypeError("shape must be an iterable of ints.") if len(shape) not in (sum(mask), len(mask)): raise ValueError( f"Mask ({mask}) and shape ({shape}) have non-matching sizes." f"len(shape)(={len(shape)}) must be equal to len(mask)(={len(mask)}) " f"or the number of 'True' entries in mask '{mask}'.") if sum(mask) < len(shape): shape = tuple(int(d) for i, d in enumerate(shape) if mask[i]) else: shape = tuple(shape) if any(i <= 0 for i in shape): raise ValueError(f"shape ({shape}) contains non-positive value.") else: raise TypeError("shape must be an iterable of ints.") if default_origin is not None: if not gt_util.is_iterable_of(default_origin, numbers.Integral): raise TypeError("default_origin must be an iterable of ints.") if len(default_origin) not in (sum(mask), len(mask)): raise ValueError( f"Mask ({mask}) and default_origin ({default_origin}) have non-matching sizes." f"len(default_origin)(={len(default_origin)}) must be equal to len(mask)(={len(mask)}) " f"or the number of 'True' entries in mask '{mask}'.") if sum(mask) < len(default_origin): default_origin = tuple(d for i, d in enumerate(default_origin) if mask[i]) else: default_origin = tuple(default_origin) if any(i < 0 for i in default_origin): raise ValueError( "default_origin ({}) contains negative value.".format( default_origin)) else: raise TypeError("default_origin must be an iterable of ints.") dtype = np.dtype(dtype) if dtype.shape: # Subarray dtype default_origin = (*default_origin, *((0, ) * dtype.ndim)) shape = (*shape, *(dtype.subdtype[1])) mask = (*mask, *((True, ) * dtype.ndim)) dtype = dtype.subdtype[0] return default_origin, shape, dtype, mask