Example #1
0
    def __init__(self,
                 manifest_path: str,
                 tar_filepaths: Union[str, List[str]],
                 shuffle_n: int = 128):
        self._manifest = collections.ASRAudioText(manifest_path,
                                                  parser=parsers.make_parser(
                                                      []),
                                                  index_by_file_id=True)

        if isinstance(tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in tar_filepaths:
                    tar_filepaths = tar_filepaths.replace(bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in tar_filepaths:
                    tar_filepaths = tar_filepaths.replace(bkey, "}")

        self.audio_dataset = (
            wd.Dataset(tar_filepaths).shuffle(shuffle_n).rename(
                audio='wav', key='__key__').to_tuple('audio', 'key'))
        self.audio_iter = iter(self.audio_dataset)
Example #2
0
    def __init__(
        self,
        manifest_filepath,
        labels,
        featurizer,
        max_duration=None,
        min_duration=None,
        max_utts=0,
        blank_index=-1,
        unk_index=-1,
        normalize=True,
        trim=False,
        bos_id=None,
        eos_id=None,
        load_audio=True,
        parser='en',
        add_misc=False,
    ):
        self.collection = collections.ASRAudioText(
            manifests_files=manifest_filepath.split(','),
            parser=parsers.make_parser(
                labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize,
            ),
            min_duration=min_duration,
            max_duration=max_duration,
            max_number=max_utts,
        )

        self.featurizer = featurizer
        self.trim = trim
        self.eos_id = eos_id
        self.bos_id = bos_id
        self.load_audio = load_audio
        self._add_misc = add_misc
Example #3
0
    def __init__(self, manifest_path: str, tar_filepaths: Union[str, List[str]], shuffle_n: int = 128):
        self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True)

        if isinstance(tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in tar_filepaths:
                    tar_filepaths = tar_filepaths.replace(bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in tar_filepaths:
                    tar_filepaths = tar_filepaths.replace(bkey, "}")

        self.audio_dataset = wd.WebDataset(urls=tar_filepaths, nodesplitter=None)

        if shuffle_n > 0:
            self.audio_dataset = self.audio_dataset.shuffle(shuffle_n)
        else:
            logging.info("WebDataset will not shuffle files within the tar files.")

        self.audio_dataset = self.audio_dataset.rename(audio='wav', key='__key__').to_tuple('audio', 'key')
        self.audio_iter = iter(self.audio_dataset)
Example #4
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(FastSpeech2Config)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig")
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.pitch = cfg.add_pitch_predictor
        self.energy = cfg.add_energy_predictor
        self.duration_coeff = cfg.duration_coeff

        self.audio_to_melspec_preprocessor = instantiate(self._cfg.preprocessor)
        self.encoder = instantiate(self._cfg.encoder)
        self.mel_decoder = instantiate(self._cfg.decoder)
        self.variance_adapter = instantiate(self._cfg.variance_adaptor)
        self.loss = L2MelLoss()
        self.mseloss = torch.nn.MSELoss()
        self.durationloss = DurationLoss()

        self.log_train_images = False

        # Parser and mappings are used for inference only.
        self.parser = parsers.make_parser(name='en')
        with open(cfg.mappings_filepath, 'r') as f:
            mappings = json.load(f)
            self.word2phones = mappings['word2phones']
            self.phone2idx = mappings['phone2idx']
Example #5
0
    def __init__(
        self,
        manifest_path=None,
        min_snr_db=10,
        max_snr_db=50,
        max_gain_db=300.0,
        rng=None,
        audio_tar_filepaths=None,
        shuffle_n=100,
        orig_sr=16000,
    ):
        self._manifest = collections.ASRAudioText(manifest_path,
                                                  parser=parsers.make_parser(
                                                      []),
                                                  index_by_file_id=True)
        self._audiodataset = None
        self._tarred_audio = False
        self._orig_sr = orig_sr
        self._data_iterator = None

        if audio_tar_filepaths:
            self._tarred_audio = True
            self._audiodataset = AugmentationDataset(manifest_path,
                                                     audio_tar_filepaths,
                                                     shuffle_n)
            self._data_iterator = iter(self._audiodataset)

        self._rng = random.Random() if rng is None else rng
        self._min_snr_db = min_snr_db
        self._max_snr_db = max_snr_db
        self._max_gain_db = max_gain_db
def main():
    filelist_base = 'https://raw.githubusercontent.com/NVIDIA/tacotron2/master/filelists/'
    filelists = ['train', 'val', 'test']

    # NeMo parser for text normalization
    text_parser = parsers.make_parser(name='en')

    for split in filelists:
        # Download file list if necessary
        filelist_path = os.path.join(args.ljspeech_base, f"ljs_audio_text_{split}_filelist.txt")
        if not os.path.exists(filelist_path):
            wget.download(f"{filelist_base}/ljs_audio_text_{split}_filelist.txt", out=args.ljspeech_base)

        manifest_target = os.path.join(args.ljspeech_base, f"ljspeech_{split}.json")
        with open(manifest_target, 'w') as f_out:
            with open(filelist_path, 'r') as filelist:
                print(f"\nCreating {manifest_target}...")
                for line in filelist:
                    basename = line[6:16]
                    text = text_parser._normalize(line[21:].strip())

                    # Make sure corresponding wavfile exists and write .txt transcript
                    wav_path = os.path.join(args.ljspeech_base, 'wavs/', basename + '.wav')
                    assert os.path.exists(wav_path)
                    txt_path = os.path.join(args.ljspeech_base, 'wavs/', basename + '.txt')
                    with open(txt_path, 'w') as f_txt:
                        f_txt.write(text)

                    # Write manifest entry
                    entry = {
                        'audio_filepath': wav_path,
                        'duration': sox.file_info.duration(wav_path),
                        'text': text,
                    }
                    f_out.write(json.dumps(entry) + '\n')
Example #7
0
    def _loader(self, cfg):
        parser = parsers.make_parser(
            labels=self._cfg.labels,
            name='en',
            unk_id=-1,
            blank_id=-1,
            do_normalize=True,
            abbreviation_version="fastpitch",
            make_table=False,
        )

        dataset = FastPitchDataset(
            manifest_filepath=cfg['manifest_filepath'],
            parser=parser,
            sample_rate=cfg['sample_rate'],
            int_values=cfg.get('int_values', False),
            max_duration=cfg.get('max_duration', None),
            min_duration=cfg.get('min_duration', None),
            max_utts=cfg.get('max_utts', 0),
            trim=cfg.get('trim_silence', True),
            load_audio=cfg.get('load_audio', True),
            add_misc=cfg.get('add_misc', False),
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=cfg['batch_size'],
            collate_fn=dataset.collate_fn,
            drop_last=cfg.get('drop_last', True),
            shuffle=cfg['shuffle'],
            num_workers=cfg.get('num_workers', 16),
        )
Example #8
0
    def __init__(
        self,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        labels: List[str],
        sample_rate: int,
        int_values: bool = False,
        augmentor: Optional[
            'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
        shuffle_n: int = 0,
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_utts: int = 0,
        blank_index: int = -1,
        unk_index: int = -1,
        normalize: bool = True,
        trim: bool = False,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        parser: Optional[str] = 'en',
        add_misc: bool = False,
        pad_id: int = 0,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
    ):
        self.labels = labels

        parser = parsers.make_parser(labels=labels,
                                     name=parser,
                                     unk_id=unk_index,
                                     blank_id=blank_index,
                                     do_normalize=normalize)

        super().__init__(
            audio_tar_filepaths=audio_tar_filepaths,
            manifest_filepath=manifest_filepath,
            parser=parser,
            sample_rate=sample_rate,
            int_values=int_values,
            augmentor=augmentor,
            shuffle_n=shuffle_n,
            min_duration=min_duration,
            max_duration=max_duration,
            max_utts=max_utts,
            trim=trim,
            bos_id=bos_id,
            eos_id=eos_id,
            add_misc=add_misc,
            pad_id=pad_id,
            shard_strategy=shard_strategy,
            global_rank=global_rank,
            world_size=world_size,
        )
Example #9
0
    def __init__(
        self,
        manifest_filepath: Union[str, 'pathlib.Path'],
        n_segments: int,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        trim: Optional[bool] = False,
        truncate_to: Optional[int] = 1,
    ):
        """
        See above AudioDataset for details on dataset and manifest formats.

        Unlike the regular AudioDataset, which samples random segments from each audio array as an example,
        SplicedAudioDataset concatenates all audio arrays together and indexes segments as examples. This way,
        the model sees more data (about 9x for LJSpeech) per epoch.

        Note: this class is not recommended to be used in validation.

        Args:
            manifest_filepath (str, Path): Path to manifest json as described above. Can be comma-separated paths
                such as "train_1.json,train_2.json" which is treated as two separate json files.
            n_segments (int): The length of audio in samples to load. For example, given a sample rate of 16kHz, and
                n_segments=16000, a random 1 second section of audio from the clip will be loaded. The section will
                be randomly sampled everytime the audio is batched. Can be set to -1 to load the entire audio.
            max_duration (float): If audio exceeds this length in seconds, it is filtered from the dataset.
                Defaults to None, which does not filter any audio.
            min_duration(float): If audio is less than this length in seconds, it is filtered from the dataset.
                Defaults to None, which does not filter any audio.
            trim (bool): Whether to use librosa.effects.trim on the audio clip
            truncate_to (int): Ensures that the audio segment returned is a multiple of truncate_to.
                Defaults to 1, which does no truncating.
        """
        assert n_segments > 0

        collection = collections.ASRAudioText(
            manifests_files=manifest_filepath.split(','),
            parser=parsers.make_parser(),
            min_duration=min_duration,
            max_duration=max_duration,
        )
        self.trim = trim
        self.n_segments = n_segments
        self.truncate_to = truncate_to

        self.samples = []
        for index in range(len(collection)):
            example = collection[index]
            with sf.SoundFile(example.audio_file, 'r') as f:
                samples = f.read(dtype='float32').transpose()
                self.samples.append(samples)
        self.samples = np.concatenate(self.samples, axis=0)
        self.samples = self.samples[:self.samples.shape[0] -
                                    (self.samples.shape[0] % self.n_segments),
                                    ...]
Example #10
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        self.encoder = instantiate(cfg.encoder)
        self.variance_adapter = instantiate(cfg.variance_adaptor)

        self.generator = instantiate(cfg.generator)
        self.multiperioddisc = MultiPeriodDiscriminator()
        self.multiscaledisc = MultiScaleDiscriminator()

        self.melspec_fn = instantiate(cfg.preprocessor,
                                      highfreq=None,
                                      use_grads=True)
        self.mel_val_loss = L1MelLoss()
        self.durationloss = DurationLoss()
        self.feat_matching_loss = FeatureMatchingLoss()
        self.disc_loss = DiscriminatorLoss()
        self.gen_loss = GeneratorLoss()
        self.mseloss = torch.nn.MSELoss()

        self.energy = cfg.add_energy_predictor
        self.pitch = cfg.add_pitch_predictor
        self.mel_loss_coeff = cfg.mel_loss_coeff
        self.pitch_loss_coeff = cfg.pitch_loss_coeff
        self.energy_loss_coeff = cfg.energy_loss_coeff
        self.splice_length = cfg.splice_length

        self.use_energy_pred = False
        self.use_pitch_pred = False
        self.log_train_images = False
        self.logged_real_samples = False
        self._tb_logger = None
        self.sample_rate = cfg.sample_rate
        self.hop_size = cfg.hop_size

        # Parser and mappings are used for inference only.
        self.parser = parsers.make_parser(name='en')
        if 'mappings_filepath' in cfg:
            mappings_filepath = cfg.get('mappings_filepath')
        else:
            logging.error(
                "ERROR: You must specify a mappings.json file in the config file under model.mappings_filepath."
            )
        mappings_filepath = self.register_artifact('mappings_filepath',
                                                   mappings_filepath)
        with open(mappings_filepath, 'r') as f:
            mappings = json.load(f)
            self.word2phones = mappings['word2phones']
            self.phone2idx = mappings['phone2idx']
Example #11
0
    def __init__(self, manifest_path=None, rng=None, audio_tar_filepaths=None, shuffle_n=128, shift_impulse=False):
        self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True)
        self._audiodataset = None
        self._tarred_audio = False
        self._shift_impulse = shift_impulse
        self._data_iterator = None

        if audio_tar_filepaths:
            self._tarred_audio = True
            self._audiodataset = AugmentationDataset(manifest_path, audio_tar_filepaths, shuffle_n)
            self._data_iterator = iter(self._audiodataset)

        self._rng = random.Random() if rng is None else rng
Example #12
0
    def parser(self):
        if self._parser is not None:
            return self._parser

        self._parser = parsers.make_parser(
            labels=self._cfg.labels,
            name='en',
            unk_id=-1,
            blank_id=-1,
            do_normalize=True,
            abbreviation_version="fastpitch",
            make_table=False,
        )
        return self._parser
Example #13
0
File: perturb.py Project: vsl9/NeMo
 def __init__(
     self,
     manifest_path=None,
     min_snr_db=40,
     max_snr_db=50,
     max_gain_db=300.0,
     rng=None,
 ):
     self._manifest = collections.ASRAudioText(manifest_path,
                                               parser=parsers.make_parser(
                                                   []))
     self._rng = random.Random() if rng is None else rng
     self._min_snr_db = min_snr_db
     self._max_snr_db = max_snr_db
     self._max_gain_db = max_gain_db
Example #14
0
    def __init__(
        self,
        manifest_filepath: str,
        labels: Union[str, List[str]],
        sample_rate: int,
        int_values: bool = False,
        augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        max_utts: int = 0,
        blank_index: int = -1,
        unk_index: int = -1,
        normalize: bool = True,
        trim: bool = False,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        pad_id: int = 0,
        load_audio: bool = True,
        parser: Union[str, Callable] = 'en',
        add_misc: bool = False,
    ):
        self.labels = labels

        parser = parsers.make_parser(
            labels=labels,
            name=parser,
            unk_id=unk_index,
            blank_id=blank_index,
            do_normalize=normalize,
        )

        super().__init__(
            manifest_filepath=manifest_filepath,
            parser=parser,
            sample_rate=sample_rate,
            int_values=int_values,
            augmentor=augmentor,
            max_duration=max_duration,
            min_duration=min_duration,
            max_utts=max_utts,
            trim=trim,
            bos_id=bos_id,
            eos_id=eos_id,
            pad_id=pad_id,
            load_audio=load_audio,
            add_misc=add_misc,
        )
Example #15
0
 def __init__(
     self,
     manifest_filepath,
     n_segments=0,
     max_duration=None,
     min_duration=None,
     trim=False,
 ):
     """See AudioDataLayer"""
     self.collection = collections.ASRAudioText(
         manifests_files=manifest_filepath.split(','),
         parser=parsers.make_parser(),
         min_duration=min_duration,
         max_duration=max_duration,
     )
     self.trim = trim
     self.n_segments = n_segments
Example #16
0
    def __init__(
        self,
        manifest_filepath: Union[str, "pathlib.Path"],
        n_segments: int,
        max_duration: Optional[float] = None,
        min_duration: Optional[float] = None,
        trim: Optional[bool] = False,
        truncate_to: Optional[int] = 1,
    ):
        """
        Mostly compliant with nemo.collections.asr.data.datalayers.AudioToTextDataset except it only returns Audio
        without text. Dataset that loads tensors via a json file containing paths to audio files, transcripts, and
        durations (in seconds). Each new line is a different sample. Note that text is required, but is ignored for
        AudioDataset. Example below:
        {"audio_filepath": "/path/to/audio.wav", "text_filepath":
        "/path/to/audio.txt", "duration": 23.147}
        ...
        {"audio_filepath": "/path/to/audio.wav", "text": "the
        transcription", "offset": 301.75, "duration": 0.82, "utt":
        "utterance_id", "ctm_utt": "en_4156", "side": "A"}
        Args:
            manifest_filepath (str, Path): Path to manifest json as described above. Can be comma-separated paths
                such as "train_1.json,train_2.json" which is treated as two separate json files.
            n_segments (int): The length of audio in samples to load. For example, given a sample rate of 16kHz, and
                n_segments=16000, a random 1 second section of audio from the clip will be loaded. The section will
                be randomly sampled everytime the audio is batched. Can be set to -1 to load the entire audio.
            max_duration (float): If audio exceeds this length in seconds, it is filtered from the dataset.
                Defaults to None, which does not filter any audio.
            min_duration(float): If audio is less than this length in seconds, it is filtered from the dataset.
                Defaults to None, which does not filter any audio.
            trim (bool): Whether to use librosa.effects.trim on the audio clip
            truncate_to (int): Ensures that the audio segment returned is a multiple of truncate_to.
                Defaults to 1, which does no truncating.
        """

        self.collection = collections.ASRAudioText(
            manifests_files=manifest_filepath.split(","),
            parser=parsers.make_parser(),
            min_duration=min_duration,
            max_duration=max_duration,
        )
        self.trim = trim
        self.n_segments = n_segments
        self.truncate_to = truncate_to
Example #17
0
    def parser(self):
        if self._parser is not None:
            return self._parser
        if self._validation_dl is not None:
            return self._validation_dl.dataset.parser
        if self._test_dl is not None:
            return self._test_dl.dataset.parser
        if self._train_dl is not None:
            return self._train_dl.dataset.parser

        # Else construct a parser
        # Try to get params from validation, test, and then train
        params = {}
        try:
            params = self._cfg.validation_ds.dataset.params
        except ConfigAttributeError:
            pass
        if params == {}:
            try:
                params = self._cfg.test_ds.dataset.params
            except ConfigAttributeError:
                pass
        if params == {}:
            try:
                params = self._cfg.train_ds.dataset.params
            except ConfigAttributeError:
                pass

        name = params.get('parser', None) or params.get('parser', None) or 'en'
        unk_id = params.get('unk_index', None) or params.get(
            'unk_index', None) or -1
        blank_id = params.get('blank_index', None) or params.get(
            'blank_index', None) or -1
        do_normalize = params.get('normalize', None) or params.get(
            'normalize', None) or False
        self._parser = parsers.make_parser(
            labels=self._cfg.labels,
            name=name,
            unk_id=unk_id,
            blank_id=blank_id,
            do_normalize=do_normalize,
        )
        return self._parser
Example #18
0
    def __init__(self,
                 text_batch: List[str],
                 labels: List[str],
                 bos_id: Optional[int] = None,
                 eos_id: Optional[int] = None,
                 lowercase: bool = True) -> None:
        """Text dataset reader for TextDataLayer.

        Args:
            text_batch: Texts to be used for speech synthesis.
            labels: List of string labels to use when to str2int translation.
            bos_id: Label position of beginning of string symbol.
            eos_id: Label position of end of string symbol.
            lowercase: Whether to convert all uppercase characters in a text batch into lowercase characters.

        """
        parser = parsers.make_parser(labels, do_lowercase=lowercase)
        self.texts = collections.Text(text_batch, parser)
        self.bos_id = bos_id
        self.eos_id = eos_id
Example #19
0
    def __init__(
        self,
        manifest_filepath: str,
        device: str,
        batch_size: int,
        labels: Union[str, List[str]],
        sample_rate: int = 16000,
        num_threads: int = 4,
        max_duration: float = 0.0,
        min_duration: float = 0.0,
        blank_index: int = -1,
        unk_index: int = -1,
        normalize: bool = True,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        trim: bool = False,
        shuffle: bool = True,
        drop_last: bool = False,
        parser: Union[str, Callable] = 'en',
        device_id: int = 0,
        global_rank: int = 0,
        world_size: int = 1,
        preprocessor_cfg: DictConfig = None,
    ):
        if not HAVE_DALI:
            raise ModuleNotFoundError(
                f"{self} requires NVIDIA DALI to be installed. "
                f"See: https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html#id1"
            )

        if device not in ('cpu', 'gpu'):
            raise ValueError(
                f"{self} received an unexpected device argument {device}. Supported values are: 'cpu', 'gpu'"
            )

        self.batch_size = batch_size  # Used by NeMo

        self.device = device
        self.device_id = device_id

        if world_size > 1:
            self.shard_id = global_rank
            self.num_shards = world_size
        else:
            self.shard_id = None
            self.num_shards = None

        self.labels = labels
        if self.labels is None or len(self.labels) == 0:
            raise ValueError(f"{self} expects non empty labels list")

        self.parser = parsers.make_parser(
            labels=labels,
            name=parser,
            unk_id=unk_index,
            blank_id=blank_index,
            do_normalize=normalize,
        )

        self.eos_id = eos_id
        self.bos_id = bos_id
        self.sample_rate = sample_rate

        self.pipe = Pipeline(
            batch_size=batch_size,
            num_threads=num_threads,
            device_id=self.device_id,
            exec_async=True,
            exec_pipelined=True,
        )

        has_preprocessor = preprocessor_cfg is not None
        if has_preprocessor:
            if preprocessor_cfg.cls == "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor":
                feature_type = "mel_spectrogram"
            elif preprocessor_cfg.cls == "nemo.collections.asr.modules.AudioToMFCCPreprocessor":
                feature_type = "mfcc"
            else:
                raise ValueError(
                    f"{self} received an unexpected preprocessor configuration: {preprocessor_cfg.cls}."
                    f" Supported preprocessors are: AudioToMelSpectrogramPreprocessor, AudioToMFCCPreprocessor"
                )

            # Default values taken from AudioToMelSpectrogramPreprocessor
            params = preprocessor_cfg.params
            self.dither = params['dither'] if 'dither' in params else 0.0
            self.preemph = params['preemph'] if 'preemph' in params else 0.97
            self.window_size_sec = params[
                'window_size'] if 'window_size' in params else 0.02
            self.window_stride_sec = params[
                'window_stride'] if 'window_stride' in params else 0.01
            self.sample_rate = params[
                'sample_rate'] if 'sample_rate' in params else sample_rate
            self.window_size = int(self.window_size_sec * self.sample_rate)
            self.window_stride = int(self.window_size_sec * self.sample_rate)

            normalize = params[
                'normalize'] if 'normalize' in params else 'per_feature'
            if normalize == 'per_feature':  # Each freq channel independently
                self.normalization_axes = (1, )
            elif normalize == 'all_features':
                self.normalization_axes = (0, 1)
            else:
                raise ValueError(
                    f"{self} received {normalize} for the normalize parameter."
                    f" It must be either 'per_feature' or 'all_features'.")

            self.window = None
            window_name = params['window'] if 'window' in params else None
            torch_windows = {
                'hamming': torch.hamming_window,
                'blackman': torch.blackman_window,
                'bartlett': torch.bartlett_window,
            }
            if window_name is None or window_name == 'hann':
                self.window = None  # Hann is DALI's default
            elif window_name == 'ones':
                self.window = torch.ones(self.window_size)
            else:
                try:
                    window_fn = torch_windows.get(window_name, None)
                    self.window = window_fn(self.window_size, periodic=False)
                except:
                    raise ValueError(
                        f"{self} received {window_name} for the window parameter."
                        f" It must be one of: ('hann', 'ones', 'hamming', 'blackman', 'bartlett', None)."
                        f" None is equivalent to 'hann'.")

            self.n_fft = params[
                'n_fft'] if 'n_fft' in params else None  # None means default
            self.n_mels = params['n_mels'] if 'n_mels' in params else 64
            self.n_mfcc = params['n_mfcc'] if 'n_mfcc' in params else 64

            features = params['features'] if 'features' in params else 0
            if features > 0:
                if feature_type == 'mel_spectrogram':
                    self.n_mels = features
                elif feature_type == 'mfcc':
                    self.n_mfcc = features

            # TODO Implement frame splicing
            if 'frame_splicing' in params:
                assert params[
                    'frame_splicing'] == 1, "Frame splicing is not implemented"

            self.freq_low = params['lowfreq'] if 'lowfreq' in params else 0.0
            self.freq_high = params[
                'highfreq'] if 'highfreq' in params else self.sample_rate / 2.0
            self.log_features = params['log'] if 'log' in params else True

            # We want to avoid taking the log of zero
            # There are two options: either adding or clamping to a small value

            self.log_zero_guard_type = params[
                'log_zero_guard_type'] if 'log_zero_guard_type' in params else 'add'
            if self.log_zero_guard_type not in ["add", "clamp"]:
                raise ValueError(
                    f"{self} received {self.log_zero_guard_type} for the "
                    f"log_zero_guard_type parameter. It must be either 'add' or "
                    f"'clamp'.")

            self.log_zero_guard_value = params[
                'log_zero_guard_value'] if 'log_zero_guard_value' in params else 1e-05
            if isinstance(self.log_zero_guard_value, str):
                if self.log_zero_guard_value == "tiny":
                    self.log_zero_guard_value = torch.finfo(torch.float32).tiny
                elif self.log_zero_guard_value == "eps":
                    self.log_zero_guard_value = torch.finfo(torch.float32).eps
                else:
                    raise ValueError(
                        f"{self} received {self.log_zero_guard_value} for the log_zero_guard_type parameter."
                        f"It must be either a number, 'tiny', or 'eps'")

            self.mag_power = params['mag_power'] if 'mag_power' in params else 2
            if self.mag_power != 1.0 and self.mag_power != 2.0:
                raise ValueError(
                    f"{self} received {self.mag_power} for the mag_power parameter."
                    f" It must be either 1.0 or 2.0.")

            self.pad_to = params['pad_to'] if 'pad_to' in params else 16
            self.pad_value = params[
                'pad_value'] if 'pad_value' in params else 0.0

        with self.pipe:
            audio, transcript = dali.fn.nemo_asr_reader(
                name="Reader",
                manifest_filepaths=manifest_filepath.split(','),
                dtype=dali.types.FLOAT,
                downmix=True,
                sample_rate=float(self.sample_rate),
                min_duration=min_duration,
                max_duration=max_duration,
                read_sample_rate=False,
                read_text=True,
                random_shuffle=shuffle,
                shard_id=self.shard_id,
                num_shards=self.num_shards,
                pad_last_batch=True,
            )

            transcript_len = dali.fn.shapes(
                dali.fn.reshape(transcript, shape=[-1]))
            transcript = dali.fn.pad(transcript)

            # Extract nonsilent region, if necessary
            if trim:
                # Need to extract non-silent region before moving to the GPU
                roi_start, roi_len = dali.fn.nonsilent_region(audio,
                                                              cutoff_db=-60)
                audio = audio.gpu() if self.device == 'gpu' else audio
                audio = dali.fn.slice(audio,
                                      roi_start,
                                      roi_len,
                                      normalized_anchor=False,
                                      normalized_shape=False,
                                      axes=[0])
            else:
                audio = audio.gpu() if self.device == 'gpu' else audio

            if not has_preprocessor:
                # No preprocessing, the output is the audio signal
                audio = dali.fn.pad(audio)
                audio_len = dali.fn.shapes(dali.fn.reshape(audio, shape=[-1]))
                self.pipe.set_outputs(audio, audio_len, transcript,
                                      transcript_len)
            else:
                # Additive gaussian noise (dither)
                if self.dither > 0.0:
                    gaussian_noise = dali.fn.normal_distribution(
                        device=self.device)
                    audio = audio + self.dither * gaussian_noise

                # Preemphasis filter
                if self.preemph > 0.0:
                    audio = dali.fn.preemphasis_filter(
                        audio, preemph_coeff=self.preemph)

                # Power spectrogram
                spec = dali.fn.spectrogram(audio,
                                           nfft=self.n_fft,
                                           window_length=self.window_size,
                                           window_step=self.window_stride)

                if feature_type == 'mel_spectrogram' or feature_type == 'mfcc':
                    # Spectrogram to Mel Spectrogram
                    spec = dali.fn.mel_filter_bank(
                        spec,
                        sample_rate=self.sample_rate,
                        nfilter=self.n_mels,
                        normalize=True,
                        freq_low=self.freq_low,
                        freq_high=self.freq_high,
                    )
                    # Mel Spectrogram to MFCC
                    if feature_type == 'mfcc':
                        spec = dali.fn.mfcc(spec, n_mfcc=self.n_mfcc)

                # Logarithm
                if self.log_zero_guard_type == 'add':
                    spec = spec + self.log_zero_guard_value

                spec = dali.fn.to_decibels(spec,
                                           multiplier=math.log(10),
                                           reference=1.0,
                                           cutoff_db=math.log(
                                               self.log_zero_guard_value))

                # Normalization
                spec = dali.fn.normalize(spec, axes=self.normalization_axes)

                # Extracting the length of the spectrogram
                shape_start = dali.types.Constant(np.array([1],
                                                           dtype=np.float32),
                                                  device='cpu')
                shape_len = dali.types.Constant(np.array([1],
                                                         dtype=np.float32),
                                                device='cpu')
                spec_len = dali.fn.slice(
                    dali.fn.shapes(spec),
                    shape_start,
                    shape_len,
                    normalized_anchor=False,
                    normalized_shape=False,
                    axes=(0, ),
                )

                # Pads feature dimension to be a multiple of `pad_to` and the temporal dimension to be as big as the largest sample (shape -1)
                spec = dali.fn.pad(spec,
                                   fill_value=self.pad_value,
                                   axes=(0, 1),
                                   align=(self.pad_to, 1),
                                   shape=(1, -1))
            self.pipe.set_outputs(spec, spec_len, transcript, transcript_len)
        # Building DALI pipeline
        self.pipe.build()

        if has_preprocessor:
            output_names = [
                'processed_signal', 'processed_signal_len', 'transcript_raw',
                'transcript_raw_len'
            ]
        else:
            output_names = [
                'audio', 'audio_len', 'transcript_raw', 'transcript_raw_len'
            ]

        last_batch_policy = LastBatchPolicy.DROP if drop_last else LastBatchPolicy.PARTIAL
        self._iter = DALIPytorchIterator(
            [self.pipe],
            output_map=output_names,
            reader_name="Reader",
            last_batch_policy=last_batch_policy,
            dynamic_shape=True,
            auto_reset=True,
        )

        # TODO come up with a better solution
        class DummyDataset:
            def __init__(self, parent):
                self.parent = parent

            def __len__(self):
                return self.parent.size

        self.dataset = DummyDataset(self)  # Used by NeMo
Example #20
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)

        self._parser = parsers.make_parser(
            labels=cfg.labels,
            name='en',
            unk_id=-1,
            blank_id=-1,
            do_normalize=True,
            abbreviation_version="fastpitch",
            make_table=False,
        )

        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(FastPitchHifiGanE2EConfig)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
            )
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.preprocessor = instantiate(cfg.preprocessor)
        self.melspec_fn = instantiate(cfg.preprocessor,
                                      highfreq=None,
                                      use_grads=True)

        self.encoder = instantiate(cfg.input_fft)
        self.duration_predictor = instantiate(cfg.duration_predictor)
        self.pitch_predictor = instantiate(cfg.pitch_predictor)

        self.generator = instantiate(cfg.generator)
        self.multiperioddisc = MultiPeriodDiscriminator()
        self.multiscaledisc = MultiScaleDiscriminator()
        self.mel_val_loss = L1MelLoss()
        self.feat_matching_loss = FeatureMatchingLoss()
        self.disc_loss = DiscriminatorLoss()
        self.gen_loss = GeneratorLoss()

        self.max_token_duration = cfg.max_token_duration

        self.pitch_emb = torch.nn.Conv1d(
            1,
            cfg.symbols_embedding_dim,
            kernel_size=cfg.pitch_embedding_kernel_size,
            padding=int((cfg.pitch_embedding_kernel_size - 1) / 2),
        )

        # Store values precomputed from training data for convenience
        self.register_buffer('pitch_mean', torch.zeros(1))
        self.register_buffer('pitch_std', torch.zeros(1))

        self.loss = BaseFastPitchLoss()

        self.mel_loss_coeff = cfg.mel_loss_coeff

        self.log_train_images = False
        self.logged_real_samples = False
        self._tb_logger = None
        self.hann_window = None
        self.splice_length = cfg.splice_length
        self.sample_rate = cfg.sample_rate
        self.hop_size = cfg.hop_size
Example #21
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default="QuartzNet15x5Base-En",
        choices=[
            x.pretrained_model_name
            for x in EncDecCTCModel.list_available_models()
        ],
    )
    parser.add_argument(
        "--tts_model_spec",
        type=str,
        default="Tacotron2-22050Hz",
        choices=[
            x.pretrained_model_name
            for x in SpectrogramGenerator.list_available_models()
        ],
    )
    parser.add_argument(
        "--tts_model_vocoder",
        type=str,
        default="WaveGlow-22050Hz",
        choices=[
            x.pretrained_model_name for x in Vocoder.list_available_models()
        ],
    )
    parser.add_argument("--wer_tolerance",
                        type=float,
                        default=1.0,
                        help="used by test")
    parser.add_argument("--trim", action="store_true")
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    if args.debug:
        logging.set_verbosity(logging.DEBUG)

    logging.info(f"Using NGC cloud ASR model {args.asr_model}")
    asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)
    logging.info(
        f"Using NGC cloud TTS Spectrogram Generator model {args.tts_model_spec}"
    )
    tts_model_spec = SpectrogramGenerator.from_pretrained(
        model_name=args.tts_model_spec)
    logging.info(f"Using NGC cloud TTS Vocoder model {args.tts_model_vocoder}")
    tts_model_vocoder = Vocoder.from_pretrained(
        model_name=args.tts_model_vocoder)
    models = [asr_model, tts_model_spec, tts_model_vocoder]

    if torch.cuda.is_available():
        for i, m in enumerate(models):
            models[i] = m.cuda()
    for m in models:
        m.eval()

    asr_model, tts_model_spec, tts_model_vocoder = models

    parser = parsers.make_parser(
        labels=asr_model.decoder.vocabulary,
        name="en",
        unk_id=-1,
        blank_id=-1,
        do_normalize=True,
    )
    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])

    tts_input = []
    asr_references = []
    longest_tts_input = 0
    for test_str in LIST_OF_TEST_STRINGS:
        tts_parsed_input = tts_model_spec.parse(test_str)
        if len(tts_parsed_input[0]) > longest_tts_input:
            longest_tts_input = len(tts_parsed_input[0])
        tts_input.append(tts_parsed_input.squeeze())

        asr_parsed = parser(test_str)
        asr_parsed = ''.join([labels_map[c] for c in asr_parsed])
        asr_references.append(asr_parsed)

    # Pad TTS Inputs
    for i, text in enumerate(tts_input):
        pad = (0, longest_tts_input - len(text))
        tts_input[i] = torch.nn.functional.pad(text, pad, value=68)

    logging.debug(tts_input)

    # Do TTS
    tts_input = torch.stack(tts_input)
    if torch.cuda.is_available():
        tts_input = tts_input.cuda()
    specs = tts_model_spec.generate_spectrogram(tokens=tts_input)
    audio = []
    step = ceil(len(specs) / 4)
    for i in range(4):
        audio.append(
            tts_model_vocoder.convert_spectrogram_to_audio(
                spec=specs[i * step:i * step + step]))

    audio = [item for sublist in audio for item in sublist]
    audio_file_paths = []
    # Save audio
    logging.debug(f"args.trim: {args.trim}")
    for i, aud in enumerate(audio):
        aud = aud.cpu().numpy()
        if args.trim:
            aud = librosa.effects.trim(aud, top_db=40)[0]
        librosa.output.write_wav(f"{i}.wav", aud, sr=22050)
        audio_file_paths.append(str(Path(f"{i}.wav")))

    # Do ASR
    hypotheses = asr_model.transcribe(audio_file_paths)
    for i, _ in enumerate(hypotheses):
        logging.debug(f"{i}")
        logging.debug(f"ref:'{asr_references[i]}'")
        logging.debug(f"hyp:'{hypotheses[i]}'")
    wer_value = word_error_rate(hypotheses=hypotheses,
                                references=asr_references)
    if wer_value > args.wer_tolerance:
        raise ValueError(
            f"Got WER of {wer_value}. It was higher than {args.wer_tolerance}")
    logging.info(f'Got WER of {wer_value}. Tolerance was {args.wer_tolerance}')
