Пример #1
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        self.global_rank = 0
        self.world_size = 0
        if trainer is not None:
            self.global_rank = (trainer.node_rank *
                                trainer.num_gpus) + trainer.local_rank
            self.world_size = trainer.num_nodes * trainer.num_gpus

        super().__init__(cfg=cfg, trainer=trainer)
        self.preprocessor = EncDecCTCModel.from_config_dict(
            self._cfg.preprocessor)
        self.encoder = EncDecCTCModel.from_config_dict(self._cfg.encoder)
        self.decoder = EncDecCTCModel.from_config_dict(self._cfg.decoder)
        self.loss = CTCLoss(num_classes=self.decoder.num_classes_with_blank -
                            1,
                            zero_infinity=True)
        if hasattr(self._cfg,
                   'spec_augment') and self._cfg.spec_augment is not None:
            self.spec_augmentation = EncDecCTCModel.from_config_dict(
                self._cfg.spec_augment)
        else:
            self.spec_augmentation = None

        # Setup metric objects
        self._wer = WER(vocabulary=self.decoder.vocabulary,
                        batch_dim_index=0,
                        use_cer=False,
                        ctc_decode=True)
Пример #2
0
    def test_wer_metric_return_hypothesis(self, batch_dim_index, test_wer_bpe):
        wer = WER(vocabulary=self.vocabulary.copy(),
                  batch_dim_index=batch_dim_index,
                  use_cer=False,
                  ctc_decode=True)

        tensor = self.__string_to_ctc_tensor('cat', test_wer_bpe).int()
        if batch_dim_index > 0:
            tensor.transpose_(0, 1)

        # pass batchsize 1 tensor, get back list of length 1 Hypothesis
        hyp = wer.ctc_decoder_predictions_tensor(tensor,
                                                 return_hypotheses=True)
        hyp = hyp[0]
        assert isinstance(hyp, Hypothesis)

        assert hyp.y_sequence is None
        assert hyp.score == -1.0
        assert hyp.text == 'cat'
        assert hyp.alignments == [3, 1, 20]
        assert hyp.length == 0

        length = torch.tensor([tensor.shape[1 - batch_dim_index]],
                              dtype=torch.long)

        # pass batchsize 1 tensor, get back list of length 1 Hypothesis [add length info]
        hyp = wer.ctc_decoder_predictions_tensor(tensor,
                                                 predictions_len=length,
                                                 return_hypotheses=True)
        hyp = hyp[0]
        assert isinstance(hyp, Hypothesis)
        assert hyp.length == 3
Пример #3
0
def inference(config_path: str, test_manifest: str, checkpoint: str = None):
    """
    Inference for formated data
    :param config_path:
    :param test_manifest:
    :param checkpoint:
    :return:
    """
    yaml = YAML(typ='safe')
    with open(config_path) as f:
        params = yaml.load(f)

    params['model']['validation_ds']['manifest_filepath'] = test_manifest
    # Bigger batch-size = bigger throughput
    params['model']['validation_ds']['batch_size'] = 1

    # Setup the test data loader and make sure the model is on GPU
    # asr_model.restore_from(restore_path=str(WORK_DIR / 'checkpoint.nemo'))
    asr_model = nemo_asr.models.EncDecCTCModel.load_from_checkpoint(
        checkpoint_path=checkpoint)
    asr_model.setup_test_data(
        test_data_config=params['model']['validation_ds'])

    asr_model.cuda()

    wer = WER(vocabulary=asr_model.decoder.vocabulary)

    for test_batch in asr_model.test_dataloader():
        test_batch = [x.cuda() for x in test_batch]
        log_probs, encoded_len, greedy_predictions = asr_model(
            input_signal=test_batch[0], input_signal_length=test_batch[1])

        pred = wer.ctc_decoder_predictions_tensor(greedy_predictions)
        trans = wer.ctc_decoder_predictions_tensor(test_batch[2])
        print("original: {}, prediction: {}".format(trans, pred))
Пример #4
0
def predict():
    
    files = ['description.wav']
    '''
    for fname, transcription in zip(files, quartznet.transcribe(paths2audio_files=files)):
        print(f"Audio in {fname} was recognized as: {transcription}")
    '''
    with tempfile.TemporaryDirectory() as tmpdir:
        with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp:
            for audio_file in files:
                entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': 'nothing'}
                fp.write(json.dumps(entry) + '\n')

        config = {'paths2audio_files': files, 'batch_size': 4, 'temp_dir': tmpdir}
        temporary_datalayer = setup_transcribe_dataloader(config, quartznet.decoder.vocabulary)
        for test_batch in temporary_datalayer:
            processed_signal, processed_signal_len = quartznet.preprocessor(
                input_signal=test_batch[0].to(quartznet.device), length=test_batch[1].to(quartznet.device)
            )
            ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(processed_signal),}
            ologits = ort_session.run(None, ort_inputs)
            alogits = np.asarray(ologits)
            logits = torch.from_numpy(alogits[0])
            greedy_predictions = logits.argmax(dim=-1, keepdim=False)
            wer = WER(vocabulary=quartznet.decoder.vocabulary, batch_dim_index=0, use_cer=False, ctc_decode=True)
            hypotheses = wer.ctc_decoder_predictions_tensor(greedy_predictions)
            print(hypotheses)
            break

    return jsonify({'prediccion' : hypotheses})
Пример #5
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        self.global_rank = 0
        self.world_size = 1
        self.local_rank = 0
        if trainer is not None:
            self.global_rank = (trainer.node_rank *
                                trainer.num_gpus) + trainer.local_rank
            self.world_size = trainer.num_nodes * trainer.num_gpus
            self.local_rank = trainer.local_rank

        super().__init__(cfg=cfg, trainer=trainer)
        self.preprocessor = EncDecCTCModel.from_config_dict(
            self._cfg.preprocessor)
        self.encoder = EncDecCTCModel.from_config_dict(self._cfg.encoder)

        with open_dict(self._cfg):
            if "params" in self._cfg.decoder:
                if "feat_in" not in self._cfg.decoder.params or (
                        not self._cfg.decoder.params.feat_in
                        and hasattr(self.encoder, '_feat_out')):
                    self._cfg.decoder.params.feat_in = self.encoder._feat_out
                if "feat_in" not in self._cfg.decoder.params or not self._cfg.decoder.params.feat_in:
                    raise ValueError(
                        "param feat_in of the decoder's config is not set!")
            else:
                if "feat_in" not in self._cfg.decoder or (
                        not self._cfg.decoder.feat_in
                        and hasattr(self.encoder, '_feat_out')):
                    self._cfg.decoder.feat_in = self.encoder._feat_out
                if "feat_in" not in self._cfg.decoder or not self._cfg.decoder.feat_in:
                    raise ValueError(
                        "param feat_in of the decoder's config is not set!")

        self.decoder = EncDecCTCModel.from_config_dict(self._cfg.decoder)

        self.loss = CTCLoss(
            num_classes=self.decoder.num_classes_with_blank - 1,
            zero_infinity=True,
            reduction=self._cfg.get("ctc_reduction", "mean_batch"),
        )

        if hasattr(self._cfg,
                   'spec_augment') and self._cfg.spec_augment is not None:
            self.spec_augmentation = EncDecCTCModel.from_config_dict(
                self._cfg.spec_augment)
        else:
            self.spec_augmentation = None

        # Setup metric objects
        self._wer = WER(
            vocabulary=self.decoder.vocabulary,
            batch_dim_index=0,
            use_cer=self._cfg.get('use_cer', False),
            ctc_decode=True,
            dist_sync_on_step=True,
            log_prediction=self._cfg.get("log_prediction", False),
        )
Пример #6
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0
        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.world_size

        super().__init__(cfg=cfg, trainer=trainer)
        self.preprocessor = EncDecCTCModel.from_config_dict(
            self._cfg.preprocessor)
        self.encoder = EncDecCTCModel.from_config_dict(self._cfg.encoder)

        with open_dict(self._cfg):
            if "feat_in" not in self._cfg.decoder or (
                    not self._cfg.decoder.feat_in
                    and hasattr(self.encoder, '_feat_out')):
                self._cfg.decoder.feat_in = self.encoder._feat_out
            if "feat_in" not in self._cfg.decoder or not self._cfg.decoder.feat_in:
                raise ValueError(
                    "param feat_in of the decoder's config is not set!")

            if self.cfg.decoder.num_classes < 1 and self.cfg.decoder.vocabulary is not None:
                logging.info(
                    "\nReplacing placeholder number of classes ({}) with actual number of classes - {}"
                    .format(self.cfg.decoder.num_classes,
                            len(self.cfg.decoder.vocabulary)))
                cfg.decoder["num_classes"] = len(self.cfg.decoder.vocabulary)

        self.decoder = EncDecCTCModel.from_config_dict(self._cfg.decoder)

        self.loss = CTCLoss(
            num_classes=self.decoder.num_classes_with_blank - 1,
            zero_infinity=True,
            reduction=self._cfg.get("ctc_reduction", "mean_batch"),
        )

        if hasattr(self._cfg,
                   'spec_augment') and self._cfg.spec_augment is not None:
            self.spec_augmentation = EncDecCTCModel.from_config_dict(
                self._cfg.spec_augment)
        else:
            self.spec_augmentation = None

        # Setup metric objects
        self._wer = WER(
            vocabulary=self.decoder.vocabulary,
            batch_dim_index=0,
            use_cer=self._cfg.get('use_cer', False),
            ctc_decode=True,
            dist_sync_on_step=True,
            log_prediction=self._cfg.get("log_prediction", False),
        )

        # Setup optional Optimization flags
        self.setup_optimization_flags()

        # Adapter modules setup (from ASRAdapterModelMixin)
        self.setup_adapters()
