示例#1
0
    def _validate_args(self, field_args, param_args, domain, origin) -> None:
        """Validate input arguments to _call_run.

        Raises
        -------
            ValueError
                If invalid data or inconsistent options are specified.

            TypeError
                If an incorrect field or parameter data type is passed.
        """

        assert isinstance(field_args, dict) and isinstance(param_args, dict)

        # validate domain sizes
        domain_ndim = self.domain_info.ndim
        if len(domain) != domain_ndim:
            raise ValueError(f"Invalid 'domain' value '{domain}'")

        try:
            domain = Shape(domain)
        except:
            raise ValueError("Invalid 'domain' value ({})".format(domain))

        if not domain > Shape.zeros(domain_ndim):
            raise ValueError(f"Compute domain contains zero sizes '{domain}')")

        if not domain <= (max_domain := self._get_max_domain(field_args, origin, squeeze=False)):
            raise ValueError(
                f"Compute domain too large (provided: {domain}, maximum: {max_domain})"
            )
示例#2
0
    def _get_max_domain(self, field_args, origin):
        """Return the maximum domain size possible

        Parameters
        ----------
            field_args: `dict`
                Mapping from field names to actually passed data arrays.

            origin: `{'field_name': [int * ndims]}`
                The origin for each field.


        Returns
        -------
            `Shape`: the maximum domain size.
        """
        large_val = np.iinfo(np.uintc).max
        max_domain = Shape([large_val] * self.domain_info.ndims)
        for name, field in field_args.items():
            api_mask = self._get_field_mask(name)
            if isinstance(field, gt_storage.storage.Storage):
                storage_mask = tuple(field.mask)
                if storage_mask != api_mask:
                    raise ValueError(
                        f"The storage for '{name}' has mask '{storage_mask}', but the API signature expects '{api_mask}'"
                    )
            upper_boundary = self.field_info[
                name].boundary.upper_indices.filter_mask(api_mask)
            field_domain = Shape(field.shape) - (origin[name] + upper_boundary)
            max_domain &= Shape.from_mask(field_domain,
                                          api_mask,
                                          default=large_val)
        return max_domain
示例#3
0
    def _get_max_domain(
        self,
        field_args: Dict[str, Any],
        origin: Dict[str, Tuple[int, ...]],
        *,
        squeeze: bool = True,
    ) -> Shape:
        """Return the maximum domain size possible

        Parameters
        ----------
            field_args:
                Mapping from field names to actually passed data arrays.
            origin:
                The origin for each field.
            squeeze:
                Convert non-used domain dimensions to singleton dimensions.

        Returns
        -------
            `Shape`: the maximum domain size.
        """
        domain_ndim = self.domain_info.ndim
        max_size = sys.maxsize
        max_domain = Shape([max_size] * domain_ndim)

        for name, field_info in self.field_info.items():
            if field_info is not None:
                assert field_args.get(
                    name,
                    None) is not None, f"Invalid value for '{name}' field."
                field = field_args[name]
                api_domain_mask = field_info.domain_mask
                api_domain_ndim = field_info.domain_ndim
                assert (
                    not isinstance(field, gt_storage.storage.Storage)
                    or tuple(field.mask)[:domain_ndim] == api_domain_mask
                ), (f"Storage for '{name}' has domain mask '{field.mask}' but the API signature "
                    f"expects '[{', '.join(field_info.axes)}]'")
                upper_indices = field_info.boundary.upper_indices.filter_mask(
                    api_domain_mask)
                field_origin = Index.from_value(origin[name])
                field_domain = tuple(field.shape[i] -
                                     (field_origin[i] + upper_indices[i])
                                     for i in range(api_domain_ndim))
                max_domain &= Shape.from_mask(field_domain,
                                              api_domain_mask,
                                              default=max_size)

        if squeeze:
            return Shape([i if i != max_size else 1 for i in max_domain])
        else:
            return max_domain
示例#4
0
    def _get_max_domain(
        self,
        field_args: Dict[str, Any],
        origin: Dict[str, Tuple[int, ...]],
        *,
        squeeze: bool = True,
    ) -> Shape:
        """Return the maximum domain size possible.

        Parameters
        ----------
            field_args:
                Mapping from field names to actually passed data arrays.
            origin:
                The origin for each field.
            squeeze:
                Convert non-used domain dimensions to singleton dimensions.

        Returns
        -------
            `Shape`: the maximum domain size.
        """
        domain_ndim = self.domain_info.ndim
        max_size = sys.maxsize
        max_domain = Shape([max_size] * domain_ndim)

        for name, field_info in self.field_info.items():
            if field_info.access != AccessKind.NONE:
                assert field_args.get(
                    name,
                    None) is not None, f"Invalid value for '{name}' field."
                field = field_args[name]
                api_domain_mask = field_info.domain_mask
                api_domain_ndim = field_info.domain_ndim
                upper_indices = field_info.boundary.upper_indices.filter_mask(
                    api_domain_mask)
                field_origin = Index.from_value(origin[name])
                field_domain = tuple(field.shape[i] -
                                     (field_origin[i] + upper_indices[i])
                                     for i in range(api_domain_ndim))
                max_domain &= Shape.from_mask(field_domain,
                                              api_domain_mask,
                                              default=max_size)

        if squeeze:
            return Shape([i if i != max_size else 1 for i in max_domain])
        else:
            return max_domain
示例#5
0
def normalize_shape(
    shape: Optional[Sequence[int]], mask: Optional[Sequence[bool]] = None
) -> Optional[Shape]:

    check_mask(mask)

    if shape is None:
        return None
    if mask is None:
        mask = (True,) * len(shape)

    if sum(mask) != len(shape) and len(mask) != len(shape):
        raise ValueError(
            "len(shape) must be equal to len(mask) or the number of 'True' entries in mask."
        )

    if not gt_util.is_iterable_of(shape, numbers.Integral):
        raise TypeError("shape must be a tuple of ints or pairs of ints.")
    if any(o <= 0 for o in shape):
        raise ValueError("shape ({}) contains non-positive value.".format(shape))

    new_shape = list(shape)
    if sum(mask) < len(shape):
        new_shape = [int(h) for i, h in enumerate(new_shape) if mask[i]]

    return Shape(new_shape)
示例#6
0
    def _get_max_domain(self, field_args, origin):
        """Return the maximum domain size possible

        Parameters
        ----------
            field_args: `dict`
                Mapping from field names to actually passed data arrays.

            origin: `{'field_name': [int * ndims]}`
                The origin for each field.


        Returns
        -------
            `Shape`: the maximum domain size.
        """
        max_domain = Shape([np.iinfo(np.uintc).max] * self.domain_info.ndims)
        shapes = {name: Shape(field.shape) for name, field in field_args.items()}
        for name, shape in shapes.items():
            upper_boundary = Index(self.field_info[name].boundary.upper_indices)
            max_domain &= shape - (Index(origin[name]) + upper_boundary)
        return max_domain
