Esempio n. 1
0
File: rnnt.py Progetto: sycomix/NeMo
    def __init__(
        self,
        jointnet: Dict[str, Any],
        num_classes: int,
        vocabulary: Optional[List] = None,
        log_softmax: Optional[bool] = None,
        preserve_memory: bool = False,
        fuse_loss_wer: bool = False,
        fused_batch_size: Optional[int] = None,
        experimental_fuse_loss_wer: Any = None,
    ):
        super().__init__()

        self.vocabulary = vocabulary

        self._vocab_size = num_classes
        self._num_classes = num_classes + 1  # add 1 for blank symbol

        if experimental_fuse_loss_wer is not None:
            # TODO: Deprecate in 1.6
            logging.warning(
                "`experimental_fuse_loss_wer` will be deprecated in NeMo 1.6. Please use `fuse_loss_wer` instead."
            )
            # Override fuse_loss_wer from deprecated argument
            fuse_loss_wer = experimental_fuse_loss_wer

        self._fuse_loss_wer = fuse_loss_wer
        self._fused_batch_size = fused_batch_size

        if fuse_loss_wer and (fused_batch_size is None):
            raise ValueError("If `fuse_loss_wer` is set, then `fused_batch_size` cannot be None!")

        self._loss = None
        self._wer = None

        # Log softmax should be applied explicitly only for CPU
        self.log_softmax = log_softmax
        self.preserve_memory = preserve_memory

        if preserve_memory:
            logging.warning(
                "`preserve_memory` was set for the Joint Model. Please be aware this will severely impact "
                "the forward-backward step time. It also might not solve OOM issues if the GPU simply "
                "does not have enough memory to compute the joint."
            )

        # Required arguments
        self.encoder_hidden = jointnet['encoder_hidden']
        self.pred_hidden = jointnet['pred_hidden']
        self.joint_hidden = jointnet['joint_hidden']
        self.activation = jointnet['activation']

        # Optional arguments
        dropout = jointnet.get('dropout', 0.0)

        self.pred, self.enc, self.joint_net = self._joint_net_modules(
            num_classes=self._num_classes,  # add 1 for blank symbol
            pred_n_hidden=self.pred_hidden,
            enc_n_hidden=self.encoder_hidden,
            joint_n_hidden=self.joint_hidden,
            activation=self.activation,
            dropout=dropout,
        )

        # Flag needed for RNNT export support
        self._rnnt_export = False
Esempio n. 2
0
def check_resume(
    trainer: 'pytorch_lightning.Trainer',
    log_dir: str,
    resume_past_end: bool = False,
    resume_ignore_no_checkpoint: bool = False,
):
    """Checks that resume=True was used correctly with the arguments pass to exp_manager. Sets
    trainer.resume_from_checkpoint as necessary.

    Returns:
        log_dir (Path): the log_dir
        exp_dir (str): the base exp_dir without name nor version
        name (str): The name of the experiment
        version (str): The version of the experiment

    Raises:
        NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found.
        ValueError: If resume is True, and there were more than 1 checkpoint could found.
    """
    if not log_dir:
        raise ValueError(
            f"Resuming requires the log_dir {log_dir} to be passed to exp_manager"
        )

    checkpoint_dir = Path(Path(log_dir) / "checkpoints")
    checkpoint = None
    end_checkpoints = list(checkpoint_dir.glob("*end.ckpt"))
    end_checkpoints.extend(list(checkpoint_dir.glob("*.nemo")))
    last_checkpoints = list(checkpoint_dir.glob("*last.ckpt"))
    if not checkpoint_dir.exists():
        if resume_ignore_no_checkpoint:
            logging.warning(
                f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Training from scratch."
            )
            return
        else:
            raise NotFoundError(
                f"There was no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume."
            )
    elif len(end_checkpoints) > 0:
        if resume_past_end:
            if len(end_checkpoints) > 1:
                raise ValueError(
                    f"Multiple multiple checkpoints {end_checkpoints} that matches *end.ckpt."
                )
            logging.info(f"Resuming from {end_checkpoints[0]}")
            checkpoint = end_checkpoints[0]
        else:
            raise ValueError(
                f"Found {end_checkpoints[0]} indicating that the last training run has already completed."
            )
    elif not len(last_checkpoints) > 0:
        if resume_ignore_no_checkpoint:
            logging.warning(
                f"There were no checkpoints found in {checkpoint_dir}. Training from scratch."
            )
            return
        else:
            raise NotFoundError(
                f"There were no checkpoints found in {checkpoint_dir}. Cannot resume."
            )
    elif len(last_checkpoints) > 1:
        raise ValueError(
            f"Multiple multiple checkpoints {last_checkpoints} that matches *last.ckpt."
        )
    else:
        logging.info(f"Resuming from {last_checkpoints[0]}")
        checkpoint = last_checkpoints[0]

    trainer.resume_from_checkpoint = str(checkpoint)

    if is_global_rank_zero():
        # Check to see if any files exist that need to be moved
        files_to_move = []
        for child in Path(log_dir).iterdir():
            if child.is_file():
                files_to_move.append(child)

        if len(files_to_move) > 0:
            # Move old files to a new folder
            other_run_dirs = Path(log_dir).glob("run_*")
            run_count = 0
            for fold in other_run_dirs:
                if fold.is_dir():
                    run_count += 1
            new_run_dir = Path(Path(log_dir) / f"run_{run_count}")
            new_run_dir.mkdir()
            for _file in files_to_move:
                move(str(_file), str(new_run_dir))
Esempio n. 3
0
def get_final_text(pred_text: str,
                   orig_text: str,
                   do_lower_case: bool,
                   verbose_logging: bool = False):
    """Project the tokenized prediction back to the original text.
    When we created the data, we kept track of the alignment between original
    (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
    now `orig_text` contains the span of our original text corresponding to
    the span that we predicted.

    However, `orig_text` may contain extra characters that we don't want in
    our prediction.

    For example, let's say:
      pred_text = steve smith
      orig_text = Steve Smith's

    We don't want to return `orig_text` because it contains the extra "'s".

    We don't want to return `pred_text` because it's already been normalized
    (the SQuAD eval script also does punctuation stripping/lower casing but
    our tokenizer does additional normalization like stripping accent
    characters).

    What we really want to return is "Steve Smith".

    Therefore, we have to apply a semi-complicated alignment heuristic
    between `pred_text` and `orig_text` to get a character-to-character
    alignment. This can fail in certain cases in which case we just return
    `orig_text`."""
    def _strip_spaces(text):
        ns_chars = []
        ns_to_s_map = collections.OrderedDict()
        for (i, c) in enumerate(text):
            if c == " ":
                continue
            ns_to_s_map[len(ns_chars)] = i
            ns_chars.append(c)
        ns_text = "".join(ns_chars)
        return ns_text, ns_to_s_map

    # We first tokenize `orig_text`, strip whitespace from the result
    # and `pred_text`, and check if they are the same length. If they are
    # NOT the same length, the heuristic has failed. If they are the same
    # length, we assume the characters are one-to-one aligned.
    tokenizer = BasicTokenizer(do_lower_case=do_lower_case)

    tok_text = " ".join(tokenizer.tokenize(orig_text))

    start_position = tok_text.find(pred_text)
    if start_position == -1:
        if verbose_logging:
            logging.warning("Unable to find text: '%s' in '%s'" %
                            (pred_text, orig_text))
        return orig_text
    end_position = start_position + len(pred_text) - 1

    (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
    (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)

    if len(orig_ns_text) != len(tok_ns_text):
        if verbose_logging:
            logging.warning(
                "Length not equal after stripping spaces: '%s' vs '%s'",
                orig_ns_text,
                tok_ns_text,
            )
        return orig_text

    # We then project the characters in `pred_text` back to `orig_text` using
    # the character-to-character alignment.
    tok_s_to_ns_map = {}
    for (i, tok_index) in tok_ns_to_s_map.items():
        tok_s_to_ns_map[tok_index] = i

    orig_start_position = None
    if start_position in tok_s_to_ns_map:
        ns_start_position = tok_s_to_ns_map[start_position]
        if ns_start_position in orig_ns_to_s_map:
            orig_start_position = orig_ns_to_s_map[ns_start_position]

    if orig_start_position is None:
        if verbose_logging:
            logging.warning("Couldn't map start position")
        return orig_text

    orig_end_position = None
    if end_position in tok_s_to_ns_map:
        ns_end_position = tok_s_to_ns_map[end_position]
        if ns_end_position in orig_ns_to_s_map:
            orig_end_position = orig_ns_to_s_map[ns_end_position]

    if orig_end_position is None:
        if verbose_logging:
            logging.warning("Couldn't map end position")
        return orig_text

    output_text = orig_text[orig_start_position:(orig_end_position + 1)]
    return output_text
Esempio n. 4
0
    def _setup_dataloader_from_config(self, config: Optional[Dict]):
        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        shuffle = config['shuffle']
        device = 'gpu' if torch.cuda.is_available() else 'cpu'
        if config.get('use_dali', False):
            device_id = self.local_rank if device == 'gpu' else None
            dataset = audio_to_text_dataset.get_dali_char_dataset(
                config=config,
                shuffle=shuffle,
                device_id=device_id,
                global_rank=self.global_rank,
                world_size=self.world_size,
                preprocessor_cfg=self._cfg.preprocessor,
            )
            return dataset

        # Instantiate tarred dataset loader or normal dataset loader
        if config.get('is_tarred', False):
            if ('tarred_audio_filepaths' in config
                    and config['tarred_audio_filepaths'] is None) or (
                        'manifest_filepath' in config
                        and config['manifest_filepath'] is None):
                logging.warning(
                    "Could not load dataset as `manifest_filepath` was None or "
                    f"`tarred_audio_filepaths` is None. Provided config : {config}"
                )
                return None

            shuffle_n = config.get('shuffle_n', 4 *
                                   config['batch_size']) if shuffle else 0
            dataset = audio_to_text_dataset.get_tarred_char_dataset(
                config=config,
                shuffle_n=shuffle_n,
                global_rank=self.global_rank,
                world_size=self.world_size,
                augmentor=augmentor,
            )
            shuffle = False
        else:
            if 'manifest_filepath' in config and config[
                    'manifest_filepath'] is None:
                logging.warning(
                    f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}"
                )
                return None

            dataset = audio_to_text_dataset.get_char_dataset(
                config=config, augmentor=augmentor)

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=config['batch_size'],
            collate_fn=dataset.collate_fn,
            drop_last=config.get('drop_last', False),
            shuffle=shuffle,
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )
Esempio n. 5
0
def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir,
                       clustering_params):
    """
    performs spectral clustering on embeddings with time stamps generated from VAD output

    Args:
        embs_and_timestamps (dict): This dictionary contains the following items indexed by unique IDs.
            'embeddings' : Embeddings with key as unique_id
            'time_stamps' : Time stamps list for each audio recording
        AUDIO_RTTM_MAP (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path
        out_rttm_dir (str): Path to write predicted rttms
        clustering_params (dict): clustering parameters provided through config that contains max_num_speakers (int),
        oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) and enhance_count_threshold (int)

    Returns:
        all_reference (list[uniq_name,Annotation]): reference annotations for score calculation
        all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation

    """
    all_hypothesis = []
    all_reference = []
    no_references = False
    max_num_speakers = clustering_params['max_num_speakers']
    lines_cluster_labels = []

    cuda = True
    if not torch.cuda.is_available():
        logging.warning(
            "cuda=False, using CPU for Eigen decompostion. This might slow down the clustering process."
        )
        cuda = False

    for uniq_id, value in tqdm(AUDIO_RTTM_MAP.items()):
        if clustering_params.oracle_num_speakers:
            num_speakers = value.get('num_speakers', None)
            if num_speakers is None:
                raise ValueError(
                    "Provided option as oracle num of speakers but num_speakers in manifest is null"
                )
        else:
            num_speakers = None

        cluster_labels = COSclustering(
            uniq_embs_and_timestamps=embs_and_timestamps[uniq_id],
            oracle_num_speakers=num_speakers,
            max_num_speaker=max_num_speakers,
            enhanced_count_thres=clustering_params.enhanced_count_thres,
            max_rp_threshold=clustering_params.max_rp_threshold,
            sparse_search_volume=clustering_params.sparse_search_volume,
            cuda=cuda,
        )

        base_scale_idx = max(embs_and_timestamps[uniq_id]['scale_dict'].keys())
        lines = embs_and_timestamps[uniq_id]['scale_dict'][base_scale_idx][
            'time_stamps']
        assert len(cluster_labels) == len(lines)
        for idx, label in enumerate(cluster_labels):
            tag = 'speaker_' + str(label)
            lines[idx] += tag

        a = get_contiguous_stamps(lines)
        labels = merge_stamps(a)
        if out_rttm_dir:
            labels_to_rttmfile(labels, uniq_id, out_rttm_dir)
            lines_cluster_labels.extend(
                [f'{uniq_id} {seg_line}\n' for seg_line in lines])
        hypothesis = labels_to_pyannote_object(labels, uniq_name=uniq_id)
        all_hypothesis.append([uniq_id, hypothesis])

        rttm_file = value.get('rttm_filepath', None)
        if rttm_file is not None and os.path.exists(
                rttm_file) and not no_references:
            ref_labels = rttm_to_labels(rttm_file)
            reference = labels_to_pyannote_object(ref_labels,
                                                  uniq_name=uniq_id)
            all_reference.append([uniq_id, reference])
        else:
            no_references = True
            all_reference = []

    if out_rttm_dir:
        write_cluster_labels(base_scale_idx, lines_cluster_labels,
                             out_rttm_dir)

    return all_reference, all_hypothesis
