Exemple #1
0
 def from_tsv(
     cls,
     root: str,
     data_cfg: S2SDataConfig,
     splits: str,
     is_train_split: bool,
     epoch: int,
     seed: int,
     target_is_code: bool = False,
     target_dictionary: Dictionary = None,
     n_frames_per_step: int = 1,
     multitask: Optional[Dict] = None,
 ) -> SpeechToSpeechDataset:
     datasets = []
     for split in splits.split(","):
         samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
             root, split)
         ds = cls._from_list(
             split,
             is_train_split,
             samples,
             data_cfg,
             target_is_code,
             target_dictionary,
             n_frames_per_step,
             multitask,
         )
         datasets.append(ds)
     return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
Exemple #2
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        is_train_split = split.startswith("train")
        pre_tokenizer = self.build_tokenizer(self.args)
        bpe_tokenizer = self.build_bpe(self.args)

        #if is_train_split and self.args.path_audio_dict != '':
        if is_train_split and self.args.use_AudioDictDataset:
            self.datasets[split] = AudioDictDatasetCreator.from_tsv(
                self.args.data,
                self.data_cfg,
                split,
                self.tgt_dict,
                pre_tokenizer,
                bpe_tokenizer,
                is_train_split=is_train_split,
                epoch=epoch,
                seed=self.args.seed,
                audio_dict=self.audio_dict)
        else:
            self.datasets[split] = SpeechToTextDatasetCreator.from_tsv(
                self.args.data,
                self.data_cfg,
                split,
                self.tgt_dict,
                pre_tokenizer,
                bpe_tokenizer,
                is_train_split=is_train_split,
                epoch=epoch,
                seed=self.args.seed,
            )
Exemple #3
0
 def load_langpair_dataset(self,
                           prepend_tgt_lang_tag=False,
                           sampling_alpha=1.0,
                           epoch=0):
     lang_pairs = []
     text_dataset = None
     split = "train"
     for lp in self.args.langpairs.split(","):
         src, tgt = lp.split("-")
         text_dataset = load_langpair_dataset(
             self.args.parallel_text_data,
             split,
             src,
             self.src_dict,
             tgt,
             self.tgt_dict,
             combine=True,
             dataset_impl=None,
             upsample_primary=1,
             left_pad_source=False,
             left_pad_target=False,
             max_source_positions=self.args.max_positions_text,
             max_target_positions=self.args.max_target_positions,
             load_alignments=False,
             truncate_source=False,
         )
         if prepend_tgt_lang_tag:
             # TODO
             text_dataset = TransformEosLangPairDataset(
                 text_dataset,
                 src_eos=self.src_dict.eos(),
                 tgt_bos=self.tgt_dict.eos(
                 ),  # 'prev_output_tokens' starts with eos
                 new_tgt_bos=self.tgt_dict.index(
                     LANG_TAG_TEMPLATE.format(tgt)),
             )
         lang_pairs.append(text_dataset)
     if len(lang_pairs) > 1:
         if sampling_alpha != 1.0:
             size_ratios = SpeechToTextDatasetCreator.get_size_ratios(
                 self.args.langpairs.split(","),
                 [len(s) for s in lang_pairs],
                 alpha=sampling_alpha,
             )
             lang_pairs = [
                 ResamplingDataset(d,
                                   size_ratio=r,
                                   epoch=epoch,
                                   replace=(r >= 1.0))
                 for d, r in zip(lang_pairs, size_ratios)
             ]
         return ConcatDataset(lang_pairs)
     return text_dataset
Exemple #4
0
 def load_dataset(self, split, epoch=1, combine=False, **kwargs):
     is_train_split = split.startswith('train')
     pre_tokenizer = self.build_tokenizer(self.args)
     bpe_tokenizer = self.build_bpe(self.args)
     self.datasets[split] = SpeechToTextDatasetCreator.from_tsv(
         self.args.data,
         self.data_cfg,
         split,
         self.tgt_dict,
         pre_tokenizer,
         bpe_tokenizer,
         is_train_split=is_train_split,
         epoch=epoch,
         seed=self.args.seed)
    def __init__(
        self,
        split: str,
        is_train_split: bool,
        data_cfg: S2SDataConfig,
        src_audio_paths: List[str],
        src_n_frames: List[int],
        tgt_audio_paths: List[str],
        tgt_n_frames: List[int],
        src_langs: Optional[List[str]] = None,
        tgt_langs: Optional[List[str]] = None,
        ids: Optional[List[str]] = None,
        target_is_code: bool = False,
        tgt_dict: Dictionary = None,
        n_frames_per_step: int = 1,
    ):
        tgt_texts = tgt_audio_paths if target_is_code else None
        super().__init__(
            split,
            is_train_split,
            data_cfg,
            src_audio_paths,
            src_n_frames,
            ids=ids,
            tgt_dict=tgt_dict,
            tgt_texts=tgt_texts,
            src_langs=src_langs,
            tgt_langs=tgt_langs,
            n_frames_per_step=n_frames_per_step,
        )

        self.tgt_audio_paths = tgt_audio_paths
        self.tgt_lens = [t // self.n_frames_per_step for t in tgt_n_frames]

        assert not target_is_code or tgt_dict is not None
        self.target_is_code = target_is_code

        assert len(tgt_audio_paths) == self.n_samples
        assert len(tgt_n_frames) == self.n_samples

        self.tgt_speakers = None
        if self.cfg.target_speaker_embed:
            samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
                self.cfg.target_speaker_embed, split)
            spk_emb_dict = {s["id"]: s["speaker_embed"] for s in samples}
            self.tgt_speakers = [spk_emb_dict[id] for id in self.ids]
            assert len(self.tgt_speakers) == self.n_samples

        logger.info(self.__repr__())
    def _load_samples_from_tsv(cls, root: str, split: str, src_lang_map,
                               tgt_lang_map, domain_map):
        # metadata from split
        _, src_lang, tgt_lang, domain = split.split("_")
        src_lang_id = src_lang_map[src_lang]
        tgt_lang_id = tgt_lang_map[tgt_lang]
        domain_id = domain_map[domain]

        samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
            root, split)
        for s in samples:
            s.update({
                cls.KEY_SRC_LANG_ID: src_lang_id,
                cls.KEY_TGT_LANG_ID: tgt_lang_id,
                cls.KEY_DOMAIN_ID: domain_id
            })
        return samples
Exemple #7
0
 def __init__(self, args, split, tgt_dict):
     samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
         args.data, split)
     self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples}
     self.dict = tgt_dict
     self.append_eos = args.decoder_type != "ctc"