Пример #7
0
    def change_vocabulary(self, new_vocabulary: List[str]):
        """
        Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model.
        This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
        use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need
        model to learn capitalization, punctuation and/or special characters.

        If new_vocabulary == self.decoder.vocabulary then nothing will be changed.

        Args:

            new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \
            this is target alphabet.

        Returns: None

        """
        if self.decoder.vocabulary == new_vocabulary:
            logging.warning(
                f"Old {self.decoder.vocabulary} and new {new_vocabulary} match. Not changing anything."
            )
        else:
            if new_vocabulary is None or len(new_vocabulary) == 0:
                raise ValueError(
                    f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}'
                )
            decoder_config = self.decoder.to_config_dict()
            new_decoder_config = copy.deepcopy(decoder_config)
            new_decoder_config['vocabulary'] = new_vocabulary
            new_decoder_config['num_classes'] = len(new_vocabulary)

            del self.decoder
            self.decoder = EncDecCTCModel.from_config_dict(new_decoder_config)
            del self.loss
            self.loss = CTCLoss(
                num_classes=self.decoder.num_classes_with_blank - 1,
                zero_infinity=True,
                reduction=self._cfg.get("ctc_reduction", "mean_batch"),
            )
            self._wer = WER(
                vocabulary=self.decoder.vocabulary,
                batch_dim_index=0,
                use_cer=self._cfg.get('use_cer', False),
                ctc_decode=True,
                dist_sync_on_step=True,
                log_prediction=self._cfg.get("log_prediction", False),
            )

            # Update config
            OmegaConf.set_struct(self._cfg.decoder, False)
            self._cfg.decoder = new_decoder_config
            OmegaConf.set_struct(self._cfg.decoder, True)

            logging.info(
                f"Changed decoder to output to {self.decoder.vocabulary} vocabulary."
            )
Пример #8
0
    def test_wer_metric_decode(self):
        wer = WER(vocabulary=self.vocabulary, batch_dim_index=0, use_cer=False, ctc_decode=True)

        tokens = self.__string_to_ctc_tensor('cat')[0].int().numpy().tolist()
        assert tokens == [3, 1, 20]

        tokens_decoded = wer.decode_ids_to_tokens(tokens)
        assert tokens_decoded == ['c', 'a', 't']

        str_decoded = wer.decode_tokens_to_str(tokens)
        assert str_decoded == 'cat'
Пример #9
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'",
    )
    parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test")
    parser.add_argument(
        "--normalize_text", default=True, type=bool, help="Normalize transcripts or not. Set to False for non-English."
    )
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)
    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.normalize_text,
        }
    )
    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()
    labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    hypotheses = []
    references = []
    for test_batch in asr_model.test_dataloader():
        if can_gpu:
            test_batch = [x.cuda() for x in test_batch]
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=test_batch[0], input_signal_length=test_batch[1]
            )
        hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions)
        for batch_ind in range(greedy_predictions.shape[0]):
            reference = ''.join([labels_map[c] for c in test_batch[2][batch_ind].cpu().detach().numpy()])
            references.append(reference)
        del test_batch
    wer_value = word_error_rate(hypotheses=hypotheses, references=references)
    if wer_value > args.wer_tolerance:
        raise ValueError(f"Got WER of {wer_value}. It was higher than {args.wer_tolerance}")
    logging.info(f'Got WER of {wer_value}. Tolerance was {args.wer_tolerance}')
Пример #10
0
    def test_wer_metric_randomized(self, test_wer_bpe):
        """This test relies on correctness of word_error_rate function."""
        def __random_string(length):
            return ''.join(
                random.choice(''.join(self.vocabulary)) for _ in range(length))

        if test_wer_bpe:
            wer = WERBPE(deepcopy(self.char_tokenizer),
                         batch_dim_index=0,
                         use_cer=False,
                         ctc_decode=True)
        else:
            wer = WER(vocabulary=self.vocabulary,
                      batch_dim_index=0,
                      use_cer=False,
                      ctc_decode=True)

        for test_id in range(256):
            n1 = random.randint(1, 512)
            n2 = random.randint(1, 512)
            s1 = __random_string(n1)
            s2 = __random_string(n2)
            # skip empty strings as reference
            if s2.strip():
                assert (abs(
                    self.get_wer(wer,
                                 prediction=s1,
                                 reference=s2,
                                 use_tokenizer=test_wer_bpe) -
                    word_error_rate(hypotheses=[s1], references=[s2])) < 1e-6)
Пример #11
0
def generate_ref_hyps(asr_model: EncDecCTCModel, search: str, arpa: str):

    if can_gpu:
        asr_model = asr_model.cuda()
        print("USING GPU!")

    asr_model.eval()
    vocabulary = asr_model.decoder.vocabulary
    labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))])
    wer = WER(vocabulary=vocabulary)

    if search == "kenlm" or search == "beamsearch":
        arpa_file = prepare_arpa_file(arpa)
        lm_path = arpa_file if search == "kenlm" else None

        beamsearcher = nemo_asr.modules.BeamSearchDecoderWithLM(
            vocab=list(vocabulary),
            beam_width=16,
            alpha=2,
            beta=1.5,
            lm_path=lm_path,
            num_cpus=max(os.cpu_count(), 1),
            input_tensor=True,
        )

    for batch in asr_model.test_dataloader():
        # TODO(tilo): test_loader should return dict or some typed object not tuple of tensors!!
        if can_gpu:
            batch = [x.cuda() for x in batch]
        input_signal, inpsig_len, transcript, transc_len = batch
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=input_signal, input_signal_length=inpsig_len)
        if search == "greedy":
            decoded = wer.ctc_decoder_predictions_tensor(greedy_predictions)
        else:
            decoded = beamsearch_forward(beamsearcher,
                                         log_probs=log_probs,
                                         log_probs_length=encoded_len)

        for i, hyp in enumerate(decoded):
            reference = "".join([
                labels_map[c]
                for c in transcript[i].cpu().detach().numpy()[:transc_len[i]]
            ])
            yield reference, hyp
Пример #12
0
def get_transcript(manifest_path: str,
                   asr_model: nemo_asr.models.EncDecCTCModel,
                   batch_size: int) -> List[str]:
    """
    Returns transcripts for audio segments in the batch

    Args:
        manifest_path: path to the manifest for inference
        asr_model: CTC-based ASR model, for example, QuartzNet15x5Base-En
        batch_size: batch size

    Returns: hypotheses: transcripts for the audio segments
    """
    # batch inference
    try:
        from torch.cuda.amp import autocast
    except ImportError:
        from contextlib import contextmanager

        @contextmanager
        def autocast(enabled=None):
            yield

    torch.set_grad_enabled(False)
    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': manifest_path,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': batch_size,
            'normalize_transcripts': False,
        })
    asr_model.eval()
    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    hypotheses = []
    for test_batch in asr_model.test_dataloader():
        if torch.cuda.is_available():
            test_batch = [x.cuda() for x in test_batch]
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=test_batch[0], input_signal_length=test_batch[1])
        hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions)
        del test_batch
        torch.cuda.empty_cache()
    return hypotheses
Пример #13
0
    def test_wer_metric_simple(self):
        wer = WER(vocabulary=self.vocabulary, batch_dim_index=0, use_cer=False, ctc_decode=True)

        assert self.get_wer(wer, 'cat', 'cot') == 1.0
        assert self.get_wer(wer, 'gpu', 'g p u') == 1.0
        assert self.get_wer(wer, 'g p u', 'gpu') == 3.0
        assert self.get_wer(wer, 'ducati motorcycle', 'motorcycle') == 1.0
        assert self.get_wer(wer, 'ducati motorcycle', 'ducuti motorcycle') == 0.5
        assert self.get_wer(wer, 'a f c', 'a b c') == 1.0 / 3.0
Пример #14
0
    def test_wer_metric_randomized(self):
        """This test relies on correctness of word_error_rate function."""

        def __randomString(N):
            return ''.join(random.choice(''.join(self.vocabulary)) for i in range(N))

        wer = WER(vocabulary=self.vocabulary, batch_dim_index=0, use_cer=False, ctc_decode=True)

        for test_id in range(256):
            n1 = random.randint(1, 512)
            n2 = random.randint(1, 512)
            s1 = __randomString(n1)
            s2 = __randomString(n2)
            # Floating-point math doesn't seem to be an issue here. Leaving as ==
            assert self.get_wer(wer, prediction=s1, reference=s2) == word_error_rate(hypotheses=[s1], references=[s2])
Пример #15
0
    def test_wer_metric_decode(self, test_wer_bpe):
        if test_wer_bpe:
            wer = WERBPE(self.char_tokenizer,
                         batch_dim_index=0,
                         use_cer=False,
                         ctc_decode=True)
        else:
            wer = WER(vocabulary=self.vocabulary.copy(),
                      batch_dim_index=0,
                      use_cer=False,
                      ctc_decode=True)

        tokens = self.__string_to_ctc_tensor(
            'cat', use_tokenizer=test_wer_bpe)[0].int().numpy().tolist()
        assert tokens == [3, 1, 20]

        tokens_decoded = wer.decode_ids_to_tokens(tokens)
        assert tokens_decoded == ['c', 'a', 't']

        str_decoded = wer.decode_tokens_to_str(tokens)
        assert str_decoded == 'cat'
