Пример #1
0
    def __init__(
        self,
        manifest_filepath: str,
        parser: Union[str, Callable],
        sample_rate: int,
        int_values: bool = False,
        augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
        max_duration: Optional[int] = None,
        min_duration: Optional[int] = None,
        max_utts: int = 0,
        trim: bool = False,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        pad_id: int = 0,
        return_sample_id: bool = False,
    ):
        if type(manifest_filepath) == str:
            manifest_filepath = manifest_filepath.split(",")

        self.manifest_processor = ASRManifestProcessor(
            manifest_filepath=manifest_filepath,
            parser=parser,
            max_duration=max_duration,
            min_duration=min_duration,
            max_utts=max_utts,
            bos_id=bos_id,
            eos_id=eos_id,
            pad_id=pad_id,
        )
        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)
        self.trim = trim
        self.return_sample_id = return_sample_id
Пример #2
0
    def get_batch_embeddings(speaker_model, manifest_filepath, batch_size=32, sample_rate=16000, device='cuda'):

        speaker_model.eval()
        if device == 'cuda':
            speaker_model.to(device)

        featurizer = WaveformFeaturizer(sample_rate=sample_rate)
        dataset = AudioToSpeechLabelDataset(manifest_filepath=manifest_filepath, labels=None, featurizer=featurizer)

        dataloader = torch.utils.data.DataLoader(
            dataset=dataset, batch_size=batch_size, collate_fn=dataset.fixed_seq_collate_fn,
        )

        all_logits = []
        all_labels = []
        all_embs = []

        for test_batch in tqdm(dataloader):
            if device == 'cuda':
                test_batch = [x.to(device) for x in test_batch]
            audio_signal, audio_signal_len, labels, _ = test_batch
            logits, embs = speaker_model.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)

            all_logits.extend(logits.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_embs.extend(embs.cpu().numpy())

        all_logits, true_labels, all_embs = np.asarray(all_logits), np.asarray(all_labels), np.asarray(all_embs)

        return all_embs, all_logits, true_labels, dataset.id2label
Пример #3
0
    def __setup_dataloader_from_config(self, config: Optional[Dict]):
        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        featurizer = WaveformFeaturizer(
            sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=augmentor
        )
        shuffle = config.get('shuffle', False)
        if config.get('is_tarred', False):
            if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or (
                'manifest_filepath' in config and config['manifest_filepath'] is None
            ):
                logging.warning(
                    "Could not load dataset as `manifest_filepath` was None or "
                    f"`tarred_audio_filepaths` is None. Provided config : {config}"
                )
                return None

            shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
            dataset = get_tarred_speech_label_dataset(
                featurizer=featurizer,
                config=config,
                shuffle_n=shuffle_n,
                global_rank=self.global_rank,
                world_size=self.world_size,
            )
            shuffle = False
        else:
            if 'manifest_filepath' in config and config['manifest_filepath'] is None:
                logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
                return None

            dataset = AudioToSpeechLabelDataset(
                manifest_filepath=config['manifest_filepath'],
                labels=config['labels'],
                featurizer=featurizer,
                max_duration=config.get('max_duration', None),
                min_duration=config.get('min_duration', None),
                trim=config.get('trim_silence', False),
                normalize_audio=config.get('normalize_audio', False),
            )

        if hasattr(dataset, 'fixed_seq_collate_fn'):
            collate_fn = dataset.fixed_seq_collate_fn
        else:
            collate_fn = dataset.datasets[0].fixed_seq_collate_fn

        batch_size = config['batch_size']
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            collate_fn=collate_fn,
            drop_last=config.get('drop_last', False),
            shuffle=shuffle,
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )
Пример #4
0
    def __setup_dataloader_from_config(self, config: Optional[Dict]):
        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        featurizer = WaveformFeaturizer(sample_rate=config['sample_rate'],
                                        int_values=config.get(
                                            'int_values', False),
                                        augmentor=augmentor)
        self.dataset = AudioToSpeechLabelDataset(
            manifest_filepath=config['manifest_filepath'],
            labels=config['labels'],
            featurizer=featurizer,
            max_duration=config.get('max_duration', None),
            min_duration=config.get('min_duration', None),
            trim=False,
            time_length=config.get('time_length', 8),
            shift_length=config.get('shift_length', 0.75),
        )

        if self.task == 'diarization':
            logging.info("Setting up diarization parameters")
            _collate_func = self.dataset.sliced_seq_collate_fn
            batch_size = config['batch_size']
            shuffle = False
        else:
            logging.info("Setting up identification parameters")
            _collate_func = self.dataset.fixed_seq_collate_fn
            batch_size = config['batch_size']
            shuffle = config.get('shuffle', False)

        return torch.utils.data.DataLoader(
            dataset=self.dataset,
            batch_size=batch_size,
            collate_fn=_collate_func,
            drop_last=config.get('drop_last', False),
            shuffle=shuffle,
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )
Пример #5
0
    def test_tarred_dataset_duplicate_name(self, test_data_dir):
        manifest_path = os.path.abspath(
            os.path.join(
                test_data_dir,
                'asr/tarred_an4/tarred_duplicate_audio_manifest.json'))

        # Test braceexpand loading
        tarpath = os.path.abspath(
            os.path.join(test_data_dir, 'asr/tarred_an4/audio_{0..1}.tar'))
        featurizer = WaveformFeaturizer(sample_rate=16000,
                                        int_values=False,
                                        augmentor=None)
        ds_braceexpand = TarredAudioToClassificationLabelDataset(
            audio_tar_filepaths=tarpath,
            manifest_filepath=manifest_path,
            labels=self.labels,
            featurizer=featurizer)

        assert len(ds_braceexpand) == 6
        count = 0
        for _ in ds_braceexpand:
            count += 1
        assert count == 6

        # Test loading via list
        tarpath = [
            os.path.abspath(
                os.path.join(test_data_dir, f'asr/tarred_an4/audio_{i}.tar'))
            for i in range(2)
        ]
        ds_list_load = TarredAudioToClassificationLabelDataset(
            audio_tar_filepaths=tarpath,
            manifest_filepath=manifest_path,
            labels=self.labels,
            featurizer=featurizer)
        count = 0
        for _ in ds_list_load:
            count += 1
        assert count == 6
Пример #6
0
class VocoderDataset(Dataset):
    def __init__(
        self,
        manifest_filepath: Union[str, Path, List[str], List[Path]],
        sample_rate: int,
        n_segments: Optional[int] = None,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        ignore_file: Optional[Union[str, Path]] = None,
        trim: Optional[bool] = False,
        load_precomputed_mel: bool = False,
        hop_length: Optional[int] = None,
    ):
        """Dataset which can be used for training and fine-tuning vocoder with pre-computed mel-spectrograms.
        Args:
            manifest_filepath (Union[str, Path, List[str], List[Path]]): Path(s) to the .json manifests containing information on the
            dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
            json. Each line should contain the following:
                "audio_filepath": <PATH_TO_WAV>,
                "duration": <Duration of audio clip in seconds> (Optional),
                "mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional)
            sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
            n_segments (int): The length of audio in samples to load. For example, given a sample rate of 16kHz, and
                n_segments=16000, a random 1 second section of audio from the clip will be loaded. The section will
                be randomly sampled everytime the audio is batched. Can be set to None to load the entire audio.
                Must be specified if load_precomputed_mel is True.
            max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            ignore_file (Optional[Union[str, Path]]): The location of a pickle-saved list of audio paths
                that will be pruned prior to training. Defaults to None which does not prune.
            trim (bool): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
            load_precomputed_mel (bool): Whether to load precomputed mel (useful for fine-tuning). Note: Requires "mel_filepath" to be set in the manifest file.
            hop_length (Optional[int]): The hope length between fft computations. Must be specified if load_precomputed_mel is True.
        """
        super().__init__()

        if load_precomputed_mel:
            if hop_length is None:
                raise ValueError(
                    "hop_length must be specified when load_precomputed_mel is True"
                )

            if n_segments is None:
                raise ValueError(
                    "n_segments must be specified when load_precomputed_mel is True"
                )

        # Initialize and read manifest file(s), filter out data by duration and ignore_file
        if isinstance(manifest_filepath, str):
            manifest_filepath = [manifest_filepath]
        self.manifest_filepath = manifest_filepath

        data = []
        total_duration = 0
        for manifest_file in self.manifest_filepath:
            with open(Path(manifest_file).expanduser(), 'r') as f:
                logging.info(f"Loading dataset from {manifest_file}.")
                for line in tqdm(f):
                    item = json.loads(line)

                    if "mel_filepath" not in item and load_precomputed_mel:
                        raise ValueError(
                            f"mel_filepath is missing in {manifest_file}")

                    file_info = {
                        "audio_filepath":
                        item["audio_filepath"],
                        "mel_filepath":
                        item["mel_filepath"]
                        if "mel_filepath" in item else None,
                        "duration":
                        item["duration"] if "duration" in item else None,
                    }

                    data.append(file_info)

                    if file_info["duration"] is None:
                        logging.info(
                            "Not all audio files have duration information. Duration logging will be disabled."
                        )
                        total_duration = None

                    if total_duration is not None:
                        total_duration += item["duration"]

        logging.info(f"Loaded dataset with {len(data)} files.")
        if total_duration is not None:
            logging.info(
                f"Dataset contains {total_duration / 3600:.2f} hours.")

        self.data = TTSDataset.filter_files(data, ignore_file, min_duration,
                                            max_duration, total_duration)
        self.base_data_dir = get_base_dir(
            [item["audio_filepath"] for item in self.data])

        # Initialize audio and mel related parameters
        self.load_precomputed_mel = load_precomputed_mel
        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate)
        self.sample_rate = sample_rate
        self.n_segments = n_segments
        self.hop_length = hop_length
        self.trim = trim

    def _collate_fn(self, batch):
        if self.load_precomputed_mel:
            return torch.utils.data.dataloader.default_collate(batch)

        audio_lengths = [audio_len for _, audio_len in batch]
        audio_signal = torch.zeros(len(batch),
                                   max(audio_lengths),
                                   dtype=torch.float)

        for i, sample in enumerate(batch):
            audio_signal[i].narrow(0, 0, sample[0].size(0)).copy_(sample[0])

        return audio_signal, torch.tensor(audio_lengths, dtype=torch.long)

    def __getitem__(self, index):
        sample = self.data[index]

        if not self.load_precomputed_mel:
            features = AudioSegment.segment_from_file(
                sample["audio_filepath"],
                n_segments=self.n_segments
                if self.n_segments is not None else -1,
                trim=self.trim,
            )
            features = torch.tensor(features.samples)
            audio, audio_length = features, torch.tensor(
                features.shape[0]).long()

            return audio, audio_length
        else:
            features = self.featurizer.process(sample["audio_filepath"],
                                               trim=self.trim)
            audio, audio_length = features, torch.tensor(
                features.shape[0]).long()

            mel = torch.load(sample["mel_filepath"])
            frames = math.ceil(self.n_segments / self.hop_length)

            if len(audio) > self.n_segments:
                start = random.randint(0, mel.shape[1] - frames - 2)
                mel = mel[:, start:start + frames]
                audio = audio[start * self.hop_length:(start + frames) *
                              self.hop_length]
            else:
                mel = torch.nn.functional.pad(mel, (0, frames - mel.shape[1]))
                audio = torch.nn.functional.pad(
                    audio, (0, self.n_segments - len(audio)))

            return audio, len(audio), mel

    def __len__(self):
        return len(self.data)
Пример #7
0
    def __init__(
        self,
        manifest_filepath: Union[str, Path, List[str], List[Path]],
        sample_rate: int,
        n_segments: Optional[int] = None,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        ignore_file: Optional[Union[str, Path]] = None,
        trim: Optional[bool] = False,
        load_precomputed_mel: bool = False,
        hop_length: Optional[int] = None,
    ):
        """Dataset which can be used for training and fine-tuning vocoder with pre-computed mel-spectrograms.
        Args:
            manifest_filepath (Union[str, Path, List[str], List[Path]]): Path(s) to the .json manifests containing information on the
            dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
            json. Each line should contain the following:
                "audio_filepath": <PATH_TO_WAV>,
                "duration": <Duration of audio clip in seconds> (Optional),
                "mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional)
            sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
            n_segments (int): The length of audio in samples to load. For example, given a sample rate of 16kHz, and
                n_segments=16000, a random 1 second section of audio from the clip will be loaded. The section will
                be randomly sampled everytime the audio is batched. Can be set to None to load the entire audio.
                Must be specified if load_precomputed_mel is True.
            max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            ignore_file (Optional[Union[str, Path]]): The location of a pickle-saved list of audio paths
                that will be pruned prior to training. Defaults to None which does not prune.
            trim (bool): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
            load_precomputed_mel (bool): Whether to load precomputed mel (useful for fine-tuning). Note: Requires "mel_filepath" to be set in the manifest file.
            hop_length (Optional[int]): The hope length between fft computations. Must be specified if load_precomputed_mel is True.
        """
        super().__init__()

        if load_precomputed_mel:
            if hop_length is None:
                raise ValueError(
                    "hop_length must be specified when load_precomputed_mel is True"
                )

            if n_segments is None:
                raise ValueError(
                    "n_segments must be specified when load_precomputed_mel is True"
                )

        # Initialize and read manifest file(s), filter out data by duration and ignore_file
        if isinstance(manifest_filepath, str):
            manifest_filepath = [manifest_filepath]
        self.manifest_filepath = manifest_filepath

        data = []
        total_duration = 0
        for manifest_file in self.manifest_filepath:
            with open(Path(manifest_file).expanduser(), 'r') as f:
                logging.info(f"Loading dataset from {manifest_file}.")
                for line in tqdm(f):
                    item = json.loads(line)

                    if "mel_filepath" not in item and load_precomputed_mel:
                        raise ValueError(
                            f"mel_filepath is missing in {manifest_file}")

                    file_info = {
                        "audio_filepath":
                        item["audio_filepath"],
                        "mel_filepath":
                        item["mel_filepath"]
                        if "mel_filepath" in item else None,
                        "duration":
                        item["duration"] if "duration" in item else None,
                    }

                    data.append(file_info)

                    if file_info["duration"] is None:
                        logging.info(
                            "Not all audio files have duration information. Duration logging will be disabled."
                        )
                        total_duration = None

                    if total_duration is not None:
                        total_duration += item["duration"]

        logging.info(f"Loaded dataset with {len(data)} files.")
        if total_duration is not None:
            logging.info(
                f"Dataset contains {total_duration / 3600:.2f} hours.")

        self.data = TTSDataset.filter_files(data, ignore_file, min_duration,
                                            max_duration, total_duration)
        self.base_data_dir = get_base_dir(
            [item["audio_filepath"] for item in self.data])

        # Initialize audio and mel related parameters
        self.load_precomputed_mel = load_precomputed_mel
        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate)
        self.sample_rate = sample_rate
        self.n_segments = n_segments
        self.hop_length = hop_length
        self.trim = trim