Esempio n. 6
0
    def __init__(
        self,
        ids: List[int],
        audio_files: List[str],
        durations: List[float],
        texts: List[str],
        offsets: List[str],
        speakers: List[Optional[int]],
        orig_sampling_rates: List[Optional[int]],
        parser: parsers.CharParser,
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_number: Optional[int] = None,
        do_sort_by_duration: bool = False,
        index_by_file_id: bool = False,
    ):
        """Instantiates audio-text manifest with filters and preprocessing.

        Args:
            ids: List of examples positions.
            audio_files: List of audio files.
            durations: List of float durations.
            texts: List of raw text transcripts.
            offsets: List of duration offsets or None.
            speakers: List of optional speakers ids.
            orig_sampling_rates: List of original sampling rates of audio files.
            parser: Instance of `CharParser` to convert string to tokens.
            min_duration: Minimum duration to keep entry with (default: None).
            max_duration: Maximum duration to keep entry with (default: None).
            max_number: Maximum number of samples to collect.
            do_sort_by_duration: True if sort samples list by duration. Not compatible with index_by_file_id.
            index_by_file_id: If True, saves a mapping from filename base (ID) to index in data.
        """

        output_type = self.OUTPUT_TYPE
        data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0
        if index_by_file_id:
            self.mapping = {}

        for id_, audio_file, duration, offset, text, speaker, orig_sr in zip(
                ids, audio_files, durations, offsets, texts, speakers,
                orig_sampling_rates):
            # Duration filters.
            if min_duration is not None and duration < min_duration:
                duration_filtered += duration
                num_filtered += 1
                continue

            if max_duration is not None and duration > max_duration:
                duration_filtered += duration
                num_filtered += 1
                continue

            text_tokens = parser(text)
            if text_tokens is None:
                duration_filtered += duration
                num_filtered += 1
                continue

            total_duration += duration

            data.append(
                output_type(id_, audio_file, duration, text_tokens, offset,
                            text, speaker, orig_sr))
            if index_by_file_id:
                file_id, _ = os.path.splitext(os.path.basename(audio_file))
                self.mapping[file_id] = len(data) - 1

            # Max number of entities filter.
            if len(data) == max_number:
                break

        if do_sort_by_duration:
            if index_by_file_id:
                logging.warning(
                    "Tried to sort dataset by duration, but cannot since index_by_file_id is set."
                )
            else:
                data.sort(key=lambda entity: entity.duration)

        logging.info("Dataset loaded with %d files totalling %.2f hours",
                     len(data), total_duration / 3600)
        logging.info("%d files were filtered totalling %.2f hours",
                     num_filtered, duration_filtered / 3600)

        super().__init__(data)
