コード例 #1
0
 def stft(x):
     comp = stft_patch(x.squeeze(1),
                       n_fft=1024,
                       hop_length=256,
                       win_length=1024)
     real, imag = comp[..., 0], comp[..., 1]
     mags = torch.sqrt(real**2 + imag**2)
     phase = torch.atan2(imag, real)
     return mags, phase
コード例 #2
0
 def yet_another_patch(audio, n_fft, hop_length, win_length,
                       window):
     spec = stft_patch(audio,
                       n_fft=n_fft,
                       hop_length=hop_length,
                       win_length=win_length,
                       window=window)
     if spec.dtype in [torch.cfloat, torch.cdouble]:
         spec = torch.view_as_real(spec)
     return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(
         spec[..., -1], spec[..., 0])
コード例 #3
0
ファイル: stftlosses.py プロジェクト: zt706/NeMo
def stft(x, fft_size, hop_size, win_length, window):
    """Perform STFT and convert to magnitude spectrogram.
    Args:
        x (Tensor): Input signal tensor (B, T).
        fft_size (int): FFT size.
        hop_size (int): Hop size.
        win_length (int): Window length.
        window (str): Window function type.
    Returns:
        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
    """
    x_stft = stft_patch(x,
                        fft_size,
                        hop_size,
                        win_length,
                        window,
                        return_complex=False)
    real = x_stft[..., 0]
    imag = x_stft[..., 1]

    # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
    return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
コード例 #4
0
ファイル: features.py プロジェクト: mousebaiker/NeMo
    def __init__(
        self,
        sample_rate=16000,
        n_window_size=320,
        n_window_stride=160,
        window="hann",
        normalize="per_feature",
        n_fft=None,
        preemph=0.97,
        nfilt=64,
        lowfreq=0,
        highfreq=None,
        log=True,
        log_zero_guard_type="add",
        log_zero_guard_value=2**-24,
        dither=CONSTANT,
        pad_to=16,
        max_duration=16.7,
        frame_splicing=1,
        exact_pad=False,
        stft_exact_pad=False,  # TODO: Remove this in 1.1.0
        stft_conv=False,  # TODO: Remove this in 1.1.0
        pad_value=0,
        mag_power=2.0,
        use_grads=False,
    ):
        super().__init__()
        if stft_conv or stft_exact_pad:
            logging.warning(
                "Using torch_stft is deprecated and will be removed in 1.1.0. Please set stft_conv and stft_exact_pad "
                "to False for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True "
                "as needed.")
        if (exact_pad or stft_exact_pad) and n_window_stride % 2 == 1:
            raise NotImplementedError(
                f"{self} received exact_pad == True, but hop_size was odd. If audio_length % hop_size == 0. Then the "
                "returned spectrogram would not be of length audio_length // hop_size. Please use an even hop_size."
            )
        self.log_zero_guard_value = log_zero_guard_value
        if (n_window_size is None or n_window_stride is None
                or not isinstance(n_window_size, int)
                or not isinstance(n_window_stride, int) or n_window_size <= 0
                or n_window_stride <= 0):
            raise ValueError(
                f"{self} got an invalid value for either n_window_size or "
                f"n_window_stride. Both must be positive ints.")
        logging.info(f"PADDING: {pad_to}")

        self.win_length = n_window_size
        self.hop_length = n_window_stride
        self.n_fft = n_fft or 2**math.ceil(math.log2(self.win_length))
        self.stft_pad_amount = (self.n_fft -
                                self.hop_length) // 2 if exact_pad else None
        self.stft_exact_pad = stft_exact_pad
        self.stft_conv = stft_conv

        if stft_conv:
            logging.info("STFT using conv")
            if stft_exact_pad:
                logging.info("STFT using exact pad")
                self.stft = STFTExactPad(self.n_fft, self.hop_length,
                                         self.win_length, window)
            else:
                self.stft = STFTPatch(self.n_fft, self.hop_length,
                                      self.win_length, window)
        else:
            logging.info("STFT using torch")
            if exact_pad:
                logging.info("STFT using exact pad")
            torch_windows = {
                'hann': torch.hann_window,
                'hamming': torch.hamming_window,
                'blackman': torch.blackman_window,
                'bartlett': torch.bartlett_window,
                'none': None,
            }
            window_fn = torch_windows.get(window, None)
            window_tensor = window_fn(self.win_length,
                                      periodic=False) if window_fn else None
            self.register_buffer("window", window_tensor)
            self.stft = lambda x: stft_patch(
                x,
                n_fft=self.n_fft,
                hop_length=self.hop_length,
                win_length=self.win_length,
                center=False if exact_pad else True,
                window=self.window.to(dtype=torch.float),
                return_complex=False,
            )

        self.normalize = normalize
        self.log = log
        self.dither = dither
        self.frame_splicing = frame_splicing
        self.nfilt = nfilt
        self.preemph = preemph
        self.pad_to = pad_to
        highfreq = highfreq or sample_rate / 2

        filterbanks = torch.tensor(librosa.filters.mel(sample_rate,
                                                       self.n_fft,
                                                       n_mels=nfilt,
                                                       fmin=lowfreq,
                                                       fmax=highfreq),
                                   dtype=torch.float).unsqueeze(0)
        self.register_buffer("fb", filterbanks)

        # Calculate maximum sequence length
        max_length = self.get_seq_len(
            torch.tensor(max_duration * sample_rate, dtype=torch.float))
        max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
        self.max_length = max_length + max_pad
        self.pad_value = pad_value
        self.mag_power = mag_power

        # We want to avoid taking the log of zero
        # There are two options: either adding or clamping to a small value
        if log_zero_guard_type not in ["add", "clamp"]:
            raise ValueError(
                f"{self} received {log_zero_guard_type} for the "
                f"log_zero_guard_type parameter. It must be either 'add' or "
                f"'clamp'.")

        self.use_grads = use_grads
        if not use_grads:
            self.forward = torch.no_grad()(self.forward)

        # log_zero_guard_value is the the small we want to use, we support
        # an actual number, or "tiny", or "eps"
        self.log_zero_guard_type = log_zero_guard_type
        logging.debug(f"sr: {sample_rate}")
        logging.debug(f"n_fft: {self.n_fft}")
        logging.debug(f"win_length: {self.win_length}")
        logging.debug(f"hop_length: {self.hop_length}")
        logging.debug(f"n_mels: {nfilt}")
        logging.debug(f"fmin: {lowfreq}")
        logging.debug(f"fmax: {highfreq}")
        logging.debug(f"using grads: {use_grads}")