Пример #8
0
class TTSDataset(Dataset):
    def __init__(
        self,
        manifest_filepath: Union[str, Path, List[str], List[Path]],
        sample_rate: int,
        text_tokenizer: Union[BaseTokenizer, Callable[[str], List[int]]],
        tokens: Optional[List[str]] = None,
        text_normalizer: Optional[Union[Normalizer, Callable[[str],
                                                             str]]] = None,
        text_normalizer_call_kwargs: Optional[Dict] = None,
        text_tokenizer_pad_id: Optional[int] = None,
        sup_data_types: Optional[List[str]] = None,
        sup_data_path: Optional[Union[Path, str]] = None,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        ignore_file: Optional[Union[str, Path]] = None,
        trim: bool = False,
        n_fft: int = 1024,
        win_length: Optional[int] = None,
        hop_length: Optional[int] = None,
        window: str = "hann",
        n_mels: int = 80,
        lowfreq: int = 0,
        highfreq: Optional[int] = None,
        **kwargs,
    ):
        """Dataset which can be used for training spectrogram generators and end-to-end TTS models.
        It loads main data types (audio, text) and specified supplementary data types (log mel, durations, align prior matrix, pitch, energy, speaker id).
        Some of supplementary data types will be computed on the fly and saved in the sup_data_path if they did not exist before.
        Saved folder can be changed for some supplementary data types (see keyword args section).
        Arguments for supplementary data should be also specified in this class and they will be used from kwargs (see keyword args section).
        Args:
            manifest_filepath (Union[str, Path, List[str], List[Path]]): Path(s) to the .json manifests containing information on the
                dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
                json. Each line should contain the following:
                    "audio_filepath": <PATH_TO_WAV>,
                    "text": <THE_TRANSCRIPT>,
                    "normalized_text": <NORMALIZED_TRANSCRIPT> (Optional),
                    "mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional),
                    "duration": <Duration of audio clip in seconds> (Optional)
            sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
            text_tokenizer (Optional[Union[BaseTokenizer, Callable[[str], List[int]]]]): BaseTokenizer or callable which represents text tokenizer.
            tokens (Optional[List[str]]): Tokens from text_tokenizer. Should be specified if text_tokenizer is not BaseTokenizer.
            text_normalizer (Optional[Union[Normalizer, Callable[[str], str]]]): Normalizer or callable which represents text normalizer.
            text_normalizer_call_kwargs (Optional[Dict]): Additional arguments for text_normalizer function.
            text_tokenizer_pad_id (Optional[int]): Index of padding. Should be specified if text_tokenizer is not BaseTokenizer.
            sup_data_types (Optional[List[str]]): List of supplementary data types.
            sup_data_path (Optional[Union[Path, str]]): A folder that contains or will contain supplementary data (e.g. pitch).
            max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            ignore_file (Optional[Union[str, Path]]): The location of a pickle-saved list of audio paths
                that will be pruned prior to training. Defaults to None which does not prune.
            trim (Optional[bool]): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
            n_fft (int): The number of fft samples. Defaults to 1024
            win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft.
            hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4.
            window (str): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the
                equivalent torch window function.
            n_mels (int): The number of mel filters. Defaults to 80.
            lowfreq (int): The lowfreq input to the mel filter calculation. Defaults to 0.
            highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None.
        Keyword Args:
            log_mel_folder (Optional[Union[Path, str]]): The folder that contains or will contain log mel spectrograms.
            align_prior_matrix_folder (Optional[Union[Path, str]]): The folder that contains or will contain align prior matrices.
            pitch_folder (Optional[Union[Path, str]]): The folder that contains or will contain pitch.
            energy_folder (Optional[Union[Path, str]]): The folder that contains or will contain energy.
            durs_file (Optional[str]): String path to pickled durations location.
            durs_type (Optional[str]): Type of durations. Currently supported only "aligner-based".
            use_beta_binomial_interpolator (Optional[bool]): Whether to use beta-binomial interpolator for calculating alignment prior matrix. Defaults to False.
            pitch_fmin (Optional[float]): The fmin input to librosa.pyin. Defaults to librosa.note_to_hz('C2').
            pitch_fmax (Optional[float]): The fmax input to librosa.pyin. Defaults to librosa.note_to_hz('C7').
            pitch_mean (Optional[float]): The mean that we use to normalize the pitch.
            pitch_std (Optional[float]): The std that we use to normalize the pitch.
            pitch_norm (Optional[bool]): Whether to normalize pitch (via pitch_mean and pitch_std) or not.
        """
        super().__init__()

        # Initialize text tokenizer
        self.text_tokenizer = text_tokenizer
        if isinstance(self.text_tokenizer, BaseTokenizer):
            self.text_tokenizer_pad_id = text_tokenizer.pad
            self.tokens = text_tokenizer.tokens
        else:
            if text_tokenizer_pad_id is None:
                raise ValueError(
                    f"text_tokenizer_pad_id must be specified if text_tokenizer is not BaseTokenizer"
                )

            if tokens is None:
                raise ValueError(
                    f"tokens must be specified if text_tokenizer is not BaseTokenizer"
                )

            self.text_tokenizer_pad_id = text_tokenizer_pad_id
            self.tokens = tokens

        # Initialize text normalizer is specified
        self.text_normalizer = text_normalizer
        self.text_normalizer_call = (
            self.text_normalizer.normalize if isinstance(
                self.text_normalizer, Normalizer) else self.text_normalizer)
        self.text_normalizer_call_kwargs = (text_normalizer_call_kwargs
                                            if text_normalizer_call_kwargs
                                            is not None else {})

        # Initialize and read manifest file(s), filter out data by duration and ignore_file, compute base dir
        if isinstance(manifest_filepath, str):
            manifest_filepath = [manifest_filepath]
        self.manifest_filepath = manifest_filepath

        data = []
        total_duration = 0
        for manifest_file in self.manifest_filepath:
            with open(Path(manifest_file).expanduser(), 'r') as f:
                logging.info(f"Loading dataset from {manifest_file}.")
                for line in tqdm(f):
                    item = json.loads(line)

                    file_info = {
                        "audio_filepath":
                        item["audio_filepath"],
                        "original_text":
                        item["text"],
                        "mel_filepath":
                        item["mel_filepath"]
                        if "mel_filepath" in item else None,
                        "duration":
                        item["duration"] if "duration" in item else None,
                        "speaker_id":
                        item["speaker"] if "speaker" in item else None,
                    }

                    if "normalized_text" not in item:
                        text = item["text"]

                        if self.text_normalizer is not None:
                            text = self.text_normalizer_call(
                                text, **self.text_normalizer_call_kwargs)

                        file_info["normalized_text"] = text
                        file_info["text_tokens"] = self.text_tokenizer(text)
                    else:
                        file_info["normalized_text"] = item["normalized_text"]
                        file_info["text_tokens"] = self.text_tokenizer(
                            item["normalized_text"])

                    data.append(file_info)

                    if file_info["duration"] is None:
                        logging.info(
                            "Not all audio files have duration information. Duration logging will be disabled."
                        )
                        total_duration = None

                    if total_duration is not None:
                        total_duration += item["duration"]

        logging.info(f"Loaded dataset with {len(data)} files.")
        if total_duration is not None:
            logging.info(
                f"Dataset contains {total_duration / 3600:.2f} hours.")

        self.data = TTSDataset.filter_files(data, ignore_file, min_duration,
                                            max_duration, total_duration)
        self.base_data_dir = get_base_dir(
            [item["audio_filepath"] for item in self.data])

        # Initialize audio and mel related parameters
        self.sample_rate = sample_rate
        self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate)
        self.trim = trim

        self.n_fft = n_fft
        self.n_mels = n_mels
        self.lowfreq = lowfreq
        self.highfreq = highfreq
        self.window = window
        self.win_length = win_length or self.n_fft
        self.hop_length = hop_length
        self.hop_len = self.hop_length or self.n_fft // 4
        self.fb = torch.tensor(
            librosa.filters.mel(self.sample_rate,
                                self.n_fft,
                                n_mels=self.n_mels,
                                fmin=self.lowfreq,
                                fmax=self.highfreq),
            dtype=torch.float,
        ).unsqueeze(0)

        window_fn = {
            'hann': torch.hann_window,
            'hamming': torch.hamming_window,
            'blackman': torch.blackman_window,
            'bartlett': torch.bartlett_window,
            'none': None,
        }.get(self.window, None)

        self.stft = lambda x: torch.stft(
            input=x,
            n_fft=self.n_fft,
            hop_length=self.hop_len,
            win_length=self.win_length,
            window=window_fn(self.win_length, periodic=False).to(torch.float)
            if window_fn else None,
        )

        # Initialize sup_data_path, sup_data_types and run preprocessing methods for every supplementary data type
        if sup_data_path is not None:
            Path(sup_data_path).mkdir(parents=True, exist_ok=True)
            self.sup_data_path = sup_data_path

        self.sup_data_types = ([
            DATA_STR2DATA_CLASS[d_as_str] for d_as_str in sup_data_types
        ] if sup_data_types is not None else [])
        self.sup_data_types_set = set(self.sup_data_types)

        for data_type in self.sup_data_types:
            if data_type not in VALID_SUPPLEMENTARY_DATA_TYPES:
                raise NotImplementedError(
                    f"Current implementation doesn't support {data_type} type."
                )

            getattr(self, f"add_{data_type.name}")(**kwargs)

    @staticmethod
    def filter_files(data, ignore_file, min_duration, max_duration,
                     total_duration):
        if ignore_file:
            logging.info(f"Using {ignore_file} to prune dataset.")
            with open(Path(ignore_file).expanduser(), "rb") as f:
                wavs_to_ignore = set(pickle.load(f))

        filtered_data: List[Dict] = []
        pruned_duration = 0 if total_duration is not None else None
        pruned_items = 0
        for item in data:
            audio_path = item['audio_filepath']

            # Prune data according to min/max_duration & the ignore file
            if total_duration is not None:
                if (min_duration and item["duration"] < min_duration) or (
                        max_duration and item["duration"] > max_duration):
                    pruned_duration += item["duration"]
                    pruned_items += 1
                    continue

            if ignore_file and (audio_path in wavs_to_ignore):
                pruned_items += 1
                pruned_duration += item["duration"]
                wavs_to_ignore.remove(audio_path)
                continue

            filtered_data.append(item)

        logging.info(
            f"Pruned {pruned_items} files. Final dataset contains {len(filtered_data)} files"
        )
        if pruned_duration is not None:
            logging.info(
                f"Pruned {pruned_duration / 3600:.2f} hours. Final dataset contains "
                f"{(total_duration - pruned_duration) / 3600:.2f} hours.")

        return filtered_data

    def add_log_mel(self, **kwargs):
        self.log_mel_folder = kwargs.pop('log_mel_folder', None)

        if self.log_mel_folder is None:
            self.log_mel_folder = Path(self.sup_data_path) / LogMel.name

        self.log_mel_folder.mkdir(exist_ok=True, parents=True)

    def add_durations(self, **kwargs):
        durs_file = kwargs.pop('durs_file')
        durs_type = kwargs.pop('durs_type')

        audio_stem2durs = torch.load(durs_file)
        self.durs = []

        for tag in [Path(d["audio_filepath"]).stem for d in self.data]:
            durs = audio_stem2durs[tag]
            if durs_type == "aligner-based":
                self.durs.append(durs)
            else:
                raise NotImplementedError(
                    f"{durs_type} duration type is not supported. Only aligner-based is supported at this moment."
                )

    def add_align_prior_matrix(self, **kwargs):
        self.align_prior_matrix_folder = kwargs.pop(
            'align_prior_matrix_folder', None)

        if self.align_prior_matrix_folder is None:
            self.align_prior_matrix_folder = Path(
                self.sup_data_path) / AlignPriorMatrix.name

        self.align_prior_matrix_folder.mkdir(exist_ok=True, parents=True)

        self.use_beta_binomial_interpolator = kwargs.pop(
            'use_beta_binomial_interpolator', False)

        if self.use_beta_binomial_interpolator:
            self.beta_binomial_interpolator = BetaBinomialInterpolator()

    def add_pitch(self, **kwargs):
        self.pitch_folder = kwargs.pop('pitch_folder', None)

        if self.pitch_folder is None:
            self.pitch_folder = Path(self.sup_data_path) / Pitch.name

        self.pitch_folder.mkdir(exist_ok=True, parents=True)

        self.pitch_fmin = kwargs.pop("pitch_fmin", librosa.note_to_hz('C2'))
        self.pitch_fmax = kwargs.pop("pitch_fmax", librosa.note_to_hz('C7'))
        self.pitch_mean = kwargs.pop("pitch_mean", None)
        self.pitch_std = kwargs.pop("pitch_std", None)
        self.pitch_norm = kwargs.pop("pitch_norm", False)

    def add_energy(self, **kwargs):
        self.energy_folder = kwargs.pop('energy_folder', None)

        if self.energy_folder is None:
            self.energy_folder = Path(self.sup_data_path) / Energy.name

        self.energy_folder.mkdir(exist_ok=True, parents=True)

    def add_speaker_id(self, **kwargs):
        pass

    def get_spec(self, audio):
        with torch.cuda.amp.autocast(enabled=False):
            spec = self.stft(audio)
            if spec.dtype in [torch.cfloat, torch.cdouble]:
                spec = torch.view_as_real(spec)
            spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9)
        return spec

    def get_log_mel(self, audio):
        with torch.cuda.amp.autocast(enabled=False):
            spec = self.get_spec(audio)
            mel = torch.matmul(self.fb.to(spec.dtype), spec)
            log_mel = torch.log(
                torch.clamp(mel, min=torch.finfo(mel.dtype).tiny))
        return log_mel

    def __getitem__(self, index):
        sample = self.data[index]

        # Let's keep audio name and all internal directories in rel_audio_path_as_text_id to avoid any collisions
        rel_audio_path = Path(sample["audio_filepath"]).relative_to(
            self.base_data_dir).with_suffix("")
        rel_audio_path_as_text_id = str(rel_audio_path).replace("/", "_")

        # Load audio
        features = self.featurizer.process(sample["audio_filepath"],
                                           trim=self.trim)
        audio, audio_length = features, torch.tensor(features.shape[0]).long()

        # Load text
        text = torch.tensor(sample["text_tokens"]).long()
        text_length = torch.tensor(len(sample["text_tokens"])).long()

        # Load mel if needed
        log_mel, log_mel_length = None, None
        if LogMel in self.sup_data_types_set:
            mel_path = sample["mel_filepath"]

            if mel_path is not None and Path(mel_path).exists():
                log_mel = torch.load(mel_path)
            else:
                mel_path = self.log_mel_folder / f"{rel_audio_path_as_text_id}.pt"

                if mel_path.exists():
                    log_mel = torch.load(mel_path)
                else:
                    log_mel = self.get_log_mel(audio)
                    torch.save(log_mel, mel_path)

            log_mel = log_mel.squeeze(0)
            log_mel_length = torch.tensor(log_mel.shape[1]).long()

        # Load durations if needed
        durations = None
        if Durations in self.sup_data_types_set:
            durations = self.durs[index]

        # Load alignment prior matrix if needed
        align_prior_matrix = None
        if AlignPriorMatrix in self.sup_data_types_set:
            if self.use_beta_binomial_interpolator:
                mel_len = self.get_log_mel(audio).shape[2]
                align_prior_matrix = torch.from_numpy(
                    self.beta_binomial_interpolator(mel_len,
                                                    text_length.item()))
            else:
                prior_path = self.align_prior_matrix_folder / f"{rel_audio_path_as_text_id}.pt"

                if prior_path.exists():
                    align_prior_matrix = torch.load(prior_path)
                else:
                    mel_len = self.get_log_mel(audio).shape[2]
                    align_prior_matrix = beta_binomial_prior_distribution(
                        text_length, mel_len)
                    align_prior_matrix = torch.from_numpy(align_prior_matrix)
                    torch.save(align_prior_matrix, prior_path)

        # Load pitch if needed
        pitch, pitch_length = None, None
        if Pitch in self.sup_data_types_set:
            pitch_path = self.pitch_folder / f"{rel_audio_path_as_text_id}.pt"

            if pitch_path.exists():
                pitch = torch.load(pitch_path).float()
            else:
                pitch, _, _ = librosa.pyin(
                    audio.numpy(),
                    fmin=self.pitch_fmin,
                    fmax=self.pitch_fmax,
                    frame_length=self.win_length,
                    sr=self.sample_rate,
                    fill_na=0.0,
                )
                pitch = torch.from_numpy(pitch).float()
                torch.save(pitch, pitch_path)

            if self.pitch_mean is not None and self.pitch_std is not None and self.pitch_norm:
                pitch -= self.pitch_mean
                pitch[
                    pitch == -self.
                    pitch_mean] = 0.0  # Zero out values that were perviously zero
                pitch /= self.pitch_std

            pitch_length = torch.tensor(len(pitch)).long()

        # Load energy if needed
        energy, energy_length = None, None
        if Energy in self.sup_data_types_set:
            energy_path = self.energy_folder / f"{rel_audio_path_as_text_id}.pt"

            if energy_path.exists():
                energy = torch.load(energy_path).float()
            else:
                spec = self.get_spec(audio)
                energy = torch.linalg.norm(spec.squeeze(0), axis=0).float()
                torch.save(energy, energy_path)

            energy_length = torch.tensor(len(energy)).long()

        # Load speaker id if needed
        speaker_id = None
        if SpeakerID in self.sup_data_types_set:
            speaker_id = torch.tensor(sample["speaker_id"]).long()

        return (
            audio,
            audio_length,
            text,
            text_length,
            log_mel,
            log_mel_length,
            durations,
            align_prior_matrix,
            pitch,
            pitch_length,
            energy,
            energy_length,
            speaker_id,
        )

    def __len__(self):
        return len(self.data)

    def join_data(self, data_dict):
        result = []
        for data_type in MAIN_DATA_TYPES + self.sup_data_types:
            result.append(data_dict[data_type.name])

            if issubclass(data_type, TTSDataType) and issubclass(
                    data_type, WithLens):
                result.append(data_dict[f"{data_type.name}_lens"])

        return tuple(result)

    def general_collate_fn(self, batch):
        (
            _,
            audio_lengths,
            _,
            tokens_lengths,
            _,
            log_mel_lengths,
            durations_list,
            align_prior_matrices_list,
            pitches,
            pitches_lengths,
            energies,
            energies_lengths,
            _,
        ) = zip(*batch)

        max_audio_len = max(audio_lengths).item()
        max_tokens_len = max(tokens_lengths).item()
        max_log_mel_len = max(
            log_mel_lengths) if LogMel in self.sup_data_types_set else None
        max_durations_len = max([
            len(i) for i in durations_list
        ]) if Durations in self.sup_data_types_set else None
        max_pitches_len = max(pitches_lengths).item(
        ) if Pitch in self.sup_data_types_set else None
        max_energies_len = max(energies_lengths).item(
        ) if Energy in self.sup_data_types_set else None

        if LogMel in self.sup_data_types_set:
            log_mel_pad = torch.finfo(batch[0][2].dtype).tiny

        align_prior_matrices = (torch.zeros(
            len(align_prior_matrices_list),
            max([prior_i.shape[0] for prior_i in align_prior_matrices_list]),
            max([prior_i.shape[1] for prior_i in align_prior_matrices_list]),
        ) if AlignPriorMatrix in self.sup_data_types_set else [])
        audios, tokens, log_mels, durations_list, pitches, energies, speaker_ids = [], [], [], [], [], [], []

        for i, sample_tuple in enumerate(batch):
            (
                audio,
                audio_len,
                token,
                token_len,
                log_mel,
                log_mel_len,
                durations,
                align_prior_matrix,
                pitch,
                pitch_length,
                energy,
                energy_length,
                speaker_id,
            ) = sample_tuple

            audio = general_padding(audio, audio_len.item(), max_audio_len)
            audios.append(audio)

            token = general_padding(token,
                                    token_len.item(),
                                    max_tokens_len,
                                    pad_value=self.text_tokenizer_pad_id)
            tokens.append(token)

            if LogMel in self.sup_data_types_set:
                log_mels.append(
                    general_padding(log_mel,
                                    log_mel_len,
                                    max_log_mel_len,
                                    pad_value=log_mel_pad))
            if Durations in self.sup_data_types_set:
                durations_list.append(
                    general_padding(durations, len(durations),
                                    max_durations_len))
            if AlignPriorMatrix in self.sup_data_types_set:
                align_prior_matrices[
                    i, :align_prior_matrix.shape[0], :align_prior_matrix.
                    shape[1]] = align_prior_matrix
            if Pitch in self.sup_data_types_set:
                pitches.append(
                    general_padding(pitch, pitch_length.item(),
                                    max_pitches_len))
            if Energy in self.sup_data_types_set:
                energies.append(
                    general_padding(energy, energy_length.item(),
                                    max_energies_len))
            if SpeakerID in self.sup_data_types_set:
                speaker_ids.append(speaker_id)

        data_dict = {
            "audio":
            torch.stack(audios),
            "audio_lens":
            torch.stack(audio_lengths),
            "text":
            torch.stack(tokens),
            "text_lens":
            torch.stack(tokens_lengths),
            "log_mel":
            torch.stack(log_mels)
            if LogMel in self.sup_data_types_set else None,
            "log_mel_lens":
            torch.stack(log_mel_lengths)
            if LogMel in self.sup_data_types_set else None,
            "durations":
            torch.stack(durations_list)
            if Durations in self.sup_data_types_set else None,
            "align_prior_matrix":
            align_prior_matrices
            if AlignPriorMatrix in self.sup_data_types_set else None,
            "pitch":
            torch.stack(pitches) if Pitch in self.sup_data_types_set else None,
            "pitch_lens":
            torch.stack(pitches_lengths)
            if Pitch in self.sup_data_types_set else None,
            "energy":
            torch.stack(energies)
            if Energy in self.sup_data_types_set else None,
            "energy_lens":
            torch.stack(energies_lengths)
            if Energy in self.sup_data_types_set else None,
            "speaker_id":
            torch.stack(speaker_ids)
            if SpeakerID in self.sup_data_types_set else None,
        }

        return data_dict

    def _collate_fn(self, batch):
        data_dict = self.general_collate_fn(batch)
        joined_data = self.join_data(data_dict)
        return joined_data
