Example #1
0
    def __call__(self, cuts: CutSet) -> Tuple[torch.Tensor, torch.IntTensor]:
        """
        Reads the audio samples from recordings on disk/other storage
        and computes their features.
        The returned shape is ``(B, T, F) => (batch_size, num_frames, num_features)``.

        :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
        """
        audio, _ = collate_audio(cuts)

        for tfnm in self.wave_transforms:
            audio = tfnm(audio)

        features_single = []
        for idx, cut in enumerate(cuts):
            samples = audio[idx].numpy()
            try:
                features = self.extractor.extract(samples,
                                                  cuts[idx].sampling_rate)
            except:
                logging.error(
                    f"Error while extracting the features for cut with ID {cut.id} -- details:\n{cut}"
                )
                raise
            features_single.append(torch.from_numpy(features))
        features_batch = torch.stack(features_single)

        feature_lens = torch.tensor([
            compute_num_frames(cut.duration, self.extractor.frame_shift,
                               cut.sampling_rate) for cut in cuts
        ],
                                    dtype=torch.int32)

        return features_batch, feature_lens
Example #2
0
    def __call__(self, cuts: CutSet) -> Tuple[torch.Tensor, torch.IntTensor]:
        """
        Reads the audio samples from recordings on disk/other storage.
        The returned shape is ``(B, T) => (batch_size, num_samples)``.

        :return: a tensor with collated audio samples, and a tensor of ``num_samples`` of each cut before padding.
        """
        return collate_audio(cuts)
Example #3
0
def test_collate_audio_padding():
    cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json")
    assert len(set(cut.num_samples for cut in cuts)) > 1

    correct_pad = max(cut.num_samples for cut in cuts)
    audio, audio_lens = collate_audio(cuts)

    assert audio.shape[-1] == correct_pad
    assert max(audio_lens).item() == correct_pad
Example #4
0
def test_collate_audio_padding_fault_tolerant_return_vals():
    cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json")
    assert len(set(cut.num_samples for cut in cuts)) > 1

    correct_pad = max(cut.num_samples for cut in cuts)
    audio, audio_lens, cuts_ok = collate_audio(cuts, fault_tolerant=True)

    assert len(cuts) == len(cuts_ok)
    assert audio.shape[-1] == correct_pad
    assert max(audio_lens).item() == correct_pad
Example #5
0
 def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
     if self.collate:
         audio, audio_lens = collate_audio(cuts)
         return {
             "cuts": cuts,
             "audio": audio,
             "audio_lens": audio_lens,
         }
     else:
         return {"cuts": cuts, "audio": [c.load_audio() for c in cuts]}
Example #6
0
    def __call__(
        self, cuts: CutSet
    ) -> Union[
        Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, CutSet]
    ]:
        """
        Reads the audio samples from recordings on disk/other storage.
        The returned shape is ``(B, T) => (batch_size, num_samples)``.

        :return: a tensor with collated audio samples, and a tensor of ``num_samples`` of each cut before padding.
        """
        return collate_audio(
            cuts,
            executor=_get_executor(self.num_workers, executor_type=self._executor_type),
            fault_tolerant=self.fault_tolerant,
        )
    def test_fault_tolerant_loading_skips_cut(self, snr):
        sr = 16000
        zero_cut = self.with_cut(sampling_rate=sr,
                                 num_samples=sr,
                                 features=False,
                                 use_zeroes=True)
        rand_cut = self.with_cut(sampling_rate=sr,
                                 num_samples=sr,
                                 features=False)

        zero_mixed = zero_cut.mix(rand_cut, snr=snr)
        rand_mixed = rand_cut.mix(rand_cut, snr=snr)

        cuts_all = CutSet.from_cuts([zero_mixed, rand_mixed])

        audio, audio_lens, cuts_ok = collate_audio(cuts_all,
                                                   fault_tolerant=True)
        assert len(cuts_ok) == 1
        assert cuts_ok[0] == rand_mixed
Example #8
0
 def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
     if self.collate:
         audio, audio_lens = collate_audio(cuts)
         return {
             "cuts": cuts,
             "audio": audio,
             "audio_lens": audio_lens,
         }
     else:
         remain_cuts = []
         remain_audios = []
         for c in cuts:
             with suppress_audio_loading_errors():
                 remain_audios.append(c.load_audio())
                 remain_cuts.append(c)
         return {
             "cuts": CutSet.from_cuts(remain_cuts),
             "audio": remain_audios
         }
Example #9
0
 def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
     if self.collate:
         audio, audio_lens = collate_audio(cuts)
         return {
             "cuts": cuts,
             "audio": audio,
             "audio_lens": audio_lens,
         }
     else:
         remain_cuts = []
         remain_audios = []
         for c in cuts:
             with suppress_and_warn(AudioLoadingError,
                                    DurationMismatchError,
                                    NonPositiveEnergyError):
                 remain_audios.append(c.load_audio())
                 remain_cuts.append(c)
         return {
             "cuts": CutSet.from_cuts(remain_cuts),
             "audio": remain_audios
         }
Example #10
0
    def __call__(
        self,
        cuts: CutSet,
        recording_field: Optional[str] = None
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor,
                                                        torch.Tensor, CutSet]]:
        """
        Reads the audio samples from recordings on disk/other storage.
        The returned shape is ``(B, T) => (batch_size, num_samples)``.

        :return: a tensor with collated audio samples, and a tensor of ``num_samples`` of each cut before padding.
        :param recording_field: when specified, we will try to load recordings from a custom field with this name
            (i.e., ``cut.load_<recording_field>()`` instead of default ``cut.load_audio()``).
        """
        return collate_audio(
            cuts,
            executor=_get_executor(self.num_workers,
                                   executor_type=self._executor_type),
            fault_tolerant=self.fault_tolerant,
            recording_field=recording_field,
        )
Example #11
0
    def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
        validate_for_tts(cuts)

        for transform in self.cut_transforms:
            cuts = transform(cuts)

        audio, audio_lens = collate_audio(cuts)
        features, features_lens = self.feature_input_strategy(cuts)

        for transform in self.feature_transforms:
            features = transform(features)

        tokens, tokens_lens = self.token_collater(cuts)

        return {
            "audio": audio,
            "features": features,
            "tokens": tokens,
            "audio_lens": audio_lens,
            "features_lens": features_lens,
            "tokens_lens": tokens_lens,
        }
Example #12
0
    def __getitem__(self, cut_ids: Iterable[str]) -> Dict[str, torch.Tensor]:
        cuts = self.cuts.subset(cut_ids=cut_ids)

        for transform in self.cut_transforms:
            cuts = transform(cuts)

        audio, audio_lens = collate_audio(cuts)
        features, features_lens = self.feature_input_strategy(cuts)

        for transform in self.feature_transforms:
            features = transform(features)

        tokens, tokens_lens = self.token_collater(cuts)

        return {
            "audio": audio,
            "features": features,
            "tokens": tokens,
            "audio_lens": audio_lens,
            "features_lens": features_lens,
            "tokens_lens": tokens_lens,
        }