Ejemplo n.º 1
0
    def __getitem__(self, index: int) -> TextToSpeechDatasetItem:
        s2t_item = super().__getitem__(index)

        duration, pitch, energy = None, None, None
        if self.durations is not None:
            duration = torch.tensor(
                self.durations[index] + [0],
                dtype=torch.long  # pad 0 for EOS
            )
        if self.pitches is not None:
            pitch = get_features_or_waveform(self.pitches[index])
            pitch = torch.from_numpy(
                np.concatenate((pitch, [0]))  # pad 0 for EOS
            ).float()
        if self.energies is not None:
            energy = get_features_or_waveform(self.energies[index])
            energy = torch.from_numpy(
                np.concatenate((energy, [0]))  # pad 0 for EOS
            ).float()
        return TextToSpeechDatasetItem(
            index=index,
            source=s2t_item.source,
            target=s2t_item.target,
            speaker_id=s2t_item.speaker_id,
            duration=duration,
            pitch=pitch,
            energy=energy,
        )
Ejemplo n.º 2
0
    def __getitem__(self, index: int) -> SpeechToSpeechDatasetItem:
        source = self._get_source_audio(index)

        tgt_lang_tag = None
        if self.cfg.prepend_tgt_lang_tag_as_bos:
            # prepend_tgt_lang_tag_as_bos: put tgt_lang_tag as bos of target
            tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index],
                                                 self.tgt_dict)

        if not self.target_is_code:
            target = get_features_or_waveform(self.tgt_audio_paths[index])
            target = torch.from_numpy(target).float()
            target = self.pack_frames(target)
        else:
            target = self.tgt_dict.encode_line(
                self.tgt_audio_paths[index],
                add_if_not_exist=False,
                append_eos=True,
            ).long()
            if self.n_frames_per_step > 1:
                n_tgt_frame = target.size(0) - 1  # exclude <eos>
                keep_n_tgt_frame = n_tgt_frame - n_tgt_frame % self.n_frames_per_step
                target = torch.cat(
                    (
                        target[:keep_n_tgt_frame],
                        target.new_full((1, ), self.tgt_dict.eos()),
                    ),
                    dim=0,
                )

        if self.tgt_speakers:
            tgt_spk = get_features_or_waveform(self.tgt_speakers[index])
            tgt_spk = torch.from_numpy(tgt_spk).float()
        else:
            tgt_spk = torch.FloatTensor([])

        return SpeechToSpeechDatasetItem(
            index=index,
            source=source,
            target=target,
            target_speaker=tgt_spk,
            tgt_lang_tag=tgt_lang_tag,
        )
Ejemplo n.º 3
0
 def read_audio(self, path, ref_len=None):
     wav = get_features_or_waveform(
         path, need_waveform=True, use_sample_rate=self.task.cfg.sample_rate
     )
     if wav.ndim == 2:
         wav = wav.mean(-1)
     assert wav.ndim == 1, wav.ndim
     if ref_len is not None and abs(ref_len - len(wav)) > 160:
         logging.warning(f"ref {ref_len} != read {len(wav)} ({path})")
     return wav
Ejemplo n.º 4
0
 def _get_source_audio(self, index: int) -> torch.Tensor:
     source = get_features_or_waveform(
         self.audio_paths[index],
         need_waveform=self.cfg.use_audio_input,
         use_sample_rate=self.cfg.use_sample_rate,
     )
     if self.cfg.use_audio_input:
         source = torch.from_numpy(source).float()
         if self.cfg.standardize_audio:
             with torch.no_grad():
                 source = F.layer_norm(source, source.shape)
     else:
         if self.feature_transforms is not None:
             source = self.feature_transforms(source)
         source = torch.from_numpy(source).float()
     return source
Ejemplo n.º 5
0
 def read_audio(self, path, ref_len=None):
     wav = get_features_or_waveform(path, need_waveform=True, use_sample_rate=self.sample_rate)
     if ref_len is not None and abs(ref_len - len(wav)) > 160:
         logging.warning(f"ref {ref_len} != read {len(wav)} ({path})")
     return wav
Ejemplo n.º 6
0
 def get_interactive_tokens_and_lengths(self, lines, encode_fn):
     n_frames = [get_features_or_waveform(p).shape[0] for p in lines]
     return lines, n_frames