コード例 #5
0
ファイル: features.py プロジェクト: abestern/NeMo
    def __init__(
        self,
        sample_rate=16000,
        n_window_size=320,
        n_window_stride=160,
        window="hann",
        normalize="per_feature",
        n_fft=None,
        preemph=0.97,
        nfilt=64,
        lowfreq=0,
        highfreq=None,
        log=True,
        log_zero_guard_type="add",
        log_zero_guard_value=2**-24,
        dither=CONSTANT,
        pad_to=16,
        max_duration=16.7,
        frame_splicing=1,
        stft_exact_pad=False,
        stft_conv=False,
        pad_value=0,
        mag_power=2.0,
    ):
        super().__init__()
        self.log_zero_guard_value = log_zero_guard_value
        if (n_window_size is None or n_window_stride is None
                or not isinstance(n_window_size, int)
                or not isinstance(n_window_stride, int) or n_window_size <= 0
                or n_window_stride <= 0):
            raise ValueError(
                f"{self} got an invalid value for either n_window_size or "
                f"n_window_stride. Both must be positive ints.")
        logging.info(f"PADDING: {pad_to}")

        self.win_length = n_window_size
        self.hop_length = n_window_stride
        self.n_fft = n_fft or 2**math.ceil(math.log2(self.win_length))
        self.stft_exact_pad = stft_exact_pad
        self.stft_conv = stft_conv

        if stft_conv:
            logging.info("STFT using conv")
            if stft_exact_pad:
                logging.info("STFT using exact pad")
                self.stft = STFTExactPad(self.n_fft, self.hop_length,
                                         self.win_length, window)
            else:
                self.stft = STFTPatch(self.n_fft, self.hop_length,
                                      self.win_length, window)
        else:
            logging.info("STFT using torch")
            torch_windows = {
                'hann': torch.hann_window,
                'hamming': torch.hamming_window,
                'blackman': torch.blackman_window,
                'bartlett': torch.bartlett_window,
                'none': None,
            }
            window_fn = torch_windows.get(window, None)
            window_tensor = window_fn(self.win_length,
                                      periodic=False) if window_fn else None
            self.register_buffer("window", window_tensor)
            self.stft = lambda x: stft_patch(
                x,
                n_fft=self.n_fft,
                hop_length=self.hop_length,
                win_length=self.win_length,
                center=False if stft_exact_pad else True,
                window=self.window.to(dtype=torch.float),
                return_complex=False,
            )

        self.normalize = normalize
        self.log = log
        self.dither = dither
        self.frame_splicing = frame_splicing
        self.nfilt = nfilt
        self.preemph = preemph
        self.pad_to = pad_to
        highfreq = highfreq or sample_rate / 2

        filterbanks = torch.tensor(librosa.filters.mel(sample_rate,
                                                       self.n_fft,
                                                       n_mels=nfilt,
                                                       fmin=lowfreq,
                                                       fmax=highfreq),
                                   dtype=torch.float).unsqueeze(0)
        self.register_buffer("fb", filterbanks)

        # Calculate maximum sequence length
        max_length = self.get_seq_len(
            torch.tensor(max_duration * sample_rate, dtype=torch.float))
        max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
        self.max_length = max_length + max_pad
        self.pad_value = pad_value
        self.mag_power = mag_power

        # We want to avoid taking the log of zero
        # There are two options: either adding or clamping to a small value
        if log_zero_guard_type not in ["add", "clamp"]:
            raise ValueError(
                f"{self} received {log_zero_guard_type} for the "
                f"log_zero_guard_type parameter. It must be either 'add' or "
                f"'clamp'.")
        # log_zero_guard_value is the the small we want to use, we support
        # an actual number, or "tiny", or "eps"
        self.log_zero_guard_type = log_zero_guard_type
        logging.debug(f"sr: {sample_rate}")
        logging.debug(f"n_fft: {self.n_fft}")
        logging.debug(f"win_length: {self.win_length}")
        logging.debug(f"hop_length: {self.hop_length}")
        logging.debug(f"n_mels: {nfilt}")
        logging.debug(f"fmin: {lowfreq}")
        logging.debug(f"fmax: {highfreq}")