示例#7
0
    def test_implementation(self, test, parameters_dict):
        """Test computed values for implementations generated for all *backends* and *stencil suites*.

        The generated implementations are reused from previous tests by means of a
        :class:`utils.ImplementationsDB` instance shared at module scope.
        """
        # backend = "debug"

        cls = type(self)
        implementation_list = test["implementations"]
        if not implementation_list:
            pytest.skip(
                "Cannot perform validation tests, since there are no valid implementations."
            )
        for implementation in implementation_list:
            if not isinstance(implementation, StencilObject):
                raise RuntimeError(
                    "Wrong function got from implementations_db cache!")

            fields, exec_info = parameters_dict

            # Domain
            from gt4py.definitions import Shape
            from gt4py.ir.nodes import Index

            origin = cls.origin

            shapes = {}
            for name, field in [(k, v) for k, v in fields.items()
                                if isinstance(v, np.ndarray)]:
                shapes[name] = Shape(field.shape)
            max_domain = Shape([sys.maxsize] *
                               implementation.domain_info.ndims)
            for name, shape in shapes.items():
                upper_boundary = Index(
                    [b[1] for b in cls.symbols[name].boundary])
                max_domain &= shape - (Index(origin) + upper_boundary)
            domain = max_domain

            max_boundary = ((0, 0), (0, 0), (0, 0))
            for name, info in implementation.field_info.items():
                if isinstance(info, gt_definitions.FieldInfo):
                    max_boundary = tuple(
                        (max(m[0], abs(b[0])), max(m[1], b[1]))
                        for m, b in zip(max_boundary, info.boundary))

            new_boundary = tuple(
                (max(abs(b[0]), abs(mb[0])), max(abs(b[1]), abs(mb[1])))
                for b, mb in zip(cls.max_boundary, max_boundary))

            shape = None
            for name, field in fields.items():
                if isinstance(field, np.ndarray):
                    assert field.shape == (shape if shape is not None else
                                           field.shape)
                    shape = field.shape

            patched_origin = tuple(nb[0] for nb in new_boundary)
            patching_origin = tuple(po - o
                                    for po, o in zip(patched_origin, origin))
            patched_shape = tuple(nb[0] + nb[1] + d
                                  for nb, d in zip(new_boundary, domain))
            patching_slices = [
                slice(po, po + s) for po, s in zip(patching_origin, shape)
            ]

            for k, v in implementation.constants.items():
                sys.modules[self.__module__].__dict__[k] = v

            inputs = {}
            for k, f in fields.items():
                if isinstance(f, np.ndarray):
                    patched_f = np.empty(shape=patched_shape)
                    patched_f[patching_slices] = f
                    inputs[k] = gt_storage.from_array(
                        patched_f,
                        dtype=test["definition"].__annotations__[k],
                        shape=patched_f.shape,
                        default_origin=patched_origin,
                        backend=test["backend"],
                    )

                else:
                    inputs[k] = f
            validation_fields = {
                name: np.array(field, copy=True)
                for name, field in inputs.items()
            }

            implementation(**inputs,
                           origin=patched_origin,
                           exec_info=exec_info)
            domain = exec_info["domain"]

            validation_origins = {
                name: tuple(
                    nb[0] - g[0]
                    for nb, g in zip(new_boundary, cls.symbols[name].boundary))
                for name in implementation.field_info.keys()
            }

            validation_shapes = {
                name:
                tuple(d + g[0] + g[1]
                      for d, g in zip(domain, cls.symbols[name].boundary))
                for name in implementation.field_info.keys()
            }

            validation_field_views = {
                name: field[tuple(
                    slice(o, o + s) for o, s in zip(validation_origins[name],
                                                    validation_shapes[name]))]
                if name in implementation.field_info else field  # parameters
                for name, field in validation_fields.items()
            }
            cls.validation(
                **validation_field_views,
                domain=domain,
                origin={
                    name: tuple(b[0] for b in cls.symbols[name].boundary)
                    for name in validation_fields
                    if name in implementation.field_info
                },
            )

            # Test values
            for (name,
                 value), (expected_name,
                          expected_value) in zip(inputs.items(),
                                                 validation_fields.items()):
                if isinstance(fields[name], np.ndarray):
                    domain_slice = [
                        slice(new_boundary[d][0],
                              new_boundary[d][0] + domain[d])
                        for d in range(len(domain))
                    ]
                    np.testing.assert_allclose(
                        value.data[domain_slice],
                        expected_value[domain_slice],
                        rtol=RTOL,
                        atol=ATOL,
                        equal_nan=EQUAL_NAN,
                        err_msg="Wrong data in output field '{name}'".format(
                            name=name),
                    )