Пример #16
0
    def test_wer_metric_simple(self, batch_dim_index, test_wer_bpe):
        if test_wer_bpe:
            wer = WERBPE(self.char_tokenizer,
                         batch_dim_index,
                         use_cer=False,
                         ctc_decode=True)
        else:
            wer = WER(vocabulary=self.vocabulary,
                      batch_dim_index=batch_dim_index,
                      use_cer=False,
                      ctc_decode=True)

        assert self.get_wer(wer, 'cat', 'cot', test_wer_bpe) == 1.0
        assert self.get_wer(wer, 'gpu', 'g p u', test_wer_bpe) == 1.0
        assert self.get_wer(wer, 'g p u', 'gpu', test_wer_bpe) == 3.0
        assert self.get_wer(wer, 'ducati motorcycle', 'motorcycle',
                            test_wer_bpe) == 1.0
        assert self.get_wer(wer, 'ducati motorcycle', 'ducuti motorcycle',
                            test_wer_bpe) == 0.5
        assert abs(
            self.get_wer(wer, 'a f c', 'a b c', test_wer_bpe) -
            1.0 / 3.0) < 1e-6
Пример #17
0
class EncDecCTCModel(ASRModel, ExportableEncDecModel):
    """Base class for encoder decoder CTC-based models."""
    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        result = []
        model = PretrainedModelInfo(
            pretrained_model_name="QuartzNet15x5Base-En",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet15x5Base-En.nemo",
            description=
            "QuartzNet15x5 model trained on six datasets: LibriSpeech, Mozilla Common Voice (validated clips from en_1488h_2019-12-10), WSJ, Fisher, Switchboard, and NSC Singapore English. It was trained with Apex/Amp optimization level O1 for 600 epochs. The model achieves a WER of 3.79% on LibriSpeech dev-clean, and a WER of 10.05% on dev-other.",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="QuartzNet15x5Base-Zh",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet15x5Base-Zh.nemo",
            description=
            "QuartzNet15x5 model trained on ai-shell2 Mandarin Chinese dataset.",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="QuartzNet5x5LS-En",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet5x5LS-En.nemo",
            description=
            "QuartzNet5x5 model trained on LibriSpeech dataset only. The model achieves a WER of 5.37% on LibriSpeech dev-clean, and a WER of 15.69% on dev-other.",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="QuartzNet15x5NR-En",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet15x5NR-En.nemo",
            description=
            "QuartzNet15x5Base-En was finetuned with RIR and noise augmentation to make it more robust to noise. This model should be preferred for noisy speech transcription. This model achieves a WER of 3.96% on LibriSpeech dev-clean and a WER of 10.14% on dev-other.",
        )
        result.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="Jasper10x5Dr-En",
            location=
            "https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/Jasper10x5Dr-En.nemo",
            description=
            "JasperNet10x5Dr model trained on six datasets: LibriSpeech, Mozilla Common Voice (validated clips from en_1488h_2019-12-10), WSJ, Fisher, Switchboard, and NSC Singapore English. It was trained with Apex/Amp optimization level O1. The model achieves a WER of 3.37% on LibriSpeech dev-clean, 9.81% on dev-other.",
        )
        result.append(model)
        return result

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        self.global_rank = 0
        self.world_size = 1
        self.local_rank = 0
        if trainer is not None:
            self.global_rank = (trainer.node_rank *
                                trainer.num_gpus) + trainer.local_rank
            self.world_size = trainer.num_nodes * trainer.num_gpus
            self.local_rank = trainer.local_rank

        super().__init__(cfg=cfg, trainer=trainer)
        self.preprocessor = EncDecCTCModel.from_config_dict(
            self._cfg.preprocessor)
        self.encoder = EncDecCTCModel.from_config_dict(self._cfg.encoder)

        with open_dict(self._cfg):
            if "feat_in" not in self._cfg.decoder or (
                    not self._cfg.decoder.feat_in
                    and hasattr(self.encoder, '_feat_out')):
                self._cfg.decoder.feat_in = self.encoder._feat_out
            if "feat_in" not in self._cfg.decoder or not self._cfg.decoder.feat_in:
                raise ValueError(
                    "param feat_in of the decoder's config is not set!")

        self.decoder = EncDecCTCModel.from_config_dict(self._cfg.decoder)

        self.loss = CTCLoss(
            num_classes=self.decoder.num_classes_with_blank - 1,
            zero_infinity=True,
            reduction=self._cfg.get("ctc_reduction", "mean_batch"),
        )

        if hasattr(self._cfg,
                   'spec_augment') and self._cfg.spec_augment is not None:
            self.spec_augmentation = EncDecCTCModel.from_config_dict(
                self._cfg.spec_augment)
        else:
            self.spec_augmentation = None

        # Setup metric objects
        self._wer = WER(
            vocabulary=self.decoder.vocabulary,
            batch_dim_index=0,
            use_cer=self._cfg.get('use_cer', False),
            ctc_decode=True,
            dist_sync_on_step=True,
            log_prediction=self._cfg.get("log_prediction", False),
        )

    @torch.no_grad()
    def transcribe(self,
                   paths2audio_files: List[str],
                   batch_size: int = 4,
                   logprobs=False) -> List[str]:
        """
        Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.

        Args:

            paths2audio_files: (a list) of paths to audio files. \
        Recommended length per file is between 5 and 25 seconds. \
        But it is possible to pass a few hours long file if enough GPU memory is available.
            batch_size: (int) batch size to use during inference. \
        Bigger will result in better throughput performance but would use more memory.
            logprobs: (bool) pass True to get log probabilities instead of transcripts.

        Returns:

            A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files
        """
        if paths2audio_files is None or len(paths2audio_files) == 0:
            return {}
        # We will store transcriptions here
        hypotheses = []
        # Model's mode and device
        mode = self.training
        device = next(self.parameters()).device
        dither_value = self.preprocessor.featurizer.dither
        pad_to_value = self.preprocessor.featurizer.pad_to

        try:
            self.preprocessor.featurizer.dither = 0.0
            self.preprocessor.featurizer.pad_to = 0
            # Switch model to evaluation mode
            self.eval()
            logging_level = logging.get_verbosity()
            logging.set_verbosity(logging.WARNING)
            # Work in tmp directory - will store manifest file there
            with tempfile.TemporaryDirectory() as tmpdir:
                with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp:
                    for audio_file in paths2audio_files:
                        entry = {
                            'audio_filepath': audio_file,
                            'duration': 100000,
                            'text': 'nothing'
                        }
                        fp.write(json.dumps(entry) + '\n')

                config = {
                    'paths2audio_files': paths2audio_files,
                    'batch_size': batch_size,
                    'temp_dir': tmpdir
                }

                temporary_datalayer = self._setup_transcribe_dataloader(config)
                for test_batch in temporary_datalayer:
                    logits, logits_len, greedy_predictions = self.forward(
                        input_signal=test_batch[0].to(device),
                        input_signal_length=test_batch[1].to(device))
                    if logprobs:
                        # dump log probs per file
                        for idx in range(logits.shape[0]):
                            hypotheses.append(logits[idx][:logits_len[idx]])
                    else:
                        hypotheses += self._wer.ctc_decoder_predictions_tensor(
                            greedy_predictions, predictions_len=logits_len)
                    del test_batch
        finally:
            # set mode back to its original value
            self.train(mode=mode)
            self.preprocessor.featurizer.dither = dither_value
            self.preprocessor.featurizer.pad_to = pad_to_value
            logging.set_verbosity(logging_level)
        return hypotheses

    def change_vocabulary(self, new_vocabulary: List[str]):
        """
        Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model.
        This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
        use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need
        model to learn capitalization, punctuation and/or special characters.

        If new_vocabulary == self.decoder.vocabulary then nothing will be changed.

        Args:

            new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \
            this is target alphabet.

        Returns: None

        """
        if self.decoder.vocabulary == new_vocabulary:
            logging.warning(
                f"Old {self.decoder.vocabulary} and new {new_vocabulary} match. Not changing anything."
            )
        else:
            if new_vocabulary is None or len(new_vocabulary) == 0:
                raise ValueError(
                    f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}'
                )
            decoder_config = self.decoder.to_config_dict()
            new_decoder_config = copy.deepcopy(decoder_config)
            new_decoder_config['vocabulary'] = new_vocabulary
            new_decoder_config['num_classes'] = len(new_vocabulary)

            del self.decoder
            self.decoder = EncDecCTCModel.from_config_dict(new_decoder_config)
            del self.loss
            self.loss = CTCLoss(
                num_classes=self.decoder.num_classes_with_blank - 1,
                zero_infinity=True,
                reduction=self._cfg.get("ctc_reduction", "mean_batch"),
            )
            self._wer = WER(
                vocabulary=self.decoder.vocabulary,
                batch_dim_index=0,
                use_cer=self._cfg.get('use_cer', False),
                ctc_decode=True,
                dist_sync_on_step=True,
                log_prediction=self._cfg.get("log_prediction", False),
            )

            # Update config
            OmegaConf.set_struct(self._cfg.decoder, False)
            self._cfg.decoder = new_decoder_config
            OmegaConf.set_struct(self._cfg.decoder, True)

            logging.info(
                f"Changed decoder to output to {self.decoder.vocabulary} vocabulary."
            )

    def _setup_dataloader_from_config(self, config: Optional[Dict]):
        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        shuffle = config['shuffle']
        device = 'gpu' if torch.cuda.is_available() else 'cpu'
        if config.get('use_dali', False):
            device_id = self.local_rank if device == 'gpu' else None
            dataset = audio_to_text_dataset.get_dali_char_dataset(
                config=config,
                shuffle=shuffle,
                device_id=device_id,
                global_rank=self.global_rank,
                world_size=self.world_size,
                preprocessor_cfg=self._cfg.preprocessor,
            )
            return dataset

        # Instantiate tarred dataset loader or normal dataset loader
        if config.get('is_tarred', False):
            if ('tarred_audio_filepaths' in config
                    and config['tarred_audio_filepaths'] is None) or (
                        'manifest_filepath' in config
                        and config['manifest_filepath'] is None):
                logging.warning(
                    "Could not load dataset as `manifest_filepath` was None or "
                    f"`tarred_audio_filepaths` is None. Provided config : {config}"
                )
                return None

            shuffle_n = config.get('shuffle_n', 4 *
                                   config['batch_size']) if shuffle else 0
            dataset = audio_to_text_dataset.get_tarred_char_dataset(
                config=config,
                shuffle_n=shuffle_n,
                global_rank=self.global_rank,
                world_size=self.world_size,
                augmentor=augmentor,
            )
            shuffle = False
        else:
            if 'manifest_filepath' in config and config[
                    'manifest_filepath'] is None:
                logging.warning(
                    f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}"
                )
                return None

            dataset = audio_to_text_dataset.get_char_dataset(
                config=config, augmentor=augmentor)

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=config['batch_size'],
            collate_fn=dataset.collate_fn,
            drop_last=config.get('drop_last', False),
            shuffle=shuffle,
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )

    def setup_training_data(self, train_data_config: Optional[Union[DictConfig,
                                                                    Dict]]):
        if 'shuffle' not in train_data_config:
            train_data_config['shuffle'] = True

        # preserve config
        self._update_dataset_config(dataset_name='train',
                                    config=train_data_config)

        self._train_dl = self._setup_dataloader_from_config(
            config=train_data_config)

        # Need to set this because if using an IterableDataset, the length of the dataloader is the total number
        # of samples rather than the number of batches, and this messes up the tqdm progress bar.
        # So we set the number of steps manually (to the correct number) to fix this.
        if 'is_tarred' in train_data_config and train_data_config['is_tarred']:
            # We also need to check if limit_train_batches is already set.
            # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches,
            # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0).
            if isinstance(self._trainer.limit_train_batches, float):
                self._trainer.limit_train_batches = int(
                    self._trainer.limit_train_batches * ceil(
                        (len(self._train_dl.dataset) / self.world_size) /
                        train_data_config['batch_size']))

    def setup_validation_data(self, val_data_config: Optional[Union[DictConfig,
                                                                    Dict]]):
        if 'shuffle' not in val_data_config:
            val_data_config['shuffle'] = False

        # preserve config
        self._update_dataset_config(dataset_name='validation',
                                    config=val_data_config)

        self._validation_dl = self._setup_dataloader_from_config(
            config=val_data_config)

    def setup_test_data(self, test_data_config: Optional[Union[DictConfig,
                                                               Dict]]):
        if 'shuffle' not in test_data_config:
            test_data_config['shuffle'] = False

        # preserve config
        self._update_dataset_config(dataset_name='test',
                                    config=test_data_config)

        self._test_dl = self._setup_dataloader_from_config(
            config=test_data_config)

    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        if hasattr(self.preprocessor, '_sample_rate'):
            input_signal_eltype = AudioSignal(
                freq=self.preprocessor._sample_rate)
        else:
            input_signal_eltype = AudioSignal()
        return {
            "input_signal":
            NeuralType(('B', 'T'), input_signal_eltype, optional=True),
            "input_signal_length":
            NeuralType(tuple('B'), LengthsType(), optional=True),
            "processed_signal":
            NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
            "processed_signal_length":
            NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "outputs": NeuralType(('B', 'T', 'D'), LogprobsType()),
            "encoded_lengths": NeuralType(tuple('B'), LengthsType()),
            "greedy_predictions": NeuralType(('B', 'T'), LabelsType()),
        }

    @typecheck()
    def forward(self,
                input_signal=None,
                input_signal_length=None,
                processed_signal=None,
                processed_signal_length=None):
        has_input_signal = input_signal is not None and input_signal_length is not None
        has_processed_signal = processed_signal is not None and processed_signal_length is not None
        if (has_input_signal ^ has_processed_signal) == False:
            raise ValueError(
                f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive "
                " with ``processed_signal`` and ``processed_signal_len`` arguments."
            )

        if not has_processed_signal:
            processed_signal, processed_signal_length = self.preprocessor(
                input_signal=input_signal,
                length=input_signal_length,
            )

        if self.spec_augmentation is not None and self.training:
            processed_signal = self.spec_augmentation(
                input_spec=processed_signal)

        encoded, encoded_len = self.encoder(audio_signal=processed_signal,
                                            length=processed_signal_length)
        log_probs = self.decoder(encoder_output=encoded)
        greedy_predictions = log_probs.argmax(dim=-1, keepdim=False)

        return log_probs, encoded_len, greedy_predictions

    # PTL-specific methods
    def training_step(self, batch, batch_nb):
        signal, signal_len, transcript, transcript_len = batch
        if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
            log_probs, encoded_len, predictions = self.forward(
                processed_signal=signal, processed_signal_length=signal_len)
        else:
            log_probs, encoded_len, predictions = self.forward(
                input_signal=signal, input_signal_length=signal_len)

        loss_value = self.loss(log_probs=log_probs,
                               targets=transcript,
                               input_lengths=encoded_len,
                               target_lengths=transcript_len)

        tensorboard_logs = {
            'train_loss': loss_value,
            'learning_rate': self._optimizer.param_groups[0]['lr']
        }

        if hasattr(self, '_trainer') and self._trainer is not None:
            log_every_n_steps = self._trainer.log_every_n_steps
        else:
            log_every_n_steps = 1

        if (batch_nb + 1) % log_every_n_steps == 0:
            self._wer.update(
                predictions=predictions,
                targets=transcript,
                target_lengths=transcript_len,
                predictions_lengths=encoded_len,
            )
            wer, _, _ = self._wer.compute()
            tensorboard_logs.update({'training_batch_wer': wer})

        return {'loss': loss_value, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        signal, signal_len, transcript, transcript_len = batch
        if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
            log_probs, encoded_len, predictions = self.forward(
                processed_signal=signal, processed_signal_length=signal_len)
        else:
            log_probs, encoded_len, predictions = self.forward(
                input_signal=signal, input_signal_length=signal_len)

        loss_value = self.loss(log_probs=log_probs,
                               targets=transcript,
                               input_lengths=encoded_len,
                               target_lengths=transcript_len)
        self._wer.update(predictions=predictions,
                         targets=transcript,
                         target_lengths=transcript_len,
                         predictions_lengths=encoded_len)
        wer, wer_num, wer_denom = self._wer.compute()
        return {
            'val_loss': loss_value,
            'val_wer_num': wer_num,
            'val_wer_denom': wer_denom,
            'val_wer': wer,
        }

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        logs = self.validation_step(batch,
                                    batch_idx,
                                    dataloader_idx=dataloader_idx)
        test_logs = {
            'test_loss': logs['val_loss'],
            'test_wer_num': logs['val_wer_num'],
            'test_wer_denom': logs['val_wer_denom'],
            'test_wer': logs['val_wer'],
        }
        return test_logs

    def test_dataloader(self):
        if self._test_dl is not None:
            return self._test_dl

    def _setup_transcribe_dataloader(
            self, config: Dict) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a temporary data loader which wraps the provided audio file.

        Args:
            config: A python dictionary which contains the following keys:
            paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \
                Recommended length per file is between 5 and 25 seconds.
            batch_size: (int) batch size to use during inference. \
                Bigger will result in better throughput performance but would use more memory.
            temp_dir: (str) A temporary directory where the audio manifest is temporarily
                stored.

        Returns:
            A pytorch DataLoader for the given audio file(s).
        """
        dl_config = {
            'manifest_filepath':
            os.path.join(config['temp_dir'], 'manifest.json'),
            'sample_rate':
            self.preprocessor._sample_rate,
            'labels':
            self.decoder.vocabulary,
            'batch_size':
            min(config['batch_size'], len(config['paths2audio_files'])),
            'trim_silence':
            True,
            'shuffle':
            False,
        }

        temporary_datalayer = self._setup_dataloader_from_config(
            config=DictConfig(dl_config))
        return temporary_datalayer
Пример #18
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default="QuartzNet15x5Base-En",
        required=True,
        help="Pass: '******'",
    )
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--wer_target",
                        type=float,
                        default=None,
                        help="used by test")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--wer_tolerance",
                        type=float,
                        default=1.0,
                        help="used by test")
    parser.add_argument(
        "--dont_normalize_text",
        default=False,
        action='store_false',
        help="Turn off trasnscript normalization. Recommended for non-English.",
    )
    parser.add_argument(
        "--use_cer",
        default=False,
        action='store_true',
        help="Use Character Error Rate as the evaluation metric")
    parser.add_argument('--sensitivity',
                        action="store_true",
                        help="Perform sensitivity analysis")
    parser.add_argument('--onnx', action="store_true", help="Export to ONNX")
    parser.add_argument('--quant-disable-keyword',
                        type=str,
                        nargs='+',
                        help='disable quantizers by keyword')
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    quant_modules.initialize()

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model_cfg = EncDecCTCModelBPE.restore_from(
            restore_path=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModelBPE.restore_from(
            restore_path=args.asr_model, override_config_path=asr_model_cfg)

    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model_cfg = EncDecCTCModelBPE.from_pretrained(
            model_name=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModelBPE.from_pretrained(
            model_name=args.asr_model, override_config_path=asr_model_cfg)
    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.dont_normalize_text,
        })
    asr_model.preprocessor.featurizer.dither = 0.0
    asr_model.preprocessor.featurizer.pad_to = 0
    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()

    if args.quant_disable_keyword:
        for name, module in asr_model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                for keyword in args.quant_disable_keyword:
                    if keyword in name:
                        logging.warning(F"Disable {name}")
                        module.disable()

    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary, use_cer=args.use_cer)
    wer_quant = evaluate(asr_model, labels_map, wer)
    logging.info(f'Got WER of {wer_quant}. Tolerance was {args.wer_tolerance}')

    if args.sensitivity:
        if wer_quant < args.wer_tolerance:
            logging.info(
                "Tolerance is already met. Skip sensitivity analyasis.")
            return
        quant_layer_names = []
        for name, module in asr_model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                module.disable()
                layer_name = name.replace("._input_quantizer",
                                          "").replace("._weight_quantizer", "")
                if layer_name not in quant_layer_names:
                    quant_layer_names.append(layer_name)
        logging.info(F"{len(quant_layer_names)} quantized layers found.")

        # Build sensitivity profile
        quant_layer_sensitivity = {}
        for i, quant_layer in enumerate(quant_layer_names):
            logging.info(F"Enable {quant_layer}")
            for name, module in asr_model.named_modules():
                if isinstance(
                        module,
                        quant_nn.TensorQuantizer) and quant_layer in name:
                    module.enable()
                    logging.info(F"{name:40}: {module}")

            # Eval the model
            wer_value = evaluate(asr_model, labels_map, wer)
            logging.info(F"WER: {wer_value}")
            quant_layer_sensitivity[
                quant_layer] = args.wer_tolerance - wer_value

            for name, module in asr_model.named_modules():
                if isinstance(
                        module,
                        quant_nn.TensorQuantizer) and quant_layer in name:
                    module.disable()
                    logging.info(F"{name:40}: {module}")

        # Skip most sensitive layers until WER target is met
        for name, module in asr_model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                module.enable()
        quant_layer_sensitivity = collections.OrderedDict(
            sorted(quant_layer_sensitivity.items(), key=lambda x: x[1]))
        pprint(quant_layer_sensitivity)
        skipped_layers = []
        for quant_layer, _ in quant_layer_sensitivity.items():
            for name, module in asr_model.named_modules():
                if isinstance(module, quant_nn.TensorQuantizer):
                    if quant_layer in name:
                        logging.info(F"Disable {name}")
                        if not quant_layer in skipped_layers:
                            skipped_layers.append(quant_layer)
                        module.disable()
            wer_value = evaluate(asr_model, labels_map, wer)
            if wer_value <= args.wer_tolerance:
                logging.info(
                    F"WER tolerance {args.wer_tolerance} is met by skipping {len(skipped_layers)} sensitive layers."
                )
                print(skipped_layers)
                export_onnx(args, asr_model)
                return
        raise ValueError(
            f"WER tolerance {args.wer_tolerance} can not be met with any layer quantized!"
        )

    export_onnx(args, asr_model)
Пример #19
0
def ASR_Grade(dataset, id, key):
    try:
        from torch.cuda.amp import autocast
    except ImportError:
        from contextlib import contextmanager

        @contextmanager
        def autocast(enabled=None):
            yield

    can_gpu = torch.cuda.is_available()

    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default=model_Selected,
        required=True,
        help=f'Pass: {model_Selected}',
    )
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--wer_tolerance",
                        type=float,
                        default=1.0,
                        help="used by test")
    parser.add_argument(
        "--normalize_text",
        default=False,  # False <- we're using phonetic references
        type=bool,
        help="Normalize transcripts or not. Set to False for non-English.",
    )
    args = parser.parse_args(
        ["--dataset", dataset, "--asr_model", model_Selected])
    torch.set_grad_enabled(False)

    # Instantiate Jasper/QuartzNet models with the EncDecCTCModel class.
    asr_model = EncDecCTCModel.restore_from(model_Path)

    asr_model.setup_test_data(
        test_data_config={
            "sample_rate": 16000,
            "manifest_filepath": args.dataset,
            "labels": asr_model.decoder.vocabulary,
            "batch_size": args.batch_size,
            "normalize_transcripts": args.normalize_text,
        })
    if can_gpu:  # noqa
        asr_model = asr_model.cuda()
    asr_model.eval()
    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    hypotheses = []
    references = []
    for test_batch in asr_model.test_dataloader():
        if can_gpu:
            test_batch = [x.cuda() for x in test_batch]
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=test_batch[0], input_signal_length=test_batch[1])
        hypotheses = wer.ctc_decoder_predictions_tensor(greedy_predictions)
        for batch_ind in range(greedy_predictions.shape[0]):
            reference = key
            #reference = "".join([labels_map[c] for c in test_batch[2][batch_ind].cpu().detach().numpy()])  #debug
            print(reference)  #debug
            references.append(reference)
        del test_batch
    wer_value = word_error_rate(hypotheses=hypotheses,
                                references=references)  #cer=True

    REC = '.'
    REF = '.'
    for h, r in zip(hypotheses, references):
        print("Recognized:\t{}\nReference:\t{}\n".format(h, r))
        REC = h
        REF = r
    logging.info(f"Got PER of {wer_value}. Tolerance was {args.wer_tolerance}")

    #Score Calculation, phoneme conversion
    # divide wer_value by wer_tolerance to get the ratio of correctness (and round it)
    # then multiply by 100 to get a value above 0
    # since this give the "% wrong", subtract from 100 to get "% correct"
    # this gives a positive grade to show return to the user
    score = 100.00 - (round((wer_value / args.wer_tolerance), 4) * 100)
    if score < 0.0:
        score = 0.0
    print(score)

    #Result file creation, to be accessed by JS via 'app.py'
    Results = open(datasetPath + id + '_graded.txt', 'w')
    Results.write(REC + '\n' + REF + '\n' + str(score))
    Results.close()
    return score
Пример #20
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default="QuartzNet15x5Base-En",
        required=False,
        help="Pass: '******'",
    )
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument(
        "--normalize_text",
        default=True,
        type=bool,
        help="Normalize transcripts or not. Set to False for non-English.")
    parser.add_argument(
        "--sclite_fmt",
        default="trn",
        type=str,
        help="sclite output format. Only trn and ctm are supported")
    parser.add_argument("--out_dir",
                        type=str,
                        required=True,
                        help="Destination dir for output files")
    parser.add_argument("--sctk_dir",
                        type=str,
                        required=False,
                        default="",
                        help="Path to sctk root dir")
    parser.add_argument("--glm",
                        type=str,
                        required=False,
                        default="",
                        help="Path to glm file")
    parser.add_argument("--ref_stm",
                        type=str,
                        required=False,
                        default="",
                        help="Path to glm file")
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    use_sctk = os.path.exists(args.sctk_dir)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)
    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.normalize_text,
        })
    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()
    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])

    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    hypotheses = []
    references = []
    all_log_probs = []
    for test_batch in asr_model.test_dataloader():
        if can_gpu:
            test_batch = [x.cuda() for x in test_batch]
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=test_batch[0], input_signal_length=test_batch[1])
        for r in log_probs.cpu().numpy():
            all_log_probs.append(r)
        hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions)
        for batch_ind in range(greedy_predictions.shape[0]):
            reference = ''.join([
                labels_map[c]
                for c in test_batch[2][batch_ind].cpu().detach().numpy()
            ])
            references.append(reference)
        del test_batch

    info_list = get_utt_info(args.dataset)
    hypfile = os.path.join(args.out_dir, "hyp.trn")
    reffile = os.path.join(args.out_dir, "ref.trn")
    with open(hypfile, "w") as hyp_f, open(reffile, "w") as ref_f:
        for i in range(len(hypotheses)):
            utt_id = os.path.splitext(
                os.path.basename(info_list[i]['audio_filepath']))[0]
            # rfilter in sctk likes each transcript to have a space at the beginning
            hyp_f.write(" " + hypotheses[i] + " (" + utt_id + ")" + "\n")
            ref_f.write(" " + references[i] + " (" + utt_id + ")" + "\n")

    if use_sctk:
        score_with_sctk(args.sctk_dir,
                        reffile,
                        hypfile,
                        args.out_dir,
                        glm=args.glm,
                        fmt="trn")
Пример #21
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default="QuartzNet15x5Base-En",
        required=True,
        help="Pass: '******'",
    )
    parser.add_argument(
        "--asr_onnx",
        type=str,
        default="./QuartzNet15x5Base-En-max-32.onnx",
        help="Pass: '******'",
    )
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument(
        "--dont_normalize_text",
        default=False,
        action='store_false',
        help="Turn off trasnscript normalization. Recommended for non-English.",
    )
    parser.add_argument(
        "--use_cer",
        default=False,
        action='store_true',
        help="Use Character Error Rate as the evaluation metric")
    parser.add_argument('--qat',
                        action="store_true",
                        help="Use onnx file exported from QAT tools")
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model_cfg = EncDecCTCModel.restore_from(
            restore_path=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModel.restore_from(
            restore_path=args.asr_model, override_config_path=asr_model_cfg)

    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model_cfg = EncDecCTCModel.from_pretrained(
            model_name=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModel.from_pretrained(
            model_name=args.asr_model, override_config_path=asr_model_cfg)
    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.dont_normalize_text,
        })
    asr_model.preprocessor.featurizer.dither = 0.0
    asr_model.preprocessor.featurizer.pad_to = 0
    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()
    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary, use_cer=args.use_cer)
    wer_result = evaluate(asr_model, args.asr_onnx, labels_map, wer, args.qat)
    logging.info(f'Got WER of {wer_result}.')
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'",
    )
    parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data")
    parser.add_argument("--wer_target", type=float, default=None, help="used by test")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test")
    parser.add_argument(
        "--normalize_text", default=True, type=bool, help="Normalize transcripts or not. Set to False for non-English."
    )
    parser.add_argument('--sensitivity', action="store_true", help="Perform sensitivity analysis")
    parser.add_argument('--onnx', action="store_true", help="Export to ONNX")
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    quant_modules.initialize()

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)
    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.normalize_text,
        }
    )
    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()
    labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    wer_quant = evaluate(asr_model, labels_map, wer)
    logging.info(f'Got WER of {wer_quant}. Tolerance was {args.wer_tolerance}')

    if args.sensitivity:
        if wer_quant < args.wer_tolerance:
            logging.info("Tolerance is already met. Skip sensitivity analyasis.")
            return
        quant_layer_names = []
        for name, module in asr_model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                module.disable()
                layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "")
                if layer_name not in quant_layer_names:
                    quant_layer_names.append(layer_name)
        logging.info(F"{len(quant_layer_names)} quantized layers found.")

        # Build sensitivity profile
        quant_layer_sensitivity = {}
        for i, quant_layer in enumerate(quant_layer_names):
            logging.info(F"Enable {quant_layer}")
            for name, module in asr_model.named_modules():
                if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name:
                    module.enable()
                    logging.info(F"{name:40}: {module}")

            # Eval the model
            wer_value = evaluate(asr_model, labels_map, wer)
            logging.info(F"WER: {wer_value}")
            quant_layer_sensitivity[quant_layer] = args.wer_tolerance - wer_value

            for name, module in asr_model.named_modules():
                if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name:
                    module.disable()
                    logging.info(F"{name:40}: {module}")

        # Skip most sensitive layers until WER target is met
        for name, module in asr_model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                module.enable()
        quant_layer_sensitivity = collections.OrderedDict(sorted(quant_layer_sensitivity.items(), key=lambda x: x[1]))
        pprint(quant_layer_sensitivity)
        skipped_layers = []
        for quant_layer, _ in quant_layer_sensitivity.items():
            for name, module in asr_model.named_modules():
                if isinstance(module, quant_nn.TensorQuantizer):
                    if quant_layer in name:
                        logging.info(F"Disable {name}")
                        if not quant_layer in skipped_layers:
                            skipped_layers.append(quant_layer)
                        module.disable()
            wer_value = evaluate(asr_model, labels_map, wer)
            if wer_value <= args.wer_tolerance:
                logging.info(
                    F"WER tolerance {args.wer_tolerance} is met by skipping {len(skipped_layers)} sensitive layers."
                )
                print(skipped_layers)
                return
        raise ValueError(f"WER tolerance {args.wer_tolerance} can not be met with any layer quantized!")

    if args.onnx:
        if args.asr_model.endswith("nemo"):
            onnx_name = args.asr_model.replace(".nemo", ".onnx")
        else:
            onnx_name = args.asr_model
        logging.info("Export to ", onnx_name)
        quant_nn.TensorQuantizer.use_fb_fake_quant = True
        asr_model.export(onnx_name, onnx_opset_version=13)
        quant_nn.TensorQuantizer.use_fb_fake_quant = False
Пример #23
0
class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin):
    """Base class for encoder decoder CTC-based models."""

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        results = []

        model = PretrainedModelInfo(
            pretrained_model_name="QuartzNet15x5Base-En",
            description="QuartzNet15x5 model trained on six datasets: LibriSpeech, Mozilla Common Voice (validated clips from en_1488h_2019-12-10), WSJ, Fisher, Switchboard, and NSC Singapore English. It was trained with Apex/Amp optimization level O1 for 600 epochs. The model achieves a WER of 3.79% on LibriSpeech dev-clean, and a WER of 10.05% on dev-other. Please visit https://ngc.nvidia.com/catalog/models/nvidia:nemospeechmodels for further details.",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet15x5Base-En.nemo",
        )
        results.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="stt_en_quartznet15x5",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_quartznet15x5",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_quartznet15x5/versions/1.0.0rc1/files/stt_en_quartznet15x5.nemo",
        )
        results.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="stt_en_jasper10x5dr",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_jasper10x5dr",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_jasper10x5dr/versions/1.0.0rc1/files/stt_en_jasper10x5dr.nemo",
        )
        results.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="stt_ca_quartznet15x5",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ca_quartznet15x5",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_ca_quartznet15x5/versions/1.0.0rc1/files/stt_ca_quartznet15x5.nemo",
        )
        results.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="stt_it_quartznet15x5",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_it_quartznet15x5",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_it_quartznet15x5/versions/1.0.0rc1/files/stt_it_quartznet15x5.nemo",
        )
        results.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="stt_fr_quartznet15x5",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_fr_quartznet15x5",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_fr_quartznet15x5/versions/1.0.0rc1/files/stt_fr_quartznet15x5.nemo",
        )
        results.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="stt_es_quartznet15x5",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_es_quartznet15x5",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_es_quartznet15x5/versions/1.0.0rc1/files/stt_es_quartznet15x5.nemo",
        )
        results.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="stt_de_quartznet15x5",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_de_quartznet15x5",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_de_quartznet15x5/versions/1.0.0rc1/files/stt_de_quartznet15x5.nemo",
        )
        results.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="stt_pl_quartznet15x5",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_pl_quartznet15x5",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_pl_quartznet15x5/versions/1.0.0rc1/files/stt_pl_quartznet15x5.nemo",
        )
        results.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="stt_ru_quartznet15x5",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_ru_quartznet15x5",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_ru_quartznet15x5/versions/1.0.0rc1/files/stt_ru_quartznet15x5.nemo",
        )
        results.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="stt_zh_citrinet_512",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_zh_citrinet_512",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_zh_citrinet_512/versions/1.0.0rc1/files/stt_zh_citrinet_512.nemo",
        )
        results.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="stt_zh_citrinet_1024_gamma_0_25",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_zh_citrinet_1024_gamma_0_25",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_zh_citrinet_1024_gamma_0_25/versions/1.0.0/files/stt_zh_citrinet_1024_gamma_0_25.nemo",
        )

        results.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="stt_zh_citrinet_1024_gamma_0_25",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_zh_citrinet_1024_gamma_0_25",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_zh_citrinet_1024_gamma_0_25/versions/1.0.0/files/stt_zh_citrinet_1024_gamma_0_25.nemo",
        )
        results.append(model)

        model = PretrainedModelInfo(
            pretrained_model_name="asr_talknet_aligner",
            description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:asr_talknet_aligner",
            location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/asr_talknet_aligner/versions/1.0.0rc1/files/qn5x5_libri_tts_phonemes.nemo",
        )
        results.append(model)

        return results

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0
        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.world_size

        super().__init__(cfg=cfg, trainer=trainer)
        self.preprocessor = EncDecCTCModel.from_config_dict(self._cfg.preprocessor)
        self.encoder = EncDecCTCModel.from_config_dict(self._cfg.encoder)

        with open_dict(self._cfg):
            if "feat_in" not in self._cfg.decoder or (
                not self._cfg.decoder.feat_in and hasattr(self.encoder, '_feat_out')
            ):
                self._cfg.decoder.feat_in = self.encoder._feat_out
            if "feat_in" not in self._cfg.decoder or not self._cfg.decoder.feat_in:
                raise ValueError("param feat_in of the decoder's config is not set!")

            if self.cfg.decoder.num_classes < 1 and self.cfg.decoder.vocabulary is not None:
                logging.info(
                    "\nReplacing placeholder number of classes ({}) with actual number of classes - {}".format(
                        self.cfg.decoder.num_classes, len(self.cfg.decoder.vocabulary)
                    )
                )
                cfg.decoder["num_classes"] = len(self.cfg.decoder.vocabulary)

        self.decoder = EncDecCTCModel.from_config_dict(self._cfg.decoder)

        self.loss = CTCLoss(
            num_classes=self.decoder.num_classes_with_blank - 1,
            zero_infinity=True,
            reduction=self._cfg.get("ctc_reduction", "mean_batch"),
        )

        if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None:
            self.spec_augmentation = EncDecCTCModel.from_config_dict(self._cfg.spec_augment)
        else:
            self.spec_augmentation = None

        # Setup metric objects
        self._wer = WER(
            vocabulary=self.decoder.vocabulary,
            batch_dim_index=0,
            use_cer=self._cfg.get('use_cer', False),
            ctc_decode=True,
            dist_sync_on_step=True,
            log_prediction=self._cfg.get("log_prediction", False),
        )

    @torch.no_grad()
    def transcribe(
        self,
        paths2audio_files: List[str],
        batch_size: int = 4,
        logprobs: bool = False,
        return_hypotheses: bool = False,
        num_workers: int = 0,
    ) -> List[str]:
        """
        Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.

        Args:
            paths2audio_files: (a list) of paths to audio files. \
                Recommended length per file is between 5 and 25 seconds. \
                But it is possible to pass a few hours long file if enough GPU memory is available.
            batch_size: (int) batch size to use during inference.
                Bigger will result in better throughput performance but would use more memory.
            logprobs: (bool) pass True to get log probabilities instead of transcripts.
            return_hypotheses: (bool) Either return hypotheses or text
                With hypotheses can do some postprocessing like getting timestamp or rescoring
            num_workers: (int) number of workers for DataLoader

        Returns:
            A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files
        """
        if paths2audio_files is None or len(paths2audio_files) == 0:
            return {}

        if return_hypotheses and logprobs:
            raise ValueError(
                "Either `return_hypotheses` or `logprobs` can be True at any given time."
                "Returned hypotheses will contain the logprobs."
            )

        if num_workers is None:
            num_workers = min(batch_size, os.cpu_count() - 1)

        # We will store transcriptions here
        hypotheses = []
        # Model's mode and device
        mode = self.training
        device = next(self.parameters()).device
        dither_value = self.preprocessor.featurizer.dither
        pad_to_value = self.preprocessor.featurizer.pad_to

        try:
            self.preprocessor.featurizer.dither = 0.0
            self.preprocessor.featurizer.pad_to = 0
            # Switch model to evaluation mode
            self.eval()
            # Freeze the encoder and decoder modules
            self.encoder.freeze()
            self.decoder.freeze()
            logging_level = logging.get_verbosity()
            logging.set_verbosity(logging.WARNING)
            # Work in tmp directory - will store manifest file there
            with tempfile.TemporaryDirectory() as tmpdir:
                with open(os.path.join(tmpdir, 'manifest.json'), 'w', encoding='utf-8') as fp:
                    for audio_file in paths2audio_files:
                        entry = {'audio_filepath': audio_file, 'duration': 100000, 'text': ''}
                        fp.write(json.dumps(entry) + '\n')

                config = {
                    'paths2audio_files': paths2audio_files,
                    'batch_size': batch_size,
                    'temp_dir': tmpdir,
                    'num_workers': num_workers,
                }

                temporary_datalayer = self._setup_transcribe_dataloader(config)
                for test_batch in tqdm(temporary_datalayer, desc="Transcribing"):
                    logits, logits_len, greedy_predictions = self.forward(
                        input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device)
                    )
                    if logprobs:
                        # dump log probs per file
                        for idx in range(logits.shape[0]):
                            lg = logits[idx][: logits_len[idx]]
                            hypotheses.append(lg.cpu().numpy())
                    else:
                        current_hypotheses = self._wer.ctc_decoder_predictions_tensor(
                            greedy_predictions, predictions_len=logits_len, return_hypotheses=return_hypotheses,
                        )

                        if return_hypotheses:
                            # dump log probs per file
                            for idx in range(logits.shape[0]):
                                current_hypotheses[idx].y_sequence = logits[idx][: logits_len[idx]]

                        hypotheses += current_hypotheses

                    del greedy_predictions
                    del logits
                    del test_batch
        finally:
            # set mode back to its original value
            self.train(mode=mode)
            self.preprocessor.featurizer.dither = dither_value
            self.preprocessor.featurizer.pad_to = pad_to_value
            if mode is True:
                self.encoder.unfreeze()
                self.decoder.unfreeze()
            logging.set_verbosity(logging_level)
        return hypotheses

    def change_vocabulary(self, new_vocabulary: List[str]):
        """
        Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model.
        This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
        use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need
        model to learn capitalization, punctuation and/or special characters.

        If new_vocabulary == self.decoder.vocabulary then nothing will be changed.

        Args:

            new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \
            this is target alphabet.

        Returns: None

        """
        if self.decoder.vocabulary == new_vocabulary:
            logging.warning(f"Old {self.decoder.vocabulary} and new {new_vocabulary} match. Not changing anything.")
        else:
            if new_vocabulary is None or len(new_vocabulary) == 0:
                raise ValueError(f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}')
            decoder_config = self.decoder.to_config_dict()
            new_decoder_config = copy.deepcopy(decoder_config)
            new_decoder_config['vocabulary'] = new_vocabulary
            new_decoder_config['num_classes'] = len(new_vocabulary)

            del self.decoder
            self.decoder = EncDecCTCModel.from_config_dict(new_decoder_config)
            del self.loss
            self.loss = CTCLoss(
                num_classes=self.decoder.num_classes_with_blank - 1,
                zero_infinity=True,
                reduction=self._cfg.get("ctc_reduction", "mean_batch"),
            )
            self._wer = WER(
                vocabulary=self.decoder.vocabulary,
                batch_dim_index=0,
                use_cer=self._cfg.get('use_cer', False),
                ctc_decode=True,
                dist_sync_on_step=True,
                log_prediction=self._cfg.get("log_prediction", False),
            )

            # Update config
            OmegaConf.set_struct(self._cfg.decoder, False)
            self._cfg.decoder = new_decoder_config
            OmegaConf.set_struct(self._cfg.decoder, True)

            ds_keys = ['train_ds', 'validation_ds', 'test_ds']
            for key in ds_keys:
                if key in self.cfg:
                    with open_dict(self.cfg[key]):
                        self.cfg[key]['labels'] = OmegaConf.create(new_vocabulary)

            logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.")

    def _setup_dataloader_from_config(self, config: Optional[Dict]):
        if 'augmentor' in config:
            augmentor = process_augmentations(config['augmentor'])
        else:
            augmentor = None

        # Automatically inject args from model config to dataloader config
        audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')
        audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='labels')

        shuffle = config['shuffle']
        device = 'gpu' if torch.cuda.is_available() else 'cpu'
        if config.get('use_dali', False):
            device_id = self.local_rank if device == 'gpu' else None
            dataset = audio_to_text_dataset.get_dali_char_dataset(
                config=config,
                shuffle=shuffle,
                device_id=device_id,
                global_rank=self.global_rank,
                world_size=self.world_size,
                preprocessor_cfg=self._cfg.preprocessor,
            )
            return dataset

        # Instantiate tarred dataset loader or normal dataset loader
        if config.get('is_tarred', False):
            if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or (
                'manifest_filepath' in config and config['manifest_filepath'] is None
            ):
                logging.warning(
                    "Could not load dataset as `manifest_filepath` was None or "
                    f"`tarred_audio_filepaths` is None. Provided config : {config}"
                )
                return None

            shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0
            dataset = audio_to_text_dataset.get_tarred_dataset(
                config=config,
                shuffle_n=shuffle_n,
                global_rank=self.global_rank,
                world_size=self.world_size,
                augmentor=augmentor,
            )
            shuffle = False
        else:
            if 'manifest_filepath' in config and config['manifest_filepath'] is None:
                logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}")
                return None

            dataset = audio_to_text_dataset.get_char_dataset(config=config, augmentor=augmentor)

        if hasattr(dataset, 'collate_fn'):
            collate_fn = dataset.collate_fn
        else:
            collate_fn = dataset.datasets[0].collate_fn

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=config['batch_size'],
            collate_fn=collate_fn,
            drop_last=config.get('drop_last', False),
            shuffle=shuffle,
            num_workers=config.get('num_workers', 0),
            pin_memory=config.get('pin_memory', False),
        )

    def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]):
        """
        Sets up the training data loader via a Dict-like object.

        Args:
            train_data_config: A config that contains the information regarding construction
                of an ASR Training dataset.

        Supported Datasets:
            -   :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
        """
        if 'shuffle' not in train_data_config:
            train_data_config['shuffle'] = True

        # preserve config
        self._update_dataset_config(dataset_name='train', config=train_data_config)

        self._train_dl = self._setup_dataloader_from_config(config=train_data_config)

        # Need to set this because if using an IterableDataset, the length of the dataloader is the total number
        # of samples rather than the number of batches, and this messes up the tqdm progress bar.
        # So we set the number of steps manually (to the correct number) to fix this.
        if 'is_tarred' in train_data_config and train_data_config['is_tarred']:
            # We also need to check if limit_train_batches is already set.
            # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches,
            # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0).
            if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float):
                self._trainer.limit_train_batches = int(
                    self._trainer.limit_train_batches
                    * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size'])
                )
            elif self._trainer is None:
                logging.warning(
                    "Model Trainer was not set before constructing the dataset, incorrect number of "
                    "training batches will be used. Please set the trainer and rebuild the dataset."
                )

    def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]):
        """
        Sets up the validation data loader via a Dict-like object.

        Args:
            val_data_config: A config that contains the information regarding construction
                of an ASR Training dataset.

        Supported Datasets:
            -   :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
        """
        if 'shuffle' not in val_data_config:
            val_data_config['shuffle'] = False

        # preserve config
        self._update_dataset_config(dataset_name='validation', config=val_data_config)

        self._validation_dl = self._setup_dataloader_from_config(config=val_data_config)

    def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):
        """
        Sets up the test data loader via a Dict-like object.

        Args:
            test_data_config: A config that contains the information regarding construction
                of an ASR Training dataset.

        Supported Datasets:
            -   :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
            -   :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
        """
        if 'shuffle' not in test_data_config:
            test_data_config['shuffle'] = False

        # preserve config
        self._update_dataset_config(dataset_name='test', config=test_data_config)

        self._test_dl = self._setup_dataloader_from_config(config=test_data_config)

    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        if hasattr(self.preprocessor, '_sample_rate'):
            input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate)
        else:
            input_signal_eltype = AudioSignal()
        return {
            "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True),
            "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
            "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
            "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
            "sample_id": NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "outputs": NeuralType(('B', 'T', 'D'), LogprobsType()),
            "encoded_lengths": NeuralType(tuple('B'), LengthsType()),
            "greedy_predictions": NeuralType(('B', 'T'), LabelsType()),
        }

    @typecheck()
    def forward(
        self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None
    ):
        """
        Forward pass of the model.

        Args:
            input_signal: Tensor that represents a batch of raw audio signals,
                of shape [B, T]. T here represents timesteps, with 1 second of audio represented as
                `self.sample_rate` number of floating point values.
            input_signal_length: Vector of length B, that contains the individual lengths of the audio
                sequences.
            processed_signal: Tensor that represents a batch of processed audio signals,
                of shape (B, D, T) that has undergone processing via some DALI preprocessor.
            processed_signal_length: Vector of length B, that contains the individual lengths of the
                processed audio sequences.

        Returns:
            A tuple of 3 elements -
            1) The log probabilities tensor of shape [B, T, D].
            2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B].
            3) The greedy token predictions of the model of shape [B, T] (via argmax)
        """
        has_input_signal = input_signal is not None and input_signal_length is not None
        has_processed_signal = processed_signal is not None and processed_signal_length is not None
        if (has_input_signal ^ has_processed_signal) == False:
            raise ValueError(
                f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive "
                " with ``processed_signal`` and ``processed_signal_len`` arguments."
            )

        if not has_processed_signal:
            processed_signal, processed_signal_length = self.preprocessor(
                input_signal=input_signal, length=input_signal_length,
            )

        if self.spec_augmentation is not None and self.training:
            processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)

        encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
        log_probs = self.decoder(encoder_output=encoded)
        greedy_predictions = log_probs.argmax(dim=-1, keepdim=False)

        return log_probs, encoded_len, greedy_predictions

    # PTL-specific methods
    def training_step(self, batch, batch_nb):
        signal, signal_len, transcript, transcript_len = batch
        if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
            log_probs, encoded_len, predictions = self.forward(
                processed_signal=signal, processed_signal_length=signal_len
            )
        else:
            log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)

        loss_value = self.loss(
            log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len
        )

        tensorboard_logs = {'train_loss': loss_value, 'learning_rate': self._optimizer.param_groups[0]['lr']}

        if hasattr(self, '_trainer') and self._trainer is not None:
            log_every_n_steps = self._trainer.log_every_n_steps
        else:
            log_every_n_steps = 1

        if (batch_nb + 1) % log_every_n_steps == 0:
            self._wer.update(
                predictions=predictions,
                targets=transcript,
                target_lengths=transcript_len,
                predictions_lengths=encoded_len,
            )
            wer, _, _ = self._wer.compute()
            self._wer.reset()
            tensorboard_logs.update({'training_batch_wer': wer})

        return {'loss': loss_value, 'log': tensorboard_logs}

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        signal, signal_len, transcript, transcript_len, sample_id = batch
        if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
            log_probs, encoded_len, predictions = self.forward(
                processed_signal=signal, processed_signal_length=signal_len
            )
        else:
            log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)

        transcribed_texts = self._wer.ctc_decoder_predictions_tensor(
            predictions=predictions, predictions_len=encoded_len, return_hypotheses=False,
        )

        sample_id = sample_id.cpu().detach().numpy()
        return list(zip(sample_id, transcribed_texts))

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        signal, signal_len, transcript, transcript_len = batch
        if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
            log_probs, encoded_len, predictions = self.forward(
                processed_signal=signal, processed_signal_length=signal_len
            )
        else:
            log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)

        loss_value = self.loss(
            log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len
        )
        self._wer.update(
            predictions=predictions, targets=transcript, target_lengths=transcript_len, predictions_lengths=encoded_len
        )
        wer, wer_num, wer_denom = self._wer.compute()
        self._wer.reset()
        return {
            'val_loss': loss_value,
            'val_wer_num': wer_num,
            'val_wer_denom': wer_denom,
            'val_wer': wer,
        }

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx)
        test_logs = {
            'test_loss': logs['val_loss'],
            'test_wer_num': logs['val_wer_num'],
            'test_wer_denom': logs['val_wer_denom'],
            'test_wer': logs['val_wer'],
        }
        return test_logs

    def test_dataloader(self):
        if self._test_dl is not None:
            return self._test_dl

    def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a temporary data loader which wraps the provided audio file.

        Args:
            config: A python dictionary which contains the following keys:
            paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \
                Recommended length per file is between 5 and 25 seconds.
            batch_size: (int) batch size to use during inference. \
                Bigger will result in better throughput performance but would use more memory.
            temp_dir: (str) A temporary directory where the audio manifest is temporarily
                stored.
            num_workers: (int) number of workers. Depends of the batch_size and machine. \
                0 - only the main process will load batches, 1 - one worker (not main process)

        Returns:
            A pytorch DataLoader for the given audio file(s).
        """
        if 'manifest_filepath' in config:
            manifest_filepath = config['manifest_filepath']
            batch_size = config['batch_size']
        else:
            manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json')
            batch_size = min(config['batch_size'], len(config['paths2audio_files']))

        dl_config = {
            'manifest_filepath': manifest_filepath,
            'sample_rate': self.preprocessor._sample_rate,
            'labels': self.decoder.vocabulary,
            'batch_size': batch_size,
            'trim_silence': False,
            'shuffle': False,
            'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)),
            'pin_memory': True,
        }

        temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
        return temporary_datalayer