Esempio n. 7
0
    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " QAModel consists of two separate models with different"
                " inputs and outputs.")

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output1 = os.path.join(os.path.dirname(output),
                               'bert_' + os.path.basename(output))
        output1_descr = qual_name + ' BERT exported to ONNX'
        bert_model_onnx = self.bert_model.export(
            output1,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output2 = os.path.join(os.path.dirname(output),
                               'classifier_' + os.path.basename(output))
        output2_descr = qual_name + ' Classifier exported to ONNX'
        classifier_onnx = self.classifier.export(
            output2,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output_model = attach_onnx_to_onnx(bert_model_onnx, classifier_onnx,
                                           "QA")
        output_descr = qual_name + ' BERT+Classifier exported to ONNX'
        onnx.save(output_model, output)
        return ([output, output1,
                 output2], [output_descr, output1_descr, output2_descr])
Esempio n. 8
0
    def __init__(
        self,
        *,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        labels: List[str],
        featurizer,
        shuffle_n: int = 0,
        min_duration: Optional[float] = 0.1,
        max_duration: Optional[float] = None,
        trim: bool = False,
        load_audio: bool = True,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
    ):
        self.collection = collections.ASRSpeechLabel(
            manifests_files=manifest_filepath.split(','),
            min_duration=min_duration,
            max_duration=max_duration,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.file_occurence = count_occurence(self.collection.mapping)

        self.featurizer = featurizer
        self.trim = trim
        self.load_audio = load_audio

        self.labels = labels if labels else self.collection.uniq_labels
        self.num_classes = len(self.labels)

        self.label2id, self.id2label = {}, {}
        for label_id, label in enumerate(self.labels):
            self.label2id[label] = label_id
            self.id2label[label_id] = label

        for idx in range(len(self.labels[:5])):
            logging.debug(" label id {} and its mapped label {}".format(
                idx, self.id2label[idx]))

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        if isinstance(audio_tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "}")

        # Check for distributed and partition shards accordingly
        if world_size > 1:
            if isinstance(audio_tar_filepaths, str):
                # Brace expand
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if shard_strategy == 'scatter':
                logging.info(
                    "All tarred dataset shards will be scattered evenly across all nodes."
                )

                if len(audio_tar_filepaths) % world_size != 0:
                    logging.warning(
                        f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                        f"by number of distributed workers ({world_size}).")

                begin_idx = (len(audio_tar_filepaths) //
                             world_size) * global_rank
                end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
                audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
                logging.info(
                    "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                    global_rank, begin_idx, end_idx)

            elif shard_strategy == 'replicate':
                logging.info(
                    "All tarred dataset shards will be replicated across all nodes."
                )

            else:
                raise ValueError(
                    f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
                )

        # Put together WebDataset
        self._dataset = (
            wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename(
                audio='wav', key='__key__').to_tuple('audio', 'key').pipe(
                    self._filter).map(f=self._build_sample))
Esempio n. 9
0
def main():
    parser = argparse.ArgumentParser(
        description='Evaluate an ASR model with beam search decoding and n-gram KenLM language model.'
    )
    parser.add_argument(
        "--nemo_model_file", required=True, type=str, help="The path of the '.nemo' file of the ASR model"
    )
    parser.add_argument(
        "--kenlm_model_file", required=False, default=None, type=str, help="The path of the KenLM binary model file"
    )
    parser.add_argument("--input_manifest", required=True, type=str, help="The manifest file of the evaluation set")
    parser.add_argument(
        "--preds_output_folder", default=None, type=str, help="The optional folder where the predictions are stored"
    )
    parser.add_argument(
        "--probs_cache_file", default=None, type=str, help="The cache file for storing the outputs of the model"
    )
    parser.add_argument(
        "--acoustic_batch_size", default=16, type=int, help="The batch size to calculate log probabilities"
    )
    parser.add_argument(
        "--device", default="cuda", type=str, help="The device to load the model onto to calculate log probabilities"
    )
    parser.add_argument(
        "--use_amp", action="store_true", help="Whether to use AMP if available to calculate log probabilities"
    )
    parser.add_argument(
        "--decoding_mode",
        choices=["greedy", "beamsearch", "beamsearch_ngram"],
        default="beamsearch_ngram",
        type=str,
        help="The decoding scheme to be used for evaluation.",
    )
    parser.add_argument(
        "--beam_width",
        required=True,
        type=int,
        nargs="+",
        help="The width or list of the widths for the beam search decoding",
    )
    parser.add_argument(
        "--beam_alpha",
        required=True,
        type=float,
        nargs="+",
        help="The alpha parameter or list of the alphas for the beam search decoding",
    )
    parser.add_argument(
        "--beam_beta",
        required=True,
        type=float,
        nargs="+",
        help="The beta parameter or list of the betas for the beam search decoding",
    )
    parser.add_argument(
        "--beam_batch_size", default=128, type=int, help="The batch size to be used for beam search decoding"
    )
    args = parser.parse_args()

    asr_model = nemo_asr.models.EncDecCTCModelBPE.restore_from(
        args.nemo_model_file, map_location=torch.device(args.device)
    )

    target_transcripts = []
    with open(args.input_manifest, 'r') as manifest_file:
        audio_file_paths = []
        for line in tqdm(manifest_file, desc=f"Reading Manifest {args.input_manifest} ...", ncols=120):
            data = json.loads(line)
            target_transcripts.append(data['text'])
            audio_file_paths.append(data['audio_filepath'])

    if args.probs_cache_file and os.path.exists(args.probs_cache_file):
        logging.info(f"Found a pickle file of probabilities at '{args.probs_cache_file}'.")
        logging.info(f"Loading the cached pickle file of probabilities from '{args.probs_cache_file}' ...")
        with open(args.probs_cache_file, 'rb') as probs_file:
            all_probs = pickle.load(probs_file)

        if len(all_probs) != len(audio_file_paths):
            raise ValueError(
                f"The number of samples in the probabilities file '{args.probs_cache_file}' does not "
                f"match the manifest file. You may need to delete the probabilities cached file."
            )
    else:
        if args.use_amp:
            if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
                logging.info("AMP is enabled!\n")
                autocast = torch.cuda.amp.autocast
        else:

            @contextlib.contextmanager
            def autocast():
                yield

        with autocast():
            with torch.no_grad():
                all_logits = asr_model.transcribe(audio_file_paths, batch_size=args.acoustic_batch_size, logprobs=True)
        all_probs = [kenlm_utils.softmax(logits) for logits in all_logits]
        if args.probs_cache_file:
            logging.info(f"Writing pickle files of probabilities at '{args.probs_cache_file}'...")
            with open(args.probs_cache_file, 'wb') as f_dump:
                pickle.dump(all_probs, f_dump)

    wer_dist_greedy = 0
    cer_dist_greedy = 0
    words_count = 0
    chars_count = 0
    for batch_idx, probs in enumerate(all_probs):
        preds = np.argmax(probs, axis=1)
        preds_tensor = torch.tensor(preds, device='cpu').unsqueeze(0)
        pred_text = asr_model._wer.ctc_decoder_predictions_tensor(preds_tensor)[0]

        pred_split_w = pred_text.split()
        target_split_w = target_transcripts[batch_idx].split()
        pred_split_c = list(pred_text)
        target_split_c = list(target_transcripts[batch_idx])

        wer_dist = editdistance.eval(target_split_w, pred_split_w)
        cer_dist = editdistance.eval(target_split_c, pred_split_c)

        wer_dist_greedy += wer_dist
        cer_dist_greedy += cer_dist
        words_count += len(target_split_w)
        chars_count += len(target_split_c)

    logging.info('Greedy WER/CER = {:.2%}/{:.2%}'.format(wer_dist_greedy / words_count, cer_dist_greedy / chars_count))

    encoding_level = kenlm_utils.SUPPORTED_MODELS.get(type(asr_model).__name__, None)
    if not encoding_level:
        logging.warning(
            f"Model type '{type(asr_model).__name__}' may not be supported. Would try to train a char-level LM."
        )
        encoding_level = 'char'

    vocab = asr_model.decoder.vocabulary
    ids_to_text_func = None
    if encoding_level == "subword":
        vocab = [chr(idx + TOKEN_OFFSET) for idx in range(len(vocab))]
        ids_to_text_func = asr_model.tokenizer.ids_to_text
    # delete the model to free the memory
    del asr_model

    if args.decoding_mode == "beamsearch_ngram":
        if not os.path.exists(args.kenlm_model_file):
            raise FileNotFoundError(f"Could not find the KenLM model file '{args.kenlm_model_file}'.")
        lm_path = args.kenlm_model_file
    else:
        lm_path = None

    # 'greedy' decoding_mode would skip the beam search decoding
    if args.decoding_mode in ["beamsearch_ngram", "beamsearch"]:

        params = {'beam_width': args.beam_width, 'beam_alpha': args.beam_alpha, 'beam_beta': args.beam_beta}
        hp_grid = ParameterGrid(params)
        hp_grid = list(hp_grid)

        logging.info(f"==============================Starting the beam search decoding===============================")
        logging.info(f"Grid search size: {len(hp_grid)}")
        logging.info(f"It may take some time...")
        logging.info(f"==============================================================================================")

        if args.preds_output_folder and not os.path.exists(args.preds_output_folder):
            os.mkdir(args.preds_output_folder)
        for hp in hp_grid:
            if args.preds_output_folder:
                preds_output_file = os.path.join(
                    args.preds_output_folder,
                    f"preds_out_width{hp['beam_width']}_alpha{hp['beam_alpha']}_beta{hp['beam_beta']}.tsv",
                )
            else:
                preds_output_file = None

            beam_search_eval(
                all_probs=all_probs,
                target_transcripts=target_transcripts,
                vocab=vocab,
                ids_to_text_func=ids_to_text_func,
                preds_output_file=preds_output_file,
                lm_path=lm_path,
                beam_width=hp["beam_width"],
                beam_alpha=hp["beam_alpha"],
                beam_beta=hp["beam_beta"],
                beam_batch_size=args.beam_batch_size,
                progress_bar=True,
            )
Esempio n. 10
0
    def setup_optimization(self,
                           optim_config: Optional[Union[DictConfig,
                                                        Dict]] = None):
        """
        Prepares an optimizer from a string name and its optional config parameters.

        Args:
            optim_config: A dictionary containing the following keys:

                * "lr": mandatory key for learning rate. Will raise ValueError if not provided.
                * "optimizer": string name pointing to one of the available optimizers in the registry. \
                If not provided, defaults to "adam".
                * "opt_args": Optional list of strings, in the format "arg_name=arg_value". \
                The list of "arg_value" will be parsed and a dictionary of optimizer kwargs \
                will be built and supplied to instantiate the optimizer.
        """
        # If config was not explicitly passed to us
        if optim_config is None:
            # See if internal config has `optim` namespace
            if self._cfg is not None and hasattr(self._cfg, 'optim'):
                optim_config = self._cfg.optim

        # If config is still None, or internal config has no Optim, return without instantiation
        if optim_config is None:
            logging.info(
                'No optimizer config provided, therefore no optimizer was created'
            )
            return

        else:
            # Preserve the configuration
            if not isinstance(optim_config, DictConfig):
                optim_config = OmegaConf.create(optim_config)

            # See if internal config has `optim` namespace before preservation
            if self._cfg is not None and hasattr(self._cfg, 'optim'):
                self._cfg.optim = optim_config

        # Setup optimizer and scheduler
        if optim_config is not None and isinstance(optim_config, DictConfig):
            optim_config = OmegaConf.to_container(optim_config)

        if 'sched' in optim_config and self._trainer is not None:
            if not isinstance(self._trainer.accumulate_grad_batches, int):
                raise ValueError(
                    "We do not currently support gradient acculumation that is not an integer."
                )
            if self._trainer.max_steps is None:
                # Store information needed to calculate max_steps
                optim_config['sched'][
                    't_max_epochs'] = self._trainer.max_epochs
                optim_config['sched'][
                    't_accumulate_grad_batches'] = self._trainer.accumulate_grad_batches
                if self._trainer.distributed_backend is None:
                    optim_config['sched'][
                        't_num_workers'] = self._trainer.num_gpus or 1
                elif self._trainer.distributed_backend is "ddp_cpu":
                    optim_config['sched'][
                        't_num_workers'] = self._trainer.num_processes * self._trainer.num_nodes
                elif self._trainer.distributed_backend is "ddp":
                    optim_config['sched'][
                        't_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes
                else:
                    logging.warning(
                        f"The lightning trainer received accelerator: {self._trainer.distributed_backend }. We "
                        "recommend to use 'ddp' instead.")
                    optim_config['sched'][
                        't_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes
            else:
                optim_config['sched']['max_steps'] = self._trainer.max_steps

        # Force into DictConfig from nested structure
        optim_config = OmegaConf.create(optim_config)
        # Get back nested dict so we its mutable
        optim_config = OmegaConf.to_container(optim_config, resolve=True)

        # Extract scheduler config if inside optimizer config
        if 'sched' in optim_config:
            scheduler_config = optim_config.pop('sched')
        else:
            scheduler_config = None

        # Check if caller provided optimizer name, default to Adam otherwise
        optimizer_cls = optim_config.get('cls', None)

        if optimizer_cls is None:
            # Try to get optimizer name for dynamic resolution, defaulting to Adam
            optimizer_name = optim_config.get('name', 'adam')
        else:
            if inspect.isclass(optimizer_cls):
                optimizer_name = optimizer_cls.__name__.lower()
            else:
                # resolve the class name (lowercase) from the class path if not provided
                optimizer_name = optimizer_cls.split(".")[-1].lower()

        # We are guarenteed to have lr since it is required by the argparser
        # But maybe user forgot to pass it to this function
        lr = optim_config.get('lr', None)

        if lr is None:
            raise ValueError(
                '`lr` must be passed to `optimizer_config` when setting up the optimization !'
            )

        # Check if caller has optimizer kwargs, default to empty dictionary
        if 'args' in optim_config:
            optimizer_args = optim_config.pop('args')
            optimizer_args = optim.parse_optimizer_args(
                optimizer_name, optimizer_args)
        else:
            optimizer_args = copy.deepcopy(optim_config)

            # Remove extra parameters from optimizer_args nest
            # Assume all other parameters are to be passed into optimizer constructor
            optimizer_args.pop('name', None)
            optimizer_args.pop('cls', None)
            optimizer_args.pop('lr', None)

        # Actually instantiate the optimizer
        if optimizer_cls is not None:
            if inspect.isclass(optimizer_cls):
                optimizer = optimizer_cls(self.parameters(),
                                          lr=lr,
                                          **optimizer_args)
                logging.info("Optimizer config = %s", str(optimizer))

                self._optimizer = optimizer

            else:
                # Attempt class path resolution
                try:
                    optimizer_cls = OmegaConf.create({'cls': optimizer_cls})
                    optimizer_config = {'lr': lr}
                    optimizer_config.update(optimizer_args)

                    optimizer_instance = hydra.utils.instantiate(
                        optimizer_cls, self.parameters(),
                        **optimizer_config)  # type: DictConfig

                    logging.info("Optimizer config = %s",
                                 str(optimizer_instance))

                    self._optimizer = optimizer_instance

                except Exception as e:
                    logging.error(
                        "Could not instantiate class path - {} with kwargs {}".
                        format(optimizer_cls, str(optimizer_config)))
                    raise e

        else:
            optimizer = optim.get_optimizer(optimizer_name)
            optimizer = optimizer(self.parameters(), lr=lr, **optimizer_args)

            logging.info("Optimizer config = %s", str(optimizer))

            self._optimizer = optimizer

        # Try to instantiate scheduler for optimizer
        self._scheduler = prepare_lr_scheduler(
            optimizer=self._optimizer,
            scheduler_config=scheduler_config,
            train_dataloader=self._train_dl)

        # Return the optimizer with/without scheduler
        # This return allows multiple optimizers or schedulers to be created
        return self._optimizer, self._scheduler
Esempio n. 11
0
def configure_checkpointing(
    trainer: 'pytorch_lightning.Trainer',
    log_dir: Path,
    name: str,
    params: 'DictConfig',
):
    """ Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint
    callback or if trainer.weights_save_path was passed to Trainer.
    """
    for callback in trainer.callbacks:
        if isinstance(callback, ModelCheckpoint):
            raise CheckpointMisconfigurationError(
                "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint "
                "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback "
                "to False, or remove ModelCheckpoint from the lightning trainer"
            )
    if Path(trainer.weights_save_path) != Path.cwd():
        raise CheckpointMisconfigurationError(
            "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager"
        )

    # Create the callback and attach it to trainer
    if "filepath" in params:
        if params.filepath is not None:
            logging.warning(
                "filepath is deprecated. Please switch to dirpath and filename instead"
            )
            if params.dirpath is None:
                params.dirpath = Path(params.filepath).parent
            if params.filename is None:
                params.filename = Path(params.filepath).name
        with open_dict(params):
            del params["filepath"]
    if params.dirpath is None:
        params.dirpath = Path(log_dir / 'checkpoints')
    if params.filename is None:
        params.filename = f'{name}--{{{params.monitor}:.2f}}-{{epoch}}'
    if params.prefix is None:
        params.prefix = name
    NeMoModelCheckpoint.CHECKPOINT_NAME_LAST = params.filename + '-last'

    logging.debug(params.dirpath)
    logging.debug(params.filename)
    logging.debug(params.prefix)

    if "val" in params.monitor:
        if (trainer.max_epochs is not None and trainer.max_epochs != -1
                and trainer.max_epochs < trainer.check_val_every_n_epoch):
            logging.error(
                "The checkpoint callback was told to monitor a validation value but trainer.max_epochs("
                f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch({trainer.check_val_every_n_epoch}"
                f"). It is very likely this run will fail with ModelCheckpoint(monitor='{params.monitor}') not found "
                "in the returned metrics. Please ensure that validation is run within trainer.max_epochs."
            )
        elif trainer.max_steps is not None:
            logging.warning(
                "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to "
                f"{trainer.max_steps}. Please ensure that max_steps will run for at least "
                f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out."
            )

    checkpoint_callback = NeMoModelCheckpoint(**params)
    trainer.callbacks.append(checkpoint_callback)
Esempio n. 12
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """
        Base class from which all NeMo models should inherit

        Args:
            cfg (DictConfig):  configuration object.
                The cfg object should have (optionally) the following sub-configs:

                * train_ds - to instantiate training dataset
                * validation_ds - to instantiate validation dataset
                * test_ds - to instantiate testing dataset
                * optim - to instantiate optimizer with learning rate scheduler

            trainer (Optional): Pytorch Lightning Trainer instance
        """
        if not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg constructor argument must be of type DictConfig but got {type(cfg)} instead."
            )
        if trainer is not None and not isinstance(trainer, Trainer):
            raise ValueError(
                f"trainer constructor argument must be either None or pytroch_lightning.Trainer. But got {type(trainer)} instead."
            )
        super().__init__()
        if 'target' not in cfg:
            # This is for Jarvis service.
            OmegaConf.set_struct(cfg, False)
            cfg.target = "{0}.{1}".format(self.__class__.__module__,
                                          self.__class__.__name__)
            OmegaConf.set_struct(cfg, True)

        config = OmegaConf.to_container(cfg, resolve=True)
        config = OmegaConf.create(config)
        OmegaConf.set_struct(config, True)

        self._cfg = config

        self.save_hyperparameters(self._cfg)
        self._train_dl = None
        self._validation_dl = None
        self._test_dl = None
        self._optimizer = None
        self._scheduler = None
        self._trainer = trainer

        # Set device_id in AppState
        if torch.cuda.is_available() and torch.cuda.current_device(
        ) is not None:
            app_state = AppState()
            app_state.device_id = torch.cuda.current_device()

        if self._cfg is not None and not self._is_model_being_restored():
            if 'train_ds' in self._cfg and self._cfg.train_ds is not None:
                self.setup_training_data(self._cfg.train_ds)

            if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None:
                self.setup_multiple_validation_data(val_data_config=None)

            if 'test_ds' in self._cfg and self._cfg.test_ds is not None:
                self.setup_multiple_test_data(test_data_config=None)

        else:
            if 'train_ds' in self._cfg and self._cfg.train_ds is not None:
                logging.warning(
                    f"Please call the ModelPT.setup_training_data() method "
                    f"and provide a valid configuration file to setup the train data loader.\n"
                    f"Train config : \n{OmegaConf.to_yaml(self._cfg.train_ds)}"
                )

            if 'validation_ds' in self._cfg and self._cfg.validation_ds is not None:
                logging.warning(
                    f"Please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method "
                    f"and provide a valid configuration file to setup the validation data loader(s). \n"
                    f"Validation config : \n{OmegaConf.to_yaml(self._cfg.validation_ds)}"
                )

            if 'test_ds' in self._cfg and self._cfg.test_ds is not None:
                logging.warning(
                    f"Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method "
                    f"and provide a valid configuration file to setup the test data loader(s).\n"
                    f"Test config : \n{OmegaConf.to_yaml(self._cfg.test_ds)}")

        # ModelPT wrappers over subclass implementations
        self.training_step = model_utils.wrap_training_step(self.training_step)
    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        """
        Unlike other models' export() this one creates 5 output files, not 3:
        punct_<output> - fused punctuation model (BERT+PunctuationClassifier)
        capit_<output> - fused capitalization model (BERT+CapitalizationClassifier)
        bert_<output> - common BERT neural net
        punct_classifier_<output> - Punctuation Classifier neural net
        capt_classifier_<output> - Capitalization Classifier neural net
        """
        if input_example is not None or output_example is not None:
            logging.warning(
                "Passed input and output examples will be ignored and recomputed since"
                " PunctuationCapitalizationModel consists of three separate models with different"
                " inputs and outputs.")

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output1 = os.path.join(os.path.dirname(output),
                               'bert_' + os.path.basename(output))
        output1_descr = qual_name + ' BERT exported to ONNX'
        bert_model_onnx = self.bert_model.export(
            output1,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output2 = os.path.join(os.path.dirname(output),
                               'punct_classifier_' + os.path.basename(output))
        output2_descr = qual_name + ' Punctuation Classifier exported to ONNX'
        punct_classifier_onnx = self.punct_classifier.export(
            output2,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        output3 = os.path.join(os.path.dirname(output),
                               'capit_classifier_' + os.path.basename(output))
        output3_descr = qual_name + ' Capitalization Classifier exported to ONNX'
        capit_classifier_onnx = self.capit_classifier.export(
            output3,
            None,  # computed by input_example()
            None,
            verbose,
            export_params,
            do_constant_folding,
            keep_initializers_as_inputs,
            onnx_opset_version,
            try_script,
            set_eval,
            check_trace,
            use_dynamic_axes,
        )

        punct_output_model = attach_onnx_to_onnx(bert_model_onnx,
                                                 punct_classifier_onnx, "PTCL")
        output4 = os.path.join(os.path.dirname(output),
                               'punct_' + os.path.basename(output))
        output4_descr = qual_name + ' Punctuation BERT+Classifier exported to ONNX'
        onnx.save(punct_output_model, output4)
        capit_output_model = attach_onnx_to_onnx(bert_model_onnx,
                                                 capit_classifier_onnx, "CPCL")
        output5 = os.path.join(os.path.dirname(output),
                               'capit_' + os.path.basename(output))
        output5_descr = qual_name + ' Capitalization BERT+Classifier exported to ONNX'
        onnx.save(capit_output_model, output5)
        return (
            [output1, output2, output3, output4, output5],
            [
                output1_descr, output2_descr, output3_descr, output4_descr,
                output5_descr
            ],
        )
Esempio n. 14
0
    def __init__(
        self,
        audio_tar_filepaths,
        manifest_filepath,
        labels,
        batch_size,
        sample_rate=16000,
        int_values=False,
        bos_id=None,
        eos_id=None,
        pad_id=None,
        min_duration=0.1,
        max_duration=None,
        normalize_transcripts=True,
        trim_silence=False,
        shuffle_n=0,
        num_workers=0,
        augmentor: Optional[Union[AudioAugmentor,
                                  Dict[str, Dict[str, Any]]]] = None,
    ):
        super().__init__()
        self._sample_rate = sample_rate

        if augmentor is not None:
            augmentor = _process_augmentations(augmentor)

        self.collection = ASRAudioText(
            manifests_files=manifest_filepath.split(','),
            parser=make_parser(labels=labels,
                               name='en',
                               do_normalize=normalize_transcripts),
            min_duration=min_duration,
            max_duration=max_duration,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.featurizer = WaveformFeaturizer(sample_rate=self._sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)

        self.trim = trim_silence
        self.eos_id = eos_id
        self.bos_id = bos_id

        # Used in creating a sampler (in Actions).
        self._batch_size = batch_size
        self._num_workers = num_workers
        pad_id = 0 if pad_id is None else pad_id
        self.collate_fn = partial(seq_collate_fn, token_pad_value=pad_id)

        # Check for distributed and partition shards accordingly
        if torch.distributed.is_initialized():
            global_rank = torch.distributed.get_rank()
            world_size = torch.distributed.get_world_size()

            if isinstance(audio_tar_filepaths, str):
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if len(audio_tar_filepaths) % world_size != 0:
                logging.warning(
                    f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                    f"by number of distributed workers ({world_size}).")

            begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank
            end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
            audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]

        # Put together WebDataset
        self._dataset = (
            wd.Dataset(audio_tar_filepaths).shuffle(shuffle_n).rename(
                audio='wav', key='__key__').to_tuple('audio', 'key').pipe(
                    self._filter).map(f=self._build_sample))
Esempio n. 15
0
    def _setup_dataloader_from_config(self, config: DictConfig):

        OmegaConf.set_struct(config, False)
        config.is_regression_task = self.is_regression_task
        OmegaConf.set_struct(config, True)

        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        featurizer = WaveformFeaturizer(sample_rate=config['sample_rate'],
                                        int_values=config.get(
                                            'int_values', False),
                                        augmentor=augmentor)
        shuffle = config['shuffle']

        # Instantiate tarred dataset loader or normal dataset loader
        if config.get('is_tarred', False):
            if ('tarred_audio_filepaths' in config
                    and config['tarred_audio_filepaths'] is None) or (
                        'manifest_filepath' in config
                        and config['manifest_filepath'] is None):
                logging.warning(
                    "Could not load dataset as `manifest_filepath` is None or "
                    f"`tarred_audio_filepaths` is None. Provided config : {config}"
                )
                return None

            if 'vad_stream' in config and config['vad_stream']:
                logging.warning(
                    "VAD inference does not support tarred dataset now")
                return None

            shuffle_n = config.get('shuffle_n', 4 *
                                   config['batch_size']) if shuffle else 0
            dataset = audio_to_label_dataset.get_tarred_classification_label_dataset(
                featurizer=featurizer,
                config=OmegaConf.to_container(config),
                shuffle_n=shuffle_n,
                global_rank=self.global_rank,
                world_size=self.world_size,
            )
            shuffle = False
            batch_size = config['batch_size']
            collate_func = dataset.collate_fn

        else:
            if 'manifest_filepath' in config and config[
                    'manifest_filepath'] is None:
                logging.warning(
                    f"Could not load dataset as `manifest_filepath` is None. Provided config : {config}"
                )
                return None

            if 'vad_stream' in config and config['vad_stream']:
                logging.info("Perform streaming frame-level VAD")
                dataset = audio_to_label_dataset.get_speech_label_dataset(
                    featurizer=featurizer,
                    config=OmegaConf.to_container(config))
                batch_size = 1
                collate_func = dataset.vad_frame_seq_collate_fn
            else:
                dataset = audio_to_label_dataset.get_classification_label_dataset(
                    featurizer=featurizer,
                    config=OmegaConf.to_container(config))
                batch_size = config['batch_size']
                collate_func = dataset.collate_fn

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            collate_fn=collate_func,
            drop_last=config.get('drop_last', False),
            shuffle=shuffle,
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )
Esempio n. 16
0
def main():
    parser = argparse.ArgumentParser(
        parents=[nm_argparse.NemoArgParser()],
        description='AN4 ASR',
        conflict_handler='resolve',
    )

    # Overwrite default args
    parser.add_argument("--train_dataset",
                        type=str,
                        help="training dataset path")
    parser.add_argument("--eval_datasets",
                        type=str,
                        help="validation dataset path")

    # Create new args
    # parser.add_argument("--lm", default="./an4-lm.3gram.binary", type=str)
    parser.add_argument("--batch_size",
                        default=48,
                        type=int,
                        help="size of the training batch")
    parser.add_argument("--lm", default=None, type=str)
    parser.add_argument("--test_after_training", action='store_true')
    parser.add_argument("--momentum", type=float)
    parser.add_argument("--beta1", default=0.95, type=float)
    parser.add_argument("--beta2", default=0.25, type=float)
    parser.add_argument("--do_not_eval_at_start", action='store_true')
    parser.set_defaults(
        model_config="./configs/jasper_an4.yaml",
        train_dataset="~/TestData/an4_dataset/an4_train.json",
        eval_datasets="~/TestData/an4_dataset/an4_val.json",
        work_dir="./tmp",
        optimizer="novograd",
        num_epochs=50,
        lr=0.02,
        weight_decay=0.005,
        checkpoint_save_freq=1000,
        eval_freq=100,
        amp_opt_level="O1",
    )

    args = parser.parse_args()
    betas = (args.beta1, args.beta2)

    wer_thr = 0.20
    beam_wer_thr = 0.15

    nf = nemo.core.NeuralModuleFactory(
        local_rank=args.local_rank,
        files_to_copy=[__file__],
        optimization_level=args.amp_opt_level,
        random_seed=0,
        log_dir=args.work_dir,
        create_tb_writer=True,
        cudnn_benchmark=args.cudnn_benchmark,
    )
    tb_writer = nf.tb_writer
    checkpoint_dir = nf.checkpoint_dir

    # Load model definition
    yaml = YAML(typ="safe")
    with open(args.model_config) as f:
        jasper_params = yaml.load(f)
    # Get vocabulary.
    vocab = jasper_params['labels']

    (
        loss,
        eval_tensors,
        callbacks,
        total_steps,
        log_probs_e,
        encoded_len_e,
    ) = create_dags(args.model_config, vocab, args, nf)

    nf.train(
        tensors_to_optimize=[loss],
        callbacks=callbacks,
        optimizer=args.optimizer,
        lr_policy=CosineAnnealing(total_steps=total_steps,
                                  min_lr=args.lr / 100),
        optimization_params={
            "num_epochs": args.num_epochs,
            "max_steps": args.max_steps,
            "lr": args.lr,
            "momentum": args.momentum,
            "betas": betas,
            "weight_decay": args.weight_decay,
            "grad_norm_clip": None,
        },
        batches_per_step=args.iter_per_step,
        amp_max_loss_scale=256.0,
        # synced_batchnorm=(nf.global_rank is not None),
    )

    if args.test_after_training:
        logging.info("Testing greedy and beam search with LM WER.")
        # Create BeamSearch NM
        if nf.world_size > 1 or args.lm is None:
            logging.warning(
                "Skipping beam search WER as it does not work if doing distributed training."
            )
        else:
            beam_search_with_lm = nemo_asr.BeamSearchDecoderWithLM(
                vocab=vocab,
                beam_width=64,
                alpha=2.0,
                beta=1.5,
                lm_path=args.lm,
                num_cpus=max(os.cpu_count(), 1),
            )
            beam_predictions = beam_search_with_lm(
                log_probs=log_probs_e, log_probs_length=encoded_len_e)
            eval_tensors.append(beam_predictions)

        evaluated_tensors = nf.infer(eval_tensors)
        if nf.global_rank in [0, None]:
            greedy_hypotheses = post_process_predictions(
                evaluated_tensors[1], vocab)
            references = post_process_transcripts(evaluated_tensors[2],
                                                  evaluated_tensors[3], vocab)
            wer = word_error_rate(hypotheses=greedy_hypotheses,
                                  references=references)
            logging.info("Greedy WER: {:.2f}%".format(wer * 100))
            if wer > wer_thr:
                nf.sync_all_processes(False)
                raise ValueError(f"Final eval greedy WER {wer * 100:.2f}% > :"
                                 f"than {wer_thr * 100:.2f}%")
        nf.sync_all_processes()

        if nf.world_size == 1 and args.lm is not None:
            beam_hypotheses = []
            # Over mini-batch
            for i in evaluated_tensors[-1]:
                # Over samples
                for j in i:
                    beam_hypotheses.append(j[0][1])

            beam_wer = word_error_rate(hypotheses=beam_hypotheses,
                                       references=references)
            logging.info("Beam WER {:.2f}%".format(beam_wer * 100))
            assert beam_wer <= beam_wer_thr, "Final eval beam WER {:.2f}%  > than {:.2f}%".format(
                beam_wer * 100, beam_wer_thr * 100)
            assert beam_wer <= wer, "Final eval beam WER > than the greedy WER."

        # Reload model weights and train for extra 10 epochs
        checkpointer_callback = nemo.core.CheckpointCallback(
            folder=checkpoint_dir,
            step_freq=args.checkpoint_save_freq,
            force_load=True,
        )

        # Distributed Data Parallel changes the underlying class so we need
        # to reinstantiate Encoder and Decoder
        args.num_epochs += 10
        previous_step_count = total_steps
        loss, eval_tensors, callbacks, total_steps, _, _ = create_dags(
            args.model_config, vocab, args, nf)

        nf.reset_trainer()
        nf.train(
            tensors_to_optimize=[loss],
            callbacks=callbacks,
            optimizer=args.optimizer,
            lr_policy=CosineAnnealing(warmup_steps=previous_step_count,
                                      total_steps=total_steps),
            optimization_params={
                "num_epochs": args.num_epochs,
                "lr": args.lr / 100,
                "momentum": args.momentum,
                "betas": betas,
                "weight_decay": args.weight_decay,
                "grad_norm_clip": None,
            },
            reset=True,
            amp_max_loss_scale=256.0,
            # synced_batchnorm=(nf.global_rank is not None),
        )

        evaluated_tensors = nf.infer(eval_tensors)
        if nf.global_rank in [0, None]:
            greedy_hypotheses = post_process_predictions(
                evaluated_tensors[1], vocab)
            references = post_process_transcripts(evaluated_tensors[2],
                                                  evaluated_tensors[3], vocab)
            wer_new = word_error_rate(hypotheses=greedy_hypotheses,
                                      references=references)
            logging.info("New greedy WER: {:.2f}%".format(wer_new * 100))
            if wer_new > wer * 1.1:
                nf.sync_all_processes(False)
                raise ValueError(
                    f"Fine tuning: new WER {wer_new * 100:.2f}% > than the "
                    f"previous WER {wer * 100:.2f}%")
        nf.sync_all_processes()

        # Open the log file and ensure that epochs is strictly increasing
        if nf._exp_manager.log_file:
            epochs = []
            with open(nf._exp_manager.log_file, "r") as log_file:
                line = log_file.readline()
                while line:
                    index = line.find("Starting epoch")
                    if index != -1:
                        epochs.append(int(line[index +
                                               len("Starting epoch"):]))
                    line = log_file.readline()
            for i, e in enumerate(epochs):
                if i != e:
                    raise ValueError("Epochs from logfile was not understood")
Esempio n. 17
0
    def __init__(
        self,
        audio_files: List[str],
        durations: List[float],
        labels: List[Union[int, str]],
        offsets: List[Optional[float]],
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_number: Optional[int] = None,
        do_sort_by_duration: bool = False,
        index_by_file_id: bool = False,
    ):
        """Instantiates audio-label manifest with filters and preprocessing.

        Args:
            audio_files: List of audio files.
            durations: List of float durations.
            labels: List of labels.
            offsets: List of offsets or None.
            min_duration: Minimum duration to keep entry with (default: None).
            max_duration: Maximum duration to keep entry with (default: None).
            max_number: Maximum number of samples to collect.
            do_sort_by_duration: True if sort samples list by duration.
            index_by_file_id: If True, saves a mapping from filename base (ID) to index in data.
        """

        if index_by_file_id:
            self.mapping = {}
        output_type = self.OUTPUT_TYPE
        data, duration_filtered = [], 0.0
        for audio_file, duration, command, offset in zip(
                audio_files, durations, labels, offsets):
            # Duration filters.
            if min_duration is not None and duration < min_duration:
                duration_filtered += duration
                continue

            if max_duration is not None and duration > max_duration:
                duration_filtered += duration
                continue

            data.append(output_type(audio_file, duration, command, offset))

            if index_by_file_id:
                file_id, _ = os.path.splitext(os.path.basename(audio_file))
                self.mapping[file_id] = len(data) - 1

            # Max number of entities filter.
            if len(data) == max_number:
                break

        if do_sort_by_duration:
            if index_by_file_id:
                logging.warning(
                    "Tried to sort dataset by duration, but cannot since index_by_file_id is set."
                )
            else:
                data.sort(key=lambda entity: entity.duration)

        logging.info(
            "Filtered duration for loading collection is %f.",
            duration_filtered,
        )
        self.uniq_labels = sorted(set(map(lambda x: x.label, data)))
        logging.info("# {} files loaded accounting to # {} labels".format(
            len(data), len(self.uniq_labels)))

        super().__init__(data)
Esempio n. 18
0
    def __setup_dataloader_from_config(self, config: Optional[Dict]):
        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        featurizer = WaveformFeaturizer(sample_rate=config['sample_rate'],
                                        int_values=config.get(
                                            'int_values', False),
                                        augmentor=augmentor)
        shuffle = config.get('shuffle', False)
        if config.get('is_tarred', False):
            if ('tarred_audio_filepaths' in config
                    and config['tarred_audio_filepaths'] is None) or (
                        'manifest_filepath' in config
                        and config['manifest_filepath'] is None):
                logging.warning(
                    "Could not load dataset as `manifest_filepath` was None or "
                    f"`tarred_audio_filepaths` is None. Provided config : {config}"
                )
                return None

            shuffle_n = config.get('shuffle_n', 4 *
                                   config['batch_size']) if shuffle else 0
            dataset = get_tarred_speech_label_dataset(
                featurizer=featurizer,
                config=config,
                shuffle_n=shuffle_n,
                global_rank=self.global_rank,
                world_size=self.world_size,
            )
            shuffle = False
        else:
            if 'manifest_filepath' in config and config[
                    'manifest_filepath'] is None:
                logging.warning(
                    f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}"
                )
                return None

            dataset = AudioToSpeechLabelDataset(
                manifest_filepath=config['manifest_filepath'],
                labels=config['labels'],
                featurizer=featurizer,
                max_duration=config.get('max_duration', None),
                min_duration=config.get('min_duration', None),
                trim=config.get('trim_silence', False),
                time_length=config.get('time_length', 8),
                shift_length=config.get('shift_length', 0.75),
                normalize_audio=config.get('normalize_audio', False),
            )

        if type(dataset) is ChainDataset:
            collate_ds = dataset.datasets[0]
        else:
            collate_ds = dataset

        # self.labels = collate_ds.labels

        if self.task == 'diarization':
            logging.info("Setting up diarization parameters")
            collate_fn = collate_ds.sliced_seq_collate_fn
            shuffle = False
        else:
            logging.info("Setting up identification parameters")
            collate_fn = collate_ds.fixed_seq_collate_fn

        batch_size = config['batch_size']
        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            collate_fn=collate_fn,
            drop_last=config.get('drop_last', False),
            shuffle=shuffle,
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )
Esempio n. 19
0
    def setup(self, stage: str) -> None:
        """ PTL hook that is called after DDP is initialized.
            Called at the beginning of fit and test.

        Args:
            stage (str): either 'fit' or 'test'
        """
        # TODO: implement model parallel for test stage
        if stage == 'fit':
            # adds self.bert_model config to .nemo file
            if hasattr(self, 'bert_model') and self.bert_model is not None:
                self.register_bert_model()

            app_state = AppState()

            if app_state.model_parallel_size is not None:

                if app_state.model_parallel_group is None:
                    self.init_model_parallel(app_state.global_rank,
                                             app_state.world_size)

                # mpu grad clipping needs parameters to have the attribute model_parallel
                parameters = self._trainer.get_model().parameters()
                for p in parameters:
                    if not hasattr(p, 'model_parallel'):
                        p.model_parallel = False

                # Update PTL trainer to use our configure_ddp
                self._trainer.accelerator_backend.ddp_plugin.configure_ddp = self.configure_ddp
                # Update PTL trainer to use our _clip_gradients
                self._trainer.accelerator_backend._clip_gradients = self._clip_gradients
                self._trainer.checkpoint_connector = NLPCheckpointConnector(
                    self._trainer)

                # Configure checkpointing for model parallel
                if app_state.create_checkpoint_callback:
                    # global rank 0 is configured by exp_manager
                    if not is_global_rank_zero(
                    ) and app_state.data_parallel_rank == 0:
                        configure_checkpointing(
                            self._trainer,
                            app_state.log_dir,
                            app_state.checkpoint_name,
                            app_state.checkpoint_callback_params,
                        )

                if isinstance(self.bert_model, MegatronBertEncoder):
                    self.bert_model.complete_lazy_init()

                    # model parallel checkpoints need to be restored after torch.distributed is initialized
                    if self._trainer.resume_from_checkpoint is not None:
                        # update path based on model parallel rank
                        filepath = self._trainer.resume_from_checkpoint
                        dirname = os.path.dirname(os.path.dirname(filepath))
                        basename = os.path.basename(filepath)
                        filepath = f'{dirname}/mp_rank_{app_state.model_parallel_rank:02d}/{basename}'
                        self._trainer.resume_from_checkpoint = filepath
                        logging.info(
                            f'Resuming training from checkpoint {self._trainer.resume_from_checkpoint}'
                        )
                        # need to set checkpoint version for megatron-lm
                        checkpoint_version = torch.load(
                            self._trainer.resume_from_checkpoint).get(
                                'checkpoint_version', None)
                        if checkpoint_version is not None:
                            set_checkpoint_version(checkpoint_version)
                        else:
                            logging.warning(
                                'Megatron-lm checkpoint version not found. Setting checkpoint_version to 0.'
                            )
                            set_checkpoint_version(0)
                    else:
                        logging.info(
                            f"Restoring from pretrained model parallel checkpoint: {self.bert_model._restore_path}"
                        )
                        self.bert_model.restore_weights(
                            self.bert_model._restore_path)

                    logging.info(
                        "Replacing sampler with model parallel sampler")
                    mp_sampler = torch.utils.data.distributed.DistributedSampler(
                        self._train_dl.dataset,
                        num_replicas=app_state.data_parallel_size,
                        rank=app_state.data_parallel_rank,
                    )
                    mp_dl = self._trainer.replace_sampler(
                        self._train_dl, mp_sampler)
                    self._train_dl = mp_dl
                else:
                    raise NotImplementedError(
                        f'The BERT encoder: {self.bert_model} does not support model parallelism yet.'
                    )
            else:
                # Megatron without model parallelism
                self.complete_megatron_init()
        else:
            # testing stage
            self.complete_megatron_init()
Esempio n. 20
0
    def __init__(
        self,
        text_tar_filepaths: str,
        metadata_path: str,
        tokenizer,
        max_seq_length: int = 512,
        batch_step: int = None,
        shuffle_n: int = 1,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
    ):
        super(TarredL2RLanguageModelingDataset, self).__init__()

        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.batch_step = batch_step or self.max_seq_length

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        with open(metadata_path, 'r') as f:
            metadata = json.load(f)

        self.metadata = metadata

        if isinstance(text_tar_filepaths, str):
            # Replace '(', '[', '<' and '_OP_' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "{")

            # Replace ')', ']', '>' and '_CL_' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in text_tar_filepaths:
                    text_tar_filepaths = text_tar_filepaths.replace(bkey, "}")

        if isinstance(text_tar_filepaths, str):
            # Brace expand
            text_tar_filepaths = list(
                braceexpand.braceexpand(text_tar_filepaths))

        if shard_strategy == 'scatter':
            logging.info(
                "All tarred dataset shards will be scattered evenly across all nodes."
            )

            if len(text_tar_filepaths) % world_size != 0:
                logging.warning(
                    f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible "
                    f"by number of distributed workers ({world_size}).")

            begin_idx = (len(text_tar_filepaths) // world_size) * global_rank
            end_idx = begin_idx + (len(text_tar_filepaths) // world_size)
            text_tar_filepaths = text_tar_filepaths[begin_idx:end_idx]
            logging.info(
                "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                global_rank, begin_idx, end_idx)

        elif shard_strategy == 'replicate':
            logging.info(
                "All tarred dataset shards will be replicated across all nodes."
            )

        else:
            raise ValueError(
                f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
            )

        self.tarpath = text_tar_filepaths

        # Put together WebDataset
        self._dataset = (
            wd.Dataset(text_tar_filepaths).shuffle(shuffle_n).rename(
                npy='npy',
                key='__key__').to_tuple('npy',
                                        'key').map(f=self._build_sample))
Esempio n. 21
0
    def change_vocabulary(self,
                          new_vocabulary: List[str],
                          decoding_cfg: Optional[DictConfig] = None):
        """
        Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a pre-trained model.
        This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
        use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need
        model to learn capitalization, punctuation and/or special characters.

        Args:
            new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \
                this is target alphabet.
            decoding_cfg: A config for the decoder, which is optional. If the decoding type
                needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here.

        Returns: None

        """
        if self.joint.vocabulary == new_vocabulary:
            logging.warning(
                f"Old {self.joint.vocabulary} and new {new_vocabulary} match. Not changing anything."
            )
        else:
            if new_vocabulary is None or len(new_vocabulary) == 0:
                raise ValueError(
                    f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}'
                )

            joint_config = self.joint.to_config_dict()
            new_joint_config = copy.deepcopy(joint_config)
            new_joint_config['vocabulary'] = new_vocabulary
            new_joint_config['num_classes'] = len(new_vocabulary)
            del self.joint
            self.joint = EncDecRNNTModel.from_config_dict(new_joint_config)

            decoder_config = self.decoder.to_config_dict()
            new_decoder_config = copy.deepcopy(decoder_config)
            new_decoder_config.vocab_size = len(new_vocabulary)
            del self.decoder
            self.decoder = EncDecRNNTModel.from_config_dict(new_decoder_config)

            del self.loss
            loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(
                self.cfg.get('loss', None))
            self.loss = RNNTLoss(
                num_classes=self.joint.num_classes_with_blank - 1,
                loss_name=loss_name,
                loss_kwargs=loss_kwargs)

            if decoding_cfg is None:
                # Assume same decoding config as before
                decoding_cfg = self.cfg.decoding

            self.decoding = RNNTDecoding(
                decoding_cfg=decoding_cfg,
                decoder=self.decoder,
                joint=self.joint,
                vocabulary=self.joint.vocabulary,
            )

            self.wer = RNNTWER(
                decoding=self.decoding,
                batch_dim_index=self.wer.batch_dim_index,
                use_cer=self.wer.use_cer,
                log_prediction=self.wer.log_prediction,
                dist_sync_on_step=True,
            )

            # Setup fused Joint step
            if self.joint.fuse_loss_wer:
                self.joint.set_loss(self.loss)
                self.joint.set_wer(self.wer)

            # Update config
            with open_dict(self.cfg.joint):
                self.cfg.joint = new_joint_config

            with open_dict(self.cfg.decoder):
                self.cfg.decoder = new_decoder_config

            with open_dict(self.cfg.decoding):
                self.cfg.decoding = decoding_cfg

            logging.info(
                f"Changed decoder to output to {self.joint.vocabulary} vocabulary."
            )
Esempio n. 22
0
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_text_processing.inverse_text_normalization.taggers.tokenize_and_classify import ClassifyFst
from nemo_text_processing.inverse_text_normalization.verbalizers.verbalize_final import VerbalizeFinalFst

from nemo.utils import logging

try:
    import pynini

    PYNINI_AVAILABLE = True
except (ModuleNotFoundError, ImportError):
    logging.warning("`pynini` is not installed ! \n"
                    "Please run the `nemo_text_processing/setup.sh` script"
                    "prior to usage of this toolkit.")

    PYNINI_AVAILABLE = False
Esempio n. 23
0
    def __init__(
        self,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        parser: Callable,
        sample_rate: int,
        int_values: bool = False,
        augmentor: Optional[
            'nemo.collections.asr.parts.perturb.AudioAugmentor'] = None,
        shuffle_n: int = 0,
        min_duration: Optional[float] = None,
        max_duration: Optional[float] = None,
        max_utts: int = 0,
        trim: bool = False,
        bos_id: Optional[int] = None,
        eos_id: Optional[int] = None,
        pad_id: int = 0,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
    ):
        self.collection = collections.ASRAudioText(
            manifests_files=manifest_filepath.split(','),
            parser=parser,
            min_duration=min_duration,
            max_duration=max_duration,
            max_number=max_utts,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.featurizer = WaveformFeaturizer(sample_rate=sample_rate,
                                             int_values=int_values,
                                             augmentor=augmentor)
        self.trim = trim
        self.eos_id = eos_id
        self.bos_id = bos_id
        self.pad_id = pad_id

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        if isinstance(audio_tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "}")

        # Check for distributed and partition shards accordingly
        if world_size > 1:
            if isinstance(audio_tar_filepaths, str):
                # Brace expand
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if shard_strategy == 'scatter':
                logging.info(
                    "All tarred dataset shards will be scattered evenly across all nodes."
                )

                if len(audio_tar_filepaths) % world_size != 0:
                    logging.warning(
                        f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                        f"by number of distributed workers ({world_size}).")

                begin_idx = (len(audio_tar_filepaths) //
                             world_size) * global_rank
                end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
                audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
                logging.info(
                    "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                    global_rank, begin_idx, end_idx)

            elif shard_strategy == 'replicate':
                logging.info(
                    "All tarred dataset shards will be replicated across all nodes."
                )

            else:
                raise ValueError(
                    f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
                )

        # Put together WebDataset
        self._dataset = wd.WebDataset(audio_tar_filepaths)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info(
                "WebDataset will not shuffle files within the tar files.")

        self._dataset = (self._dataset.rename(
            audio='wav', key='__key__').to_tuple('audio', 'key').pipe(
                self._filter).map(f=self._build_sample))
Esempio n. 24
0
def main(cfg):
    if not cfg.dataset:
        raise ValueError("You must input the path of json file of evaluation data")

    # each line of dataset should be have different audio_filepath and unique name to simplfiy edge cases or conditions
    key_meta_map = {}
    with open(cfg.dataset, 'r') as manifest:
        for line in manifest.readlines():
            audio_filepath = json.loads(line.strip())['audio_filepath']
            uniq_audio_name = audio_filepath.split('/')[-1].rsplit('.', 1)[0]
            if uniq_audio_name in key_meta_map:
                raise ValueError("Please make sure each line is with different audio_filepath! ")
            key_meta_map[uniq_audio_name] = {'audio_filepath': audio_filepath}

    # Prepare manifest for streaming VAD
    manifest_vad_input = cfg.dataset
    if cfg.prepare_manifest.auto_split:
        logging.info("Split long audio file to avoid CUDA memory issue")
        logging.debug("Try smaller split_duration if you still have CUDA memory issue")
        config = {
            'input': manifest_vad_input,
            'window_length_in_sec': cfg.vad.parameters.window_length_in_sec,
            'split_duration': cfg.prepare_manifest.split_duration,
            'num_workers': cfg.num_workers,
            'prepared_manfiest_vad_input': cfg.prepared_manfiest_vad_input,
        }
        manifest_vad_input = prepare_manifest(config)
    else:
        logging.warning(
            "If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it."
        )

    torch.set_grad_enabled(False)
    vad_model = init_vad_model(cfg.vad.model_path)

    # setup_test_data
    vad_model.setup_test_data(
        test_data_config={
            'vad_stream': True,
            'sample_rate': 16000,
            'manifest_filepath': manifest_vad_input,
            'labels': ['infer',],
            'num_workers': cfg.num_workers,
            'shuffle': False,
            'window_length_in_sec': cfg.vad.parameters.window_length_in_sec,
            'shift_length_in_sec': cfg.vad.parameters.shift_length_in_sec,
            'trim_silence': False,
            'normalize_audio': cfg.vad.parameters.normalize_audio,
        }
    )

    vad_model = vad_model.to(device)
    vad_model.eval()

    if not os.path.exists(cfg.frame_out_dir):
        os.mkdir(cfg.frame_out_dir)
    else:
        logging.warning(
            "Note frame_out_dir exists. If new file has same name as file inside existing folder, it will append result to existing file and might cause mistakes for next steps."
        )

    logging.info("Generating frame level prediction ")
    pred_dir = generate_vad_frame_pred(
        vad_model=vad_model,
        window_length_in_sec=cfg.vad.parameters.window_length_in_sec,
        shift_length_in_sec=cfg.vad.parameters.shift_length_in_sec,
        manifest_vad_input=manifest_vad_input,
        out_dir=cfg.frame_out_dir,
    )
    logging.info(
        f"Finish generating VAD frame level prediction with window_length_in_sec={cfg.vad.parameters.window_length_in_sec} and shift_length_in_sec={cfg.vad.parameters.shift_length_in_sec}"
    )

    # overlap smoothing filter
    if cfg.gen_overlap_seq:
        # Generate predictions with overlapping input segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple segments.
        # smoothing_method would be either in majority vote (median) or average (mean)
        logging.info("Generating predictions with overlapping input segments")
        smoothing_pred_dir = generate_overlap_vad_seq(
            frame_pred_dir=pred_dir,
            smoothing_method=cfg.vad.parameters.smoothing,
            overlap=cfg.vad.parameters.overlap,
            window_length_in_sec=cfg.vad.parameters.window_length_in_sec,
            shift_length_in_sec=cfg.vad.parameters.shift_length_in_sec,
            num_workers=cfg.num_workers,
            out_dir=cfg.smoothing_out_dir,
        )
        logging.info(
            f"Finish generating predictions with overlapping input segments with smoothing_method={cfg.vad.parameters.smoothing} and overlap={cfg.vad.parameters.overlap}"
        )
        pred_dir = smoothing_pred_dir

    # postprocessing and generate speech segments
    if cfg.gen_seg_table:
        logging.info("Converting frame level prediction to speech/no-speech segment in start and end times format.")
        table_out_dir = generate_vad_segment_table(
            vad_pred_dir=pred_dir,
            postprocessing_params=cfg.vad.parameters.postprocessing,
            shift_length_in_sec=cfg.vad.parameters.shift_length_in_sec,
            num_workers=cfg.num_workers,
            out_dir=cfg.table_out_dir,
        )
        logging.info(
            f"Finish generating speech semgents table with postprocessing_params: {cfg.vad.parameters.postprocessing}"
        )

    if cfg.write_to_manifest:
        for i in key_meta_map:
            key_meta_map[i]['rttm_filepath'] = os.path.join(table_out_dir, i + ".txt")

        if not cfg.out_manifest_filepath:
            out_manifest_filepath = "vad_out.json"
        else:
            out_manifest_filepath = cfg.out_manifest_filepath
        out_manifest_filepath = write_rttm2manifest(key_meta_map, out_manifest_filepath)
        logging.info(f"Writing VAD output to manifest: {out_manifest_filepath}")
Esempio n. 25
0
def write_rttm2manifest(AUDIO_RTTM_MAP, manifest_file):
    """
    writes manifest file based on rttm files (or vad table out files). This manifest file would be used by 
    speaker diarizer to compute embeddings and cluster them. This function also takes care of overlap time stamps

    Args:
    AUDIO_RTTM_MAP: dict containing keys to uniqnames, that contains audio filepath and rttm_filepath as its contents,
    these are used to extract oracle vad timestamps.
    manifest (str): path to write manifest file

    Returns:
    manifest (str): path to write manifest file
    """

    with open(manifest_file, 'w') as outfile:
        for key in AUDIO_RTTM_MAP:
            rttm_filename = AUDIO_RTTM_MAP[key]['rttm_filepath']
            if rttm_filename and os.path.exists(rttm_filename):
                f = open(rttm_filename, 'r')
            else:
                raise FileNotFoundError(
                    "Requested to construct manifest from rttm with oracle VAD option or from NeMo VAD but received filename as {}"
                    .format(rttm_filename))

            audio_path = AUDIO_RTTM_MAP[key]['audio_filepath']
            if AUDIO_RTTM_MAP[key].get('duration', None):
                max_duration = AUDIO_RTTM_MAP[key]['duration']
            else:
                sound = sf.SoundFile(audio_path)
                max_duration = sound.frames / sound.samplerate

            lines = f.readlines()
            time_tup = (-1, -1)
            for line in lines:
                vad_out = line.strip().split()
                if len(vad_out) > 3:
                    start, dur, _ = float(vad_out[3]), float(
                        vad_out[4]), vad_out[7]
                else:
                    start, dur, _ = float(vad_out[0]), float(
                        vad_out[1]), vad_out[2]
                start, dur = float("{:.3f}".format(start)), float(
                    "{:.3f}".format(dur))

                if start == 0 and dur == 0:  # No speech segments
                    continue
                else:

                    if time_tup[0] >= 0 and start > time_tup[1]:
                        dur2 = float("{:.3f}".format(time_tup[1] -
                                                     time_tup[0]))
                        if time_tup[0] < max_duration and dur2 > 0:
                            meta = {
                                "audio_filepath": audio_path,
                                "offset": time_tup[0],
                                "duration": dur2,
                                "label": 'UNK',
                            }
                            json.dump(meta, outfile)
                            outfile.write("\n")
                        else:
                            logging.warning(
                                "RTTM label has been truncated since start is greater than duration of audio file"
                            )
                        time_tup = (start, start + dur)
                    else:
                        if time_tup[0] == -1:
                            end_time = start + dur
                            if end_time > max_duration:
                                end_time = max_duration
                            time_tup = (start, end_time)
                        else:
                            end_time = max(time_tup[1], start + dur)
                            if end_time > max_duration:
                                end_time = max_duration
                            time_tup = (min(time_tup[0], start), end_time)
            dur2 = float("{:.3f}".format(time_tup[1] - time_tup[0]))
            if time_tup[0] < max_duration and dur2 > 0:
                meta = {
                    "audio_filepath": audio_path,
                    "offset": time_tup[0],
                    "duration": dur2,
                    "label": 'UNK'
                }
                json.dump(meta, outfile)
                outfile.write("\n")
            else:
                logging.warning(
                    "RTTM label has been truncated since start is greater than duration of audio file"
                )
            f.close()
    return manifest_file
Esempio n. 26
0
def _warn_unused_additional_kwargs(loss_name, kwargs):
    if len(kwargs) > 0:
        logging.warning(
            f"Loss function `{loss_name}` was provided with following additional kwargs,\n"
            f"however they were ignored as it is unused.\n"
            f"{kwargs}")
Esempio n. 27
0
def get_log_dir(
    trainer: 'pytorch_lightning.Trainer',
    exp_dir: str = None,
    name: str = None,
    version: str = None,
    explicit_log_dir: str = None,
    use_datetime_version: bool = True,
) -> (Path, str, str, str):
    """
    Obtains the log_dir used for exp_manager.

    Returns:
        log_dir (Path): the log_dir
        exp_dir (str): the base exp_dir without name nor version
        name (str): The name of the experiment
        version (str): The version of the experiment

    Raise:
        LoggerMisconfigurationError: If trainer is incompatible with arguments
        NotFoundError: If resume is True, resume_ignore_no_checkpoint is False, and checkpoints could not be found.
        ValueError: If resume is True, and there were more than 1 checkpoint could found.
    """
    if explicit_log_dir:  # If explicit log_dir was passed, short circuit
        return check_explicit_log_dir(trainer, explicit_log_dir, exp_dir, name,
                                      version)

    # Default exp_dir to ./nemo_experiments if None was passed
    _exp_dir = exp_dir
    if exp_dir is None:
        _exp_dir = str(Path.cwd() / 'nemo_experiments')

    # If the user has already defined a logger for the trainer, use the logger defaults for logging directory
    if trainer.logger is not None:
        if trainer.logger.save_dir:
            if exp_dir:
                raise LoggerMisconfigurationError(
                    "The pytorch lightning trainer that was passed to exp_manager contained a logger, the logger's "
                    f"save_dir was not None, and exp_dir ({exp_dir}) was not None. If trainer.logger.save_dir "
                    "exists, exp_manager will use trainer.logger.save_dir as the logging directory and exp_dir "
                    "must be None.")
            _exp_dir = trainer.logger.save_dir
        if name:
            raise LoggerMisconfigurationError(
                "The pytorch lightning trainer that was passed to exp_manager contained a logger, and name: "
                f"{name} was also passed to exp_manager. If the trainer contains a "
                "logger, exp_manager will use trainer.logger.name, and name passed to exp_manager must be None."
            )
        name = trainer.logger.name
        version = f"version_{trainer.logger.version}"
    # Use user-defined exp_dir, project_name, exp_name, and versioning options
    else:
        name = name or "default"
        version = version or os.environ.get(NEMO_ENV_VARNAME_VERSION, None)

        if version is None:
            if trainer.is_slurm_managing_tasks:
                logging.warning(
                    "Running on a slurm cluster. exp_manager will not add a version number."
                )
                version = ""
            elif is_global_rank_zero():
                if use_datetime_version:
                    version = time.strftime('%Y-%m-%d_%H-%M-%S')
                else:
                    tensorboard_logger = TensorBoardLogger(
                        save_dir=Path(_exp_dir), name=name, version=version)
                    version = f"version_{tensorboard_logger.version}"
                os.environ[NEMO_ENV_VARNAME_VERSION] = version

    log_dir = Path(_exp_dir) / Path(str(name)) / Path(str(version))
    return log_dir, str(_exp_dir), name, version
Esempio n. 28
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Convert to Hydra 1.0 compatible DictConfig
        cfg = model_utils.convert_model_config_to_dict_config(cfg)
        cfg = model_utils.maybe_update_config_version(cfg)

        # Setup normalizer
        self.normalizer = None
        self.text_normalizer_call = None
        self.text_normalizer_call_kwargs = {}
        self._setup_normalizer(cfg)

        self.learn_alignment = cfg.get("learn_alignment", False)

        # Setup vocabulary (=tokenizer) and input_fft_kwargs (supported only with self.learn_alignment=True)
        input_fft_kwargs = {}
        if self.learn_alignment:
            self.vocab = None
            self.ds_class_name = cfg.train_ds.dataset._target_.split(".")[-1]

            if self.ds_class_name == "TTSDataset":
                self._setup_tokenizer(cfg)
                assert self.vocab is not None
                input_fft_kwargs["n_embed"] = len(self.vocab.tokens)
                input_fft_kwargs["padding_idx"] = self.vocab.pad
            elif self.ds_class_name == "AudioToCharWithPriorAndPitchDataset":
                logging.warning(
                    "AudioToCharWithPriorAndPitchDataset class has been deprecated. No support for"
                    " training or finetuning. Only inference is supported."
                )
                tokenizer_conf = self._get_default_text_tokenizer_conf()
                self._setup_tokenizer(tokenizer_conf)
                assert self.vocab is not None
                input_fft_kwargs["n_embed"] = len(self.vocab.tokens)
                input_fft_kwargs["padding_idx"] = self.vocab.pad
            else:
                raise ValueError(f"Unknown dataset class: {self.ds_class_name}")

        self._parser = None
        self._tb_logger = None
        super().__init__(cfg=cfg, trainer=trainer)

        self.bin_loss_warmup_epochs = cfg.get("bin_loss_warmup_epochs", 100)
        self.log_train_images = False

        loss_scale = 0.1 if self.learn_alignment else 1.0
        dur_loss_scale = loss_scale
        pitch_loss_scale = loss_scale
        if "dur_loss_scale" in cfg:
            dur_loss_scale = cfg.dur_loss_scale
        if "pitch_loss_scale" in cfg:
            pitch_loss_scale = cfg.pitch_loss_scale

        self.mel_loss = MelLoss()
        self.pitch_loss = PitchLoss(loss_scale=pitch_loss_scale)
        self.duration_loss = DurationLoss(loss_scale=dur_loss_scale)

        self.aligner = None
        if self.learn_alignment:
            self.aligner = instantiate(self._cfg.alignment_module)
            self.forward_sum_loss = ForwardSumLoss()
            self.bin_loss = BinLoss()

        self.preprocessor = instantiate(self._cfg.preprocessor)
        input_fft = instantiate(self._cfg.input_fft, **input_fft_kwargs)
        output_fft = instantiate(self._cfg.output_fft)
        duration_predictor = instantiate(self._cfg.duration_predictor)
        pitch_predictor = instantiate(self._cfg.pitch_predictor)

        self.fastpitch = FastPitchModule(
            input_fft,
            output_fft,
            duration_predictor,
            pitch_predictor,
            self.aligner,
            cfg.n_speakers,
            cfg.symbols_embedding_dim,
            cfg.pitch_embedding_kernel_size,
            cfg.n_mel_channels,
        )
        self._input_types = self._output_types = None
Esempio n. 29
0
import pickle as pkl
from argparse import ArgumentParser
from collections import OrderedDict
from typing import Dict

import numpy as np
import torch
from build_index import load_model
from omegaconf import DictConfig, OmegaConf

from nemo.utils import logging

try:
    import faiss
except ModuleNotFoundError:
    logging.warning(
        "Faiss is required for building the index. Please install faiss-gpu")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_query_embedding(query, model):
    """Use entity linking encoder to get embedding for index query"""
    model_input = model.tokenizer(
        query,
        add_special_tokens=True,
        padding=True,
        truncation=True,
        max_length=512,
        return_token_type_ids=True,
        return_attention_mask=True,
    )
Esempio n. 30
0
    def _setup_dataloader_from_config(self, config: Optional[Dict]):
        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        shuffle = config['shuffle']

        # Instantiate tarred dataset loader or normal dataset loader
        if config.get('is_tarred', False):
            if ('tarred_audio_filepaths' in config
                    and config['tarred_audio_filepaths'] is None) or (
                        'manifest_filepath' in config
                        and config['manifest_filepath'] is None):
                logging.warning(
                    "Could not load dataset as `manifest_filepath` was None or "
                    f"`tarred_audio_filepaths` is None. Provided config : {config}"
                )
                return None

            shuffle_n = config.get('shuffle_n', 4 * config['batch_size'])
            dataset = TarredAudioToCharDataset(
                audio_tar_filepaths=config['tarred_audio_filepaths'],
                manifest_filepath=config['manifest_filepath'],
                labels=config['labels'],
                sample_rate=config['sample_rate'],
                int_values=config.get('int_values', False),
                augmentor=augmentor,
                shuffle_n=shuffle_n,
                max_duration=config.get('max_duration', None),
                min_duration=config.get('min_duration', None),
                max_utts=config.get('max_utts', 0),
                blank_index=config.get('blank_index', -1),
                unk_index=config.get('unk_index', -1),
                normalize=config.get('normalize_transcripts', False),
                trim=config.get('trim_silence', True),
                parser=config.get('parser', 'en'),
                add_misc=config.get('add_misc', False),
                global_rank=self.global_rank,
                world_size=self.world_size,
            )
            shuffle = False
        else:
            if 'manifest_filepath' in config and config[
                    'manifest_filepath'] is None:
                logging.warning(
                    f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}"
                )
                return None

            dataset = AudioToCharDataset(
                manifest_filepath=config['manifest_filepath'],
                labels=config['labels'],
                sample_rate=config['sample_rate'],
                int_values=config.get('int_values', False),
                augmentor=augmentor,
                max_duration=config.get('max_duration', None),
                min_duration=config.get('min_duration', None),
                max_utts=config.get('max_utts', 0),
                blank_index=config.get('blank_index', -1),
                unk_index=config.get('unk_index', -1),
                normalize=config.get('normalize_transcripts', False),
                trim=config.get('trim_silence', True),
                load_audio=config.get('load_audio', True),
                parser=config.get('parser', 'en'),
                add_misc=config.get('add_misc', False),
            )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=config['batch_size'],
            collate_fn=dataset.collate_fn,
            drop_last=config.get('drop_last', False),
            shuffle=shuffle,
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )