Exemplo n.º 1
0
    def _call_run(self,
                  field_args,
                  parameter_args,
                  domain,
                  origin,
                  exec_info=None):
        """Check and preprocess the provided arguments (called by :class:`StencilObject` subclasses).

        Note that this function will always try to expand simple parameter values to
        complete data structures by repeating the same value as many times as needed.

        Parameters
        ----------
            field_args: `dict`
                Mapping from field names to actually passed data arrays.
                This parameter encapsulates `*args` in the actual stencil subclass
                by doing: `{input_name[i]: arg for i, arg in enumerate(args)}`

            parameter_args: `dict`
                Mapping from parameter names to actually passed parameter values.
                This parameter encapsulates `**kwargs` in the actual stencil subclass
                by doing: `{name: value for name, value in kwargs.items()}`

            domain : `Sequence` of `int`, optional
                Shape of the computation domain. If `None`, it will be used the
                largest feasible domain according to the provided input fields
                and origin values (`None` by default).

            origin :  `[int * ndims]` or `{'field_name': [int * ndims]}`, optional
                If a single offset is passed, it will be used for all fields.
                If a `dict` is passed, there could be an entry for each field.
                A special key *'_all_'* will represent the value to be used for all
                the fields not explicitly defined. If `None` is passed or it is
                not possible to assign a value to some field according to the
                previous rule, the value will be inferred from the global boundaries
                of the field. Note that the function checks if the origin values
                are at least equal to the `global_border` attribute of that field,
                so a 0-based origin will only be acceptable for fields with
                a 0-area support region.

            exec_info : `dict`, optional
                Dictionary used to store information about the stencil execution.
                (`None` by default).

        Returns
        -------
            `None`

        Raises
        -------
            ValueError
                If invalid data or inconsistent options are specified.
        """

        if exec_info is not None:
            exec_info["call_run_start_time"] = time.perf_counter()
        used_arg_fields = {
            name: field
            for name, field in field_args.items()
            if name in self.field_info and self.field_info[name] is not None
        }
        used_arg_params = {
            name: param
            for name, param in parameter_args.items()
            if name in self.parameter_info
            and self.parameter_info[name] is not None
        }
        for name, field_info in self.field_info.items():
            if field_info is not None and field_args[name] is None:
                raise ValueError(f"Field '{name}' is None.")
        for name, parameter_info in self.parameter_info.items():
            if parameter_info is not None and parameter_args[name] is None:
                raise ValueError(f"Parameter '{name}' is None.")
        # assert compatibility of fields with stencil
        for name, field in used_arg_fields.items():
            if not gt_backend.from_name(
                    self.backend).storage_info["is_compatible_layout"](field):
                raise ValueError(
                    f"The layout of the field {name} is not compatible with the backend."
                )

            if not gt_backend.from_name(
                    self.backend).storage_info["is_compatible_type"](field):
                raise ValueError(
                    f"Field '{name}' has type '{type(field)}', which is not compatible with the '{self.backend}' backend."
                )
            elif type(field) is np.ndarray:
                warnings.warn(
                    "NumPy ndarray passed as field. This is discouraged and only works with constraints and only for certain backends.",
                    RuntimeWarning,
                )
            if not field.dtype == self.field_info[name].dtype:
                raise TypeError(
                    f"The dtype of field '{name}' is '{field.dtype}' instead of '{self.field_info[name].dtype}'"
                )
            # ToDo: check if mask is correct: need mask info in stencil object.

            if isinstance(field, gt_storage.storage.Storage):
                if not field.is_stencil_view:
                    raise ValueError(
                        f"An incompatible view was passed for field {name} to the stencil. "
                    )
                for name_other, field_other in used_arg_fields.items():
                    if field_other.mask == field.mask:
                        if not field_other.shape == field.shape:
                            raise ValueError(
                                f"The fields {name} and {name_other} have the same mask but different shapes."
                            )

        # assert compatibility of parameters with stencil
        for name, parameter in used_arg_params.items():
            if not type(parameter) == self.parameter_info[name].dtype:
                raise TypeError(
                    f"The type of parameter '{name}' is '{type(parameter)}' instead of '{self.parameter_info[name].dtype}'"
                )

        assert isinstance(field_args, dict) and isinstance(
            parameter_args, dict)

        # Shapes
        shapes = {}

        for name, field in used_arg_fields.items():
            shapes[name] = Shape(field.shape)
        # Origins
        if origin is None:
            origin = {}
        else:
            origin = normalize_origin_mapping(origin)
        for name, field in used_arg_fields.items():
            origin.setdefault(
                name,
                origin["_all_"] if "_all_" in origin else field.default_origin)

        # Domain
        max_domain = Shape([sys.maxsize] * self.domain_info.ndims)
        for name, shape in shapes.items():
            upper_boundary = Index(
                self.field_info[name].boundary.upper_indices)
            max_domain &= shape - (Index(origin[name]) + upper_boundary)

        if domain is None:
            domain = max_domain
        else:
            domain = normalize_domain(domain)
            if len(domain) != self.domain_info.ndims:
                raise ValueError(f"Invalid 'domain' value '{domain}'")

        # check domain+halo vs field size
        if not domain > Shape.zeros(self.domain_info.ndims):
            raise ValueError(f"Compute domain contains zero sizes '{domain}')")

        if not domain <= max_domain:
            raise ValueError(
                f"Compute domain too large (provided: {domain}, maximum: {max_domain})"
            )
        for name, field in used_arg_fields.items():
            min_origin = self.field_info[name].boundary.lower_indices
            if origin[name] < min_origin:
                raise ValueError(
                    f"Origin for field {name} too small. Must be at least {min_origin}, is {origin[name]}"
                )
            min_shape = tuple(o + d + h for o, d, h in zip(
                origin[name], domain,
                self.field_info[name].boundary.upper_indices))
            if min_shape > field.shape:
                raise ValueError(
                    f"Shape of field {name} is {field.shape} but must be at least {min_shape} for given domain and origin."
                )

        self.run(**field_args,
                 **parameter_args,
                 _domain_=domain,
                 _origin_=origin,
                 exec_info=exec_info)