Пример #24
0
def main():
    parser = ArgumentParser()
    """Training arguments"""
    parser.add_argument("--asr_model",
                        type=str,
                        default="QuartzNet15x5Base-En",
                        required=True,
                        help="Pass: '******'")
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument(
        "--normalize_text",
        default=True,
        type=bool,
        help="Normalize transcripts or not. Set to False for non-English.")
    parser.add_argument("--shuffle",
                        action='store_true',
                        help="Shuffle test data.")
    """Calibration arguments"""
    parser.add_argument("--load",
                        type=str,
                        default=None,
                        help="load path for the synthetic data")
    parser.add_argument(
        "--percentile",
        type=float,
        default=None,
        help="Max/min percentile for outlier handling. e.g., 99.9")
    """Quantization arguments"""
    parser.add_argument("--weight_bit",
                        type=int,
                        default=8,
                        help="quantization bit for weights")
    parser.add_argument("--act_bit",
                        type=int,
                        default=8,
                        help="quantization bit for activations")
    parser.add_argument("--dynamic",
                        action='store_true',
                        help="Dynamic quantization mode.")
    parser.add_argument("--no_quant",
                        action='store_true',
                        help="No quantization mode.")
    """Debugging arguments"""
    parser.add_argument("--eval_early_stop",
                        type=int,
                        default=None,
                        help="early stop for debugging")
    parser.add_argument("--calib_early_stop",
                        type=int,
                        default=None,
                        help="early stop calibration")

    args = parser.parse_args()

    torch.set_grad_enabled(False)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)

    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.normalize_text,
            'shuffle': args.shuffle,
        })

    if args.load is not None:
        print('Data loaded from %s' % args.load)
        with open(args.load, 'rb') as f:
            distilled_data = pickle.load(f)
        synthetic_batch_size, _, synthetic_seqlen = distilled_data[0].shape
    else:
        assert args.dynamic, \
                "synthetic data must be loaded unless running with the dynamic quantization mode"

    ############################## Calibration #####################################

    torch.set_grad_enabled(False)  # disable backward graph generation
    asr_model.eval()  # evaluation mode
    asr_model.set_quant_bit(args.weight_bit, mode='weight')
    asr_model.set_quant_bit(args.act_bit, mode='act')

    # set percentile
    if args.percentile is not None:
        qm.set_percentile(asr_model, args.percentile)

    if args.no_quant:
        asr_model.set_quant_mode('none')
    else:
        asr_model.encoder.bn_folding()  # BN folding

    # if not dynamic quantization, calibrate min/max/range for the activations using synthetic data
    # if dynamic, we can skip calibration
    if not args.dynamic:
        print('Calibrating...')
        qm.calibrate(asr_model)
        length = torch.tensor([synthetic_seqlen] * synthetic_batch_size).cuda()
        for batch_idx, inputs in enumerate(distilled_data):
            if args.calib_early_stop is not None and batch_idx == args.calib_early_stop:
                break
            inputs = inputs.cuda()
            encoded, encoded_len, encoded_scaling_factor = asr_model.encoder(
                audio_signal=inputs, length=length)
            log_probs = asr_model.decoder(
                encoder_output=encoded,
                encoder_output_scaling_factor=encoded_scaling_factor)

    ############################## Evaluation  #####################################

    print('Evaluating...')
    qm.evaluate(asr_model)

    qm.set_dynamic(
        asr_model,
        args.dynamic)  # if dynamic quantization, this will be enabled
    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    hypotheses = []
    references = []
    progress_bar = tqdm(asr_model.test_dataloader())

    for i, test_batch in enumerate(progress_bar):
        if i == args.eval_early_stop:
            break
        test_batch = [x.cuda().float() for x in test_batch]
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=test_batch[0], input_signal_length=test_batch[1])
        hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions)
        for batch_ind in range(greedy_predictions.shape[0]):
            reference = ''.join([
                labels_map[c]
                for c in test_batch[2][batch_ind].cpu().detach().numpy()
            ])
            references.append(reference)
        del test_batch
    wer_value = word_error_rate(hypotheses=hypotheses, references=references)
    print('WER:', wer_value)