Пример #9
0
    def __init__(
        self,
        manifest_filepath: Union[str, Path, List[str], List[Path]],
        sample_rate: int,
        text_tokenizer: Union[BaseTokenizer, Callable[[str], List[int]]],
        tokens: Optional[List[str]] = None,
        text_normalizer: Optional[Union[Normalizer, Callable[[str],
                                                             str]]] = None,
        text_normalizer_call_kwargs: Optional[Dict] = None,
        text_tokenizer_pad_id: Optional[int] = None,
        sup_data_types: Optional[List[str]] = None,
        sup_data_path: Optional[Union[Path, str]] = None,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        ignore_file: Optional[Union[str, Path]] = None,
        trim: bool = False,
        n_fft: int = 1024,
        win_length: Optional[int] = None,
        hop_length: Optional[int] = None,
        window: str = "hann",
        n_mels: int = 80,
        lowfreq: int = 0,
        highfreq: Optional[int] = None,
        **kwargs,
    ):
        """Dataset which can be used for training spectrogram generators and end-to-end TTS models.
        It loads main data types (audio, text) and specified supplementary data types (log mel, durations, align prior matrix, pitch, energy, speaker id).
        Some of supplementary data types will be computed on the fly and saved in the sup_data_path if they did not exist before.
        Saved folder can be changed for some supplementary data types (see keyword args section).
        Arguments for supplementary data should be also specified in this class and they will be used from kwargs (see keyword args section).
        Args:
            manifest_filepath (Union[str, Path, List[str], List[Path]]): Path(s) to the .json manifests containing information on the
                dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
                json. Each line should contain the following:
                    "audio_filepath": <PATH_TO_WAV>,
                    "text": <THE_TRANSCRIPT>,
                    "normalized_text": <NORMALIZED_TRANSCRIPT> (Optional),
                    "mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional),
                    "duration": <Duration of audio clip in seconds> (Optional)
            sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
            text_tokenizer (Optional[Union[BaseTokenizer, Callable[[str], List[int]]]]): BaseTokenizer or callable which represents text tokenizer.
            tokens (Optional[List[str]]): Tokens from text_tokenizer. Should be specified if text_tokenizer is not BaseTokenizer.
            text_normalizer (Optional[Union[Normalizer, Callable[[str], str]]]): Normalizer or callable which represents text normalizer.
            text_normalizer_call_kwargs (Optional[Dict]): Additional arguments for text_normalizer function.
            text_tokenizer_pad_id (Optional[int]): Index of padding. Should be specified if text_tokenizer is not BaseTokenizer.
            sup_data_types (Optional[List[str]]): List of supplementary data types.
            sup_data_path (Optional[Union[Path, str]]): A folder that contains or will contain supplementary data (e.g. pitch).
            max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            ignore_file (Optional[Union[str, Path]]): The location of a pickle-saved list of audio paths
                that will be pruned prior to training. Defaults to None which does not prune.
            trim (Optional[bool]): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
            n_fft (int): The number of fft samples. Defaults to 1024
            win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft.
            hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4.
            window (str): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the
                equivalent torch window function.
            n_mels (int): The number of mel filters. Defaults to 80.
            lowfreq (int): The lowfreq input to the mel filter calculation. Defaults to 0.
            highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None.
        Keyword Args:
            log_mel_folder (Optional[Union[Path, str]]): The folder that contains or will contain log mel spectrograms.
            align_prior_matrix_folder (Optional[Union[Path, str]]): The folder that contains or will contain align prior matrices.
            pitch_folder (Optional[Union[Path, str]]): The folder that contains or will contain pitch.
            energy_folder (Optional[Union[Path, str]]): The folder that contains or will contain energy.
            durs_file (Optional[str]): String path to pickled durations location.
            durs_type (Optional[str]): Type of durations. Currently supported only "aligner-based".
            use_beta_binomial_interpolator (Optional[bool]): Whether to use beta-binomial interpolator for calculating alignment prior matrix. Defaults to False.
            pitch_fmin (Optional[float]): The fmin input to librosa.pyin. Defaults to librosa.note_to_hz('C2').
            pitch_fmax (Optional[float]): The fmax input to librosa.pyin. Defaults to librosa.note_to_hz('C7').
            pitch_mean (Optional[float]): The mean that we use to normalize the pitch.
            pitch_std (Optional[float]): The std that we use to normalize the pitch.
            pitch_norm (Optional[bool]): Whether to normalize pitch (via pitch_mean and pitch_std) or not.
        """
        super().__init__()

        # Initialize text tokenizer
        self.text_tokenizer = text_tokenizer
        if isinstance(self.text_tokenizer, BaseTokenizer):
            self.text_tokenizer_pad_id = text_tokenizer.pad
            self.tokens = text_tokenizer.tokens
        else:
            if text_tokenizer_pad_id is None:
                raise ValueError(
                    f"text_tokenizer_pad_id must be specified if text_tokenizer is not BaseTokenizer"
                )

            if tokens is None:
                raise ValueError(
                    f"tokens must be specified if text_tokenizer is not BaseTokenizer"
                )

            self.text_tokenizer_pad_id = text_tokenizer_pad_id
            self.tokens = tokens

        # Initialize text normalizer is specified
        self.text_normalizer = text_normalizer
        self.text_normalizer_call = (
            self.text_normalizer.normalize if isinstance(
                self.text_normalizer, Normalizer) else self.text_normalizer)
        self.text_normalizer_call_kwargs = (text_normalizer_call_kwargs
                                            if text_normalizer_call_kwargs
                                            is not None else {})

        # Initialize and read manifest file(s), filter out data by duration and ignore_file, compute base dir
        if isinstance(manifest_filepath, str):
            manifest_filepath = [manifest_filepath]
        self.manifest_filepath = manifest_filepath

        data = []
        total_duration = 0
        for manifest_file in self.manifest_filepath:
            with open(Path(manifest_file).expanduser(), 'r') as f:
                logging.info(f"Loading dataset from {manifest_file}.")
                for line in tqdm(f):
                    item = json.loads(line)

                    file_info = {
                        "audio_filepath":
                        item["audio_filepath"],
                        "original_text":
                        item["text"],
                        "mel_filepath":
                        item["mel_filepath"]
                        if "mel_filepath" in item else None,
                        "duration":
                        item["duration"] if "duration" in item else None,
                        "speaker_id":
                        item["speaker"] if "speaker" in item else None,
                    }

                    if "normalized_text" not in item:
                        text = item["text"]

                        if self.text_normalizer is not None:
                            text = self.text_normalizer_call(
                                text, **self.text_normalizer_call_kwargs)

                        file_info["normalized_text"] = text
                        file_info["text_tokens"] = self.text_tokenizer(text)
                    else:
                        file_info["normalized_text"] = item["normalized_text"]
                        file_info["text_tokens"] = self.text_tokenizer(
                            item["normalized_text"])

                    data.append(file_info)

                    if file_info["duration"] is None:
                        logging.info(
                            "Not all audio files have duration information. Duration logging will be disabled."
                        )
                        total_duration = None

                    if total_duration is not None:
                        total_duration += item["duration"]

        logging.info(f"Loaded dataset with {len(data)} files.")
        if total_duration is not None:
            logging.info(
                f"Dataset contains {total_duration / 3600:.2f} hours.")

        self.data = TTSDataset.filter_files(data, ignore_file, min_duration,
                                            max_duration, total_duration)
        self.base_data_dir = get_base_dir(
            [item["audio_filepath"] for item in self.data])

        # Initialize audio and mel related parameters
        self.sample_rate = sample_rate
        self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate)
        self.trim = trim

        self.n_fft = n_fft
        self.n_mels = n_mels
        self.lowfreq = lowfreq
        self.highfreq = highfreq
        self.window = window
        self.win_length = win_length or self.n_fft
        self.hop_length = hop_length
        self.hop_len = self.hop_length or self.n_fft // 4
        self.fb = torch.tensor(
            librosa.filters.mel(self.sample_rate,
                                self.n_fft,
                                n_mels=self.n_mels,
                                fmin=self.lowfreq,
                                fmax=self.highfreq),
            dtype=torch.float,
        ).unsqueeze(0)

        window_fn = {
            'hann': torch.hann_window,
            'hamming': torch.hamming_window,
            'blackman': torch.blackman_window,
            'bartlett': torch.bartlett_window,
            'none': None,
        }.get(self.window, None)

        self.stft = lambda x: torch.stft(
            input=x,
            n_fft=self.n_fft,
            hop_length=self.hop_len,
            win_length=self.win_length,
            window=window_fn(self.win_length, periodic=False).to(torch.float)
            if window_fn else None,
        )

        # Initialize sup_data_path, sup_data_types and run preprocessing methods for every supplementary data type
        if sup_data_path is not None:
            Path(sup_data_path).mkdir(parents=True, exist_ok=True)
            self.sup_data_path = sup_data_path

        self.sup_data_types = ([
            DATA_STR2DATA_CLASS[d_as_str] for d_as_str in sup_data_types
        ] if sup_data_types is not None else [])
        self.sup_data_types_set = set(self.sup_data_types)

        for data_type in self.sup_data_types:
            if data_type not in VALID_SUPPLEMENTARY_DATA_TYPES:
                raise NotImplementedError(
                    f"Current implementation doesn't support {data_type} type."
                )

            getattr(self, f"add_{data_type.name}")(**kwargs)
