示例#1
0
    def __init__(
        self,
        manifest_filepath: str,
        parser: Union[str, Callable],
        sample_rate: int,
        int_values: bool = False,
        augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None,
        max_duration: Optional[int] = None,
        min_duration: Optional[int] = None,
        max_utts: int = 0,
        trim: bool = False,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        pad_id: int = 0,
    ):
        self.parser = parser

        self.collection = collections.ASRAudioText(
            manifests_files=manifest_filepath.split(','),
            parser=parser,
            min_duration=min_duration,
            max_duration=max_duration,
            max_number=max_utts,
        )

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)
        self.trim = trim
        self.eos_id = eos_id
        self.bos_id = bos_id
        self.pad_id = pad_id
示例#2
0
    def _setup_dataloader_from_config(self, config: Optional[Dict]):
        if config.get('manifest_filepath') is None:
            return

        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        featurizer = WaveformFeaturizer(sample_rate=config['sample_rate'],
                                        int_values=config.get(
                                            'int_values', False),
                                        augmentor=augmentor)
        dataset = AudioLabelDataset(
            manifest_filepath=config['manifest_filepath'],
            labels=config['labels'],
            featurizer=featurizer,
            max_duration=config.get('max_duration', None),
            min_duration=config.get('min_duration', None),
            trim=config.get('trim_silence', True),
            load_audio=config.get('load_audio', True),
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=config['batch_size'],
            collate_fn=dataset.collate_fn,
            drop_last=config.get('drop_last', False),
            shuffle=config['shuffle'],
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )
示例#3
0
    def __init__(self,
                 *,
                 manifest_filepath: str,
                 labels: List[str],
                 batch_size: int,
                 sample_rate: int = 16000,
                 int_values: bool = False,
                 num_workers: int = 0,
                 shuffle: bool = True,
                 min_duration: Optional[float] = 0.1,
                 max_duration: Optional[float] = None,
                 trim_silence: bool = False,
                 drop_last: bool = False,
                 load_audio: bool = True,
                 augmentor: Optional[Union[AudioAugmentor,
                                           Dict[str, Dict[str, Any]]]] = None,
                 num_classes: int = 35,
                 class_dists=None,
                 class_probs=None,
                 probs_num=0):
        super(BalancedAudioToSpeechLabelDataLayer, self).__init__()

        self._manifest_filepath = manifest_filepath
        self._labels = labels
        self._sample_rate = sample_rate

        if augmentor is not None:
            augmentor = self._process_augmentations(augmentor)

        self._featurizer = WaveformFeaturizer(sample_rate=sample_rate,
                                              int_values=int_values,
                                              augmentor=augmentor)

        dataset_params = {
            'manifest_filepath': manifest_filepath,
            'labels': labels,
            'featurizer': self._featurizer,
            'max_duration': max_duration,
            'min_duration': min_duration,
            'trim': trim_silence,
            'load_audio': load_audio,
        }
        self._dataset = AudioLabelDataset(**dataset_params)
        labels = []
        for sample in self._dataset.collection:
            labels.append(self._dataset.label2id[sample.label])
        self._dataloader = torch.utils.data.DataLoader(
            dataset=self._dataset,
            batch_sampler=BalancedBatchSampler(labels,
                                               n_classes=num_classes,
                                               n_samples=batch_size //
                                               num_classes,
                                               class_dists=class_dists,
                                               class_probs=class_probs,
                                               probs_num=probs_num),
            # TODO replace with kwargs
            collate_fn=partial(seq_collate_fn, token_pad_value=0),
            num_workers=num_workers,
        )