Example #22
0
    def test_transcript_normalizers(self):
        # Create test json
        test_strings = [
            "TEST CAPITALIZATION",
            '!\\"#$%&\'()*+,-./:;<=>?@[\\\\]^_`{|}~',
            "3+3=10",
            "3 + 3 = 10",
            "why     is \\t whitepsace\\tsuch a problem   why indeed",
            "\\\"Can you handle quotes?,\\\" says the boy",
            "I Jump!!!!With joy?Now.",
            "Maybe I want to learn periods.",
            "$10 10.90 1-800-000-0000",
            "18000000000 one thousand 2020",
            "1 10 100 1000 10000 100000 1000000",
            "Î  ĻƠvɆȩȅĘ ÀÁÃ Ą ÇĊňńŤŧș",
            "‘’“”❛❜❝❞「 」 〈 〉 《 》 【 】 〔 〕 ⦗ ⦘ 😙  👀 🔨",
            "It only costs $1 000 000! Cheap right?",
            "2500, 3000 are separate but 200, 125 is not",
            "1",
            "1 2",
            "1 2 3",
            "10:00pm is 10:00 pm is 22:00 but not 10: 00 pm",
            "10:00 10:01pm 10:10am 10:90pm",
            "Mr. Expand me!",
            "Mr Don't Expand me!",
        ]
        normalized_strings = [
            "test capitalization",
            'percent and \' plus',
            "three plus three ten",
            "three plus three ten",
            "why is whitepsace such a problem why indeed",
            "can you handle quotes says the boy",
            "i jump with joy now",
            "maybe i want to learn periods",
            "ten dollars ten point nine zero one eight hundred zero zero",
            "eighteen billion one thousand two thousand and twenty",
            # Two line string below
            "one ten thousand one hundred one thousand ten thousand one "
            "hundred thousand one million",
            "i loveeee aaa a ccnntts",
            "''",
            "it only costs one million dollars cheap right",
            # Two line string below
            "two thousand five hundred three thousand are separate but two "
            "hundred thousand one hundred and twenty five is not",
            "one",
            "one two",
            "one two three",
            "ten pm is ten pm is twenty two but not ten zero pm",
            "ten ten one pm ten ten am ten ninety pm",
            "mister expand me",
            "mr don't expand me",
        ]
        manifest_paths = os.path.abspath(
            os.path.join(os.path.dirname(__file__),
                         "../data/asr/manifest_test.json"))

        def remove_test_json():
            os.remove(manifest_paths)

        self.addCleanup(remove_test_json)

        with open(manifest_paths, "w") as f:
            for s in test_strings:
                f.write('{"audio_filepath": "", "duration": 1.0, "text": '
                        f'"{s}"}}\n')
        parser = parsers.make_parser(self.labels, 'en')
        manifest = collections.ASRAudioText(
            manifests_files=[manifest_paths],
            parser=parser,
        )

        for i, s in enumerate(normalized_strings):
            self.assertTrue(manifest[i].text_tokens == parser(s))