Пример #10
0
    def __init__(
        self,
        manifest_filepath: str,
        mappings_filepath: str,
        sample_rate: int,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        ignore_file: Optional[str] = None,
        trim: bool = False,
        load_supplementary_values=True,  # Set to False for validation
    ):
        """
        Dataset that loads audio, phonemes and their durations, pitches per frame, and energies per frame
        for FastSpeech 2 from paths described in a JSON manifest (see the AudioDataset documentation for details
        on the manifest format), as well as a mappings file for word to phones and phones to indices.
        The text in the manifest is ignored; instead, the phoneme indices for prediction come from the
        duration files.

        For each sample, paths for duration, energy, and pitch files are inferred from the manifest's audio
        filepaths by replacing '/wavs' with '/phoneme_durations', '/pitches', and '/energies', and swapping out
        the file extension to '.pt', '.npy', and '.npy' respectively.
        For example, given manifest audio path `/data/LJSpeech/wavs/LJ001-0001.wav`, the inferred duration and
        phonemes file path would be `/data/LJSpeech/phoneme_durations/LJ001-0001.pt`.

        Note that validation datasets only need the audio files and phoneme & duration files, set
        `load_supplementary_values` to False for validation sets.

        Args:
            manifest_filepath (str): Path to the JSON manifest file that lists audio files.
            mappings_filepath (str): Path to a JSON mappings file that contains mappings "word2phones" and
                "phone2idx". The latter is used to determine the padding index.
            sample_rate (int): Target sample rate of the audio.
            max_duration (float): If audio exceeds this length in seconds, it is filtered from the dataset.
                Defaults to None, which does not filter any audio.
            min_duration (float): If audio is shorter than this length in seconds, it is filtered from the dataset.
                Defaults to None, which does not filter any audio.
            ignore_file (str): Optional pickled file which contains a list of files to ignore (e.g. files that
                contain OOV words).
                Defaults to None.
            trim (bool): Whether to use librosa.effects.trim on the audio clip.
                Defaults to False.
            load_supplementary_values (bool): Whether or not to load pitch and energy files. Set this to False for
                validation datasets.
                Defaults to True.
        """
        super().__init__()

        # Retrieve mappings from file
        with open(mappings_filepath, 'r') as f:
            mappings = json.load(f)
            self.word2phones = mappings['word2phones']
            self.phone2idx = mappings['phone2idx']

        # Load data from manifests
        audio_files = []
        total_duration = 0
        if isinstance(manifest_filepath, str):
            manifest_filepath = [manifest_filepath]
        for manifest_file in manifest_filepath:
            with open(expanduser(manifest_file), 'r') as f:
                logging.info(f"Loading dataset from {manifest_file}.")
                for line in f:
                    item = json.loads(line)
                    audio_files.append({"audio_filepath": item["audio_filepath"], "duration": item["duration"]})
                    total_duration += item["duration"]

        total_dataset_len = len(audio_files)
        logging.info(f"Loaded dataset with {total_dataset_len} files totalling {total_duration/3600:.2f} hours.")
        self.data = []
        if load_supplementary_values:
            dataitem = py_collections.namedtuple(
                typename='AudioTextEntity', field_names='audio_file duration text_tokens pitches energies'
            )
        else:
            dataitem = py_collections.namedtuple(
                typename='AudioTextEntity', field_names='audio_file duration text_tokens'
            )

        if ignore_file:
            logging.info(f"using {ignore_file} to prune dataset.")
            with open(ignore_file, "rb") as f:
                wavs_to_ignore = set(pickle.load(f))

        pruned_duration = 0
        pruned_items = 0
        for item in audio_files:
            audio_path = item['audio_filepath']
            LJ_id = os.path.splitext(os.path.basename(audio_path))[0]

            # Prune data according to min/max_duration & the ignore file
            if (min_duration and item["duration"] < min_duration) or (
                max_duration and item["duration"] > max_duration
            ):
                pruned_duration += item["duration"]
                pruned_items += 1
                continue
            if ignore_file and (LJ_id in wavs_to_ignore):
                pruned_items += 1
                pruned_duration += item["duration"]
                wavs_to_ignore.remove(LJ_id)
                continue

            # Else not pruned, load additional info

            # Phoneme durations and text token indices from durations file
            dur_path = audio_path.replace('/wavs/', '/phoneme_durations/').replace('.wav', '.pt')
            duration_info = torch.load(dur_path)
            durs = duration_info['token_duration']
            text_tokens = duration_info['text_encoded']

            if load_supplementary_values:
                # Load pitch file (F0s)
                pitch_path = audio_path.replace('/wavs/', '/pitches/').replace('.wav', '.npy')
                pitches = torch.from_numpy(np.load(pitch_path).astype(dtype='float32'))

                # Load energy file (L2-norm of the amplitude of each STFT frame of an utterance)
                energies_path = audio_path.replace('/wavs/', '/energies/').replace('.wav', '.npy')
                energies = torch.from_numpy(np.load(energies_path))

                self.data.append(
                    dataitem(
                        audio_file=item['audio_filepath'],
                        duration=durs,
                        pitches=torch.clamp(pitches, min=1e-5),
                        energies=energies,
                        text_tokens=text_tokens,
                    )
                )
            else:
                self.data.append(dataitem(audio_file=item['audio_filepath'], duration=durs, text_tokens=text_tokens,))

        logging.info(f"Pruned {pruned_items} files and {pruned_duration/3600:.2f} hours.")
        logging.info(
            f"Final dataset contains {len(self.data)} files and {(total_duration-pruned_duration)/3600:.2f} hours."
        )

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate)
        self.trim = trim
        self.load_supplementary_values = load_supplementary_values
