def _make_origin_dict(origin: Any) -> Dict[str, Index]: try: if isinstance(origin, dict): return dict(origin) if origin is None: return {} if isinstance(origin, collections.abc.Iterable): return {"_all_": Index.from_value(origin)} if isinstance(origin, int): return {"_all_": Index.from_k(origin)} except Exception: pass raise ValueError("Invalid 'origin' value ({})".format(origin))
def normalize_default_origin( default_origin: Optional[Sequence[int]], mask: Optional[Sequence[bool]] = None ) -> Optional[Index]: check_mask(mask) if default_origin is None: return None if mask is None: mask = (True,) * len(default_origin) if sum(mask) != len(default_origin) and len(mask) != len(default_origin): raise ValueError( "len(default_origin) must be equal to len(mask) or the number of 'True' entries in mask." ) if not gt_util.is_iterable_of(default_origin, numbers.Integral): raise TypeError("default_origin must be a tuple of ints or pairs of ints.") if any(o < 0 for o in default_origin): raise ValueError("default_origin ({}) contains negative value.".format(default_origin)) new_default_origin = list(default_origin) if sum(mask) < len(default_origin): new_default_origin = [h for i, h in enumerate(new_default_origin) if mask[i]] return Index(new_default_origin)
def _get_max_domain( self, field_args: Dict[str, Any], origin: Dict[str, Tuple[int, ...]], *, squeeze: bool = True, ) -> Shape: """Return the maximum domain size possible Parameters ---------- field_args: Mapping from field names to actually passed data arrays. origin: The origin for each field. squeeze: Convert non-used domain dimensions to singleton dimensions. Returns ------- `Shape`: the maximum domain size. """ domain_ndim = self.domain_info.ndim max_size = sys.maxsize max_domain = Shape([max_size] * domain_ndim) for name, field_info in self.field_info.items(): if field_info is not None: assert field_args.get( name, None) is not None, f"Invalid value for '{name}' field." field = field_args[name] api_domain_mask = field_info.domain_mask api_domain_ndim = field_info.domain_ndim assert ( not isinstance(field, gt_storage.storage.Storage) or tuple(field.mask)[:domain_ndim] == api_domain_mask ), (f"Storage for '{name}' has domain mask '{field.mask}' but the API signature " f"expects '[{', '.join(field_info.axes)}]'") upper_indices = field_info.boundary.upper_indices.filter_mask( api_domain_mask) field_origin = Index.from_value(origin[name]) field_domain = tuple(field.shape[i] - (field_origin[i] + upper_indices[i]) for i in range(api_domain_ndim)) max_domain &= Shape.from_mask(field_domain, api_domain_mask, default=max_size) if squeeze: return Shape([i if i != max_size else 1 for i in max_domain]) else: return max_domain
def _get_max_domain(self, field_args, origin): """Return the maximum domain size possible Parameters ---------- field_args: `dict` Mapping from field names to actually passed data arrays. origin: `{'field_name': [int * ndims]}` The origin for each field. Returns ------- `Shape`: the maximum domain size. """ max_domain = Shape([np.iinfo(np.uintc).max] * self.domain_info.ndims) shapes = {name: Shape(field.shape) for name, field in field_args.items()} for name, shape in shapes.items(): upper_boundary = Index(self.field_info[name].boundary.upper_indices) max_domain &= shape - (Index(origin[name]) + upper_boundary) return max_domain
def _get_max_domain( self, field_args: Dict[str, Any], origin: Dict[str, Tuple[int, ...]], *, squeeze: bool = True, ) -> Shape: """Return the maximum domain size possible. Parameters ---------- field_args: Mapping from field names to actually passed data arrays. origin: The origin for each field. squeeze: Convert non-used domain dimensions to singleton dimensions. Returns ------- `Shape`: the maximum domain size. """ domain_ndim = self.domain_info.ndim max_size = sys.maxsize max_domain = Shape([max_size] * domain_ndim) for name, field_info in self.field_info.items(): if field_info.access != AccessKind.NONE: assert field_args.get( name, None) is not None, f"Invalid value for '{name}' field." field = field_args[name] api_domain_mask = field_info.domain_mask api_domain_ndim = field_info.domain_ndim upper_indices = field_info.boundary.upper_indices.filter_mask( api_domain_mask) field_origin = Index.from_value(origin[name]) field_domain = tuple(field.shape[i] - (field_origin[i] + upper_indices[i]) for i in range(api_domain_ndim)) max_domain &= Shape.from_mask(field_domain, api_domain_mask, default=max_size) if squeeze: return Shape([i if i != max_size else 1 for i in max_domain]) else: return max_domain
# # GT4Py - GridTools4Py - GridTools for Python # # Copyright (c) 2014-2021, ETH Zurich # All rights reserved. # # This file is part the GT4Py project and the GridTools framework. # GT4Py is free software: you can redistribute it and/or modify it under # the terms of the GNU General Public License as published by the # Free Software Foundation, either version 3 of the License, or any later # version. See the LICENSE.txt file at the top-level directory of this # distribution for a copy of the license or check <https://www.gnu.org/licenses/>. # # SPDX-License-Identifier: GPL-3.0-or-later """ Implementation of the intermediate representations used in GT4Py. ----------- Definitions ----------- Empty Empty node value (`None` is a valid Python value) InvalidBranch Sentinel value for wrongly build conditional expressions Builtin enumeration (:class:`Builtin`) Named Python constants [`NONE`, `FALSE`, `TRUE`]
class StencilObject(abc.ABC): """Generic singleton implementation of a stencil callable. This class is used as base class for specific subclass generated at run-time for any stencil definition and a unique set of external symbols. Instances of this class do not contain state and thus it is implemented as a singleton: only one instance per subclass is actually allocated (and it is immutable). The callable interface is the same of the stencil definition function, with some extra keyword arguments. Keyword Arguments ------------------ domain : `Sequence` of `int`, optional Shape of the computation domain. If `None`, it will be used the largest feasible domain according to the provided input fields and origin values (`None` by default). origin : `[int * ndims]` or `{'field_name': [int * ndims]}`, optional If a single offset is passed, it will be used for all fields. If a `dict` is passed, there could be an entry for each field. A special key *'_all_'* will represent the value to be used for all the fields not explicitly defined. If `None` is passed or it is not possible to assign a value to some field according to the previous rule, the value will be inferred from the `boundary` attribute of the `field_info` dict. Note that the function checks if the origin values are at least equal to the `boundary` attribute of that field, so a 0-based origin will only be acceptable for fields with a 0-area support region. exec_info : `dict`, optional Dictionary used to store information about the stencil execution. (`None` by default). If the dictionary contains the magic key '__aggregate_data' and it evaluates to `True`, the dictionary is populated with a nested dictionary per class containing different performance statistics. These include the stencil calls count, the cumulative time spent in all stencil calls, and the actual time spent in carrying out the computations. """ # Those attributes are added to the class at loading time: _gt_id_: str definition_func: Callable[..., Any] _domain_origin_cache: ClassVar[Dict[int, Tuple[Tuple[int, ...], Dict[str, Tuple[int, ...]]]]] """Stores domain/origin pairs that have been used by hash.""" def __new__(cls, *args, **kwargs): if getattr(cls, "_instance", None) is None: cls._instance = object.__new__(cls) cls._domain_origin_cache = {} return cls._instance def __setattr__(self, key, value) -> None: raise AttributeError( "Attempting a modification of an attribute in a frozen class") def __delattr__(self, item) -> None: raise AttributeError( "Attempting a deletion of an attribute in a frozen class") def __eq__(self, other) -> bool: return type(self) == type(other) def __str__(self) -> str: result = """ <StencilObject: {name}> [backend="{backend}"] - I/O fields: {fields} - Parameters: {params} - Constants: {constants} - Version: {version} - Definition ({func}): {source} """.format( name=self.options["module"] + "." + self.options["name"], version=self._gt_id_, backend=self.backend, fields=self.field_info, params=self.parameter_info, constants=self.constants, func=self.definition_func, source=self.source, ) return result def __hash__(self) -> int: return int.from_bytes(type(self)._gt_id_.encode(), byteorder="little") @property @abc.abstractmethod def backend(self) -> str: pass @property @abc.abstractmethod def source(self) -> str: pass @property @abc.abstractmethod def domain_info(self) -> DomainInfo: pass @property @abc.abstractmethod def field_info(self) -> Dict[str, FieldInfo]: pass @property @abc.abstractmethod def parameter_info(self) -> Dict[str, ParameterInfo]: pass @property @abc.abstractmethod def constants(self) -> Dict[str, Any]: pass @property @abc.abstractmethod def options(self) -> Dict[str, Any]: pass @abc.abstractmethod def run(self, *args, **kwargs) -> None: pass @abc.abstractmethod def __call__(self, *args, **kwargs) -> None: pass @staticmethod def _make_origin_dict(origin: Any) -> Dict[str, Index]: try: if isinstance(origin, dict): return dict(origin) if origin is None: return {} if isinstance(origin, collections.abc.Iterable): return {"_all_": Index.from_value(origin)} if isinstance(origin, int): return {"_all_": Index.from_k(origin)} except Exception: pass raise ValueError("Invalid 'origin' value ({})".format(origin)) def _get_max_domain( self, field_args: Dict[str, Any], origin: Dict[str, Tuple[int, ...]], *, squeeze: bool = True, ) -> Shape: """Return the maximum domain size possible. Parameters ---------- field_args: Mapping from field names to actually passed data arrays. origin: The origin for each field. squeeze: Convert non-used domain dimensions to singleton dimensions. Returns ------- `Shape`: the maximum domain size. """ domain_ndim = self.domain_info.ndim max_size = sys.maxsize max_domain = Shape([max_size] * domain_ndim) for name, field_info in self.field_info.items(): if field_info.access != AccessKind.NONE: assert field_args.get( name, None) is not None, f"Invalid value for '{name}' field." field = field_args[name] api_domain_mask = field_info.domain_mask api_domain_ndim = field_info.domain_ndim upper_indices = field_info.boundary.upper_indices.filter_mask( api_domain_mask) field_origin = Index.from_value(origin[name]) field_domain = tuple(field.shape[i] - (field_origin[i] + upper_indices[i]) for i in range(api_domain_ndim)) max_domain &= Shape.from_mask(field_domain, api_domain_mask, default=max_size) if squeeze: return Shape([i if i != max_size else 1 for i in max_domain]) else: return max_domain def _validate_args( # noqa: C901 # Function is too complex self, field_args: Dict[str, FieldType], param_args: Dict[str, Any], domain: Tuple[int, ...], origin: Dict[str, Tuple[int, ...]], ) -> None: """ Validate input arguments to _call_run. Raises ------- ValueError If invalid data or inconsistent options are specified. TypeError If an incorrect field or parameter data type is passed. """ assert isinstance(field_args, dict) and isinstance(param_args, dict) # validate domain sizes domain_ndim = self.domain_info.ndim if len(domain) != domain_ndim: raise ValueError(f"Invalid 'domain' value '{domain}'") try: domain = Shape(domain) except Exception: raise ValueError("Invalid 'domain' value ({})".format(domain)) if not domain > Shape.zeros(domain_ndim): raise ValueError(f"Compute domain contains zero sizes '{domain}')") if not domain <= (max_domain := self._get_max_domain( field_args, origin, squeeze=False)): raise ValueError( f"Compute domain too large (provided: {domain}, maximum: {max_domain})" ) if domain[2] < self.domain_info.min_sequential_axis_size: raise ValueError( f"Compute domain too small. Sequential axis is {domain[2]}, but must be at least {self.domain_info.min_sequential_axis_size}." ) # assert compatibility of fields with stencil for name, field_info in self.field_info.items(): if field_info.access != AccessKind.NONE: if name not in field_args: raise ValueError(f"Missing value for '{name}' field.") field = field_args[name] if not gt_backend.from_name( self.backend).storage_info["is_compatible_layout"]( field): raise ValueError( f"The layout of the field {name} is not compatible with the backend." ) if not gt_backend.from_name( self.backend).storage_info["is_compatible_type"]( field): raise ValueError( f"Field '{name}' has type '{type(field)}', which is not compatible with the '{self.backend}' backend." ) elif type(field) is np.ndarray: warnings.warn( "NumPy ndarray passed as field. This is discouraged and only works with constraints and only for certain backends.", RuntimeWarning, ) field_dtype = self.field_info[name].dtype if not field.dtype == field_dtype: raise TypeError( f"The dtype of field '{name}' is '{field.dtype}' instead of '{field_dtype}'" ) if isinstance(field, gt_storage.storage.Storage ) and not field.is_stencil_view: raise ValueError( f"An incompatible view was passed for field {name} to the stencil. " ) # Check: domain + halo vs field size field_info = self.field_info[name] field_domain_mask = field_info.domain_mask field_domain_ndim = field_info.domain_ndim field_domain_origin = Index.from_mask( origin[name], field_domain_mask[:domain_ndim]) if field.ndim != field_domain_ndim + len(field_info.data_dims): raise ValueError( f"Storage for '{name}' has {field.ndim} dimensions but the API signature " f"expects {field_domain_ndim + len(field_info.data_dims)} ('{field_info.axes}[{field_info.data_dims}]')" ) if (isinstance(field, gt_storage.storage.Storage) and tuple(field.mask)[:domain_ndim] != field_domain_mask): raise ValueError( f"Storage for '{name}' has domain mask '{field.mask}' but the API signature " f"expects '[{', '.join(field_info.axes)}]'") # Check: data dimensions shape if field.shape[field_domain_ndim:] != field_info.data_dims: raise ValueError( f"Field '{name}' expects data dimensions {field_info.data_dims} but got {field.shape[field_domain_ndim:]}" ) min_origin = gt_utils.interpolate_mask( field_info.boundary.lower_indices.filter_mask( field_domain_mask), field_domain_mask, default=0, ) if field_domain_origin < min_origin: raise ValueError( f"Origin for field {name} too small. Must be at least {min_origin}, is {field_domain_origin}" ) spatial_domain = typing.cast( Shape, domain).filter_mask(field_domain_mask) lower_indices = field_info.boundary.lower_indices.filter_mask( field_domain_mask) upper_indices = field_info.boundary.upper_indices.filter_mask( field_domain_mask) min_shape = tuple(lb + d + ub for lb, d, ub in zip( lower_indices, spatial_domain, upper_indices)) if min_shape > field.shape: raise ValueError( f"Shape of field {name} is {field.shape} but must be at least {min_shape} for given domain and origin." ) # assert compatibility of parameters with stencil for name, parameter_info in self.parameter_info.items(): if parameter_info.access != AccessKind.NONE: if name not in param_args: raise ValueError(f"Missing value for '{name}' parameter.") elif type( parameter := param_args[name]) != parameter_info.dtype: raise TypeError( f"The type of parameter '{name}' is '{type(parameter)}' instead of '{parameter_info.dtype}'" )
def _call_run(self, field_args, parameter_args, domain, origin, exec_info=None): """Check and preprocess the provided arguments (called by :class:`StencilObject` subclasses). Note that this function will always try to expand simple parameter values to complete data structures by repeating the same value as many times as needed. Parameters ---------- field_args: `dict` Mapping from field names to actually passed data arrays. This parameter encapsulates `*args` in the actual stencil subclass by doing: `{input_name[i]: arg for i, arg in enumerate(args)}` parameter_args: `dict` Mapping from parameter names to actually passed parameter values. This parameter encapsulates `**kwargs` in the actual stencil subclass by doing: `{name: value for name, value in kwargs.items()}` domain : `Sequence` of `int`, optional Shape of the computation domain. If `None`, it will be used the largest feasible domain according to the provided input fields and origin values (`None` by default). origin : `[int * ndims]` or `{'field_name': [int * ndims]}`, optional If a single offset is passed, it will be used for all fields. If a `dict` is passed, there could be an entry for each field. A special key *'_all_'* will represent the value to be used for all the fields not explicitly defined. If `None` is passed or it is not possible to assign a value to some field according to the previous rule, the value will be inferred from the global boundaries of the field. Note that the function checks if the origin values are at least equal to the `global_border` attribute of that field, so a 0-based origin will only be acceptable for fields with a 0-area support region. exec_info : `dict`, optional Dictionary used to store information about the stencil execution. (`None` by default). Returns ------- `None` Raises ------- ValueError If invalid data or inconsistent options are specified. """ if exec_info is not None: exec_info["call_run_start_time"] = time.perf_counter() used_arg_fields = { name: field for name, field in field_args.items() if name in self.field_info and self.field_info[name] is not None } used_arg_params = { name: param for name, param in parameter_args.items() if name in self.parameter_info and self.parameter_info[name] is not None } for name, field_info in self.field_info.items(): if field_info is not None and field_args[name] is None: raise ValueError(f"Field '{name}' is None.") for name, parameter_info in self.parameter_info.items(): if parameter_info is not None and parameter_args[name] is None: raise ValueError(f"Parameter '{name}' is None.") # assert compatibility of fields with stencil for name, field in used_arg_fields.items(): if not gt_backend.from_name( self.backend).storage_info["is_compatible_layout"](field): raise ValueError( f"The layout of the field {name} is not compatible with the backend." ) if not gt_backend.from_name( self.backend).storage_info["is_compatible_type"](field): raise ValueError( f"Field '{name}' has type '{type(field)}', which is not compatible with the '{self.backend}' backend." ) elif type(field) is np.ndarray: warnings.warn( "NumPy ndarray passed as field. This is discouraged and only works with constraints and only for certain backends.", RuntimeWarning, ) if not field.dtype == self.field_info[name].dtype: raise TypeError( f"The dtype of field '{name}' is '{field.dtype}' instead of '{self.field_info[name].dtype}'" ) # ToDo: check if mask is correct: need mask info in stencil object. if isinstance(field, gt_storage.storage.Storage): if not field.is_stencil_view: raise ValueError( f"An incompatible view was passed for field {name} to the stencil. " ) for name_other, field_other in used_arg_fields.items(): if field_other.mask == field.mask: if not field_other.shape == field.shape: raise ValueError( f"The fields {name} and {name_other} have the same mask but different shapes." ) # assert compatibility of parameters with stencil for name, parameter in used_arg_params.items(): if not type(parameter) == self.parameter_info[name].dtype: raise TypeError( f"The type of parameter '{name}' is '{type(parameter)}' instead of '{self.parameter_info[name].dtype}'" ) assert isinstance(field_args, dict) and isinstance( parameter_args, dict) # Shapes shapes = {} for name, field in used_arg_fields.items(): shapes[name] = Shape(field.shape) # Origins if origin is None: origin = {} else: origin = normalize_origin_mapping(origin) for name, field in used_arg_fields.items(): origin.setdefault( name, origin["_all_"] if "_all_" in origin else field.default_origin) # Domain max_domain = Shape([sys.maxsize] * self.domain_info.ndims) for name, shape in shapes.items(): upper_boundary = Index( self.field_info[name].boundary.upper_indices) max_domain &= shape - (Index(origin[name]) + upper_boundary) if domain is None: domain = max_domain else: domain = normalize_domain(domain) if len(domain) != self.domain_info.ndims: raise ValueError(f"Invalid 'domain' value '{domain}'") # check domain+halo vs field size if not domain > Shape.zeros(self.domain_info.ndims): raise ValueError(f"Compute domain contains zero sizes '{domain}')") if not domain <= max_domain: raise ValueError( f"Compute domain too large (provided: {domain}, maximum: {max_domain})" ) for name, field in used_arg_fields.items(): min_origin = self.field_info[name].boundary.lower_indices if origin[name] < min_origin: raise ValueError( f"Origin for field {name} too small. Must be at least {min_origin}, is {origin[name]}" ) min_shape = tuple(o + d + h for o, d, h in zip( origin[name], domain, self.field_info[name].boundary.upper_indices)) if min_shape > field.shape: raise ValueError( f"Shape of field {name} is {field.shape} but must be at least {min_shape} for given domain and origin." ) self.run(**field_args, **parameter_args, _domain_=domain, _origin_=origin, exec_info=exec_info)
class StencilObject(abc.ABC): """Generic singleton implementation of a stencil function. This class is used as base class for the specific subclass generated at run-time for any stencil definition and a unique set of external symbols. Instances of this class do not contain any information and thus it is implemented as a singleton: only one instance per subclass is actually allocated (and it is immutable). """ def __new__(cls, *args, **kwargs): if getattr(cls, "_instance", None) is None: cls._instance = object.__new__(cls) return cls._instance def __setattr__(self, key, value) -> None: raise AttributeError("Attempting a modification of an attribute in a frozen class") def __delattr__(self, item) -> None: raise AttributeError("Attempting a deletion of an attribute in a frozen class") def __eq__(self, other) -> bool: return type(self) == type(other) def __str__(self) -> str: result = """ <StencilObject: {name}> [backend="{backend}"] - I/O fields: {fields} - Parameters: {params} - Constants: {constants} - Definition ({func}): {source} """.format( name=self.options["module"] + "." + self.options["name"], version=self._gt_id_, backend=self.backend, fields=self.field_info, params=self.parameter_info, constants=self.constants, func=self.definition_func, source=self.source, ) return result def __hash__(self) -> int: return int.from_bytes(type(self)._gt_id_.encode(), byteorder="little") # Those attributes are added to the class at loading time: # # _gt_id_ (stencil_id.version) # definition_func @property @abc.abstractmethod def backend(self) -> str: pass @property @abc.abstractmethod def source(self) -> str: pass @property @abc.abstractmethod def domain_info(self) -> DomainInfo: pass @property @abc.abstractmethod def field_info(self) -> Dict[str, FieldInfo]: pass @property @abc.abstractmethod def parameter_info(self) -> Dict[str, ParameterInfo]: pass @property @abc.abstractmethod def constants(self) -> Dict[str, Any]: pass @property @abc.abstractmethod def options(self) -> Dict[str, Any]: pass @abc.abstractmethod def run(self, *args, **kwargs) -> None: pass @abc.abstractmethod def __call__(self, *args, **kwargs) -> None: pass @staticmethod def _make_origin_dict(origin: Any) -> Dict[str, Index]: try: if isinstance(origin, dict): return origin if origin is None: return {} if isinstance(origin, collections.abc.Iterable): return {"_all_": Index.from_value(origin)} if isinstance(origin, int): return {"_all_": Index.from_k(origin)} except: pass raise ValueError("Invalid 'origin' value ({})".format(origin)) def _get_max_domain( self, field_args: Dict[str, Any], origin: Dict[str, Tuple[int, ...]], *, squeeze: bool = True, ) -> Shape: """Return the maximum domain size possible Parameters ---------- field_args: Mapping from field names to actually passed data arrays. origin: The origin for each field. squeeze: Convert non-used domain dimensions to singleton dimensions. Returns ------- `Shape`: the maximum domain size. """ domain_ndim = self.domain_info.ndim max_size = sys.maxsize max_domain = Shape([max_size] * domain_ndim) for name, field_info in self.field_info.items(): if field_info is not None: assert field_args.get(name, None) is not None, f"Invalid value for '{name}' field." field = field_args[name] api_domain_mask = field_info.domain_mask api_domain_ndim = field_info.domain_ndim assert ( not isinstance(field, gt_storage.storage.Storage) or tuple(field.mask)[:domain_ndim] == api_domain_mask ), ( f"Storage for '{name}' has domain mask '{field.mask}' but the API signature " f"expects '[{', '.join(field_info.axes)}]'" ) upper_indices = field_info.boundary.upper_indices.filter_mask(api_domain_mask) field_origin = Index.from_value(origin[name]) field_domain = tuple( field.shape[i] - (field_origin[i] + upper_indices[i]) for i in range(api_domain_ndim) ) max_domain &= Shape.from_mask(field_domain, api_domain_mask, default=max_size) if squeeze: return Shape([i if i != max_size else 1 for i in max_domain]) else: return max_domain def _validate_args(self, field_args, param_args, domain, origin) -> None: """Validate input arguments to _call_run. Raises ------- ValueError If invalid data or inconsistent options are specified. TypeError If an incorrect field or parameter data type is passed. """ assert isinstance(field_args, dict) and isinstance(param_args, dict) # validate domain sizes domain_ndim = self.domain_info.ndim if len(domain) != domain_ndim: raise ValueError(f"Invalid 'domain' value '{domain}'") try: domain = Shape(domain) except: raise ValueError("Invalid 'domain' value ({})".format(domain)) if not domain > Shape.zeros(domain_ndim): raise ValueError(f"Compute domain contains zero sizes '{domain}')") if not domain <= (max_domain := self._get_max_domain(field_args, origin, squeeze=False)): raise ValueError( f"Compute domain too large (provided: {domain}, maximum: {max_domain})" ) # assert compatibility of fields with stencil for name, field_info in self.field_info.items(): if field_info is not None: if name not in field_args: raise ValueError(f"Missing value for '{name}' field.") field = field_args[name] if not gt_backend.from_name(self.backend).storage_info["is_compatible_layout"]( field ): raise ValueError( f"The layout of the field {name} is not compatible with the backend." ) if not gt_backend.from_name(self.backend).storage_info["is_compatible_type"](field): raise ValueError( f"Field '{name}' has type '{type(field)}', which is not compatible with the '{self.backend}' backend." ) elif type(field) is np.ndarray: warnings.warn( "NumPy ndarray passed as field. This is discouraged and only works with constraints and only for certain backends.", RuntimeWarning, ) field_dtype = self.field_info[name].dtype if not field.dtype == field_dtype: raise TypeError( f"The dtype of field '{name}' is '{field.dtype}' instead of '{field_dtype}'" ) if isinstance(field, gt_storage.storage.Storage) and not field.is_stencil_view: raise ValueError( f"An incompatible view was passed for field {name} to the stencil. " ) # Check: domain + halo vs field size field_info = self.field_info[name] field_domain_mask = field_info.domain_mask field_domain_ndim = field_info.domain_ndim field_domain_origin = Index.from_mask(origin[name], field_domain_mask[:domain_ndim]) if field.ndim != field_domain_ndim + len(field_info.data_dims): raise ValueError( f"Storage for '{name}' has {field.ndim} dimensions but the API signature " f"expects {field_domain_ndim + len(field_info.data_dims)} ('{field_info.axes}[{field_info.data_dims}]')" ) min_origin = gt_utils.interpolate_mask( field_info.boundary.lower_indices.filter_mask(field_domain_mask), field_domain_mask, default=0, ) if field_domain_origin < min_origin: raise ValueError( f"Origin for field {name} too small. Must be at least {min_origin}, is {field_domain_origin}" ) spatial_domain = domain.filter_mask(field_domain_mask) upper_indices = field_info.boundary.upper_indices.filter_mask(field_domain_mask) min_shape = tuple( o + d + h for o, d, h in zip(field_domain_origin, spatial_domain, upper_indices) ) if min_shape > field.shape: raise ValueError( f"Shape of field {name} is {field.shape} but must be at least {min_shape} for given domain and origin." ) # assert compatibility of parameters with stencil for name, parameter_info in self.parameter_info.items(): if parameter_info is not None: if name not in param_args: raise ValueError(f"Missing value for '{name}' parameter.") if not type(parameter := param_args[name]) == self.parameter_info[name].dtype: raise TypeError( f"The type of parameter '{name}' is '{type(parameter)}' instead of '{self.parameter_info[name].dtype}'" )