示例#8
0
    def _call_run(self,
                  field_args,
                  parameter_args,
                  domain,
                  origin,
                  exec_info=None):
        """Check and preprocess the provided arguments (called by :class:`StencilObject` subclasses).

        Note that this function will always try to expand simple parameter values to
        complete data structures by repeating the same value as many times as needed.

        Parameters
        ----------
            field_args: `dict`
                Mapping from field names to actually passed data arrays.
                This parameter encapsulates `*args` in the actual stencil subclass
                by doing: `{input_name[i]: arg for i, arg in enumerate(args)}`

            parameter_args: `dict`
                Mapping from parameter names to actually passed parameter values.
                This parameter encapsulates `**kwargs` in the actual stencil subclass
                by doing: `{name: value for name, value in kwargs.items()}`

            domain : `Sequence` of `int`, optional
                Shape of the computation domain. If `None`, it will be used the
                largest feasible domain according to the provided input fields
                and origin values (`None` by default).

            origin :  `[int * ndims]` or `{'field_name': [int * ndims]}`, optional
                If a single offset is passed, it will be used for all fields.
                If a `dict` is passed, there could be an entry for each field.
                A special key *'_all_'* will represent the value to be used for all
                the fields not explicitly defined. If `None` is passed or it is
                not possible to assign a value to some field according to the
                previous rule, the value will be inferred from the global boundaries
                of the field. Note that the function checks if the origin values
                are at least equal to the `global_border` attribute of that field,
                so a 0-based origin will only be acceptable for fields with
                a 0-area support region.

            exec_info : `dict`, optional
                Dictionary used to store information about the stencil execution.
                (`None` by default).

        Returns
        -------
            `None`

        Raises
        -------
            ValueError
                If invalid data or inconsistent options are specified.
        """

        if exec_info is not None:
            exec_info["call_run_start_time"] = time.perf_counter()
        used_arg_fields = {
            name: field
            for name, field in field_args.items()
            if name in self.field_info and self.field_info[name] is not None
        }
        used_arg_params = {
            name: param
            for name, param in parameter_args.items()
            if name in self.parameter_info
            and self.parameter_info[name] is not None
        }
        for name, field_info in self.field_info.items():
            if field_info is not None and field_args[name] is None:
                raise ValueError(f"Field '{name}' is None.")
        for name, parameter_info in self.parameter_info.items():
            if parameter_info is not None and parameter_args[name] is None:
                raise ValueError(f"Parameter '{name}' is None.")
        # assert compatibility of fields with stencil
        for name, field in used_arg_fields.items():
            if not gt_backend.from_name(
                    self.backend).storage_info["is_compatible_layout"](field):
                raise ValueError(
                    f"The layout of the field {name} is not compatible with the backend."
                )

            if not gt_backend.from_name(
                    self.backend).storage_info["is_compatible_type"](field):
                raise ValueError(
                    f"Field '{name}' has type '{type(field)}', which is not compatible with the '{self.backend}' backend."
                )
            elif type(field) is np.ndarray:
                warnings.warn(
                    "NumPy ndarray passed as field. This is discouraged and only works with constraints and only for certain backends.",
                    RuntimeWarning,
                )
            if not field.dtype == self.field_info[name].dtype:
                raise TypeError(
                    f"The dtype of field '{name}' is '{field.dtype}' instead of '{self.field_info[name].dtype}'"
                )
            # ToDo: check if mask is correct: need mask info in stencil object.

            if isinstance(field, gt_storage.storage.Storage):
                if not field.is_stencil_view:
                    raise ValueError(
                        f"An incompatible view was passed for field {name} to the stencil. "
                    )
                for name_other, field_other in used_arg_fields.items():
                    if field_other.mask == field.mask:
                        if not field_other.shape == field.shape:
                            raise ValueError(
                                f"The fields {name} and {name_other} have the same mask but different shapes."
                            )

        # assert compatibility of parameters with stencil
        for name, parameter in used_arg_params.items():
            if not type(parameter) == self.parameter_info[name].dtype:
                raise TypeError(
                    f"The type of parameter '{name}' is '{type(parameter)}' instead of '{self.parameter_info[name].dtype}'"
                )

        assert isinstance(field_args, dict) and isinstance(
            parameter_args, dict)

        # Shapes
        shapes = {}

        for name, field in used_arg_fields.items():
            shapes[name] = Shape(field.shape)
        # Origins
        if origin is None:
            origin = {}
        else:
            origin = normalize_origin_mapping(origin)
        for name, field in used_arg_fields.items():
            origin.setdefault(
                name,
                origin["_all_"] if "_all_" in origin else field.default_origin)

        # Domain
        max_domain = Shape([sys.maxsize] * self.domain_info.ndims)
        for name, shape in shapes.items():
            upper_boundary = Index(
                self.field_info[name].boundary.upper_indices)
            max_domain &= shape - (Index(origin[name]) + upper_boundary)

        if domain is None:
            domain = max_domain
        else:
            domain = normalize_domain(domain)
            if len(domain) != self.domain_info.ndims:
                raise ValueError(f"Invalid 'domain' value '{domain}'")

        # check domain+halo vs field size
        if not domain > Shape.zeros(self.domain_info.ndims):
            raise ValueError(f"Compute domain contains zero sizes '{domain}')")

        if not domain <= max_domain:
            raise ValueError(
                f"Compute domain too large (provided: {domain}, maximum: {max_domain})"
            )
        for name, field in used_arg_fields.items():
            min_origin = self.field_info[name].boundary.lower_indices
            if origin[name] < min_origin:
                raise ValueError(
                    f"Origin for field {name} too small. Must be at least {min_origin}, is {origin[name]}"
                )
            min_shape = tuple(o + d + h for o, d, h in zip(
                origin[name], domain,
                self.field_info[name].boundary.upper_indices))
            if min_shape > field.shape:
                raise ValueError(
                    f"Shape of field {name} is {field.shape} but must be at least {min_shape} for given domain and origin."
                )

        self.run(**field_args,
                 **parameter_args,
                 _domain_=domain,
                 _origin_=origin,
                 exec_info=exec_info)