コード例 #6
0
ファイル: data.py プロジェクト: manneh/NeMo
    def __init__(
        self,
        manifest_filepath: str,
        sample_rate: int,
        text_tokenizer: Union[BaseTokenizer, Callable[[str], List[int]]],
        tokens: Optional[List[str]] = None,
        text_normalizer: Optional[Union[Normalizer, Callable[[str],
                                                             str]]] = None,
        text_normalizer_call_args: Optional[Dict] = None,
        text_tokenizer_pad_id: Optional[int] = None,
        sup_data_types: Optional[List[str]] = None,
        sup_data_path: Optional[Union[Path, str]] = None,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        ignore_file: Optional[str] = None,
        trim: bool = False,
        n_fft=1024,
        win_length=None,
        hop_length=None,
        window="hann",
        n_mels=80,
        lowfreq=0,
        highfreq=None,
        **kwargs,
    ):
        """Dataset that loads main data types (audio and text) and specified supplementary data types (e.g. log mel, durations, pitch).
        Most supplementary data types will be computed on the fly and saved in the supplementary_folder if they did not exist before.
        Arguments for supplementary data should be also specified in this class and they will be used from kwargs (see keyword args section).
        Args:
            manifest_filepath (str, Path, List[str, Path]): Path(s) to the .json manifests containing information on the
                dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
                json. Each line should contain the following:
                    "audio_filepath": <PATH_TO_WAV>
                    "mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional)
                    "duration": <Duration of audio clip in seconds> (Optional)
                    "text": <THE_TRANSCRIPT> (Optional)
            sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
            text_tokenizer (Optional[Union[BaseTokenizer, Callable[[str], List[int]]]]): BaseTokenizer or callable which represents text tokenizer.
            tokens (Optional[List[str]]): Tokens from text_tokenizer. Should be specified if text_tokenizer is not BaseTokenizer.
            text_normalizer (Optional[Union[Normalizer, Callable[[str], str]]]): Normalizer or callable which represents text normalizer.
            text_normalizer_call_args (Optional[Dict]): Additional arguments for text_normalizer function.
            text_tokenizer_pad_id (Optional[int]): Index of padding. Should be specified if text_tokenizer is not BaseTokenizer.
            sup_data_types (Optional[List[str]]): List of supplementary data types.
            sup_data_path (Optional[Union[Path, str]]): A folder that contains or will contain supplementary data (e.g. pitch).
            max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            ignore_file (Optional[str, Path]): The location of a pickle-saved list of audio_ids (the stem of the audio
                files) that will be pruned prior to training. Defaults to None which does not prune.
            trim (Optional[bool]): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
            n_fft (Optional[int]): The number of fft samples. Defaults to 1024
            win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft.
            hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4.
            window (Optional[str]): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the
                equivalent torch window function.
            n_mels (Optional[int]): The number of mel filters. Defaults to 80.
            lowfreq (Optional[int]): The lowfreq input to the mel filter calculation. Defaults to 0.
            highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None.
        Keyword Args:
            durs_file (Optional[str]): String path to pickled durations location.
            durs_type (Optional[str]): Type of durations. Currently supported only "aligned-based".
            pitch_fmin (Optional[float]): The fmin input to librosa.pyin. Defaults to librosa.note_to_hz('C2').
            pitch_fmax (Optional[float]): The fmax input to librosa.pyin. Defaults to librosa.note_to_hz('C7').
            pitch_avg (Optional[float]): The mean that we use to normalize the pitch.
            pitch_std (Optional[float]): The std that we use to normalize the pitch.
            pitch_norm (Optional[bool]): Whether to normalize pitch (via pitch_avg and pitch_std) or not.
        """
        super().__init__()

        self.text_normalizer = text_normalizer
        self.text_normalizer_call = (
            self.text_normalizer.normalize if isinstance(
                self.text_normalizer, Normalizer) else self.text_normalizer)
        self.text_normalizer_call_args = text_normalizer_call_args

        self.text_tokenizer = text_tokenizer

        if isinstance(self.text_tokenizer, BaseTokenizer):
            self.text_tokenizer_pad_id = text_tokenizer.pad
            self.tokens = text_tokenizer.tokens
        else:
            if text_tokenizer_pad_id is None:
                raise ValueError(
                    f"text_tokenizer_pad_id must be specified if text_tokenizer is not BaseTokenizer"
                )

            if tokens is None:
                raise ValueError(
                    f"tokens must be specified if text_tokenizer is not BaseTokenizer"
                )

            self.text_tokenizer_pad_id = text_tokenizer_pad_id
            self.tokens = tokens

        if isinstance(manifest_filepath, str):
            manifest_filepath = [manifest_filepath]
        self.manifest_filepath = manifest_filepath

        if sup_data_path is not None:
            Path(sup_data_path).mkdir(parents=True, exist_ok=True)
            self.sup_data_path = sup_data_path

        self.sup_data_types = ([
            DATA_STR2DATA_CLASS[d_as_str] for d_as_str in sup_data_types
        ] if sup_data_types is not None else [])
        self.sup_data_types_set = set(self.sup_data_types)

        self.data = []
        audio_files = []
        total_duration = 0
        for manifest_file in self.manifest_filepath:
            with open(Path(manifest_file).expanduser(), 'r') as f:
                logging.info(f"Loading dataset from {manifest_file}.")
                for line in tqdm(f):
                    item = json.loads(line)

                    file_info = {
                        "audio_filepath":
                        item["audio_filepath"],
                        "mel_filepath":
                        item["mel_filepath"]
                        if "mel_filepath" in item else None,
                        "duration":
                        item["duration"] if "duration" in item else None,
                        "text_tokens":
                        None,
                    }

                    if "text" in item:
                        text = item["text"]

                        if self.text_normalizer is not None:
                            text = self.text_normalizer_call(
                                text, **self.text_normalizer_call_args)

                        text_tokens = self.text_tokenizer(text)
                        file_info["raw_text"] = item["text"]
                        file_info["text_tokens"] = text_tokens

                    audio_files.append(file_info)

                    if file_info["duration"] is None:
                        logging.info(
                            "Not all audio files have duration information. Duration logging will be disabled."
                        )
                        total_duration = None

                    if total_duration is not None:
                        total_duration += item["duration"]

        logging.info(f"Loaded dataset with {len(audio_files)} files.")
        if total_duration is not None:
            logging.info(
                f"Dataset contains {total_duration / 3600:.2f} hours.")

        if ignore_file:
            logging.info(f"using {ignore_file} to prune dataset.")
            with open(Path(ignore_file).expanduser(), "rb") as f:
                wavs_to_ignore = set(pickle.load(f))

        pruned_duration = 0 if total_duration is not None else None
        pruned_items = 0
        for item in audio_files:
            audio_path = item['audio_filepath']
            audio_id = Path(audio_path).stem

            # Prune data according to min/max_duration & the ignore file
            if total_duration is not None:
                if (min_duration and item["duration"] < min_duration) or (
                        max_duration and item["duration"] > max_duration):
                    pruned_duration += item["duration"]
                    pruned_items += 1
                    continue

            if ignore_file and (audio_id in wavs_to_ignore):
                pruned_items += 1
                pruned_duration += item["duration"]
                wavs_to_ignore.remove(audio_id)
                continue

            self.data.append(item)

        logging.info(
            f"Pruned {pruned_items} files. Final dataset contains {len(self.data)} files"
        )
        if pruned_duration is not None:
            logging.info(
                f"Pruned {pruned_duration / 3600:.2f} hours. Final dataset contains "
                f"{(total_duration - pruned_duration) / 3600:.2f} hours.")

        self.sample_rate = sample_rate
        self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate)
        self.trim = trim

        self.n_fft = n_fft
        self.n_mels = n_mels
        self.lowfreq = lowfreq
        self.highfreq = highfreq
        self.window = window
        self.win_length = win_length or self.n_fft
        self.hop_length = hop_length
        self.hop_len = self.hop_length or self.n_fft // 4
        self.fb = torch.tensor(
            librosa.filters.mel(self.sample_rate,
                                self.n_fft,
                                n_mels=self.n_mels,
                                fmin=self.lowfreq,
                                fmax=self.highfreq),
            dtype=torch.float,
        ).unsqueeze(0)

        window_fn = {
            'hann': torch.hann_window,
            'hamming': torch.hamming_window,
            'blackman': torch.blackman_window,
            'bartlett': torch.bartlett_window,
            'none': None,
        }.get(self.window, None)

        self.stft = lambda x: stft_patch(
            input=x,
            n_fft=self.n_fft,
            hop_length=self.hop_len,
            win_length=self.win_length,
            window=window_fn(self.win_length, periodic=False).to(torch.float)
            if window_fn else None,
        )

        for data_type in self.sup_data_types:
            if data_type not in VALID_SUPPLEMENTARY_DATA_TYPES:
                raise NotImplementedError(
                    f"Current implementation of TTSDataset doesn't support {data_type} type."
                )

            getattr(self, f"add_{data_type.name}")(**kwargs)
コード例 #7
0
ファイル: data.py プロジェクト: mousebaiker/NeMo
    def __init__(
        self,
        manifest_filepath: str,
        sample_rate: int,
        supplementary_folder: Path,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        ignore_file: Optional[str] = None,
        trim: bool = False,
        n_fft=1024,
        win_length=None,
        hop_length=None,
        window="hann",
        n_mels=64,
        lowfreq=0,
        highfreq=None,
        pitch_fmin=80,
        pitch_fmax=640,
        pitch_avg=0,
        pitch_std=1,
        tokenize_text=True,
    ):
        """Dataset that loads audio, log mel specs, text tokens, duration / attention priors, pitches, and energies.
        Log mels, priords, pitches, and energies will be computed on the fly and saved in the supplementary_folder if
        they did not exist before.

        Args:
            manifest_filepath (str, Path, List[str, Path]): Path(s) to the .json manifests containing information on the
                dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
                json. Each line should contain the following:
                    "audio_filepath": <PATH_TO_WAV>
                    "mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional)
                    "duration": <Duration of audio clip in seconds> (Optional)
                    "text": <THE_TRANSCRIPT> (Optional)
            sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
            supplementary_folder (Path): A folder that contains or will contain extra information such as log_mel if not
                specified in the manifest .json file. It will also contain priors, pitches, and energies
            max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            ignore_file (Optional[str, Path]): The location of a pickle-saved list of audio_ids (the stem of the audio
                files) that will be pruned prior to training. Defaults to None which does not prune.
            trim (Optional[bool]): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
            n_fft (Optional[int]): The number of fft samples. Defaults to 1024
            win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft.
            hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4.
            window (Optional[str]): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the
                equivalent torch window function.
            n_mels (Optional[int]): The number of mel filters. Defaults to 64.
            lowfreq (Optional[int]): The lowfreq input to the mel filter calculation. Defaults to 0.
            highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None.
            pitch_fmin (Optional[int]): The fmin input to librosa.pyin. Defaults to None.
            pitch_fmax (Optional[int]): The fmax input to librosa.pyin. Defaults to None.
            pitch_avg (Optional[float]): The mean that we use to normalize the pitch. Defaults to 0.
            pitch_std (Optional[float]): The std that we use to normalize the pitch. Defaults to 1.
            tokenize_text (Optional[bool]): Whether to tokenize (turn chars into ints). Defaults to True.
        """
        super().__init__()

        self.pitch_fmin = pitch_fmin
        self.pitch_fmax = pitch_fmax
        self.pitch_avg = pitch_avg
        self.pitch_std = pitch_std
        self.win_length = win_length or n_fft
        self.sample_rate = sample_rate
        self.hop_len = hop_length or n_fft // 4

        self.parser = make_parser(name="en", do_tokenize=tokenize_text)
        self.pad_id = self.parser._blank_id
        Path(supplementary_folder).mkdir(parents=True, exist_ok=True)
        self.supplementary_folder = supplementary_folder

        audio_files = []
        total_duration = 0
        # Load data from manifests
        # Note: audio is always required, even for text -> mel_spectrogram models, due to the fact that most models
        # extract pitch from the audio
        # Note: mel_filepath is not required and if not present, we then check the supplementary folder. If we fail, we
        # compute the mel on the fly and save it to the supplementary folder
        # Note: text is not required. Any models that require on text (spectrogram generators, end-to-end models) will
        # fail if not set. However vocoders (mel -> audio) will be able to work without text
        if isinstance(manifest_filepath, str):
            manifest_filepath = [manifest_filepath]
        for manifest_file in manifest_filepath:
            with open(Path(manifest_file).expanduser(), 'r') as f:
                logging.info(f"Loading dataset from {manifest_file}.")
                for line in f:
                    item = json.loads(line)
                    # Grab audio, text, mel if they exist
                    file_info = {}
                    file_info["audio_filepath"] = item["audio_filepath"]
                    file_info["mel_filepath"] = item[
                        "mel_filepath"] if "mel_filepath" in item else None
                    file_info["duration"] = item[
                        "duration"] if "duration" in item else None
                    # Parse text
                    file_info["text_tokens"] = None
                    if "text" in item:
                        text = item["text"]
                        text_tokens = self.parser(text)
                        file_info["text_tokens"] = text_tokens
                    audio_files.append(file_info)
                    if file_info["duration"] is None:
                        logging.info(
                            "Not all audio files have duration information. Duration logging will be disabled."
                        )
                        total_duration = None
                    if total_duration is not None:
                        total_duration += item["duration"]

        logging.info(f"Loaded dataset with {len(audio_files)} files.")
        if total_duration is not None:
            logging.info(f"Dataset contains {total_duration/3600:.2f} hours.")

        self.data = []

        if ignore_file:
            logging.info(f"using {ignore_file} to prune dataset.")
            with open(Path(ignore_file).expanduser(), "rb") as f:
                wavs_to_ignore = set(pickle.load(f))

        pruned_duration = 0 if total_duration is not None else None
        pruned_items = 0
        for item in audio_files:
            audio_path = item['audio_filepath']
            audio_id = Path(audio_path).stem

            # Prune data according to min/max_duration & the ignore file
            if total_duration is not None:
                if (min_duration and item["duration"] < min_duration) or (
                        max_duration and item["duration"] > max_duration):
                    pruned_duration += item["duration"]
                    pruned_items += 1
                    continue

            if ignore_file and (audio_id in wavs_to_ignore):
                pruned_items += 1
                pruned_duration += item["duration"]
                wavs_to_ignore.remove(audio_id)
                continue

            self.data.append(item)

        logging.info(
            f"Pruned {pruned_items} files. Final dataset contains {len(self.data)} files"
        )
        if pruned_duration is not None:
            logging.info(
                f"Pruned {pruned_duration/3600:.2f} hours. Final dataset contains "
                f"{(total_duration-pruned_duration)/3600:.2f} hours.")

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate)
        self.trim = trim

        filterbanks = torch.tensor(librosa.filters.mel(sample_rate,
                                                       n_fft,
                                                       n_mels=n_mels,
                                                       fmin=lowfreq,
                                                       fmax=highfreq),
                                   dtype=torch.float).unsqueeze(0)
        self.fb = filterbanks

        torch_windows = {
            'hann': torch.hann_window,
            'hamming': torch.hamming_window,
            'blackman': torch.blackman_window,
            'bartlett': torch.bartlett_window,
            'none': None,
        }
        window_fn = torch_windows.get(window, None)
        window_tensor = window_fn(self.win_length,
                                  periodic=False) if window_fn else None

        self.stft = lambda x: stft_patch(
            input=x,
            n_fft=n_fft,
            hop_length=self.hop_len,
            win_length=self.win_length,
            window=window_tensor.to(torch.float),
        )