示例#4
0
    def _setup_dataloader_from_config(self, config: Optional[Dict]):
        if config.get('manifest_filepath') is None:
            return

        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        featurizer = WaveformFeaturizer(sample_rate=config['sample_rate'],
                                        int_values=config.get(
                                            'int_values', False),
                                        augmentor=augmentor)

        if 'vad_stream' in config and config['vad_stream']:
            print("Perform streaming frame-level VAD")
            dataset = AudioToSpeechLabelDataSet(
                manifest_filepath=config['manifest_filepath'],
                labels=config['labels'],
                featurizer=featurizer,
                max_duration=config.get('max_duration', None),
                min_duration=config.get('min_duration', None),
                trim=config.get('trim_silence', True),
                load_audio=config.get('load_audio', True),
                time_length=config.get('time_length', 0.31),
                shift_length=config.get('shift_length', 0.01),
            )
            batch_size = 1
            collate_func = dataset.vad_frame_seq_collate_fn
        else:
            dataset = AudioLabelDataset(
                manifest_filepath=config['manifest_filepath'],
                labels=config['labels'],
                featurizer=featurizer,
                max_duration=config.get('max_duration', None),
                min_duration=config.get('min_duration', None),
                trim=config.get('trim_silence', True),
                load_audio=config.get('load_audio', True),
            )
            batch_size = config['batch_size']
            collate_func = dataset.collate_fn

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            collate_fn=collate_func,
            drop_last=config.get('drop_last', False),
            shuffle=config['shuffle'],
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )
示例#5
0
    def __init__(self,
                 audio_batch: List[Union[str, BytesIO]],
                 sample_rate: int,
                 int_values: bool,
                 trim=False) -> None:
        """Dataset reader for AudioInferDataLayer.

        Args:
            audio_batch: Batch to be read. Elements could be either paths to audio files or Binary I/O objects.
            sample_rate: Audio files sample rate.
            int_values: If true, load samples as 32-bit integers.
            trim: Trim leading and trailing silence from an audio signal if True.

        """
        self.audio_batch = audio_batch
        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate,
                                             int_values=int_values)
        self.trim = trim
示例#6
0
    def __setup_dataloader_from_config(self, config: Optional[Dict]):
        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        featurizer = WaveformFeaturizer(sample_rate=config['sample_rate'],
                                        int_values=config.get(
                                            'int_values', False),
                                        augmentor=augmentor)
        self.dataset = AudioToSpeechLabelDataset(
            manifest_filepath=config['manifest_filepath'],
            labels=config['labels'],
            featurizer=featurizer,
            max_duration=config.get('max_duration', None),
            min_duration=config.get('min_duration', None),
            trim=False,
            load_audio=config.get('load_audio', True),
            time_length=config.get('time_length', 8),
            shift_length=config.get('shift_length', 0.75),
        )

        if self.task == 'diarization':
            logging.info("Setting up diarization parameters")
            _collate_func = self.dataset.sliced_seq_collate_fn
            batch_size = 1
            shuffle = False
        else:
            logging.info("Setting up identification parameters")
            _collate_func = self.dataset.fixed_seq_collate_fn
            batch_size = config['batch_size']
            shuffle = config.get('shuffle', False)

        return torch.utils.data.DataLoader(
            dataset=self.dataset,
            batch_size=batch_size,
            collate_fn=_collate_func,
            drop_last=config.get('drop_last', False),
            shuffle=shuffle,
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )
示例#7
0
    def test_tarred_dataset_duplicate_name(self, test_data_dir):
        manifest_path = os.path.abspath(
            os.path.join(
                test_data_dir,
                'asr/tarred_an4/tarred_duplicate_audio_manifest.json'))

        # Test braceexpand loading
        tarpath = os.path.abspath(
            os.path.join(test_data_dir, 'asr/tarred_an4/audio_{0..1}.tar'))
        featurizer = WaveformFeaturizer(sample_rate=16000,
                                        int_values=False,
                                        augmentor=None)
        ds_braceexpand = TarredAudioToClassificationLabelDataset(
            audio_tar_filepaths=tarpath,
            manifest_filepath=manifest_path,
            labels=self.labels,
            featurizer=featurizer)

        assert len(ds_braceexpand) == 6
        count = 0
        for _ in ds_braceexpand:
            count += 1
        assert count == 6

        # Test loading via list
        tarpath = [
            os.path.abspath(
                os.path.join(test_data_dir, f'asr/tarred_an4/audio_{i}.tar'))
            for i in range(2)
        ]
        ds_list_load = TarredAudioToClassificationLabelDataset(
            audio_tar_filepaths=tarpath,
            manifest_filepath=manifest_path,
            labels=self.labels,
            featurizer=featurizer)
        count = 0
        for _ in ds_list_load:
            count += 1
        assert count == 6
