def create_dataset( self, config: Optional[Dict[str, Any]] = None, inject_fake_data: bool = True, patch_checks: Optional[bool] = None, **kwargs: Any, ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]: r"""Create the dataset in a temporary directory. The configuration passed to the dataset is populated to contain at least all parameters with default values. For this the following order of precedence is used: 1. Parameters in :attr:`kwargs`. 2. Configuration in :attr:`config`. 3. Configuration in :attr:`~DatasetTestCase.DEFAULT_CONFIG`. 4. Default parameters of the dataset. Args: config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset. inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before creating the dataset. patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If omitted defaults to the same value as ``inject_fake_data``. **kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they overlap with ``config``. Yields: dataset (torchvision.dataset.VisionDataset): Dataset. info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data` for details. """ if patch_checks is None: patch_checks = inject_fake_data special_kwargs, other_kwargs = self._split_kwargs(kwargs) complete_config = self._KWARG_DEFAULTS.copy() if self.DEFAULT_CONFIG: complete_config.update(self.DEFAULT_CONFIG) if config: complete_config.update(config) if other_kwargs: complete_config.update(other_kwargs) if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False): # override download param to False param if its default is truthy special_kwargs["download"] = False patchers = self._patch_download_extract() if patch_checks: patchers.update(self._patch_checks()) with get_tmp_dir() as tmpdir: args = self.dataset_args(tmpdir, complete_config) info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None with self._maybe_apply_patches(patchers), disable_console_output(): dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs) yield dataset, info
def create_dataset( self, config: Optional[Dict[str, Any]] = None, inject_fake_data: bool = True, patch_checks: Optional[bool] = None, **kwargs: Any, ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]: r"""Create the dataset in a temporary directory. Args: config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset. If omitted, the default configuration is used. inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before creating the dataset. patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If omitted defaults to the same value as ``inject_fake_data``. **kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they overlap with ``config``. Yields: dataset (torchvision.dataset.VisionDataset): Dataset. info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data` for details. """ default_config = self._DEFAULT_CONFIG.copy() if config is not None: default_config.update(config) config = default_config if patch_checks is None: patch_checks = inject_fake_data special_kwargs, other_kwargs = self._split_kwargs(kwargs) if "download" in self._HAS_SPECIAL_KWARG: special_kwargs["download"] = False config.update(other_kwargs) patchers = self._patch_download_extract() if patch_checks: patchers.update(self._patch_checks()) with get_tmp_dir() as tmpdir: args = self.dataset_args(tmpdir, config) info = self._inject_fake_data(tmpdir, config) if inject_fake_data else None with self._maybe_apply_patches(patchers), disable_console_output(): dataset = self.DATASET_CLASS(*args, **config, **special_kwargs) yield dataset, info
def create_dataset( self, config: Optional[Dict[str, Any]] = None, inject_fake_data: bool = True, disable_download_extract: Optional[bool] = None, **kwargs: Any, ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]: r"""Create the dataset in a temporary directory. Args: config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset. If omitted, the default configuration is used. inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before creating the dataset. disable_download_extract (Optional[bool]): If ``True`` disable download and extract logic while creating the dataset. If ``None`` (default) this takes the same value as ``inject_fake_data``. **kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they overlap with ``config``. Yields: dataset (torchvision.dataset.VisionDataset): Dataset. info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data` for details. """ if config is None: config = self.CONFIGS[0].copy() special_kwargs, other_kwargs = self._split_kwargs(kwargs) config.update(other_kwargs) if disable_download_extract is None: disable_download_extract = inject_fake_data with get_tmp_dir() as tmpdir: output = self.inject_fake_data( tmpdir, config) if inject_fake_data else None if output is None: raise UsageError( "The method 'inject_fake_data' needs to return at least an integer indicating the number of " "examples for the current configuration.") if isinstance(output, collections.abc.Sequence) and len(output) == 2: args, info = output else: args = (tmpdir, ) info = output if isinstance(info, int): info = dict(num_examples=info) elif isinstance(info, dict): if "num_examples" not in info: raise UsageError( "The information dictionary returned by the method 'inject_fake_data' must contain a " "'num_examples' field that holds the number of examples for the current configuration." ) else: raise UsageError( f"The additional information returned by the method 'inject_fake_data' must be either an integer " f"indicating the number of examples for the current configuration or a dictionary with the the " f"same content. Got {type(info)} instead.") cm = self._disable_download_extract if disable_download_extract else nullcontext with cm(special_kwargs), disable_console_output(): dataset = self.DATASET_CLASS(*args, **config, **special_kwargs) yield dataset, info