Пример #11
0
    def _setup_dataloader_from_config(self, config: DictConfig):

        OmegaConf.set_struct(config, False)
        config.is_regression_task = self.is_regression_task
        OmegaConf.set_struct(config, True)

        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        featurizer = WaveformFeaturizer(
            sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=augmentor
        )
        shuffle = config['shuffle']

        # Instantiate tarred dataset loader or normal dataset loader
        if config.get('is_tarred', False):
            if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or (
                'manifest_filepath' in config and config['manifest_filepath'] is None
            ):
                logging.warning(
                    "Could not load dataset as `manifest_filepath` is None or "
                    f"`tarred_audio_filepaths` is None. Provided config : {config}"
                )
                return None

            if 'vad_stream' in config and config['vad_stream']:
                logging.warning("VAD inference does not support tarred dataset now")
                return None

            shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
            dataset = audio_to_label_dataset.get_tarred_classification_label_dataset(
                featurizer=featurizer,
                config=OmegaConf.to_container(config),
                shuffle_n=shuffle_n,
                global_rank=self.global_rank,
                world_size=self.world_size,
            )
            shuffle = False
            batch_size = config['batch_size']
            collate_func = dataset.collate_fn

        else:
            if 'manifest_filepath' in config and config['manifest_filepath'] is None:
                logging.warning(f"Could not load dataset as `manifest_filepath` is None. Provided config : {config}")
                return None

            if 'vad_stream' in config and config['vad_stream']:
                logging.info("Perform streaming frame-level VAD")
                dataset = audio_to_label_dataset.get_speech_label_dataset(
                    featurizer=featurizer, config=OmegaConf.to_container(config)
                )
                batch_size = 1
                collate_func = dataset.vad_frame_seq_collate_fn
            else:
                dataset = audio_to_label_dataset.get_classification_label_dataset(
                    featurizer=featurizer, config=OmegaConf.to_container(config)
                )
                batch_size = config['batch_size']
                collate_func = dataset.collate_fn

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            collate_fn=collate_func,
            drop_last=config.get('drop_last', False),
            shuffle=shuffle,
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )
Пример #12
0
    def __init__(
        self,
        manifest_filepath: str,
        sample_rate: int,
        text_tokenizer: Union[BaseTokenizer, Callable[[str], List[int]]],
        tokens: Optional[List[str]] = None,
        text_normalizer: Optional[Union[Normalizer, Callable[[str],
                                                             str]]] = None,
        text_normalizer_call_args: Optional[Dict] = None,
        text_tokenizer_pad_id: Optional[int] = None,
        sup_data_types: Optional[List[str]] = None,
        sup_data_path: Optional[Union[Path, str]] = None,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        ignore_file: Optional[str] = None,
        trim: bool = False,
        n_fft=1024,
        win_length=None,
        hop_length=None,
        window="hann",
        n_mels=80,
        lowfreq=0,
        highfreq=None,
        **kwargs,
    ):
        """Dataset that loads main data types (audio and text) and specified supplementary data types (e.g. log mel, durations, pitch).
        Most supplementary data types will be computed on the fly and saved in the supplementary_folder if they did not exist before.
        Arguments for supplementary data should be also specified in this class and they will be used from kwargs (see keyword args section).
        Args:
            manifest_filepath (str, Path, List[str, Path]): Path(s) to the .json manifests containing information on the
                dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
                json. Each line should contain the following:
                    "audio_filepath": <PATH_TO_WAV>
                    "mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional)
                    "duration": <Duration of audio clip in seconds> (Optional)
                    "text": <THE_TRANSCRIPT> (Optional)
            sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
            text_tokenizer (Optional[Union[BaseTokenizer, Callable[[str], List[int]]]]): BaseTokenizer or callable which represents text tokenizer.
            tokens (Optional[List[str]]): Tokens from text_tokenizer. Should be specified if text_tokenizer is not BaseTokenizer.
            text_normalizer (Optional[Union[Normalizer, Callable[[str], str]]]): Normalizer or callable which represents text normalizer.
            text_normalizer_call_args (Optional[Dict]): Additional arguments for text_normalizer function.
            text_tokenizer_pad_id (Optional[int]): Index of padding. Should be specified if text_tokenizer is not BaseTokenizer.
            sup_data_types (Optional[List[str]]): List of supplementary data types.
            sup_data_path (Optional[Union[Path, str]]): A folder that contains or will contain supplementary data (e.g. pitch).
            max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            ignore_file (Optional[str, Path]): The location of a pickle-saved list of audio_ids (the stem of the audio
                files) that will be pruned prior to training. Defaults to None which does not prune.
            trim (Optional[bool]): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
            n_fft (Optional[int]): The number of fft samples. Defaults to 1024
            win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft.
            hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4.
            window (Optional[str]): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the
                equivalent torch window function.
            n_mels (Optional[int]): The number of mel filters. Defaults to 80.
            lowfreq (Optional[int]): The lowfreq input to the mel filter calculation. Defaults to 0.
            highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None.
        Keyword Args:
            durs_file (Optional[str]): String path to pickled durations location.
            durs_type (Optional[str]): Type of durations. Currently supported only "aligned-based".
            use_beta_binomial_interpolator (Optional[bool]): Whether to use beta-binomial interpolator. Defaults to False.
            pitch_fmin (Optional[float]): The fmin input to librosa.pyin. Defaults to librosa.note_to_hz('C2').
            pitch_fmax (Optional[float]): The fmax input to librosa.pyin. Defaults to librosa.note_to_hz('C7').
            pitch_avg (Optional[float]): The mean that we use to normalize the pitch.
            pitch_std (Optional[float]): The std that we use to normalize the pitch.
            pitch_norm (Optional[bool]): Whether to normalize pitch (via pitch_avg and pitch_std) or not.
        """
        super().__init__()

        self.text_normalizer = text_normalizer
        self.text_normalizer_call = (
            self.text_normalizer.normalize if isinstance(
                self.text_normalizer, Normalizer) else self.text_normalizer)
        self.text_normalizer_call_args = text_normalizer_call_args if text_normalizer_call_args is not None else {}

        self.text_tokenizer = text_tokenizer

        if isinstance(self.text_tokenizer, BaseTokenizer):
            self.text_tokenizer_pad_id = text_tokenizer.pad
            self.tokens = text_tokenizer.tokens
        else:
            if text_tokenizer_pad_id is None:
                raise ValueError(
                    f"text_tokenizer_pad_id must be specified if text_tokenizer is not BaseTokenizer"
                )

            if tokens is None:
                raise ValueError(
                    f"tokens must be specified if text_tokenizer is not BaseTokenizer"
                )

            self.text_tokenizer_pad_id = text_tokenizer_pad_id
            self.tokens = tokens

        if isinstance(manifest_filepath, str):
            manifest_filepath = [manifest_filepath]
        self.manifest_filepath = manifest_filepath

        if sup_data_path is not None:
            Path(sup_data_path).mkdir(parents=True, exist_ok=True)
            self.sup_data_path = sup_data_path

        self.sup_data_types = ([
            DATA_STR2DATA_CLASS[d_as_str] for d_as_str in sup_data_types
        ] if sup_data_types is not None else [])
        self.sup_data_types_set = set(self.sup_data_types)

        self.data = []
        audio_files = []
        total_duration = 0
        for manifest_file in self.manifest_filepath:
            with open(Path(manifest_file).expanduser(), 'r') as f:
                logging.info(f"Loading dataset from {manifest_file}.")
                for line in tqdm(f):
                    item = json.loads(line)

                    file_info = {
                        "audio_filepath":
                        item["audio_filepath"],
                        "mel_filepath":
                        item["mel_filepath"]
                        if "mel_filepath" in item else None,
                        "duration":
                        item["duration"] if "duration" in item else None,
                        "text_tokens":
                        None,
                        "speaker_id":
                        item["speaker"] if "speaker" in item else None,
                    }

                    if "text" in item:
                        text = item["text"]

                        if self.text_normalizer is not None:
                            text = self.text_normalizer_call(
                                text, **self.text_normalizer_call_args)

                        text_tokens = self.text_tokenizer(text)
                        file_info["raw_text"] = item["text"]
                        file_info["text_tokens"] = text_tokens

                    audio_files.append(file_info)

                    if file_info["duration"] is None:
                        logging.info(
                            "Not all audio files have duration information. Duration logging will be disabled."
                        )
                        total_duration = None

                    if total_duration is not None:
                        total_duration += item["duration"]

        logging.info(f"Loaded dataset with {len(audio_files)} files.")
        if total_duration is not None:
            logging.info(
                f"Dataset contains {total_duration / 3600:.2f} hours.")

        if ignore_file:
            logging.info(f"using {ignore_file} to prune dataset.")
            with open(Path(ignore_file).expanduser(), "rb") as f:
                wavs_to_ignore = set(pickle.load(f))

        pruned_duration = 0 if total_duration is not None else None
        pruned_items = 0
        for item in audio_files:
            audio_path = item['audio_filepath']
            audio_id = Path(audio_path).stem

            # Prune data according to min/max_duration & the ignore file
            if total_duration is not None:
                if (min_duration and item["duration"] < min_duration) or (
                        max_duration and item["duration"] > max_duration):
                    pruned_duration += item["duration"]
                    pruned_items += 1
                    continue

            if ignore_file and (audio_id in wavs_to_ignore):
                pruned_items += 1
                pruned_duration += item["duration"]
                wavs_to_ignore.remove(audio_id)
                continue

            self.data.append(item)

        logging.info(
            f"Pruned {pruned_items} files. Final dataset contains {len(self.data)} files"
        )
        if pruned_duration is not None:
            logging.info(
                f"Pruned {pruned_duration / 3600:.2f} hours. Final dataset contains "
                f"{(total_duration - pruned_duration) / 3600:.2f} hours.")

        self.sample_rate = sample_rate
        self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate)
        self.trim = trim

        self.n_fft = n_fft
        self.n_mels = n_mels
        self.lowfreq = lowfreq
        self.highfreq = highfreq
        self.window = window
        self.win_length = win_length or self.n_fft
        self.hop_length = hop_length
        self.hop_len = self.hop_length or self.n_fft // 4
        self.fb = torch.tensor(
            librosa.filters.mel(self.sample_rate,
                                self.n_fft,
                                n_mels=self.n_mels,
                                fmin=self.lowfreq,
                                fmax=self.highfreq),
            dtype=torch.float,
        ).unsqueeze(0)

        window_fn = {
            'hann': torch.hann_window,
            'hamming': torch.hamming_window,
            'blackman': torch.blackman_window,
            'bartlett': torch.bartlett_window,
            'none': None,
        }.get(self.window, None)

        self.stft = lambda x: torch.stft(
            input=x,
            n_fft=self.n_fft,
            hop_length=self.hop_len,
            win_length=self.win_length,
            window=window_fn(self.win_length, periodic=False).to(torch.float)
            if window_fn else None,
        )

        for data_type in self.sup_data_types:
            if data_type not in VALID_SUPPLEMENTARY_DATA_TYPES:
                raise NotImplementedError(
                    f"Current implementation of TTSDataset doesn't support {data_type} type."
                )

            getattr(self, f"add_{data_type.name}")(**kwargs)
