Example #1
0
 def __array_finalize__(self, obj):
     if obj is None:
         # constructor called previously
         return
     else:
         if self.base is None:
             # case np.array
             raise RuntimeError(
                 "Copying storages is only possible through Storage.copy() or deepcopy."
             )
         else:
             if not isinstance(obj, Storage) and not isinstance(
                     obj, _ViewableNdarray):
                 raise RuntimeError(
                     "Meta information can not be inferred when creating Storage views from other classes than Storage."
                 )
             self.__dict__ = {**obj.__dict__, **self.__dict__}
             self.is_stencil_view = False
             if hasattr(obj, "_new_index"):
                 index_iter = itertools.chain(
                     obj._new_index, [slice(None, None)] *
                     (len(obj.mask) - len(obj._new_index)))
                 interpolated_mask = gt_utils.interpolate_mask(
                     (isinstance(x, slice) for x in index_iter), obj.mask,
                     False)
                 self._mask = tuple(
                     x & y for x, y in zip(obj.mask, interpolated_mask))
                 delattr(obj, "_new_index")
             if not hasattr(obj, "default_origin"):
                 self.is_stencil_view = True
             elif self._is_consistent(obj):
                 self.is_stencil_view = obj.is_stencil_view
             self._finalize_view(obj)
Example #2
0
 def from_mask(cls, seq, mask, default=None):
     if default is None:
         default = cls._DEFAULT
     return cls(gt_utils.interpolate_mask(seq, mask, default))
