def __init__(self, root: Path, languages: Iterable[str], balanced=True, split='train'): assert split in ('train', 'test', 'validated') split += '.tsv' self.split = split self.root = root self.languages = languages self._item_length = 200000 for language in languages: assert language in self.supported_languages, ( f"Got {language}, options are {self.supported_languages.keys()}" ) # FIXME: add appropriate message if dataset isn't downloaded first datasets = [ COMMONVOICE(root=root, tsv=split, url=language, download=False, version=self.version) for language in languages ] if balanced: min_length = min(len(dataset) for dataset in datasets) datasets = [ Subset(dataset, list(range(min_length))) for dataset in datasets ] super(CommonVoiceDataset, self).__init__(datasets)
def download(cls, root, languages): for language in languages: language_path = cls._get_language_dir(root, language) if not language_path.exists(): language_path.mkdir(parents=True, exist_ok=True) try: print('downloaden') COMMONVOICE( root=language_path, tsv='', url=language, download=True, version=cls.version ) except FileNotFoundError: pass
def load_dataset(fold: str, commonvoice_root: Union[str, Path], commonvoice_version: str, lang: str = 'fr') -> torch.utils.data.Dataset: """ Load the commonvoice dataset within the path commonvoice_root/commonvoice_version/lang In this folder, we expect to find the tsv files of CommonVoice Args: fold (str): the fold to load, e.g. train, dev, test, validated, .. commonvoice_root Returns: torch.utils.data.Dataset: ``dataset`` """ datasetpath = os.path.join(commonvoice_root, commonvoice_version, lang) return COMMONVOICE(root=datasetpath, tsv=fold + ".tsv")
def test_commonvoice_path(self): dataset = COMMONVOICE(Path(self.root_dir)) self._test_commonvoice(dataset)
def test_commonvoice_str(self): dataset = COMMONVOICE(self.root_dir) self._test_commonvoice(dataset)