示例#9
0
    def _run_test_implementation(cls, parameters_dict, implementation):  # noqa: C901  # too complex
        input_data, exec_info = parameters_dict

        origin = cls.origin
        max_boundary = Boundary(cls.max_boundary)
        field_params = cls.field_params
        field_masks = {}
        for name, value in input_data.items():
            if isinstance(value, np.ndarray):
                field_masks[name] = tuple(
                    ax in field_params[name][0] for ax in CartesianSpace.names
                )

        data_shape = Shape((sys.maxsize,) * 3)
        for name, data in input_data.items():
            if isinstance(data, np.ndarray):
                data_shape &= Shape(
                    interpolate_mask(data.shape, field_masks[name], default=sys.maxsize)
                )

        domain = data_shape - (
            Index(max_boundary.lower_indices) + Index(max_boundary.upper_indices)
        )

        referenced_inputs = {
            name: info for name, info in implementation.field_info.items() if info is not None
        }
        referenced_inputs.update(
            {name: info for name, info in implementation.parameter_info.items() if info is not None}
        )

        # set externals for validation method
        for k, v in implementation.constants.items():
            sys.modules[cls.__module__].__dict__[k] = v

        # copy input data
        test_values = {}
        validation_values = {}
        for name, data in input_data.items():
            data = input_data[name]
            if name in referenced_inputs:
                info = referenced_inputs[name]
                if isinstance(info, FieldInfo):
                    data_dims = field_params[name][1]
                    if data_dims:
                        dtype = (data.dtype, data_dims)
                        shape = data.shape[: -len(data_dims)]
                    else:
                        dtype = data.dtype
                        shape = data.shape
                    test_values[name] = gt_storage.from_array(
                        data,
                        dtype=dtype,
                        shape=shape,
                        mask=field_masks[name],
                        default_origin=origin,
                        backend=implementation.backend,
                    )
                    validation_values[name] = np.array(data)
                else:
                    test_values[name] = data
                    validation_values[name] = data
            else:
                test_values[name] = None
                validation_values[name] = None

        # call implementation
        implementation(**test_values, origin=origin, exec_info=exec_info)
        assert domain == exec_info["domain"]

        # for validation data, data is cropped to actually touched domain, so that origin offseting
        # does not have to be implemented for every test suite. This is done based on info
        # specified in test suite
        cropped_validation_values = {}
        for name, data in validation_values.items():
            sym = cls.symbols[name]
            if data is not None and sym.kind == SymbolKind.FIELD:
                field_extent_low = tuple(b[0] for b in sym.boundary)
                offset_low = tuple(b[0] - e for b, e in zip(max_boundary, field_extent_low))
                field_extent_high = tuple(b[1] for b in sym.boundary)
                offset_high = tuple(b[1] - e for b, e in zip(max_boundary, field_extent_high))
                validation_slice = filter_mask(
                    tuple(slice(o, s - h) for o, s, h in zip(offset_low, data_shape, offset_high)),
                    field_masks[name],
                )
                data_dims = field_params[name][1]
                if data_dims:
                    validation_slice = tuple([*validation_slice] + [slice(None)] * len(data_dims))
                cropped_validation_values[name] = data[validation_slice]
            else:
                cropped_validation_values[name] = data

        cls.validation(
            **cropped_validation_values,
            domain=domain,
            origin={
                name: info.boundary.lower_indices
                for name, info in implementation.field_info.items()
                if info is not None
            },
        )

        # Test values
        for name, value in test_values.items():
            if isinstance(value, np.ndarray):
                expected_value = validation_values[name]

                if gt_backend.from_name(value.backend).storage_info["device"] == "gpu":
                    value.synchronize()
                    value = value.data.get()
                else:
                    value = value.data

                np.testing.assert_allclose(
                    value,
                    expected_value,
                    rtol=RTOL,
                    atol=ATOL,
                    equal_nan=EQUAL_NAN,
                    err_msg="Wrong data in output field '{name}'".format(name=name),
                )