Пример #13
0
class FastSpeech2Dataset(Dataset):
    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports."""
        return {
            'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
            'a_sig_length': NeuralType(('B'), LengthsType()),
            'transcripts': NeuralType(('B', 'T'), TokenIndex()),
            'transcript_length': NeuralType(('B'), LengthsType()),
            'durations': NeuralType(('B', 'T'), TokenDurationType()),
            'pitches': NeuralType(('B', 'T'), RegressionValuesType()),
            'energies': NeuralType(('B', 'T'), RegressionValuesType()),
        }

    def __init__(
        self,
        manifest_filepath: str,
        mappings_filepath: str,
        sample_rate: int,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        ignore_file: Optional[str] = None,
        trim: bool = False,
        load_supplementary_values=True,  # Set to False for validation
    ):
        """
        Dataset that loads audio, phonemes and their durations, pitches per frame, and energies per frame
        for FastSpeech 2 from paths described in a JSON manifest (see the AudioDataset documentation for details
        on the manifest format), as well as a mappings file for word to phones and phones to indices.
        The text in the manifest is ignored; instead, the phoneme indices for prediction come from the
        duration files.

        For each sample, paths for duration, energy, and pitch files are inferred from the manifest's audio
        filepaths by replacing '/wavs' with '/phoneme_durations', '/pitches', and '/energies', and swapping out
        the file extension to '.pt', '.npy', and '.npy' respectively.
        For example, given manifest audio path `/data/LJSpeech/wavs/LJ001-0001.wav`, the inferred duration and
        phonemes file path would be `/data/LJSpeech/phoneme_durations/LJ001-0001.pt`.

        Note that validation datasets only need the audio files and phoneme & duration files, set
        `load_supplementary_values` to False for validation sets.

        Args:
            manifest_filepath (str): Path to the JSON manifest file that lists audio files.
            mappings_filepath (str): Path to a JSON mappings file that contains mappings "word2phones" and
                "phone2idx". The latter is used to determine the padding index.
            sample_rate (int): Target sample rate of the audio.
            max_duration (float): If audio exceeds this length in seconds, it is filtered from the dataset.
                Defaults to None, which does not filter any audio.
            min_duration (float): If audio is shorter than this length in seconds, it is filtered from the dataset.
                Defaults to None, which does not filter any audio.
            ignore_file (str): Optional pickled file which contains a list of files to ignore (e.g. files that
                contain OOV words).
                Defaults to None.
            trim (bool): Whether to use librosa.effects.trim on the audio clip.
                Defaults to False.
            load_supplementary_values (bool): Whether or not to load pitch and energy files. Set this to False for
                validation datasets.
                Defaults to True.
        """
        super().__init__()

        # Retrieve mappings from file
        with open(mappings_filepath, 'r') as f:
            mappings = json.load(f)
            self.word2phones = mappings['word2phones']
            self.phone2idx = mappings['phone2idx']

        # Load data from manifests
        audio_files = []
        total_duration = 0
        if isinstance(manifest_filepath, str):
            manifest_filepath = [manifest_filepath]
        for manifest_file in manifest_filepath:
            with open(expanduser(manifest_file), 'r') as f:
                logging.info(f"Loading dataset from {manifest_file}.")
                for line in f:
                    item = json.loads(line)
                    audio_files.append({"audio_filepath": item["audio_filepath"], "duration": item["duration"]})
                    total_duration += item["duration"]

        total_dataset_len = len(audio_files)
        logging.info(f"Loaded dataset with {total_dataset_len} files totalling {total_duration/3600:.2f} hours.")
        self.data = []
        if load_supplementary_values:
            dataitem = py_collections.namedtuple(
                typename='AudioTextEntity', field_names='audio_file duration text_tokens pitches energies'
            )
        else:
            dataitem = py_collections.namedtuple(
                typename='AudioTextEntity', field_names='audio_file duration text_tokens'
            )

        if ignore_file:
            logging.info(f"using {ignore_file} to prune dataset.")
            with open(ignore_file, "rb") as f:
                wavs_to_ignore = set(pickle.load(f))

        pruned_duration = 0
        pruned_items = 0
        for item in audio_files:
            audio_path = item['audio_filepath']
            LJ_id = os.path.splitext(os.path.basename(audio_path))[0]

            # Prune data according to min/max_duration & the ignore file
            if (min_duration and item["duration"] < min_duration) or (
                max_duration and item["duration"] > max_duration
            ):
                pruned_duration += item["duration"]
                pruned_items += 1
                continue
            if ignore_file and (LJ_id in wavs_to_ignore):
                pruned_items += 1
                pruned_duration += item["duration"]
                wavs_to_ignore.remove(LJ_id)
                continue

            # Else not pruned, load additional info

            # Phoneme durations and text token indices from durations file
            dur_path = audio_path.replace('/wavs/', '/phoneme_durations/').replace('.wav', '.pt')
            duration_info = torch.load(dur_path)
            durs = duration_info['token_duration']
            text_tokens = duration_info['text_encoded']

            if load_supplementary_values:
                # Load pitch file (F0s)
                pitch_path = audio_path.replace('/wavs/', '/pitches/').replace('.wav', '.npy')
                pitches = torch.from_numpy(np.load(pitch_path).astype(dtype='float32'))

                # Load energy file (L2-norm of the amplitude of each STFT frame of an utterance)
                energies_path = audio_path.replace('/wavs/', '/energies/').replace('.wav', '.npy')
                energies = torch.from_numpy(np.load(energies_path))

                self.data.append(
                    dataitem(
                        audio_file=item['audio_filepath'],
                        duration=durs,
                        pitches=torch.clamp(pitches, min=1e-5),
                        energies=energies,
                        text_tokens=text_tokens,
                    )
                )
            else:
                self.data.append(dataitem(audio_file=item['audio_filepath'], duration=durs, text_tokens=text_tokens,))

        logging.info(f"Pruned {pruned_items} files and {pruned_duration/3600:.2f} hours.")
        logging.info(
            f"Final dataset contains {len(self.data)} files and {(total_duration-pruned_duration)/3600:.2f} hours."
        )

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate)
        self.trim = trim
        self.load_supplementary_values = load_supplementary_values

    def __getitem__(self, index):
        sample = self.data[index]

        features = self.featurizer.process(sample.audio_file, trim=self.trim)
        f, fl = features, torch.tensor(features.shape[0]).long()
        t, tl = sample.text_tokens.long(), torch.tensor(len(sample.text_tokens)).long()

        if self.load_supplementary_values:
            return f, fl, t, tl, sample.duration, sample.pitches, sample.energies
        else:
            return f, fl, t, tl, sample.duration, None, None

    def __len__(self):
        return len(self.data)

    def _collate_fn(self, batch):
        pad_id = len(self.phone2idx)
        if self.load_supplementary_values:
            _, audio_lengths, _, tokens_lengths, duration, pitches, energies = zip(*batch)
        else:
            _, audio_lengths, _, tokens_lengths, duration, _, _ = zip(*batch)
        max_audio_len = 0
        max_audio_len = max(audio_lengths).item()
        max_tokens_len = max(tokens_lengths).item()
        max_durations_len = max([len(i) for i in duration])
        max_duration_sum = max([sum(i) for i in duration])
        if self.load_supplementary_values:
            max_pitches_len = max([len(i) for i in pitches])
            max_energies_len = max([len(i) for i in energies])
            if max_pitches_len != max_energies_len or max_pitches_len != max_duration_sum:
                logging.warning(
                    f"max_pitches_len: {max_pitches_len} != max_energies_len: {max_energies_len} != "
                    f"max_duration_sum:{max_duration_sum}. Your training run will error out!"
                )

        # Add padding where necessary
        audio_signal, tokens, duration_batched, pitches_batched, energies_batched = [], [], [], [], []
        for sample_tuple in batch:
            if self.load_supplementary_values:
                sig, sig_len, tokens_i, tokens_i_len, duration, pitch, energy = sample_tuple
            else:
                sig, sig_len, tokens_i, tokens_i_len, duration, _, _ = sample_tuple
            sig_len = sig_len.item()
            if sig_len < max_audio_len:
                pad = (0, max_audio_len - sig_len)
                sig = torch.nn.functional.pad(sig, pad)
            audio_signal.append(sig)
            tokens_i_len = tokens_i_len.item()
            if tokens_i_len < max_tokens_len:
                pad = (0, max_tokens_len - tokens_i_len)
                tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id)
            tokens.append(tokens_i)
            if len(duration) < max_durations_len:
                pad = (0, max_durations_len - len(duration))
                duration = torch.nn.functional.pad(duration, pad)
            duration_batched.append(duration)

            if self.load_supplementary_values:
                pitch = pitch.squeeze(0)
                if len(pitch) < max_pitches_len:
                    pad = (0, max_pitches_len - len(pitch))
                    pitch = torch.nn.functional.pad(pitch.squeeze(0), pad)
                pitches_batched.append(pitch)

                if len(energy) < max_energies_len:
                    pad = (0, max_energies_len - len(energy))
                    energy = torch.nn.functional.pad(energy, pad)
                energies_batched.append(energy)

        audio_signal = torch.stack(audio_signal)
        audio_lengths = torch.stack(audio_lengths)
        tokens = torch.stack(tokens)
        tokens_lengths = torch.stack(tokens_lengths)
        duration_batched = torch.stack(duration_batched)

        if self.load_supplementary_values:
            pitches_batched = torch.stack(pitches_batched)
            energies_batched = torch.stack(energies_batched)
            assert pitches_batched.shape == energies_batched.shape

            return (
                audio_signal,
                audio_lengths,
                tokens,
                tokens_lengths,
                duration_batched,
                pitches_batched,
                energies_batched,
            )
        return (audio_signal, audio_lengths, tokens, tokens_lengths, duration_batched, None, None)
Пример #14
0
class _AudioTextDataset(Dataset):
    """
    Dataset that loads tensors via a json file containing paths to audio files, transcripts, and durations (in seconds).
    Each new line is a different sample. Example below:
    {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147}
    ...
    {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt":
    "utterance_id", "ctm_utt": "en_4156", "side": "A"}
    Args:
        manifest_filepath: Path to manifest json as described above. Can be comma-separated paths.
        parser: Str for a language specific preprocessor or a callable.
        sample_rate (int): Sample rate to resample loaded audio to
        int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
        augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor object used to augment loaded
            audio
        max_duration: If audio exceeds this length, do not include in dataset
        min_duration: If audio is less than this length, do not include in dataset
        max_utts: Limit number of utterances
        trim: whether or not to trim silence. Defaults to False
        bos_id: Id of beginning of sequence symbol to append if not None
        eos_id: Id of end of sequence symbol to append if not None
        pad_id: Id of pad symbol. Defaults to 0
        return_sample_id (bool): whether to return the sample_id as a part of each sample
    """
    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports.
               """
        return {
            'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
            'a_sig_length': NeuralType(tuple('B'), LengthsType()),
            'transcripts': NeuralType(('B', 'T'), LabelsType()),
            'transcript_length': NeuralType(tuple('B'), LengthsType()),
            'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    def __init__(
        self,
        manifest_filepath: str,
        parser: Union[str, Callable],
        sample_rate: int,
        int_values: bool = False,
        augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
        max_duration: Optional[int] = None,
        min_duration: Optional[int] = None,
        max_utts: int = 0,
        trim: bool = False,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        pad_id: int = 0,
        return_sample_id: bool = False,
    ):
        if type(manifest_filepath) == str:
            manifest_filepath = manifest_filepath.split(",")

        self.manifest_processor = ASRManifestProcessor(
            manifest_filepath=manifest_filepath,
            parser=parser,
            max_duration=max_duration,
            min_duration=min_duration,
            max_utts=max_utts,
            bos_id=bos_id,
            eos_id=eos_id,
            pad_id=pad_id,
        )
        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)
        self.trim = trim
        self.return_sample_id = return_sample_id

    def get_manifest_sample(self, sample_id):
        return self.manifest_processor.collection[sample_id]

    def __getitem__(self, index):
        sample = self.manifest_processor.collection[index]
        offset = sample.offset

        if offset is None:
            offset = 0

        features = self.featurizer.process(sample.audio_file,
                                           offset=offset,
                                           duration=sample.duration,
                                           trim=self.trim,
                                           orig_sr=sample.orig_sr)
        f, fl = features, torch.tensor(features.shape[0]).long()

        t, tl = self.manifest_processor.process_text_by_sample(sample=sample)

        if self.return_sample_id:
            output = f, fl, torch.tensor(t).long(), torch.tensor(
                tl).long(), index
        else:
            output = f, fl, torch.tensor(t).long(), torch.tensor(tl).long()

        return output

    def __len__(self):
        return len(self.manifest_processor.collection)

    def _collate_fn(self, batch):
        return _speech_collate_fn(batch, pad_id=self.manifest_processor.pad_id)
