コード例 #1
0
ファイル: dataloader.py プロジェクト: robogast/DL4NLP
    def __init__(self, root: Path, languages: Iterable[str], balanced=True, split='train'):
        assert split in ('train', 'test', 'validated')
        split += '.tsv'
        self.split = split

        self.root = root
        self.languages = languages

        self._item_length = 200000

        for language in languages:
            assert language in self.supported_languages, (
                f"Got {language}, options are {self.supported_languages.keys()}"
            )

        # FIXME: add appropriate message if dataset isn't downloaded first
        datasets = [
            COMMONVOICE(root=root,
            tsv=split, url=language, download=False, version=self.version)
            for language in languages
        ]
        if balanced:
            min_length = min(len(dataset) for dataset in datasets)
            datasets = [
                Subset(dataset, list(range(min_length)))
                for dataset in datasets
            ]
        super(CommonVoiceDataset, self).__init__(datasets)
コード例 #2
0
ファイル: dataloader.py プロジェクト: robogast/DL4NLP
 def download(cls, root, languages):
     for language in languages:
         language_path = cls._get_language_dir(root, language)
         if not language_path.exists():
             language_path.mkdir(parents=True, exist_ok=True)
             try:
                 print('downloaden')
                 COMMONVOICE(
                     root=language_path,
                     tsv='', url=language, download=True, version=cls.version
                 )
             except FileNotFoundError:
                 pass
コード例 #3
0
def load_dataset(fold: str,
                 commonvoice_root: Union[str, Path],
                 commonvoice_version: str,
                 lang: str = 'fr') -> torch.utils.data.Dataset:
    """
    Load the commonvoice dataset within the path
    commonvoice_root/commonvoice_version/lang

    In this folder, we expect to find the tsv files of CommonVoice

    Args:
        fold (str): the fold to load, e.g. train, dev, test, validated, ..
        commonvoice_root

    Returns:
        torch.utils.data.Dataset: ``dataset``
    """
    datasetpath = os.path.join(commonvoice_root, commonvoice_version, lang)
    return COMMONVOICE(root=datasetpath, tsv=fold + ".tsv")
コード例 #4
0
ファイル: commonvoice_test.py プロジェクト: fagan2888/audio
 def test_commonvoice_path(self):
     dataset = COMMONVOICE(Path(self.root_dir))
     self._test_commonvoice(dataset)
コード例 #5
0
ファイル: commonvoice_test.py プロジェクト: fagan2888/audio
 def test_commonvoice_str(self):
     dataset = COMMONVOICE(self.root_dir)
     self._test_commonvoice(dataset)