示例#10
0
    def _validate_args(self, used_field_args, used_param_args, domain, origin):
        """Validate input arguments to _call_run.

        Raises
        -------
            ValueError
                If invalid data or inconsistent options are specified.

            TypeError
                If an incorrect field or parameter data type is passed.
        """

        # assert compatibility of fields with stencil
        for name, field in used_field_args.items():
            if not gt_backend.from_name(
                    self.backend).storage_info["is_compatible_layout"](field):
                raise ValueError(
                    f"The layout of the field {name} is not compatible with the backend."
                )

            if not gt_backend.from_name(
                    self.backend).storage_info["is_compatible_type"](field):
                raise ValueError(
                    f"Field '{name}' has type '{type(field)}', which is not compatible with the '{self.backend}' backend."
                )
            elif type(field) is np.ndarray:
                warnings.warn(
                    "NumPy ndarray passed as field. This is discouraged and only works with constraints and only for certain backends.",
                    RuntimeWarning,
                )

            field_dtype = self.field_info[name].dtype
            if not field.dtype == field_dtype:
                raise TypeError(
                    f"The dtype of field '{name}' is '{field.dtype}' instead of '{field_dtype}'"
                )

            if isinstance(field, gt_storage.storage.Storage):
                field_mask = self._get_field_mask(name)
                storage_mask = tuple(field.mask)
                if storage_mask != field_mask:
                    raise ValueError(
                        f"The storage for '{name}' has mask '{storage_mask}', but the API signature expects '{field_mask}'"
                    )

                if not field.is_stencil_view:
                    raise ValueError(
                        f"An incompatible view was passed for field {name} to the stencil. "
                    )

        # assert compatibility of parameters with stencil
        for name, parameter in used_param_args.items():
            if not type(parameter) == self.parameter_info[name].dtype:
                raise TypeError(
                    f"The type of parameter '{name}' is '{type(parameter)}' instead of '{self.parameter_info[name].dtype}'"
                )

        assert isinstance(used_field_args, dict) and isinstance(
            used_param_args, dict)

        if len(domain) != self.domain_info.ndims:
            raise ValueError(f"Invalid 'domain' value '{domain}'")

        # check domain+halo vs field size
        if not domain > Shape.zeros(self.domain_info.ndims):
            raise ValueError(f"Compute domain contains zero sizes '{domain}')")

        max_domain = self._get_max_domain(used_field_args, origin)
        if not domain <= max_domain:
            raise ValueError(
                f"Compute domain too large (provided: {domain}, maximum: {max_domain})"
            )
        for name, field in used_field_args.items():
            field_mask = self._get_field_mask(name)
            min_origin = self.field_info[
                name].boundary.lower_indices.filter_mask(field_mask)
            restricted_domain = domain.filter_mask(field_mask)
            upper_indices = self.field_info[
                name].boundary.upper_indices.filter_mask(field_mask)
            if origin[name] < min_origin:
                raise ValueError(
                    f"Origin for field {name} too small. Must be at least {min_origin}, is {origin[name]}"
                )
            min_shape = tuple(o + d + h for o, d, h in zip(
                origin[name], restricted_domain, upper_indices))
            if min_shape > field.shape:
                raise ValueError(
                    f"Shape of field {name} is {field.shape} but must be at least {min_shape} for given domain and origin."
                )