Example #23
0
    def __init__(
        self,
        kaldi_dir,
        labels,
        min_duration=None,
        max_duration=None,
        max_utts=0,
        unk_index=-1,
        blank_index=-1,
        normalize=True,
        eos_id=None,
    ):
        self.eos_id = eos_id
        self.unk_index = unk_index
        self.blank_index = blank_index
        self.labels_map = {label: i for i, label in enumerate(labels)}

        data = []
        duration = 0.0
        filtered_duration = 0.0

        # Read Kaldi features (MFCC, PLP) using feats.scp
        feats_path = os.path.join(kaldi_dir, 'feats.scp')
        id2feats = {utt_id: torch.from_numpy(feats) for utt_id, feats in kaldi_io.read_mat_scp(feats_path)}

        # Get durations, if utt2dur exists
        utt2dur_path = os.path.join(kaldi_dir, 'utt2dur')
        id2dur = {}
        if os.path.exists(utt2dur_path):
            with open(utt2dur_path, 'r') as f:
                for line in f:
                    utt_id, dur = line.split()
                    id2dur[utt_id] = float(dur)
        elif max_duration or min_duration:
            raise ValueError(
                f"KaldiFeatureDataset max_duration or min_duration is set but"
                f" utt2dur file not found in {kaldi_dir}."
            )
        else:
            logging.info(
                f"Did not find utt2dur when loading data from " f"{kaldi_dir}. Skipping dataset duration calculations."
            )

        # Match transcripts to features
        text_path = os.path.join(kaldi_dir, 'text')
        parser = parsers.make_parser(labels, 'en', unk_id=unk_index, blank_id=self.blank_index, do_normalize=normalize)
        with open(text_path, 'r') as f:
            for line in f:
                split_idx = line.find(' ')
                utt_id = line[:split_idx]

                audio_features = id2feats.get(utt_id)

                if audio_features is not None:

                    text = line[split_idx:].strip()
                    # if normalize:
                    #     # TODO: WTF?
                    #     text = parser._normalize(text)

                    dur = id2dur[utt_id] if id2dur else None

                    # Filter by duration if specified & utt2dur exists
                    if min_duration and dur < min_duration:
                        filtered_duration += dur
                        continue
                    if max_duration and dur > max_duration:
                        filtered_duration += dur
                        continue

                    sample = {
                        'utt_id': utt_id,
                        'text': text,
                        'tokens': parser(text),
                        'audio': audio_features.t(),
                        'duration': dur,
                    }

                    data.append(sample)
                    duration += dur

                    if max_utts > 0 and len(data) >= max_utts:
                        logging.warning(f"Stop parsing due to max_utts ({max_utts})")
                        break

        if id2dur:
            # utt2dur durations are in seconds
            logging.info(
                f"Dataset loaded with {duration / 3600 : .2f} hours. "
                f"Filtered {filtered_duration / 3600 : .2f} hours."
            )

        self.data = data
Example #24
0
    def __init__(self, path, labels, bos_id=None, eos_id=None, lowercase=True):
        parser = parsers.make_parser(labels, do_lowercase=lowercase)
        self.texts = collections.FromFileText(path, parser=parser)

        self.bos_id = bos_id
        self.eos_id = eos_id
Example #25
0
File: perturb.py Project: vsl9/NeMo
 def __init__(self, manifest_path=None, rng=None):
     self._manifest = collections.ASRAudioText(manifest_path,
                                               parser=parsers.make_parser(
                                                   []))
     self._rng = random.Random() if rng is None else rng