Exemple #1
0
    def _get_max_domain(
        self,
        field_args: Dict[str, Any],
        origin: Dict[str, Tuple[int, ...]],
        *,
        squeeze: bool = True,
    ) -> Shape:
        """Return the maximum domain size possible

        Parameters
        ----------
            field_args:
                Mapping from field names to actually passed data arrays.
            origin:
                The origin for each field.
            squeeze:
                Convert non-used domain dimensions to singleton dimensions.

        Returns
        -------
            `Shape`: the maximum domain size.
        """
        domain_ndim = self.domain_info.ndim
        max_size = sys.maxsize
        max_domain = Shape([max_size] * domain_ndim)

        for name, field_info in self.field_info.items():
            if field_info is not None:
                assert field_args.get(
                    name,
                    None) is not None, f"Invalid value for '{name}' field."
                field = field_args[name]
                api_domain_mask = field_info.domain_mask
                api_domain_ndim = field_info.domain_ndim
                assert (
                    not isinstance(field, gt_storage.storage.Storage)
                    or tuple(field.mask)[:domain_ndim] == api_domain_mask
                ), (f"Storage for '{name}' has domain mask '{field.mask}' but the API signature "
                    f"expects '[{', '.join(field_info.axes)}]'")
                upper_indices = field_info.boundary.upper_indices.filter_mask(
                    api_domain_mask)
                field_origin = Index.from_value(origin[name])
                field_domain = tuple(field.shape[i] -
                                     (field_origin[i] + upper_indices[i])
                                     for i in range(api_domain_ndim))
                max_domain &= Shape.from_mask(field_domain,
                                              api_domain_mask,
                                              default=max_size)

        if squeeze:
            return Shape([i if i != max_size else 1 for i in max_domain])
        else:
            return max_domain
Exemple #2
0
    def _make_origin_dict(origin: Any) -> Dict[str, Index]:
        try:
            if isinstance(origin, dict):
                return dict(origin)
            if origin is None:
                return {}
            if isinstance(origin, collections.abc.Iterable):
                return {"_all_": Index.from_value(origin)}
            if isinstance(origin, int):
                return {"_all_": Index.from_k(origin)}
        except Exception:
            pass

        raise ValueError("Invalid 'origin' value ({})".format(origin))
Exemple #3
0
    def _get_max_domain(
        self,
        field_args: Dict[str, Any],
        origin: Dict[str, Tuple[int, ...]],
        *,
        squeeze: bool = True,
    ) -> Shape:
        """Return the maximum domain size possible.

        Parameters
        ----------
            field_args:
                Mapping from field names to actually passed data arrays.
            origin:
                The origin for each field.
            squeeze:
                Convert non-used domain dimensions to singleton dimensions.

        Returns
        -------
            `Shape`: the maximum domain size.
        """
        domain_ndim = self.domain_info.ndim
        max_size = sys.maxsize
        max_domain = Shape([max_size] * domain_ndim)

        for name, field_info in self.field_info.items():
            if field_info.access != AccessKind.NONE:
                assert field_args.get(
                    name,
                    None) is not None, f"Invalid value for '{name}' field."
                field = field_args[name]
                api_domain_mask = field_info.domain_mask
                api_domain_ndim = field_info.domain_ndim
                upper_indices = field_info.boundary.upper_indices.filter_mask(
                    api_domain_mask)
                field_origin = Index.from_value(origin[name])
                field_domain = tuple(field.shape[i] -
                                     (field_origin[i] + upper_indices[i])
                                     for i in range(api_domain_ndim))
                max_domain &= Shape.from_mask(field_domain,
                                              api_domain_mask,
                                              default=max_size)

        if squeeze:
            return Shape([i if i != max_size else 1 for i in max_domain])
        else:
            return max_domain