示例#11
0
    def _run_test_implementation(self, parameters_dict, implementation):

        cls = type(self)
        fields, exec_info = parameters_dict

        # Domain
        from gt4py.definitions import Shape
        from gt4py.ir.nodes import Index

        origin = cls.origin

        shapes = {}
        for name, field in [(k, v) for k, v in fields.items() if isinstance(v, np.ndarray)]:
            shapes[name] = Shape(field.shape)
        max_domain = Shape([sys.maxsize] * implementation.domain_info.ndims)
        for name, shape in shapes.items():
            upper_boundary = Index([b[1] for b in cls.symbols[name].boundary])
            max_domain &= shape - (Index(origin) + upper_boundary)
        domain = max_domain

        max_boundary = ((0, 0), (0, 0), (0, 0))
        for info in implementation.field_info.values():
            if isinstance(info, gt_definitions.FieldInfo):
                max_boundary = tuple(
                    (max(m[0], abs(b[0])), max(m[1], b[1]))
                    for m, b in zip(max_boundary, info.boundary)
                )

        new_boundary = tuple(
            (max(abs(b[0]), abs(mb[0])), max(abs(b[1]), abs(mb[1])))
            for b, mb in zip(cls.max_boundary, max_boundary)
        )

        shape = None
        for field in fields.values():
            if isinstance(field, np.ndarray):
                assert field.shape == (shape if shape is not None else field.shape)
                shape = field.shape

        patched_origin = tuple(nb[0] for nb in new_boundary)
        patching_origin = tuple(po - o for po, o in zip(patched_origin, origin))
        patched_shape = tuple(nb[0] + nb[1] + d for nb, d in zip(new_boundary, domain))
        patching_slices = [slice(po, po + s) for po, s in zip(patching_origin, shape)]

        for k, v in implementation.constants.items():
            sys.modules[self.__module__].__dict__[k] = v

        inputs = {}
        for k, f in fields.items():
            if isinstance(f, np.ndarray):
                patched_f = np.empty(shape=patched_shape)
                patched_f[patching_slices] = f
                inputs[k] = gt_storage.from_array(
                    patched_f,
                    dtype=f.dtype,
                    shape=patched_f.shape,
                    default_origin=patched_origin,
                    backend=implementation.backend,
                )

            else:
                inputs[k] = f

        # remove unused input parameters
        inputs = {key: value for key, value in inputs.items() if value is not None}

        validation_fields = {name: np.array(field, copy=True) for name, field in inputs.items()}

        implementation(**inputs, origin=patched_origin, exec_info=exec_info)
        domain = exec_info["domain"]

        validation_origins = {
            name: tuple(nb[0] - g[0] for nb, g in zip(new_boundary, cls.symbols[name].boundary))
            for name in inputs
            if name in implementation.field_info
        }

        validation_shapes = {
            name: tuple(d + g[0] + g[1] for d, g in zip(domain, cls.symbols[name].boundary))
            for name in inputs
            if name in implementation.field_info
        }

        validation_field_views = {
            name: field[
                tuple(
                    slice(o, o + s)
                    for o, s in zip(validation_origins[name], validation_shapes[name])
                )
            ]
            if name in implementation.field_info
            else field  # parameters
            for name, field in validation_fields.items()
        }
        cls.validation(
            **validation_field_views,
            domain=domain,
            origin={
                name: tuple(b[0] for b in cls.symbols[name].boundary)
                for name in validation_fields
                if name in implementation.field_info
            },
        )

        # Test values
        for (name, value), expected_value in zip(inputs.items(), validation_fields.values()):
            if isinstance(fields[name], np.ndarray):
                domain_slice = [
                    slice(new_boundary[d][0], new_boundary[d][0] + domain[d])
                    for d in range(len(domain))
                ]

                if gt_backend.from_name(value.backend).storage_info["device"] == "gpu":
                    value.synchronize()
                    value = value.data.get()
                else:
                    value = value.data

                np.testing.assert_allclose(
                    value[domain_slice],
                    expected_value[domain_slice],
                    rtol=RTOL,
                    atol=ATOL,
                    equal_nan=EQUAL_NAN,
                    err_msg="Wrong data in output field '{name}'".format(name=name),
                )