Exemplo n.º 2
0
    def _call_run(self,
                  field_args,
                  parameter_args,
                  domain,
                  origin,
                  *,
                  validate_args=True,
                  exec_info=None):
        """Check and preprocess the provided arguments (called by :class:`StencilObject` subclasses).

        Note that this function will always try to expand simple parameter values to
        complete data structures by repeating the same value as many times as needed.

        Parameters
        ----------
            field_args: `dict`
                Mapping from field names to actually passed data arrays.
                This parameter encapsulates `*args` in the actual stencil subclass
                by doing: `{input_name[i]: arg for i, arg in enumerate(args)}`

            parameter_args: `dict`
                Mapping from parameter names to actually passed parameter values.
                This parameter encapsulates `**kwargs` in the actual stencil subclass
                by doing: `{name: value for name, value in kwargs.items()}`

            domain : `Sequence` of `int`, optional
                Shape of the computation domain. If `None`, it will be used the
                largest feasible domain according to the provided input fields
                and origin values (`None` by default).

            origin :  `[int * ndims]` or {'field_name': [int * ndims]} , optional
                If a single offset is passed, it will be used for all fields.
                If a `dict` is passed, there could be an entry for each field.
                A special key '_all_' will represent the value to be used for all
                the fields not explicitly defined. If `None` is passed or it is
                not possible to assign a value to some field according to the
                previous rule, the value will be inferred from the global boundaries
                of the field. Note that the function checks if the origin values
                are at least equal to the `global_border` attribute of that field,
                so a 0-based origin will only be acceptable for fields with
                a 0-area support region.

            exec_info : `dict`, optional
                Dictionary used to store information about the stencil execution.
                (`None` by default).

        Returns
        -------
            `None`

        Raises
        -------
            ValueError
                If invalid data or inconsistent options are specified.
        """

        if exec_info is not None:
            exec_info["call_run_start_time"] = time.perf_counter()

        # Collect used arguments and parameters
        used_field_args = {
            name: field
            for name, field in field_args.items()
            if self.field_info.get(name, None) is not None
        }
        for name, field_info in self.field_info.items():
            if field_info is not None and used_field_args[name] is None:
                raise ValueError(f"Field '{name}' is None.")

        used_param_args = {
            name: param
            for name, param in parameter_args.items()
            if self.parameter_info.get(name, None) is not None
        }
        for name, parameter_info in self.parameter_info.items():
            if parameter_info is not None and used_param_args[name] is None:
                raise ValueError(f"Parameter '{name}' is None.")

        # Origins
        if origin is None:
            origin = {}
        else:
            origin = normalize_origin_mapping(origin)

        for name, field in used_field_args.items():
            if "_all_" in origin:
                field_mask = self._get_field_mask(name)
                origin.setdefault(
                    name, gt_ir.Index(origin["_all_"].filter_mask(field_mask)))
            else:
                storage_ndim = len(field.shape)
                api_ndim = len(self.field_info[name].axes)
                if storage_ndim != api_ndim:
                    raise ValueError(
                        f"The storage for '{name}' has {storage_ndim} dimensions, but the API signature expects {api_ndim}"
                    )
                origin.setdefault(name, gt_ir.Index(field.default_origin))

        # Domain
        if domain is None:
            domain = self._get_max_domain(used_field_args, origin)
            if any(axis_bound == np.iinfo(np.uintc).max
                   for axis_bound in domain):
                raise ValueError(
                    f"Compute domain could not be deduced. Specifiy the domain explicitly or ensure you reference at least one field."
                )
        else:
            domain = normalize_domain(domain)

        if validate_args:
            self._validate_args(used_field_args, used_param_args, domain,
                                origin)

        self.run(_domain_=domain,
                 _origin_=origin,
                 exec_info=exec_info,
                 **field_args,
                 **parameter_args)

        if exec_info is not None:
            exec_info["call_run_end_time"] = time.perf_counter()