Example #3
0
class StencilObject(abc.ABC):
    """Generic singleton implementation of a stencil callable.

    This class is used as base class for specific subclass generated
    at run-time for any stencil definition and a unique set of external symbols.
    Instances of this class do not contain state and thus it is
    implemented as a singleton: only one instance per subclass is actually
    allocated (and it is immutable).

    The callable interface is the same of the stencil definition function,
    with some extra keyword arguments.

    Keyword Arguments
    ------------------
    domain : `Sequence` of `int`, optional
        Shape of the computation domain. If `None`, it will be used the
        largest feasible domain according to the provided input fields
        and origin values (`None` by default).

    origin :  `[int * ndims]` or `{'field_name': [int * ndims]}`, optional
        If a single offset is passed, it will be used for all fields.
        If a `dict` is passed, there could be an entry for each field.
        A special key *'_all_'* will represent the value to be used for all
        the fields not explicitly defined. If `None` is passed or it is
        not possible to assign a value to some field according to the
        previous rule, the value will be inferred from the `boundary` attribute
        of the `field_info` dict. Note that the function checks if the origin values
        are at least equal to the `boundary` attribute of that field,
        so a 0-based origin will only be acceptable for fields with
        a 0-area support region.

    exec_info : `dict`, optional
        Dictionary used to store information about the stencil execution.
        (`None` by default). If the dictionary contains the magic key
        '__aggregate_data' and it evaluates to `True`, the dictionary is
        populated with a nested dictionary per class containing different
        performance statistics. These include the stencil calls count, the
        cumulative time spent in all stencil calls, and the actual time spent
        in carrying out the computations.

    """

    # Those attributes are added to the class at loading time:
    _gt_id_: str
    definition_func: Callable[..., Any]

    _domain_origin_cache: ClassVar[Dict[int, Tuple[Tuple[int, ...],
                                                   Dict[str, Tuple[int,
                                                                   ...]]]]]
    """Stores domain/origin pairs that have been used by hash."""
    def __new__(cls, *args, **kwargs):
        if getattr(cls, "_instance", None) is None:
            cls._instance = object.__new__(cls)
            cls._domain_origin_cache = {}
        return cls._instance

    def __setattr__(self, key, value) -> None:
        raise AttributeError(
            "Attempting a modification of an attribute in a frozen class")

    def __delattr__(self, item) -> None:
        raise AttributeError(
            "Attempting a deletion of an attribute in a frozen class")

    def __eq__(self, other) -> bool:
        return type(self) == type(other)

    def __str__(self) -> str:
        result = """
<StencilObject: {name}> [backend="{backend}"]
    - I/O fields: {fields}
    - Parameters: {params}
    - Constants: {constants}
    - Version: {version}
    - Definition ({func}):
{source}
        """.format(
            name=self.options["module"] + "." + self.options["name"],
            version=self._gt_id_,
            backend=self.backend,
            fields=self.field_info,
            params=self.parameter_info,
            constants=self.constants,
            func=self.definition_func,
            source=self.source,
        )

        return result

    def __hash__(self) -> int:
        return int.from_bytes(type(self)._gt_id_.encode(), byteorder="little")

    @property
    @abc.abstractmethod
    def backend(self) -> str:
        pass

    @property
    @abc.abstractmethod
    def source(self) -> str:
        pass

    @property
    @abc.abstractmethod
    def domain_info(self) -> DomainInfo:
        pass

    @property
    @abc.abstractmethod
    def field_info(self) -> Dict[str, FieldInfo]:
        pass

    @property
    @abc.abstractmethod
    def parameter_info(self) -> Dict[str, ParameterInfo]:
        pass

    @property
    @abc.abstractmethod
    def constants(self) -> Dict[str, Any]:
        pass

    @property
    @abc.abstractmethod
    def options(self) -> Dict[str, Any]:
        pass

    @abc.abstractmethod
    def run(self, *args, **kwargs) -> None:
        pass

    @abc.abstractmethod
    def __call__(self, *args, **kwargs) -> None:
        pass

    @staticmethod
    def _make_origin_dict(origin: Any) -> Dict[str, Index]:
        try:
            if isinstance(origin, dict):
                return dict(origin)
            if origin is None:
                return {}
            if isinstance(origin, collections.abc.Iterable):
                return {"_all_": Index.from_value(origin)}
            if isinstance(origin, int):
                return {"_all_": Index.from_k(origin)}
        except Exception:
            pass

        raise ValueError("Invalid 'origin' value ({})".format(origin))

    def _get_max_domain(
        self,
        field_args: Dict[str, Any],
        origin: Dict[str, Tuple[int, ...]],
        *,
        squeeze: bool = True,
    ) -> Shape:
        """Return the maximum domain size possible.

        Parameters
        ----------
            field_args:
                Mapping from field names to actually passed data arrays.
            origin:
                The origin for each field.
            squeeze:
                Convert non-used domain dimensions to singleton dimensions.

        Returns
        -------
            `Shape`: the maximum domain size.
        """
        domain_ndim = self.domain_info.ndim
        max_size = sys.maxsize
        max_domain = Shape([max_size] * domain_ndim)

        for name, field_info in self.field_info.items():
            if field_info.access != AccessKind.NONE:
                assert field_args.get(
                    name,
                    None) is not None, f"Invalid value for '{name}' field."
                field = field_args[name]
                api_domain_mask = field_info.domain_mask
                api_domain_ndim = field_info.domain_ndim
                upper_indices = field_info.boundary.upper_indices.filter_mask(
                    api_domain_mask)
                field_origin = Index.from_value(origin[name])
                field_domain = tuple(field.shape[i] -
                                     (field_origin[i] + upper_indices[i])
                                     for i in range(api_domain_ndim))
                max_domain &= Shape.from_mask(field_domain,
                                              api_domain_mask,
                                              default=max_size)

        if squeeze:
            return Shape([i if i != max_size else 1 for i in max_domain])
        else:
            return max_domain

    def _validate_args(  # noqa: C901  # Function is too complex
        self,
        field_args: Dict[str, FieldType],
        param_args: Dict[str, Any],
        domain: Tuple[int, ...],
        origin: Dict[str, Tuple[int, ...]],
    ) -> None:
        """
        Validate input arguments to _call_run.

        Raises
        -------
            ValueError
                If invalid data or inconsistent options are specified.

            TypeError
                If an incorrect field or parameter data type is passed.
        """
        assert isinstance(field_args, dict) and isinstance(param_args, dict)

        # validate domain sizes
        domain_ndim = self.domain_info.ndim
        if len(domain) != domain_ndim:
            raise ValueError(f"Invalid 'domain' value '{domain}'")

        try:
            domain = Shape(domain)
        except Exception:
            raise ValueError("Invalid 'domain' value ({})".format(domain))

        if not domain > Shape.zeros(domain_ndim):
            raise ValueError(f"Compute domain contains zero sizes '{domain}')")

        if not domain <= (max_domain := self._get_max_domain(
                field_args, origin, squeeze=False)):
            raise ValueError(
                f"Compute domain too large (provided: {domain}, maximum: {max_domain})"
            )

        if domain[2] < self.domain_info.min_sequential_axis_size:
            raise ValueError(
                f"Compute domain too small. Sequential axis is {domain[2]}, but must be at least {self.domain_info.min_sequential_axis_size}."
            )

        # assert compatibility of fields with stencil
        for name, field_info in self.field_info.items():
            if field_info.access != AccessKind.NONE:
                if name not in field_args:
                    raise ValueError(f"Missing value for '{name}' field.")
                field = field_args[name]

                if not gt_backend.from_name(
                        self.backend).storage_info["is_compatible_layout"](
                            field):
                    raise ValueError(
                        f"The layout of the field {name} is not compatible with the backend."
                    )

                if not gt_backend.from_name(
                        self.backend).storage_info["is_compatible_type"](
                            field):
                    raise ValueError(
                        f"Field '{name}' has type '{type(field)}', which is not compatible with the '{self.backend}' backend."
                    )
                elif type(field) is np.ndarray:
                    warnings.warn(
                        "NumPy ndarray passed as field. This is discouraged and only works with constraints and only for certain backends.",
                        RuntimeWarning,
                    )

                field_dtype = self.field_info[name].dtype
                if not field.dtype == field_dtype:
                    raise TypeError(
                        f"The dtype of field '{name}' is '{field.dtype}' instead of '{field_dtype}'"
                    )

                if isinstance(field, gt_storage.storage.Storage
                              ) and not field.is_stencil_view:
                    raise ValueError(
                        f"An incompatible view was passed for field {name} to the stencil. "
                    )

                # Check: domain + halo vs field size
                field_info = self.field_info[name]
                field_domain_mask = field_info.domain_mask
                field_domain_ndim = field_info.domain_ndim
                field_domain_origin = Index.from_mask(
                    origin[name], field_domain_mask[:domain_ndim])

                if field.ndim != field_domain_ndim + len(field_info.data_dims):
                    raise ValueError(
                        f"Storage for '{name}' has {field.ndim} dimensions but the API signature "
                        f"expects {field_domain_ndim + len(field_info.data_dims)} ('{field_info.axes}[{field_info.data_dims}]')"
                    )

                if (isinstance(field, gt_storage.storage.Storage) and
                        tuple(field.mask)[:domain_ndim] != field_domain_mask):
                    raise ValueError(
                        f"Storage for '{name}' has domain mask '{field.mask}' but the API signature "
                        f"expects '[{', '.join(field_info.axes)}]'")

                # Check: data dimensions shape
                if field.shape[field_domain_ndim:] != field_info.data_dims:
                    raise ValueError(
                        f"Field '{name}' expects data dimensions {field_info.data_dims} but got {field.shape[field_domain_ndim:]}"
                    )

                min_origin = gt_utils.interpolate_mask(
                    field_info.boundary.lower_indices.filter_mask(
                        field_domain_mask),
                    field_domain_mask,
                    default=0,
                )
                if field_domain_origin < min_origin:
                    raise ValueError(
                        f"Origin for field {name} too small. Must be at least {min_origin}, is {field_domain_origin}"
                    )

                spatial_domain = typing.cast(
                    Shape, domain).filter_mask(field_domain_mask)
                lower_indices = field_info.boundary.lower_indices.filter_mask(
                    field_domain_mask)
                upper_indices = field_info.boundary.upper_indices.filter_mask(
                    field_domain_mask)
                min_shape = tuple(lb + d + ub for lb, d, ub in zip(
                    lower_indices, spatial_domain, upper_indices))
                if min_shape > field.shape:
                    raise ValueError(
                        f"Shape of field {name} is {field.shape} but must be at least {min_shape} for given domain and origin."
                    )

        # assert compatibility of parameters with stencil
        for name, parameter_info in self.parameter_info.items():
            if parameter_info.access != AccessKind.NONE:
                if name not in param_args:
                    raise ValueError(f"Missing value for '{name}' parameter.")
                elif type(
                        parameter := param_args[name]) != parameter_info.dtype:
                    raise TypeError(
                        f"The type of parameter '{name}' is '{type(parameter)}' instead of '{parameter_info.dtype}'"
                    )
Example #4
0
    def _run_test_implementation(cls, parameters_dict, implementation):  # noqa: C901  # too complex
        input_data, exec_info = parameters_dict

        origin = cls.origin
        max_boundary = Boundary(cls.max_boundary)
        field_params = cls.field_params
        field_masks = {}
        for name, value in input_data.items():
            if isinstance(value, np.ndarray):
                field_masks[name] = tuple(
                    ax in field_params[name][0] for ax in CartesianSpace.names
                )

        data_shape = Shape((sys.maxsize,) * 3)
        for name, data in input_data.items():
            if isinstance(data, np.ndarray):
                data_shape &= Shape(
                    interpolate_mask(data.shape, field_masks[name], default=sys.maxsize)
                )

        domain = data_shape - (
            Index(max_boundary.lower_indices) + Index(max_boundary.upper_indices)
        )

        referenced_inputs = {
            name: info for name, info in implementation.field_info.items() if info is not None
        }
        referenced_inputs.update(
            {name: info for name, info in implementation.parameter_info.items() if info is not None}
        )

        # set externals for validation method
        for k, v in implementation.constants.items():
            sys.modules[cls.__module__].__dict__[k] = v

        # copy input data
        test_values = {}
        validation_values = {}
        for name, data in input_data.items():
            data = input_data[name]
            if name in referenced_inputs:
                info = referenced_inputs[name]
                if isinstance(info, FieldInfo):
                    data_dims = field_params[name][1]
                    if data_dims:
                        dtype = (data.dtype, data_dims)
                        shape = data.shape[: -len(data_dims)]
                    else:
                        dtype = data.dtype
                        shape = data.shape
                    test_values[name] = gt_storage.from_array(
                        data,
                        dtype=dtype,
                        shape=shape,
                        mask=field_masks[name],
                        default_origin=origin,
                        backend=implementation.backend,
                    )
                    validation_values[name] = np.array(data)
                else:
                    test_values[name] = data
                    validation_values[name] = data
            else:
                test_values[name] = None
                validation_values[name] = None

        # call implementation
        implementation(**test_values, origin=origin, exec_info=exec_info)
        assert domain == exec_info["domain"]

        # for validation data, data is cropped to actually touched domain, so that origin offseting
        # does not have to be implemented for every test suite. This is done based on info
        # specified in test suite
        cropped_validation_values = {}
        for name, data in validation_values.items():
            sym = cls.symbols[name]
            if data is not None and sym.kind == SymbolKind.FIELD:
                field_extent_low = tuple(b[0] for b in sym.boundary)
                offset_low = tuple(b[0] - e for b, e in zip(max_boundary, field_extent_low))
                field_extent_high = tuple(b[1] for b in sym.boundary)
                offset_high = tuple(b[1] - e for b, e in zip(max_boundary, field_extent_high))
                validation_slice = filter_mask(
                    tuple(slice(o, s - h) for o, s, h in zip(offset_low, data_shape, offset_high)),
                    field_masks[name],
                )
                data_dims = field_params[name][1]
                if data_dims:
                    validation_slice = tuple([*validation_slice] + [slice(None)] * len(data_dims))
                cropped_validation_values[name] = data[validation_slice]
            else:
                cropped_validation_values[name] = data

        cls.validation(
            **cropped_validation_values,
            domain=domain,
            origin={
                name: info.boundary.lower_indices
                for name, info in implementation.field_info.items()
                if info is not None
            },
        )

        # Test values
        for name, value in test_values.items():
            if isinstance(value, np.ndarray):
                expected_value = validation_values[name]

                if gt_backend.from_name(value.backend).storage_info["device"] == "gpu":
                    value.synchronize()
                    value = value.data.get()
                else:
                    value = value.data

                np.testing.assert_allclose(
                    value,
                    expected_value,
                    rtol=RTOL,
                    atol=ATOL,
                    equal_nan=EQUAL_NAN,
                    err_msg="Wrong data in output field '{name}'".format(name=name),
                )
Example #5
0
class StencilObject(abc.ABC):
    """Generic singleton implementation of a stencil function.

    This class is used as base class for the specific subclass generated
    at run-time for any stencil definition and a unique set of external symbols.
    Instances of this class do not contain any information and thus it is
    implemented as a singleton: only one instance per subclass is actually
    allocated (and it is immutable).
    """

    def __new__(cls, *args, **kwargs):
        if getattr(cls, "_instance", None) is None:
            cls._instance = object.__new__(cls)
        return cls._instance

    def __setattr__(self, key, value) -> None:
        raise AttributeError("Attempting a modification of an attribute in a frozen class")

    def __delattr__(self, item) -> None:
        raise AttributeError("Attempting a deletion of an attribute in a frozen class")

    def __eq__(self, other) -> bool:
        return type(self) == type(other)

    def __str__(self) -> str:
        result = """
<StencilObject: {name}> [backend="{backend}"]
    - I/O fields: {fields}
    - Parameters: {params}
    - Constants: {constants}
    - Definition ({func}):
{source}
        """.format(
            name=self.options["module"] + "." + self.options["name"],
            version=self._gt_id_,
            backend=self.backend,
            fields=self.field_info,
            params=self.parameter_info,
            constants=self.constants,
            func=self.definition_func,
            source=self.source,
        )

        return result

    def __hash__(self) -> int:
        return int.from_bytes(type(self)._gt_id_.encode(), byteorder="little")

    # Those attributes are added to the class at loading time:
    #
    #   _gt_id_ (stencil_id.version)
    #   definition_func

    @property
    @abc.abstractmethod
    def backend(self) -> str:
        pass

    @property
    @abc.abstractmethod
    def source(self) -> str:
        pass

    @property
    @abc.abstractmethod
    def domain_info(self) -> DomainInfo:
        pass

    @property
    @abc.abstractmethod
    def field_info(self) -> Dict[str, FieldInfo]:
        pass

    @property
    @abc.abstractmethod
    def parameter_info(self) -> Dict[str, ParameterInfo]:
        pass

    @property
    @abc.abstractmethod
    def constants(self) -> Dict[str, Any]:
        pass

    @property
    @abc.abstractmethod
    def options(self) -> Dict[str, Any]:
        pass

    @abc.abstractmethod
    def run(self, *args, **kwargs) -> None:
        pass

    @abc.abstractmethod
    def __call__(self, *args, **kwargs) -> None:
        pass

    @staticmethod
    def _make_origin_dict(origin: Any) -> Dict[str, Index]:
        try:
            if isinstance(origin, dict):
                return origin
            if origin is None:
                return {}
            if isinstance(origin, collections.abc.Iterable):
                return {"_all_": Index.from_value(origin)}
            if isinstance(origin, int):
                return {"_all_": Index.from_k(origin)}
        except:
            pass

        raise ValueError("Invalid 'origin' value ({})".format(origin))

    def _get_max_domain(
        self,
        field_args: Dict[str, Any],
        origin: Dict[str, Tuple[int, ...]],
        *,
        squeeze: bool = True,
    ) -> Shape:
        """Return the maximum domain size possible

        Parameters
        ----------
            field_args:
                Mapping from field names to actually passed data arrays.
            origin:
                The origin for each field.
            squeeze:
                Convert non-used domain dimensions to singleton dimensions.

        Returns
        -------
            `Shape`: the maximum domain size.
        """
        domain_ndim = self.domain_info.ndim
        max_size = sys.maxsize
        max_domain = Shape([max_size] * domain_ndim)

        for name, field_info in self.field_info.items():
            if field_info is not None:
                assert field_args.get(name, None) is not None, f"Invalid value for '{name}' field."
                field = field_args[name]
                api_domain_mask = field_info.domain_mask
                api_domain_ndim = field_info.domain_ndim
                assert (
                    not isinstance(field, gt_storage.storage.Storage)
                    or tuple(field.mask)[:domain_ndim] == api_domain_mask
                ), (
                    f"Storage for '{name}' has domain mask '{field.mask}' but the API signature "
                    f"expects '[{', '.join(field_info.axes)}]'"
                )
                upper_indices = field_info.boundary.upper_indices.filter_mask(api_domain_mask)
                field_origin = Index.from_value(origin[name])
                field_domain = tuple(
                    field.shape[i] - (field_origin[i] + upper_indices[i])
                    for i in range(api_domain_ndim)
                )
                max_domain &= Shape.from_mask(field_domain, api_domain_mask, default=max_size)

        if squeeze:
            return Shape([i if i != max_size else 1 for i in max_domain])
        else:
            return max_domain

    def _validate_args(self, field_args, param_args, domain, origin) -> None:
        """Validate input arguments to _call_run.

        Raises
        -------
            ValueError
                If invalid data or inconsistent options are specified.

            TypeError
                If an incorrect field or parameter data type is passed.
        """

        assert isinstance(field_args, dict) and isinstance(param_args, dict)

        # validate domain sizes
        domain_ndim = self.domain_info.ndim
        if len(domain) != domain_ndim:
            raise ValueError(f"Invalid 'domain' value '{domain}'")

        try:
            domain = Shape(domain)
        except:
            raise ValueError("Invalid 'domain' value ({})".format(domain))

        if not domain > Shape.zeros(domain_ndim):
            raise ValueError(f"Compute domain contains zero sizes '{domain}')")

        if not domain <= (max_domain := self._get_max_domain(field_args, origin, squeeze=False)):
            raise ValueError(
                f"Compute domain too large (provided: {domain}, maximum: {max_domain})"
            )

        # assert compatibility of fields with stencil
        for name, field_info in self.field_info.items():
            if field_info is not None:
                if name not in field_args:
                    raise ValueError(f"Missing value for '{name}' field.")
                field = field_args[name]

                if not gt_backend.from_name(self.backend).storage_info["is_compatible_layout"](
                    field
                ):
                    raise ValueError(
                        f"The layout of the field {name} is not compatible with the backend."
                    )

                if not gt_backend.from_name(self.backend).storage_info["is_compatible_type"](field):
                    raise ValueError(
                        f"Field '{name}' has type '{type(field)}', which is not compatible with the '{self.backend}' backend."
                    )
                elif type(field) is np.ndarray:
                    warnings.warn(
                        "NumPy ndarray passed as field. This is discouraged and only works with constraints and only for certain backends.",
                        RuntimeWarning,
                    )

                field_dtype = self.field_info[name].dtype
                if not field.dtype == field_dtype:
                    raise TypeError(
                        f"The dtype of field '{name}' is '{field.dtype}' instead of '{field_dtype}'"
                    )

                if isinstance(field, gt_storage.storage.Storage) and not field.is_stencil_view:
                    raise ValueError(
                        f"An incompatible view was passed for field {name} to the stencil. "
                    )

                # Check: domain + halo vs field size
                field_info = self.field_info[name]
                field_domain_mask = field_info.domain_mask
                field_domain_ndim = field_info.domain_ndim
                field_domain_origin = Index.from_mask(origin[name], field_domain_mask[:domain_ndim])

                if field.ndim != field_domain_ndim + len(field_info.data_dims):
                    raise ValueError(
                        f"Storage for '{name}' has {field.ndim} dimensions but the API signature "
                        f"expects {field_domain_ndim + len(field_info.data_dims)} ('{field_info.axes}[{field_info.data_dims}]')"
                    )

                min_origin = gt_utils.interpolate_mask(
                    field_info.boundary.lower_indices.filter_mask(field_domain_mask),
                    field_domain_mask,
                    default=0,
                )
                if field_domain_origin < min_origin:
                    raise ValueError(
                        f"Origin for field {name} too small. Must be at least {min_origin}, is {field_domain_origin}"
                    )

                spatial_domain = domain.filter_mask(field_domain_mask)
                upper_indices = field_info.boundary.upper_indices.filter_mask(field_domain_mask)
                min_shape = tuple(
                    o + d + h for o, d, h in zip(field_domain_origin, spatial_domain, upper_indices)
                )
                if min_shape > field.shape:
                    raise ValueError(
                        f"Shape of field {name} is {field.shape} but must be at least {min_shape} for given domain and origin."
                    )

        # assert compatibility of parameters with stencil
        for name, parameter_info in self.parameter_info.items():
            if parameter_info is not None:
                if name not in param_args:
                    raise ValueError(f"Missing value for '{name}' parameter.")
                if not type(parameter := param_args[name]) == self.parameter_info[name].dtype:
                    raise TypeError(
                        f"The type of parameter '{name}' is '{type(parameter)}' instead of '{self.parameter_info[name].dtype}'"
                    )