Пример #15
0
class _TarredAudioToTextDataset(IterableDataset):
    """
    A similar Dataset to the AudioToCharDataset/AudioToBPEDataset, but which loads tarred audio files.

    Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToCharDataset/AudioToBPEDataset),
    as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should
    contain the information for one audio file, including at least the transcript and name of the audio
    file within the tarball.

    Valid formats for the audio_tar_filepaths argument include:
    (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or
    (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...].

    Note: For brace expansion in (1), there may be cases where `{x..y}` syntax cannot be used due to shell interference.
    This occurs most commonly inside SLURM scripts. Therefore we provide a few equivalent replacements.
    Supported opening braces - { <=> (, [, < and the special tag _OP_.
    Supported closing braces - } <=> ), ], > and the special tag _CL_.
    For SLURM based tasks, we suggest the use of the special tags for ease of use.

    See the WebDataset documentation for more information about accepted data and input formats.

    If using multiple workers the number of shards should be divisible by world_size to ensure an
    even split among workers. If it is not divisible, logging will give a warning but training will proceed.
    In addition, if using mutiprocessing, each shard MUST HAVE THE SAME NUMBER OF ENTRIES after filtering
    is applied. We currently do not check for this, but your program may hang if the shards are uneven!

    Notice that a few arguments are different from the AudioToCharDataset; for example, shuffle (bool) has been
    replaced by shuffle_n (int).

    Additionally, please note that the len() of this DataLayer is assumed to be the length of the manifest
    after filtering. An incorrect manifest length may lead to some DataLoader issues down the line.

    Args:
        audio_tar_filepaths: Either a list of audio tarball filepaths, or a
            string (can be brace-expandable).
        manifest_filepath (str): Path to the manifest.
        parser (callable): A callable which is used to pre-process the text output.
        sample_rate (int): Sample rate to resample loaded audio to
        int_values (bool): If true, load samples as 32-bit integers. Defauts to False.
        augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor
            object used to augment loaded audio
        shuffle_n (int): How many samples to look ahead and load to be shuffled.
            See WebDataset documentation for more details.
            Defaults to 0.
        min_duration (float): Dataset parameter.
            All training files which have a duration less than min_duration
            are dropped. Note: Duration is read from the manifest JSON.
            Defaults to 0.1.
        max_duration (float): Dataset parameter.
            All training files which have a duration more than max_duration
            are dropped. Note: Duration is read from the manifest JSON.
            Defaults to None.
        max_utts (int): Limit number of utterances. 0 means no maximum.
        blank_index (int): Blank character index, defaults to -1.
        unk_index (int): Unknown character index, defaults to -1.
        normalize (bool): Dataset parameter.
            Whether to use automatic text cleaning.
            It is highly recommended to manually clean text for best results.
            Defaults to True.
        trim (bool): Whether to use trim silence from beginning and end
            of audio signal using librosa.effects.trim().
            Defaults to False.
        bos_id (id): Dataset parameter.
            Beginning of string symbol id used for seq2seq models.
            Defaults to None.
        eos_id (id): Dataset parameter.
            End of string symbol id used for seq2seq models.
            Defaults to None.
        pad_id (id): Token used to pad when collating samples in batches.
            If this is None, pads using 0s.
            Defaults to None.
        shard_strategy (str): Tarred dataset shard distribution strategy chosen as a str value during ddp.
            -   `scatter`: The default shard strategy applied by WebDataset, where each node gets
                a unique set of shards, which are permanently pre-allocated and never changed at runtime.
            -   `replicate`: Optional shard strategy, where each node gets all of the set of shards
                available in the tarred dataset, which are permanently pre-allocated and never changed at runtime.
                The benefit of replication is that it allows each node to sample data points from the entire
                dataset independently of other nodes, and reduces dependence on value of `shuffle_n`.

                Note: Replicated strategy allows every node to sample the entire set of available tarfiles,
                and therefore more than one node may sample the same tarfile, and even sample the same
                data points! As such, there is no assured guarantee that all samples in the dataset will be
                sampled at least once during 1 epoch.
        global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
        world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
        return_sample_id (bool): whether to return the sample_id as a part of each sample
    """
    def __init__(
        self,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        parser: Callable,
        sample_rate: int,
        int_values: bool = False,
        augmentor: Optional[
            'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
        shuffle_n: int = 0,
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_utts: int = 0,
        trim: bool = False,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        pad_id: int = 0,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
        return_sample_id: bool = False,
    ):
        self.manifest_processor = ASRManifestProcessor(
            manifest_filepath=manifest_filepath,
            parser=parser,
            max_duration=max_duration,
            min_duration=min_duration,
            max_utts=max_utts,
            bos_id=bos_id,
            eos_id=eos_id,
            pad_id=pad_id,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)
        self.trim = trim
        self.eos_id = eos_id
        self.bos_id = bos_id
        self.pad_id = pad_id
        self.return_sample_id = return_sample_id

        audio_tar_filepaths = expand_audio_filepaths(
            audio_tar_filepaths=audio_tar_filepaths,
            shard_strategy=shard_strategy,
            world_size=world_size,
            global_rank=global_rank,
        )

        # Put together WebDataset
        self._dataset = wd.WebDataset(urls=audio_tar_filepaths,
                                      nodesplitter=None)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info(
                "WebDataset will not shuffle files within the tar files.")

        self._dataset = (self._dataset.rename(
            audio='wav;ogg;flac',
            key='__key__').to_tuple('audio', 'key').pipe(self._filter).pipe(
                self._loop_offsets).map(f=self._build_sample))

    def _filter(self, iterator):
        """This function is used to remove samples that have been filtered out by ASRAudioText already.
        Otherwise, we would get a KeyError as _build_sample attempts to find the manifest entry for a sample
        that was filtered out (e.g. for duration).
        Note that if using multi-GPU training, filtering may lead to an imbalance in samples in each shard,
        which may make your code hang as one process will finish before the other.
        """
        class TarredAudioFilter:
            def __init__(self, collection):
                self.iterator = iterator
                self.collection = collection

            def __iter__(self):
                return self

            def __next__(self):
                while True:
                    audio_bytes, audio_filename = next(self.iterator)
                    file_id, _ = os.path.splitext(
                        os.path.basename(audio_filename))
                    if file_id in self.collection.mapping:
                        return audio_bytes, audio_filename

        return TarredAudioFilter(self.manifest_processor.collection)

    def _loop_offsets(self, iterator):
        """This function is used to iterate through utterances with different offsets for each file.
        """
        class TarredAudioLoopOffsets:
            def __init__(self, collection):
                self.iterator = iterator
                self.collection = collection
                self.current_fn = None
                self.current_bytes = None
                self.offset_id = 0

            def __iter__(self):
                return self

            def __next__(self):
                if self.current_fn is None:
                    self.current_bytes, self.current_fn = next(self.iterator)
                    self.offset_id = 0
                else:
                    offset_list = self.collection.mapping[self.current_fn]
                    if len(offset_list) == self.offset_id + 1:
                        self.current_bytes, self.current_fn = next(
                            self.iterator)
                        self.offset_id = 0
                    else:
                        self.offset_id += 1

                return self.current_bytes, self.current_fn, self.offset_id

        return TarredAudioLoopOffsets(self.manifest_processor.collection)

    def _collate_fn(self, batch):
        return _speech_collate_fn(batch, self.pad_id)

    def _build_sample(self, tup):
        """Builds the training sample by combining the data from the WebDataset with the manifest info.
        """
        audio_bytes, audio_filename, offset_id = tup

        # Grab manifest entry from self.manifest_preprocessor.collection
        file_id, _ = os.path.splitext(os.path.basename(audio_filename))
        manifest_idx = self.manifest_processor.collection.mapping[file_id][
            offset_id]
        manifest_entry = self.manifest_processor.collection[manifest_idx]

        offset = manifest_entry.offset
        if offset is None:
            offset = 0

        # Convert audio bytes to IO stream for processing (for SoundFile to read)
        audio_filestream = io.BytesIO(audio_bytes)
        features = self.featurizer.process(
            audio_filestream,
            offset=offset,
            duration=manifest_entry.duration,
            trim=self.trim,
            orig_sr=manifest_entry.orig_sr,
        )
        audio_filestream.close()

        # Audio features
        f, fl = features, torch.tensor(features.shape[0]).long()

        # Text features
        t, tl = manifest_entry.text_tokens, len(manifest_entry.text_tokens)

        self.manifest_processor.process_text_by_sample(sample=manifest_entry)

        if self.bos_id is not None:
            t = [self.bos_id] + t
            tl += 1
        if self.eos_id is not None:
            t = t + [self.eos_id]
            tl += 1

        if self.return_sample_id:
            return f, fl, torch.tensor(t).long(), torch.tensor(
                tl).long(), manifest_idx
        else:
            return f, fl, torch.tensor(t).long(), torch.tensor(tl).long()

    def get_manifest_sample(self, sample_id):
        return self.manifest_processor.collection[sample_id]

    def __iter__(self):
        return self._dataset.__iter__()

    def __len__(self):
        return len(self.manifest_processor.collection)
Пример #16
0
    def __init__(
        self,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        parser: Callable,
        sample_rate: int,
        int_values: bool = False,
        augmentor: Optional[
            'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
        shuffle_n: int = 0,
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_utts: int = 0,
        trim: bool = False,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        pad_id: int = 0,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
        return_sample_id: bool = False,
    ):
        self.manifest_processor = ASRManifestProcessor(
            manifest_filepath=manifest_filepath,
            parser=parser,
            max_duration=max_duration,
            min_duration=min_duration,
            max_utts=max_utts,
            bos_id=bos_id,
            eos_id=eos_id,
            pad_id=pad_id,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)
        self.trim = trim
        self.eos_id = eos_id
        self.bos_id = bos_id
        self.pad_id = pad_id
        self.return_sample_id = return_sample_id

        audio_tar_filepaths = expand_audio_filepaths(
            audio_tar_filepaths=audio_tar_filepaths,
            shard_strategy=shard_strategy,
            world_size=world_size,
            global_rank=global_rank,
        )

        # Put together WebDataset
        self._dataset = wd.WebDataset(urls=audio_tar_filepaths,
                                      nodesplitter=None)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info(
                "WebDataset will not shuffle files within the tar files.")

        self._dataset = (self._dataset.rename(
            audio='wav;ogg;flac',
            key='__key__').to_tuple('audio', 'key').pipe(self._filter).pipe(
                self._loop_offsets).map(f=self._build_sample))
Пример #17
0
    def __init__(
        self,
        manifest_filepath: str,
        sample_rate: int,
        supplementary_folder: Path,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        ignore_file: Optional[str] = None,
        trim: bool = False,
        n_fft=1024,
        win_length=None,
        hop_length=None,
        window="hann",
        n_mels=64,
        lowfreq=0,
        highfreq=None,
        pitch_fmin=80,
        pitch_fmax=640,
        pitch_avg=0,
        pitch_std=1,
        tokenize_text=True,
    ):
        """Dataset that loads audio, log mel specs, text tokens, duration / attention priors, pitches, and energies.
        Log mels, priords, pitches, and energies will be computed on the fly and saved in the supplementary_folder if
        they did not exist before.

        Args:
            manifest_filepath (str, Path, List[str, Path]): Path(s) to the .json manifests containing information on the
                dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
                json. Each line should contain the following:
                    "audio_filepath": <PATH_TO_WAV>
                    "mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional)
                    "duration": <Duration of audio clip in seconds> (Optional)
                    "text": <THE_TRANSCRIPT> (Optional)
            sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
            supplementary_folder (Path): A folder that contains or will contain extra information such as log_mel if not
                specified in the manifest .json file. It will also contain priors, pitches, and energies
            max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            ignore_file (Optional[str, Path]): The location of a pickle-saved list of audio_ids (the stem of the audio
                files) that will be pruned prior to training. Defaults to None which does not prune.
            trim (Optional[bool]): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
            n_fft (Optional[int]): The number of fft samples. Defaults to 1024
            win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft.
            hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4.
            window (Optional[str]): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the
                equivalent torch window function.
            n_mels (Optional[int]): The number of mel filters. Defaults to 64.
            lowfreq (Optional[int]): The lowfreq input to the mel filter calculation. Defaults to 0.
            highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None.
            pitch_fmin (Optional[int]): The fmin input to librosa.pyin. Defaults to None.
            pitch_fmax (Optional[int]): The fmax input to librosa.pyin. Defaults to None.
            pitch_avg (Optional[float]): The mean that we use to normalize the pitch. Defaults to 0.
            pitch_std (Optional[float]): The std that we use to normalize the pitch. Defaults to 1.
            tokenize_text (Optional[bool]): Whether to tokenize (turn chars into ints). Defaults to True.
        """
        super().__init__()

        self.pitch_fmin = pitch_fmin
        self.pitch_fmax = pitch_fmax
        self.pitch_avg = pitch_avg
        self.pitch_std = pitch_std
        self.win_length = win_length or n_fft
        self.sample_rate = sample_rate
        self.hop_len = hop_length or n_fft // 4

        self.parser = make_parser(name="en", do_tokenize=tokenize_text)
        self.pad_id = self.parser._blank_id
        Path(supplementary_folder).mkdir(parents=True, exist_ok=True)
        self.supplementary_folder = supplementary_folder

        audio_files = []
        total_duration = 0
        # Load data from manifests
        # Note: audio is always required, even for text -> mel_spectrogram models, due to the fact that most models
        # extract pitch from the audio
        # Note: mel_filepath is not required and if not present, we then check the supplementary folder. If we fail, we
        # compute the mel on the fly and save it to the supplementary folder
        # Note: text is not required. Any models that require on text (spectrogram generators, end-to-end models) will
        # fail if not set. However vocoders (mel -> audio) will be able to work without text
        if isinstance(manifest_filepath, str):
            manifest_filepath = [manifest_filepath]
        for manifest_file in manifest_filepath:
            with open(Path(manifest_file).expanduser(), 'r') as f:
                logging.info(f"Loading dataset from {manifest_file}.")
                for line in f:
                    item = json.loads(line)
                    # Grab audio, text, mel if they exist
                    file_info = {}
                    file_info["audio_filepath"] = item["audio_filepath"]
                    file_info["mel_filepath"] = item[
                        "mel_filepath"] if "mel_filepath" in item else None
                    file_info["duration"] = item[
                        "duration"] if "duration" in item else None
                    # Parse text
                    file_info["text_tokens"] = None
                    if "text" in item:
                        text = item["text"]
                        text_tokens = self.parser(text)
                        file_info["text_tokens"] = text_tokens
                    audio_files.append(file_info)
                    if file_info["duration"] is None:
                        logging.info(
                            "Not all audio files have duration information. Duration logging will be disabled."
                        )
                        total_duration = None
                    if total_duration is not None:
                        total_duration += item["duration"]

        logging.info(f"Loaded dataset with {len(audio_files)} files.")
        if total_duration is not None:
            logging.info(f"Dataset contains {total_duration/3600:.2f} hours.")

        self.data = []

        if ignore_file:
            logging.info(f"using {ignore_file} to prune dataset.")
            with open(Path(ignore_file).expanduser(), "rb") as f:
                wavs_to_ignore = set(pickle.load(f))

        pruned_duration = 0 if total_duration is not None else None
        pruned_items = 0
        for item in audio_files:
            audio_path = item['audio_filepath']
            audio_id = Path(audio_path).stem

            # Prune data according to min/max_duration & the ignore file
            if total_duration is not None:
                if (min_duration and item["duration"] < min_duration) or (
                        max_duration and item["duration"] > max_duration):
                    pruned_duration += item["duration"]
                    pruned_items += 1
                    continue

            if ignore_file and (audio_id in wavs_to_ignore):
                pruned_items += 1
                pruned_duration += item["duration"]
                wavs_to_ignore.remove(audio_id)
                continue

            self.data.append(item)

        logging.info(
            f"Pruned {pruned_items} files. Final dataset contains {len(self.data)} files"
        )
        if pruned_duration is not None:
            logging.info(
                f"Pruned {pruned_duration/3600:.2f} hours. Final dataset contains "
                f"{(total_duration-pruned_duration)/3600:.2f} hours.")

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate)
        self.trim = trim

        filterbanks = torch.tensor(librosa.filters.mel(sample_rate,
                                                       n_fft,
                                                       n_mels=n_mels,
                                                       fmin=lowfreq,
                                                       fmax=highfreq),
                                   dtype=torch.float).unsqueeze(0)
        self.fb = filterbanks

        torch_windows = {
            'hann': torch.hann_window,
            'hamming': torch.hamming_window,
            'blackman': torch.blackman_window,
            'bartlett': torch.bartlett_window,
            'none': None,
        }
        window_fn = torch_windows.get(window, None)
        window_tensor = window_fn(self.win_length,
                                  periodic=False) if window_fn else None

        self.stft = lambda x: stft_patch(
            input=x,
            n_fft=n_fft,
            hop_length=self.hop_len,
            win_length=self.win_length,
            window=window_tensor.to(torch.float),
        )
Пример #18
0
class CharMelAudioDataset(Dataset):
    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            'transcripts': NeuralType(('B', 'T'), TokenIndex()),
            'transcript_length': NeuralType(('B'), LengthsType()),
            'mels': NeuralType(('B', 'D', 'T'), TokenIndex()),
            'mel_length': NeuralType(('B'), LengthsType()),
            'audio': NeuralType(('B', 'T'), AudioSignal()),
            'audio_length': NeuralType(('B'), LengthsType()),
            'duration_prior': NeuralType(('B', 'T'), TokenDurationType()),
            'pitches': NeuralType(('B', 'T'), RegressionValuesType()),
            'energies': NeuralType(('B', 'T'), RegressionValuesType()),
        }

    def __init__(
        self,
        manifest_filepath: str,
        sample_rate: int,
        supplementary_folder: Path,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        ignore_file: Optional[str] = None,
        trim: bool = False,
        n_fft=1024,
        win_length=None,
        hop_length=None,
        window="hann",
        n_mels=64,
        lowfreq=0,
        highfreq=None,
        pitch_fmin=80,
        pitch_fmax=640,
        pitch_avg=0,
        pitch_std=1,
        tokenize_text=True,
    ):
        """Dataset that loads audio, log mel specs, text tokens, duration / attention priors, pitches, and energies.
        Log mels, priords, pitches, and energies will be computed on the fly and saved in the supplementary_folder if
        they did not exist before.

        Args:
            manifest_filepath (str, Path, List[str, Path]): Path(s) to the .json manifests containing information on the
                dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
                json. Each line should contain the following:
                    "audio_filepath": <PATH_TO_WAV>
                    "mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional)
                    "duration": <Duration of audio clip in seconds> (Optional)
                    "text": <THE_TRANSCRIPT> (Optional)
            sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
            supplementary_folder (Path): A folder that contains or will contain extra information such as log_mel if not
                specified in the manifest .json file. It will also contain priors, pitches, and energies
            max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            ignore_file (Optional[str, Path]): The location of a pickle-saved list of audio_ids (the stem of the audio
                files) that will be pruned prior to training. Defaults to None which does not prune.
            trim (Optional[bool]): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
            n_fft (Optional[int]): The number of fft samples. Defaults to 1024
            win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft.
            hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4.
            window (Optional[str]): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the
                equivalent torch window function.
            n_mels (Optional[int]): The number of mel filters. Defaults to 64.
            lowfreq (Optional[int]): The lowfreq input to the mel filter calculation. Defaults to 0.
            highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None.
            pitch_fmin (Optional[int]): The fmin input to librosa.pyin. Defaults to None.
            pitch_fmax (Optional[int]): The fmax input to librosa.pyin. Defaults to None.
            pitch_avg (Optional[float]): The mean that we use to normalize the pitch. Defaults to 0.
            pitch_std (Optional[float]): The std that we use to normalize the pitch. Defaults to 1.
            tokenize_text (Optional[bool]): Whether to tokenize (turn chars into ints). Defaults to True.
        """
        super().__init__()

        self.pitch_fmin = pitch_fmin
        self.pitch_fmax = pitch_fmax
        self.pitch_avg = pitch_avg
        self.pitch_std = pitch_std
        self.win_length = win_length or n_fft
        self.sample_rate = sample_rate
        self.hop_len = hop_length or n_fft // 4

        self.parser = make_parser(name="en", do_tokenize=tokenize_text)
        self.pad_id = self.parser._blank_id
        Path(supplementary_folder).mkdir(parents=True, exist_ok=True)
        self.supplementary_folder = supplementary_folder

        audio_files = []
        total_duration = 0
        # Load data from manifests
        # Note: audio is always required, even for text -> mel_spectrogram models, due to the fact that most models
        # extract pitch from the audio
        # Note: mel_filepath is not required and if not present, we then check the supplementary folder. If we fail, we
        # compute the mel on the fly and save it to the supplementary folder
        # Note: text is not required. Any models that require on text (spectrogram generators, end-to-end models) will
        # fail if not set. However vocoders (mel -> audio) will be able to work without text
        if isinstance(manifest_filepath, str):
            manifest_filepath = [manifest_filepath]
        for manifest_file in manifest_filepath:
            with open(Path(manifest_file).expanduser(), 'r') as f:
                logging.info(f"Loading dataset from {manifest_file}.")
                for line in f:
                    item = json.loads(line)
                    # Grab audio, text, mel if they exist
                    file_info = {}
                    file_info["audio_filepath"] = item["audio_filepath"]
                    file_info["mel_filepath"] = item[
                        "mel_filepath"] if "mel_filepath" in item else None
                    file_info["duration"] = item[
                        "duration"] if "duration" in item else None
                    # Parse text
                    file_info["text_tokens"] = None
                    if "text" in item:
                        text = item["text"]
                        text_tokens = self.parser(text)
                        file_info["text_tokens"] = text_tokens
                    audio_files.append(file_info)
                    if file_info["duration"] is None:
                        logging.info(
                            "Not all audio files have duration information. Duration logging will be disabled."
                        )
                        total_duration = None
                    if total_duration is not None:
                        total_duration += item["duration"]

        logging.info(f"Loaded dataset with {len(audio_files)} files.")
        if total_duration is not None:
            logging.info(f"Dataset contains {total_duration/3600:.2f} hours.")

        self.data = []

        if ignore_file:
            logging.info(f"using {ignore_file} to prune dataset.")
            with open(Path(ignore_file).expanduser(), "rb") as f:
                wavs_to_ignore = set(pickle.load(f))

        pruned_duration = 0 if total_duration is not None else None
        pruned_items = 0
        for item in audio_files:
            audio_path = item['audio_filepath']
            audio_id = Path(audio_path).stem

            # Prune data according to min/max_duration & the ignore file
            if total_duration is not None:
                if (min_duration and item["duration"] < min_duration) or (
                        max_duration and item["duration"] > max_duration):
                    pruned_duration += item["duration"]
                    pruned_items += 1
                    continue

            if ignore_file and (audio_id in wavs_to_ignore):
                pruned_items += 1
                pruned_duration += item["duration"]
                wavs_to_ignore.remove(audio_id)
                continue

            self.data.append(item)

        logging.info(
            f"Pruned {pruned_items} files. Final dataset contains {len(self.data)} files"
        )
        if pruned_duration is not None:
            logging.info(
                f"Pruned {pruned_duration/3600:.2f} hours. Final dataset contains "
                f"{(total_duration-pruned_duration)/3600:.2f} hours.")

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate)
        self.trim = trim

        filterbanks = torch.tensor(librosa.filters.mel(sample_rate,
                                                       n_fft,
                                                       n_mels=n_mels,
                                                       fmin=lowfreq,
                                                       fmax=highfreq),
                                   dtype=torch.float).unsqueeze(0)
        self.fb = filterbanks

        torch_windows = {
            'hann': torch.hann_window,
            'hamming': torch.hamming_window,
            'blackman': torch.blackman_window,
            'bartlett': torch.bartlett_window,
            'none': None,
        }
        window_fn = torch_windows.get(window, None)
        window_tensor = window_fn(self.win_length,
                                  periodic=False) if window_fn else None

        self.stft = lambda x: stft_patch(
            input=x,
            n_fft=n_fft,
            hop_length=self.hop_len,
            win_length=self.win_length,
            window=window_tensor.to(torch.float),
        )

    def __getitem__(self, index):
        spec = None
        sample = self.data[index]

        features = self.featurizer.process(sample["audio_filepath"],
                                           trim=self.trim)
        audio, audio_length = features, torch.tensor(features.shape[0]).long()
        if isinstance(sample["text_tokens"], str):
            # If tokenize_text is False for Phone dataset
            text = sample["text_tokens"]
            text_length = None
        else:
            text = torch.tensor(sample["text_tokens"]).long()
            text_length = torch.tensor(len(sample["text_tokens"])).long()
        audio_stem = Path(sample["audio_filepath"]).stem

        # Load mel if it exists
        mel_path = sample["mel_filepath"]
        if mel_path and Path(mel_path).exists():
            log_mel = torch.load(mel_path)
        else:
            mel_path = Path(self.supplementary_folder) / f"mel_{audio_stem}.pt"
            if mel_path.exists():
                log_mel = torch.load(mel_path)
            else:
                # disable autocast to get full range of stft values
                with torch.cuda.amp.autocast(enabled=False):
                    spec = self.stft(audio)

                    # guard is needed for sqrt if grads are passed through
                    guard = CONSTANT  # TODO: Enable 0 if not self.use_grads else CONSTANT
                    if spec.dtype in [torch.cfloat, torch.cdouble]:
                        spec = torch.view_as_real(spec)
                    spec = torch.sqrt(spec.pow(2).sum(-1) + guard)

                    mel = torch.matmul(self.fb.to(spec.dtype), spec)

                    log_mel = torch.log(
                        torch.clamp(mel, min=torch.finfo(mel.dtype).tiny))
                    torch.save(log_mel, mel_path)

        log_mel = log_mel.squeeze(0)
        log_mel_length = torch.tensor(log_mel.shape[1]).long()

        duration_prior = None
        if text_length is not None:
            ### Make duration attention prior if not exist in the supplementary folder
            prior_path = Path(self.supplementary_folder
                              ) / f"pr_tl{text_length}_al_{log_mel_length}.pt"
            if prior_path.exists():
                duration_prior = torch.load(prior_path)
            else:
                duration_prior = beta_binomial_prior_distribution(
                    text_length, log_mel_length)
                duration_prior = torch.from_numpy(duration_prior)
                torch.save(duration_prior, prior_path)

        # Load pitch file (F0s)
        pitch_path = (
            Path(self.supplementary_folder) /
            f"{audio_stem}_pitch_pyin_fmin{self.pitch_fmin}_fmax{self.pitch_fmax}_fl{self.win_length}_hs{self.hop_len}.pt"
        )
        if pitch_path.exists():
            pitch = torch.load(pitch_path)
        else:
            pitch, _, _ = librosa.pyin(
                audio.numpy(),
                fmin=self.pitch_fmin,
                fmax=self.pitch_fmax,
                frame_length=self.win_length,
                sr=self.sample_rate,
                fill_na=0.0,
            )
            pitch = torch.from_numpy(pitch)
            torch.save(pitch, pitch_path)
        # Standize pitch
        pitch -= self.pitch_avg
        pitch[pitch == -self.
              pitch_avg] = 0.0  # Zero out values that were perviously zero
        pitch /= self.pitch_std

        # Load energy file (L2-norm of the amplitude of each STFT frame of an utterance)
        energy_path = Path(
            self.supplementary_folder
        ) / f"{audio_stem}_energy_wl{self.win_length}_hs{self.hop_len}.pt"
        if energy_path.exists():
            energy = torch.load(energy_path)
        else:
            if spec is None:
                spec = self.stft(audio)
            energy = torch.linalg.norm(spec.squeeze(0), axis=0)
            # Save to new file
            torch.save(energy, energy_path)

        return text, text_length, log_mel, log_mel_length, audio, audio_length, duration_prior, pitch, energy

    def __len__(self):
        return len(self.data)

    def _collate_fn(self, batch):
        log_mel_pad = torch.finfo(batch[0][2].dtype).tiny

        _, tokens_lengths, _, log_mel_lengths, _, audio_lengths, duration_priors_list, pitches, energies = zip(
            *batch)

        max_tokens_len = max(tokens_lengths).item()
        max_log_mel_len = max(log_mel_lengths)
        max_audio_len = max(audio_lengths).item()
        max_pitches_len = max([len(i) for i in pitches])
        max_energies_len = max([len(i) for i in energies])
        if max_pitches_len != max_energies_len or max_pitches_len != max_log_mel_len:
            logging.warning(
                f"max_pitches_len: {max_pitches_len} != max_energies_len: {max_energies_len} != "
                f"max_mel_len:{max_log_mel_len}. Your training run will error out!"
            )

        # Define empty lists to be batched
        duration_priors = torch.zeros(
            len(duration_priors_list),
            max([prior_i.shape[0] for prior_i in duration_priors_list]),
            max([prior_i.shape[1] for prior_i in duration_priors_list]),
        )
        audios, tokens, log_mels, pitches, energies = [], [], [], [], []
        for i, sample_tuple in enumerate(batch):
            token, token_len, log_mel, log_mel_len, audio, audio_len, duration_prior, pitch, energy = sample_tuple
            # Pad text tokens
            token_len = token_len.item()
            if token_len < max_tokens_len:
                pad = (0, max_tokens_len - token_len)
                token = torch.nn.functional.pad(token, pad, value=self.pad_id)
            tokens.append(token)
            # Pad mel
            log_mel_len = log_mel_len
            if log_mel_len < max_log_mel_len:
                pad = (0, max_log_mel_len - log_mel_len)
                log_mel = torch.nn.functional.pad(log_mel,
                                                  pad,
                                                  value=log_mel_pad)
            log_mels.append(log_mel)
            # Pad audio
            audio_len = audio_len.item()
            if audio_len < max_audio_len:
                pad = (0, max_audio_len - audio_len)
                audio = torch.nn.functional.pad(audio, pad)
            audios.append(audio)
            # Pad duration_prior
            duration_priors[i, :duration_prior.shape[0], :duration_prior.
                            shape[1]] = duration_prior
            # Pad pitch
            if len(pitch) < max_pitches_len:
                pad = (0, max_pitches_len - len(pitch))
                pitch = torch.nn.functional.pad(pitch, pad)
            pitches.append(pitch)
            # Pad energy
            if len(energy) < max_energies_len:
                pad = (0, max_energies_len - len(energy))
                energy = torch.nn.functional.pad(energy, pad)
            energies.append(energy)

        audios = torch.stack(audios)
        audio_lengths = torch.stack(audio_lengths)
        tokens = torch.stack(tokens)
        tokens_lengths = torch.stack(tokens_lengths)
        log_mels = torch.stack(log_mels)
        log_mel_lengths = torch.stack(log_mel_lengths)
        pitches = torch.stack(pitches)
        energies = torch.stack(energies)

        logging.debug(f"audios: {audios.shape}")
        logging.debug(f"audio_lengths: {audio_lengths.shape}")
        logging.debug(f"tokens: {tokens.shape}")
        logging.debug(f"tokens_lengths: {tokens_lengths.shape}")
        logging.debug(f"log_mels: {log_mels.shape}")
        logging.debug(f"log_mel_lengths: {log_mel_lengths.shape}")
        logging.debug(f"duration_priors: {duration_priors.shape}")
        logging.debug(f"pitches: {pitches.shape}")
        logging.debug(f"energies: {energies.shape}")

        return (tokens, tokens_lengths, log_mels, log_mel_lengths,
                duration_priors, pitches, energies)

    def decode(self, tokens):
        assert len(tokens.squeeze().shape) in [0, 1]
        return self.parser.decode(tokens)