Exemple #1
0
    def create_dataset(
        self,
        config: Optional[Dict[str, Any]] = None,
        inject_fake_data: bool = True,
        patch_checks: Optional[bool] = None,
        **kwargs: Any,
    ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
        r"""Create the dataset in a temporary directory.

        The configuration passed to the dataset is populated to contain at least all parameters with default values.
        For this the following order of precedence is used:

        1. Parameters in :attr:`kwargs`.
        2. Configuration in :attr:`config`.
        3. Configuration in :attr:`~DatasetTestCase.DEFAULT_CONFIG`.
        4. Default parameters of the dataset.

        Args:
            config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset.
            inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
                creating the dataset.
            patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
                omitted defaults to the same value as ``inject_fake_data``.
            **kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
                overlap with ``config``.

        Yields:
            dataset (torchvision.dataset.VisionDataset): Dataset.
            info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
                for details.
        """
        if patch_checks is None:
            patch_checks = inject_fake_data

        special_kwargs, other_kwargs = self._split_kwargs(kwargs)

        complete_config = self._KWARG_DEFAULTS.copy()
        if self.DEFAULT_CONFIG:
            complete_config.update(self.DEFAULT_CONFIG)
        if config:
            complete_config.update(config)
        if other_kwargs:
            complete_config.update(other_kwargs)

        if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False):
            # override download param to False param if its default is truthy
            special_kwargs["download"] = False

        patchers = self._patch_download_extract()
        if patch_checks:
            patchers.update(self._patch_checks())

        with get_tmp_dir() as tmpdir:
            args = self.dataset_args(tmpdir, complete_config)
            info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None

            with self._maybe_apply_patches(patchers), disable_console_output():
                dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)

            yield dataset, info
Exemple #2
0
    def create_dataset(
        self,
        config: Optional[Dict[str, Any]] = None,
        inject_fake_data: bool = True,
        patch_checks: Optional[bool] = None,
        **kwargs: Any,
    ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
        r"""Create the dataset in a temporary directory.

        Args:
            config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset. If omitted, the
                default configuration is used.
            inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
                creating the dataset.
            patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
                omitted defaults to the same value as ``inject_fake_data``.
            **kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
                overlap with ``config``.

        Yields:
            dataset (torchvision.dataset.VisionDataset): Dataset.
            info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
                for details.
        """
        default_config = self._DEFAULT_CONFIG.copy()
        if config is not None:
            default_config.update(config)
        config = default_config

        if patch_checks is None:
            patch_checks = inject_fake_data

        special_kwargs, other_kwargs = self._split_kwargs(kwargs)
        if "download" in self._HAS_SPECIAL_KWARG:
            special_kwargs["download"] = False
        config.update(other_kwargs)

        patchers = self._patch_download_extract()
        if patch_checks:
            patchers.update(self._patch_checks())

        with get_tmp_dir() as tmpdir:
            args = self.dataset_args(tmpdir, config)
            info = self._inject_fake_data(tmpdir,
                                          config) if inject_fake_data else None

            with self._maybe_apply_patches(patchers), disable_console_output():
                dataset = self.DATASET_CLASS(*args, **config, **special_kwargs)

            yield dataset, info
Exemple #3
0
    def create_dataset(
        self,
        config: Optional[Dict[str, Any]] = None,
        inject_fake_data: bool = True,
        disable_download_extract: Optional[bool] = None,
        **kwargs: Any,
    ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
        r"""Create the dataset in a temporary directory.

        Args:
            config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset. If omitted, the
                default configuration is used.
            inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
                creating the dataset.
            disable_download_extract (Optional[bool]): If ``True`` disable download and extract logic while creating
                the dataset. If ``None`` (default) this takes the same value as ``inject_fake_data``.
            **kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
                overlap with ``config``.

        Yields:
            dataset (torchvision.dataset.VisionDataset): Dataset.
            info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
                for details.
        """
        if config is None:
            config = self.CONFIGS[0].copy()

        special_kwargs, other_kwargs = self._split_kwargs(kwargs)
        config.update(other_kwargs)

        if disable_download_extract is None:
            disable_download_extract = inject_fake_data

        with get_tmp_dir() as tmpdir:
            output = self.inject_fake_data(
                tmpdir, config) if inject_fake_data else None
            if output is None:
                raise UsageError(
                    "The method 'inject_fake_data' needs to return at least an integer indicating the number of "
                    "examples for the current configuration.")

            if isinstance(output,
                          collections.abc.Sequence) and len(output) == 2:
                args, info = output
            else:
                args = (tmpdir, )
                info = output

            if isinstance(info, int):
                info = dict(num_examples=info)
            elif isinstance(info, dict):
                if "num_examples" not in info:
                    raise UsageError(
                        "The information dictionary returned by the method 'inject_fake_data' must contain a "
                        "'num_examples' field that holds the number of examples for the current configuration."
                    )
            else:
                raise UsageError(
                    f"The additional information returned by the method 'inject_fake_data' must be either an integer "
                    f"indicating the number of examples for the current configuration or a dictionary with the the "
                    f"same content. Got {type(info)} instead.")

            cm = self._disable_download_extract if disable_download_extract else nullcontext
            with cm(special_kwargs), disable_console_output():
                dataset = self.DATASET_CLASS(*args, **config, **special_kwargs)

            yield dataset, info