示例#8
0
    def _setup_dataloader_from_config(self, config: DictConfig):

        OmegaConf.set_struct(config, False)
        config.is_regression_task = self.is_regression_task
        OmegaConf.set_struct(config, True)

        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        featurizer = WaveformFeaturizer(sample_rate=config['sample_rate'],
                                        int_values=config.get(
                                            'int_values', False),
                                        augmentor=augmentor)
        shuffle = config['shuffle']

        # Instantiate tarred dataset loader or normal dataset loader
        if config.get('is_tarred', False):
            if ('tarred_audio_filepaths' in config
                    and config['tarred_audio_filepaths'] is None) or (
                        'manifest_filepath' in config
                        and config['manifest_filepath'] is None):
                logging.warning(
                    "Could not load dataset as `manifest_filepath` is None or "
                    f"`tarred_audio_filepaths` is None. Provided config : {config}"
                )
                return None

            if 'vad_stream' in config and config['vad_stream']:
                logging.warning(
                    "VAD inference does not support tarred dataset now")
                return None

            shuffle_n = config.get('shuffle_n', 4 *
                                   config['batch_size']) if shuffle else 0
            dataset = audio_to_label_dataset.get_tarred_classification_label_dataset(
                featurizer=featurizer,
                config=OmegaConf.to_container(config),
                shuffle_n=shuffle_n,
                global_rank=self.global_rank,
                world_size=self.world_size,
            )
            shuffle = False
            batch_size = config['batch_size']
            collate_func = dataset.collate_fn

        else:
            if 'manifest_filepath' in config and config[
                    'manifest_filepath'] is None:
                logging.warning(
                    f"Could not load dataset as `manifest_filepath` is None. Provided config : {config}"
                )
                return None

            if 'vad_stream' in config and config['vad_stream']:
                logging.info("Perform streaming frame-level VAD")
                dataset = audio_to_label_dataset.get_speech_label_dataset(
                    featurizer=featurizer,
                    config=OmegaConf.to_container(config))
                batch_size = 1
                collate_func = dataset.vad_frame_seq_collate_fn
            else:
                dataset = audio_to_label_dataset.get_classification_label_dataset(
                    featurizer=featurizer,
                    config=OmegaConf.to_container(config))
                batch_size = config['batch_size']
                collate_func = dataset.collate_fn

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            collate_fn=collate_func,
            drop_last=config.get('drop_last', False),
            shuffle=shuffle,
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )
示例#9
0
    def __init__(
        self,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        parser: Callable,
        sample_rate: int,
        int_values: bool = False,
        augmentor: Optional[
            'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
        shuffle_n: int = 0,
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_utts: int = 0,
        trim: bool = False,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        add_misc: bool = False,
        pad_id: int = 0,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
    ):
        self.collection = collections.ASRAudioText(
            manifests_files=manifest_filepath.split(','),
            parser=parser,
            min_duration=min_duration,
            max_duration=max_duration,
            max_number=max_utts,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)
        self.trim = trim
        self.eos_id = eos_id
        self.bos_id = bos_id
        self.pad_id = pad_id
        self._add_misc = add_misc

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        if isinstance(audio_tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "}")

        # Check for distributed and partition shards accordingly
        if world_size > 1:
            if isinstance(audio_tar_filepaths, str):
                # Brace expand
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if shard_strategy == 'scatter':
                logging.info(
                    "All tarred dataset shards will be scattered evenly across all nodes."
                )

                if len(audio_tar_filepaths) % world_size != 0:
                    logging.warning(
                        f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                        f"by number of distributed workers ({world_size}).")

                begin_idx = (len(audio_tar_filepaths) //
                             world_size) * global_rank
                end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
                audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
                logging.info(
                    "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                    global_rank, begin_idx, end_idx)

            elif shard_strategy == 'replicate':
                logging.info(
                    "All tarred dataset shards will be replicated across all nodes."
                )

            else:
                raise ValueError(
                    f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
                )

        # Put together WebDataset
        self._dataset = (
            wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename(
                audio='wav', key='__key__').to_tuple('audio', 'key').pipe(
                    self._filter).map(f=self._build_sample))
示例#10
0
    def __init__(
            self,
            manifest_filepath: str,
            mappings_filepath: str,
            sample_rate: int,
            max_duration: Optional[float] = None,
            min_duration: Optional[float] = None,
            ignore_file: Optional[str] = None,
            trim: bool = False,
            load_supplementary_values=True,  # Set to False for validation
    ):
        """
        Dataset that loads audio, phonemes and their durations, pitches per frame, and energies per frame
        for FastSpeech 2 from paths described in a JSON manifest (see the AudioDataset documentation for details
        on the manifest format), as well as a mappings file for word to phones and phones to indices.
        The text in the manifest is ignored; instead, the phoneme indices for prediction come from the
        duration files.

        For each sample, paths for duration, energy, and pitch files are inferred from the manifest's audio
        filepaths by replacing '/wavs' with '/phoneme_durations', '/pitches', and '/energies', and swapping out
        the file extension to '.pt', '.npy', and '.npy' respectively.
        For example, given manifest audio path `/data/LJSpeech/wavs/LJ001-0001.wav`, the inferred duration and
        phonemes file path would be `/data/LJSpeech/phoneme_durations/LJ001-0001.pt`.

        Note that validation datasets only need the audio files and phoneme & duration files, set
        `load_supplementary_values` to False for validation sets.

        Args:
            manifest_filepath (str): Path to the JSON manifest file that lists audio files.
            mappings_filepath (str): Path to a JSON mappings file that contains mappings "word2phones" and
                "phone2idx". The latter is used to determine the padding index.
            sample_rate (int): Target sample rate of the audio.
            max_duration (float): If audio exceeds this length in seconds, it is filtered from the dataset.
                Defaults to None, which does not filter any audio.
            min_duration (float): If audio is shorter than this length in seconds, it is filtered from the dataset.
                Defaults to None, which does not filter any audio.
            ignore_file (str): Optional pickled file which contains a list of files to ignore (e.g. files that
                contain OOV words).
                Defaults to None.
            trim (bool): Whether to use librosa.effects.trim on the audio clip.
                Defaults to False.
            load_supplementary_values (bool): Whether or not to load pitch and energy files. Set this to False for
                validation datasets.
                Defaults to True.
        """
        super().__init__()

        # Retrieve mappings from file
        with open(mappings_filepath, 'r') as f:
            mappings = json.load(f)
            self.word2phones = mappings['word2phones']
            self.phone2idx = mappings['phone2idx']

        # Load data from manifests
        audio_files = []
        total_duration = 0
        if isinstance(manifest_filepath, str):
            manifest_filepath = [manifest_filepath]
        for manifest_file in manifest_filepath:
            with open(expanduser(manifest_file), 'r') as f:
                logging.info(f"Loading dataset from {manifest_file}.")
                for line in f:
                    item = json.loads(line)
                    audio_files.append({
                        "audio_filepath": item["audio_filepath"],
                        "duration": item["duration"]
                    })
                    total_duration += item["duration"]

        total_dataset_len = len(audio_files)
        logging.info(
            f"Loaded dataset with {total_dataset_len} files totalling {total_duration/3600:.2f} hours."
        )
        self.data = []
        if load_supplementary_values:
            dataitem = py_collections.namedtuple(
                typename='AudioTextEntity',
                field_names='audio_file duration text_tokens pitches energies')
        else:
            dataitem = py_collections.namedtuple(
                typename='AudioTextEntity',
                field_names='audio_file duration text_tokens')

        if ignore_file:
            logging.info(f"using {ignore_file} to prune dataset.")
            with open(ignore_file, "rb") as f:
                wavs_to_ignore = set(pickle.load(f))

        pruned_duration = 0
        pruned_items = 0
        for item in audio_files:
            audio_path = item['audio_filepath']
            LJ_id = os.path.splitext(os.path.basename(audio_path))[0]

            # Prune data according to min/max_duration & the ignore file
            if (min_duration and item["duration"] < min_duration) or (
                    max_duration and item["duration"] > max_duration):
                pruned_duration += item["duration"]
                pruned_items += 1
                continue
            if ignore_file and (LJ_id in wavs_to_ignore):
                pruned_items += 1
                pruned_duration += item["duration"]
                wavs_to_ignore.remove(LJ_id)
                continue

            # Else not pruned, load additional info

            # Phoneme durations and text token indices from durations file
            dur_path = audio_path.replace('/wavs/',
                                          '/phoneme_durations/').replace(
                                              '.wav', '.pt')
            duration_info = torch.load(dur_path)
            durs = duration_info['token_duration']
            text_tokens = duration_info['text_encoded']

            if load_supplementary_values:
                # Load pitch file (F0s)
                pitch_path = audio_path.replace('/wavs/', '/pitches/').replace(
                    '.wav', '.npy')
                pitches = torch.from_numpy(
                    np.load(pitch_path).astype(dtype='float32'))

                # Load energy file (L2-norm of the amplitude of each STFT frame of an utterance)
                energies_path = audio_path.replace('/wavs/',
                                                   '/energies/').replace(
                                                       '.wav', '.npy')
                energies = torch.from_numpy(np.load(energies_path))

                self.data.append(
                    dataitem(
                        audio_file=item['audio_filepath'],
                        duration=durs,
                        pitches=torch.clamp(pitches, min=1e-5),
                        energies=energies,
                        text_tokens=text_tokens,
                    ))
            else:
                self.data.append(
                    dataitem(
                        audio_file=item['audio_filepath'],
                        duration=durs,
                        text_tokens=text_tokens,
                    ))

        logging.info(
            f"Pruned {pruned_items} files and {pruned_duration/3600:.2f} hours."
        )
        logging.info(
            f"Final dataset contains {len(self.data)} files and {(total_duration-pruned_duration)/3600:.2f} hours."
        )

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate)
        self.trim = trim
        self.load_supplementary_values = load_supplementary_values
示例#11
0
def main(cfg):

    logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    enrollment_manifest = cfg.data.enrollment_manifest
    test_manifest = cfg.data.test_manifest
    out_manifest = cfg.data.out_manifest
    sample_rate = cfg.data.sample_rate

    backend = cfg.backend.backend_model.lower()

    if backend == 'cosine_similarity':
        model_path = cfg.backend.cosine_similarity.model_path
        batch_size = cfg.backend.cosine_similarity.batch_size
        if model_path.endswith('.nemo'):
            speaker_model = EncDecSpeakerLabelModel.restore_from(model_path)
        else:
            speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_path)

        enroll_embs, _, enroll_truelabels, enroll_id2label = EncDecSpeakerLabelModel.get_batch_embeddings(
            speaker_model,
            enrollment_manifest,
            batch_size,
            sample_rate,
            device=device,
        )

        test_embs, _, _, _ = EncDecSpeakerLabelModel.get_batch_embeddings(
            speaker_model,
            test_manifest,
            batch_size,
            sample_rate,
            device=device,
        )

        # length normalize
        enroll_embs = enroll_embs / (np.linalg.norm(
            enroll_embs, ord=2, axis=-1, keepdims=True))
        test_embs = test_embs / (np.linalg.norm(
            test_embs, ord=2, axis=-1, keepdims=True))

        # reference embedding
        reference_embs = []
        keyslist = list(enroll_id2label.keys())
        for label_id in keyslist:
            indices = np.where(enroll_truelabels == label_id)
            embedding = (enroll_embs[indices].sum(
                axis=0).squeeze()) / len(indices)
            reference_embs.append(embedding)

        reference_embs = np.asarray(reference_embs)

        scores = np.matmul(test_embs, reference_embs.T)
        matched_labels = scores.argmax(axis=-1)

    elif backend == 'neural_classifier':
        model_path = cfg.backend.neural_classifier.model_path
        batch_size = cfg.backend.neural_classifier.batch_size

        if model_path.endswith('.nemo'):
            speaker_model = EncDecSpeakerLabelModel.restore_from(model_path)
        else:
            speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_path)

        featurizer = WaveformFeaturizer(sample_rate=sample_rate)
        dataset = AudioToSpeechLabelDataset(
            manifest_filepath=enrollment_manifest,
            labels=None,
            featurizer=featurizer)
        enroll_id2label = dataset.id2label

        if speaker_model.decoder.final.out_features != len(enroll_id2label):
            raise ValueError(
                "number of labels mis match. Make sure you trained or finetuned neural classifier with labels from enrollement manifest_filepath"
            )

        _, test_logits, _, _ = EncDecSpeakerLabelModel.get_batch_embeddings(
            speaker_model,
            test_manifest,
            batch_size,
            sample_rate,
            device=device,
        )
        matched_labels = test_logits.argmax(axis=-1)

    with open(test_manifest, 'rb') as f1, open(out_manifest,
                                               'w',
                                               encoding='utf-8') as f2:
        lines = f1.readlines()
        for idx, line in enumerate(lines):
            line = line.strip()
            item = json.loads(line)
            item['infer'] = enroll_id2label[matched_labels[idx]]
            json.dump(item, f2)
            f2.write('\n')

    logging.info(
        "Inference labels have been written to {} manifest file".format(
            out_manifest))