示例#12
0
文件: suites.py 项目: egparedes/gt4py
    def _run_test_implementation(self, parameters_dict, implementation):

        cls = type(self)
        input_data, exec_info = parameters_dict

        origin = cls.origin
        max_boundary = Boundary(cls.max_boundary)

        shape_iter = (Shape(v.shape) for v in input_data.values()
                      if isinstance(v, np.ndarray))
        shape = next(shape_iter)
        assert all(shape == sh for sh in shape_iter)

        domain = shape - (Index(max_boundary.lower_indices) +
                          Index(max_boundary.upper_indices))

        referenced_inputs = {
            name: info
            for name, info in implementation.field_info.items()
            if info is not None
        }
        referenced_inputs.update({
            name: info
            for name, info in implementation.parameter_info.items()
            if info is not None
        })

        # set externals for validation method
        for k, v in implementation.constants.items():
            sys.modules[self.__module__].__dict__[k] = v

        # copy input data
        inputs = {}
        validation_inputs = {}
        for name, data in input_data.items():
            data = input_data[name]
            if name in referenced_inputs:
                info = referenced_inputs[name]
                if isinstance(info, FieldInfo):
                    inputs[name] = gt_storage.from_array(
                        data,
                        dtype=data.dtype,
                        shape=shape,
                        default_origin=origin,
                        backend=implementation.backend,
                    )
                    validation_inputs[name] = np.array(data)
                else:
                    inputs[name] = data
                    validation_inputs[name] = data
            else:
                inputs[name] = None
                validation_inputs[name] = None

        # call implementation
        implementation(**inputs, origin=origin, exec_info=exec_info)
        assert domain == exec_info["domain"]

        # for validation data, data is cropped to actually touched domain, so that origin offseting
        # does not have to be implemented for every test suite. This is done based on info
        # specified in test suite
        cropped_validation_inputs = {}
        for name, data in validation_inputs.items():
            sym = cls.symbols[name]
            if data is not None and sym.kind == SymbolKind.FIELD:
                field_extent_low = tuple(b[0] for b in sym.boundary)
                offset_low = tuple(
                    b[0] - e for b, e in zip(max_boundary, field_extent_low))
                field_extent_high = tuple(b[1] for b in sym.boundary)
                offset_high = tuple(
                    b[1] - e for b, e in zip(max_boundary, field_extent_high))
                validation_slice = tuple(
                    slice(o, s - h)
                    for o, s, h in zip(offset_low, shape, offset_high))
                cropped_validation_inputs[name] = data[validation_slice]
            else:
                cropped_validation_inputs[name] = data

        cls.validation(
            **cropped_validation_inputs,
            domain=domain,
            origin={
                name: info.boundary.lower_indices
                for name, info in implementation.field_info.items()
                if info is not None
            },
        )

        # Test values
        for name, value in inputs.items():
            if isinstance(value, np.ndarray):
                expected_value = validation_inputs[name]

                if gt_backend.from_name(
                        value.backend).storage_info["device"] == "gpu":
                    value.synchronize()
                    value = value.data.get()
                else:
                    value = value.data

                np.testing.assert_allclose(
                    value,
                    expected_value,
                    rtol=RTOL,
                    atol=ATOL,
                    equal_nan=EQUAL_NAN,
                    err_msg="Wrong data in output field '{name}'".format(
                        name=name),
                )