Esempio n. 1
0
 def test_EncDecCTCModel(self):
     # TODO: Switch to using named configs because here we don't really care about weights
     qn = EncDecCTCModel.from_pretrained(model_name="QuartzNet15x5Base-En")
     self.__test_restore_elsewhere(
         model=qn,
         attr_for_eq_check=set(["decoder._feat_in",
                                "decoder._num_classes"]))
Esempio n. 2
0
    def set_asr_model(self, asr_model):
        """
        Setup the parameters for the given ASR model
        Currently, the following models are supported:
            stt_en_conformer_ctc_large
            stt_en_conformer_ctc_medium
            stt_en_conformer_ctc_small
            QuartzNet15x5Base-En
        """

        if 'QuartzNet' in asr_model:
            self.run_ASR = self.run_ASR_QuartzNet_CTC
            asr_model = EncDecCTCModel.from_pretrained(model_name=asr_model,
                                                       strict=False)
            self.params['offset'] = -0.18
            self.model_stride_in_secs = 0.02
            self.asr_delay_sec = -1 * self.params['offset']

        elif 'conformer_ctc' in asr_model:
            self.run_ASR = self.run_ASR_BPE_CTC
            asr_model = EncDecCTCModelBPE.from_pretrained(model_name=asr_model,
                                                          strict=False)
            self.model_stride_in_secs = 0.04
            self.asr_delay_sec = 0.0
            self.params['offset'] = 0
            self.chunk_len_in_sec = 1.6
            self.total_buffer_in_secs = 4

        elif 'citrinet' in asr_model:
            self.run_ASR = self.run_ASR_BPE_CTC
            asr_model = EncDecCTCModelBPE.from_pretrained(model_name=asr_model,
                                                          strict=False)
            self.model_stride_in_secs = 0.08
            self.asr_delay_sec = 0.0
            self.params['offset'] = 0
            self.chunk_len_in_sec = 1.6
            self.total_buffer_in_secs = 4

        elif 'conformer_transducer' in asr_model or 'contextnet' in asr_model:
            self.run_ASR = self.run_ASR_BPE_RNNT
            asr_model = EncDecRNNTBPEModel.from_pretrained(
                model_name=asr_model, strict=False)
            self.model_stride_in_secs = 0.04
            self.asr_delay_sec = 0.0
            self.params['offset'] = 0
            self.chunk_len_in_sec = 1.6
            self.total_buffer_in_secs = 4
        else:
            raise ValueError(f"ASR model name not found: {asr_model}")
        self.params['time_stride'] = self.model_stride_in_secs
        self.asr_batch_size = 16
        asr_model.eval()

        self.audio_file_list = [
            value['audio_filepath']
            for _, value in self.AUDIO_RTTM_MAP.items()
        ]

        return asr_model
Esempio n. 3
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'",
    )
    parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test")
    parser.add_argument(
        "--normalize_text", default=True, type=bool, help="Normalize transcripts or not. Set to False for non-English."
    )
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)
    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.normalize_text,
        }
    )
    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()
    labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    hypotheses = []
    references = []
    for test_batch in asr_model.test_dataloader():
        if can_gpu:
            test_batch = [x.cuda() for x in test_batch]
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=test_batch[0], input_signal_length=test_batch[1]
            )
        hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions)
        for batch_ind in range(greedy_predictions.shape[0]):
            reference = ''.join([labels_map[c] for c in test_batch[2][batch_ind].cpu().detach().numpy()])
            references.append(reference)
        del test_batch
    wer_value = word_error_rate(hypotheses=hypotheses, references=references)
    if wer_value > args.wer_tolerance:
        raise ValueError(f"Got WER of {wer_value}. It was higher than {args.wer_tolerance}")
    logging.info(f'Got WER of {wer_value}. Tolerance was {args.wer_tolerance}')
Esempio n. 4
0
    def set_asr_model(self, ASR_model_name):
        if 'QuartzNet' in ASR_model_name:
            self.run_ASR = self.run_ASR_QuartzNet_CTC
            asr_model = EncDecCTCModel.from_pretrained(
                model_name=ASR_model_name, strict=False)
        elif 'conformer' in ASR_model_name:
            self.run_ASR = self.run_ASR_Conformer_CTC
            _ = EncDecCTCModelBPE.from_pretrained(model_name=ASR_model_name,
                                                  strict=False)
            raise NotImplementedError
            # This option has not been implemented yet.
        elif 'citrinet' in ASR_model_name:
            raise NotImplementedError
        else:
            raise ValueError(
                f"ASR model name not found: {self.params['ASR_model_name']}")

        return asr_model
Esempio n. 5
0
def batch_inference(args: argparse.Namespace):

    torch.set_grad_enabled(False)

    if args.asr_model.endswith(".nemo"):
        print(f"Using local ASR model from {args.asr_model}")
        asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
    else:
        print(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)

    manifest = prepare_manifest(args.corpora_dir, args.limit)
    asr_model.setup_test_data(
        test_data_config={
            "sample_rate": 16000,
            "manifest_filepath": manifest,
            "labels": asr_model.decoder.vocabulary,
            "batch_size": args.batch_size,
            "normalize_transcripts": args.normalize_text,
        })

    refs_hyps = list(tqdm(generate_ref_hyps(asr_model, args.search,
                                            args.arpa)))
    references, hypotheses = [list(k) for k in zip(*refs_hyps)]

    os.makedirs(args.results_dir, exist_ok=True)
    data_io.write_lines(f"{args.results_dir}/refs.txt.gz", references)
    data_io.write_lines(f"{args.results_dir}/hyps.txt.gz", hypotheses)

    wer_value = word_error_rate(hypotheses=hypotheses, references=references)
    sys.stdout.flush()
    stats = {
        "wer": wer_value,
        "args": args.__dict__,
    }
    data_io.write_json(f"{args.results_dir}/stats.txt", stats)
    print(f"Got WER of {wer_value}")
    return stats
Esempio n. 6
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default="QuartzNet15x5Base-En",
        choices=[
            x.pretrained_model_name
            for x in EncDecCTCModel.list_available_models()
        ],
    )
    parser.add_argument(
        "--tts_model_spec",
        type=str,
        default="Tacotron2-22050Hz",
        choices=[
            x.pretrained_model_name
            for x in SpectrogramGenerator.list_available_models()
        ],
    )
    parser.add_argument(
        "--tts_model_vocoder",
        type=str,
        default="WaveGlow-22050Hz",
        choices=[
            x.pretrained_model_name for x in Vocoder.list_available_models()
        ],
    )
    parser.add_argument("--wer_tolerance",
                        type=float,
                        default=1.0,
                        help="used by test")
    parser.add_argument("--trim", action="store_true")
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    if args.debug:
        logging.set_verbosity(logging.DEBUG)

    logging.info(f"Using NGC cloud ASR model {args.asr_model}")
    asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)
    logging.info(
        f"Using NGC cloud TTS Spectrogram Generator model {args.tts_model_spec}"
    )
    tts_model_spec = SpectrogramGenerator.from_pretrained(
        model_name=args.tts_model_spec)
    logging.info(f"Using NGC cloud TTS Vocoder model {args.tts_model_vocoder}")
    tts_model_vocoder = Vocoder.from_pretrained(
        model_name=args.tts_model_vocoder)
    models = [asr_model, tts_model_spec, tts_model_vocoder]

    if torch.cuda.is_available():
        for i, m in enumerate(models):
            models[i] = m.cuda()
    for m in models:
        m.eval()

    asr_model, tts_model_spec, tts_model_vocoder = models

    parser = parsers.make_parser(
        labels=asr_model.decoder.vocabulary,
        name="en",
        unk_id=-1,
        blank_id=-1,
        do_normalize=True,
    )
    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])

    tts_input = []
    asr_references = []
    longest_tts_input = 0
    for test_str in LIST_OF_TEST_STRINGS:
        tts_parsed_input = tts_model_spec.parse(test_str)
        if len(tts_parsed_input[0]) > longest_tts_input:
            longest_tts_input = len(tts_parsed_input[0])
        tts_input.append(tts_parsed_input.squeeze())

        asr_parsed = parser(test_str)
        asr_parsed = ''.join([labels_map[c] for c in asr_parsed])
        asr_references.append(asr_parsed)

    # Pad TTS Inputs
    for i, text in enumerate(tts_input):
        pad = (0, longest_tts_input - len(text))
        tts_input[i] = torch.nn.functional.pad(text, pad, value=68)

    logging.debug(tts_input)

    # Do TTS
    tts_input = torch.stack(tts_input)
    if torch.cuda.is_available():
        tts_input = tts_input.cuda()
    specs = tts_model_spec.generate_spectrogram(tokens=tts_input)
    audio = []
    step = ceil(len(specs) / 4)
    for i in range(4):
        audio.append(
            tts_model_vocoder.convert_spectrogram_to_audio(
                spec=specs[i * step:i * step + step]))

    audio = [item for sublist in audio for item in sublist]
    audio_file_paths = []
    # Save audio
    logging.debug(f"args.trim: {args.trim}")
    for i, aud in enumerate(audio):
        aud = aud.cpu().numpy()
        if args.trim:
            aud = librosa.effects.trim(aud, top_db=40)[0]
        librosa.output.write_wav(f"{i}.wav", aud, sr=22050)
        audio_file_paths.append(str(Path(f"{i}.wav")))

    # Do ASR
    hypotheses = asr_model.transcribe(audio_file_paths)
    for i, _ in enumerate(hypotheses):
        logging.debug(f"{i}")
        logging.debug(f"ref:'{asr_references[i]}'")
        logging.debug(f"hyp:'{hypotheses[i]}'")
    wer_value = word_error_rate(hypotheses=hypotheses,
                                references=asr_references)
    if wer_value > args.wer_tolerance:
        raise ValueError(
            f"Got WER of {wer_value}. It was higher than {args.wer_tolerance}")
    logging.info(f'Got WER of {wer_value}. Tolerance was {args.wer_tolerance}')
Esempio n. 7
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default="QuartzNet15x5Base-En",
        required=False,
        help="Pass: '******'",
    )
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument(
        "--normalize_text",
        default=True,
        type=bool,
        help="Normalize transcripts or not. Set to False for non-English.")
    parser.add_argument(
        "--sclite_fmt",
        default="trn",
        type=str,
        help="sclite output format. Only trn and ctm are supported")
    parser.add_argument("--out_dir",
                        type=str,
                        required=True,
                        help="Destination dir for output files")
    parser.add_argument("--sctk_dir",
                        type=str,
                        required=False,
                        default="",
                        help="Path to sctk root dir")
    parser.add_argument("--glm",
                        type=str,
                        required=False,
                        default="",
                        help="Path to glm file")
    parser.add_argument("--ref_stm",
                        type=str,
                        required=False,
                        default="",
                        help="Path to glm file")
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    use_sctk = os.path.exists(args.sctk_dir)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)
    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.normalize_text,
        })
    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()
    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])

    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    hypotheses = []
    references = []
    all_log_probs = []
    for test_batch in asr_model.test_dataloader():
        if can_gpu:
            test_batch = [x.cuda() for x in test_batch]
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=test_batch[0], input_signal_length=test_batch[1])
        for r in log_probs.cpu().numpy():
            all_log_probs.append(r)
        hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions)
        for batch_ind in range(greedy_predictions.shape[0]):
            reference = ''.join([
                labels_map[c]
                for c in test_batch[2][batch_ind].cpu().detach().numpy()
            ])
            references.append(reference)
        del test_batch

    info_list = get_utt_info(args.dataset)
    hypfile = os.path.join(args.out_dir, "hyp.trn")
    reffile = os.path.join(args.out_dir, "ref.trn")
    with open(hypfile, "w") as hyp_f, open(reffile, "w") as ref_f:
        for i in range(len(hypotheses)):
            utt_id = os.path.splitext(
                os.path.basename(info_list[i]['audio_filepath']))[0]
            # rfilter in sctk likes each transcript to have a space at the beginning
            hyp_f.write(" " + hypotheses[i] + " (" + utt_id + ")" + "\n")
            ref_f.write(" " + references[i] + " (" + utt_id + ")" + "\n")

    if use_sctk:
        score_with_sctk(args.sctk_dir,
                        reffile,
                        hypfile,
                        args.out_dir,
                        glm=args.glm,
                        fmt="trn")
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'",
    )
    parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data")
    parser.add_argument("--wer_target", type=float, default=None, help="used by test")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test")
    parser.add_argument(
        "--normalize_text", default=True, type=bool, help="Normalize transcripts or not. Set to False for non-English."
    )
    parser.add_argument('--sensitivity', action="store_true", help="Perform sensitivity analysis")
    parser.add_argument('--onnx', action="store_true", help="Export to ONNX")
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    quant_modules.initialize()

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)
    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.normalize_text,
        }
    )
    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()
    labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    wer_quant = evaluate(asr_model, labels_map, wer)
    logging.info(f'Got WER of {wer_quant}. Tolerance was {args.wer_tolerance}')

    if args.sensitivity:
        if wer_quant < args.wer_tolerance:
            logging.info("Tolerance is already met. Skip sensitivity analyasis.")
            return
        quant_layer_names = []
        for name, module in asr_model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                module.disable()
                layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "")
                if layer_name not in quant_layer_names:
                    quant_layer_names.append(layer_name)
        logging.info(F"{len(quant_layer_names)} quantized layers found.")

        # Build sensitivity profile
        quant_layer_sensitivity = {}
        for i, quant_layer in enumerate(quant_layer_names):
            logging.info(F"Enable {quant_layer}")
            for name, module in asr_model.named_modules():
                if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name:
                    module.enable()
                    logging.info(F"{name:40}: {module}")

            # Eval the model
            wer_value = evaluate(asr_model, labels_map, wer)
            logging.info(F"WER: {wer_value}")
            quant_layer_sensitivity[quant_layer] = args.wer_tolerance - wer_value

            for name, module in asr_model.named_modules():
                if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name:
                    module.disable()
                    logging.info(F"{name:40}: {module}")

        # Skip most sensitive layers until WER target is met
        for name, module in asr_model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                module.enable()
        quant_layer_sensitivity = collections.OrderedDict(sorted(quant_layer_sensitivity.items(), key=lambda x: x[1]))
        pprint(quant_layer_sensitivity)
        skipped_layers = []
        for quant_layer, _ in quant_layer_sensitivity.items():
            for name, module in asr_model.named_modules():
                if isinstance(module, quant_nn.TensorQuantizer):
                    if quant_layer in name:
                        logging.info(F"Disable {name}")
                        if not quant_layer in skipped_layers:
                            skipped_layers.append(quant_layer)
                        module.disable()
            wer_value = evaluate(asr_model, labels_map, wer)
            if wer_value <= args.wer_tolerance:
                logging.info(
                    F"WER tolerance {args.wer_tolerance} is met by skipping {len(skipped_layers)} sensitive layers."
                )
                print(skipped_layers)
                return
        raise ValueError(f"WER tolerance {args.wer_tolerance} can not be met with any layer quantized!")

    if args.onnx:
        if args.asr_model.endswith("nemo"):
            onnx_name = args.asr_model.replace(".nemo", ".onnx")
        else:
            onnx_name = args.asr_model
        logging.info("Export to ", onnx_name)
        quant_nn.TensorQuantizer.use_fb_fake_quant = True
        asr_model.export(onnx_name, onnx_opset_version=13)
        quant_nn.TensorQuantizer.use_fb_fake_quant = False
Esempio n. 9
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default="QuartzNet15x5Base-En",
        required=True,
        help="Pass: '******'",
    )
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument(
        "--normalize_text",
        default=True,
        type=bool,
        help="Normalize transcripts or not. Set to False for non-English.")
    parser.add_argument('--num_calib_batch',
                        default=1,
                        type=int,
                        help="Number of batches for calibration.")
    parser.add_argument('--calibrator',
                        type=str,
                        choices=["max", "histogram"],
                        default="max")
    parser.add_argument('--percentile',
                        nargs='+',
                        type=float,
                        default=[99.9, 99.99, 99.999, 99.9999])
    parser.add_argument("--amp",
                        action="store_true",
                        help="Use AMP in calibration.")
    parser.set_defaults(amp=False)

    args = parser.parse_args()
    torch.set_grad_enabled(False)

    # Initialize quantization
    quant_desc_input = QuantDescriptor(calib_method=args.calibrator)
    quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
    quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(
        quant_desc_input)
    quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model_cfg = EncDecCTCModel.restore_from(
            restore_path=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModel.restore_from(
            restore_path=args.asr_model, override_config_path=asr_model_cfg)

    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model_cfg = EncDecCTCModel.from_pretrained(
            model_name=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModel.from_pretrained(
            model_name=args.asr_model, override_config_path=asr_model_cfg)

    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.normalize_text,
            'shuffle': True,
        })

    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()

    # Enable calibrators
    for name, module in asr_model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    for i, test_batch in enumerate(asr_model.test_dataloader()):
        if can_gpu:
            test_batch = [x.cuda() for x in test_batch]
        if args.amp:
            with autocast():
                _ = asr_model(input_signal=test_batch[0],
                              input_signal_length=test_batch[1])
        else:
            _ = asr_model(input_signal=test_batch[0],
                          input_signal_length=test_batch[1])
        if i >= args.num_calib_batch:
            break

    # Save calibrated model(s)
    model_name = args.asr_model.replace(
        ".nemo", "") if args.asr_model.endswith(".nemo") else args.asr_model
    if not args.calibrator == "histogram":
        compute_amax(asr_model, method="max")
        asr_model.save_to(
            F"{model_name}-max-{args.num_calib_batch*args.batch_size}.nemo")
    else:
        for percentile in args.percentile:
            print(F"{percentile} percentile calibration")
            compute_amax(asr_model, method="percentile")
            asr_model.save_to(
                F"{model_name}-percentile-{percentile}-{args.num_calib_batch*args.batch_size}.nemo"
            )

        for method in ["mse", "entropy"]:
            print(F"{method} calibration")
            compute_amax(asr_model, method=method)
            asr_model.save_to(
                F"{model_name}-{method}-{args.num_calib_batch*args.batch_size}.nemo"
            )
Esempio n. 10
0
def main():
    parser = ArgumentParser()
    """Training arguments"""
    parser.add_argument("--asr_model",
                        type=str,
                        default="QuartzNet15x5Base-En",
                        required=True,
                        help="Pass: '******'")
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument(
        "--normalize_text",
        default=True,
        type=bool,
        help="Normalize transcripts or not. Set to False for non-English.")
    parser.add_argument("--shuffle",
                        action='store_true',
                        help="Shuffle test data.")
    """Calibration arguments"""
    parser.add_argument("--load",
                        type=str,
                        default=None,
                        help="load path for the synthetic data")
    parser.add_argument(
        "--percentile",
        type=float,
        default=None,
        help="Max/min percentile for outlier handling. e.g., 99.9")
    """Quantization arguments"""
    parser.add_argument("--weight_bit",
                        type=int,
                        default=8,
                        help="quantization bit for weights")
    parser.add_argument("--act_bit",
                        type=int,
                        default=8,
                        help="quantization bit for activations")
    parser.add_argument("--dynamic",
                        action='store_true',
                        help="Dynamic quantization mode.")
    parser.add_argument("--no_quant",
                        action='store_true',
                        help="No quantization mode.")
    """Debugging arguments"""
    parser.add_argument("--eval_early_stop",
                        type=int,
                        default=None,
                        help="early stop for debugging")
    parser.add_argument("--calib_early_stop",
                        type=int,
                        default=None,
                        help="early stop calibration")

    args = parser.parse_args()

    torch.set_grad_enabled(False)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)

    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.normalize_text,
            'shuffle': args.shuffle,
        })

    if args.load is not None:
        print('Data loaded from %s' % args.load)
        with open(args.load, 'rb') as f:
            distilled_data = pickle.load(f)
        synthetic_batch_size, _, synthetic_seqlen = distilled_data[0].shape
    else:
        assert args.dynamic, \
                "synthetic data must be loaded unless running with the dynamic quantization mode"

    ############################## Calibration #####################################

    torch.set_grad_enabled(False)  # disable backward graph generation
    asr_model.eval()  # evaluation mode
    asr_model.set_quant_bit(args.weight_bit, mode='weight')
    asr_model.set_quant_bit(args.act_bit, mode='act')

    # set percentile
    if args.percentile is not None:
        qm.set_percentile(asr_model, args.percentile)

    if args.no_quant:
        asr_model.set_quant_mode('none')
    else:
        asr_model.encoder.bn_folding()  # BN folding

    # if not dynamic quantization, calibrate min/max/range for the activations using synthetic data
    # if dynamic, we can skip calibration
    if not args.dynamic:
        print('Calibrating...')
        qm.calibrate(asr_model)
        length = torch.tensor([synthetic_seqlen] * synthetic_batch_size).cuda()
        for batch_idx, inputs in enumerate(distilled_data):
            if args.calib_early_stop is not None and batch_idx == args.calib_early_stop:
                break
            inputs = inputs.cuda()
            encoded, encoded_len, encoded_scaling_factor = asr_model.encoder(
                audio_signal=inputs, length=length)
            log_probs = asr_model.decoder(
                encoder_output=encoded,
                encoder_output_scaling_factor=encoded_scaling_factor)

    ############################## Evaluation  #####################################

    print('Evaluating...')
    qm.evaluate(asr_model)

    qm.set_dynamic(
        asr_model,
        args.dynamic)  # if dynamic quantization, this will be enabled
    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    hypotheses = []
    references = []
    progress_bar = tqdm(asr_model.test_dataloader())

    for i, test_batch in enumerate(progress_bar):
        if i == args.eval_early_stop:
            break
        test_batch = [x.cuda().float() for x in test_batch]
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=test_batch[0], input_signal_length=test_batch[1])
        hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions)
        for batch_ind in range(greedy_predictions.shape[0]):
            reference = ''.join([
                labels_map[c]
                for c in test_batch[2][batch_ind].cpu().detach().numpy()
            ])
            references.append(reference)
        del test_batch
    wer_value = word_error_rate(hypotheses=hypotheses, references=references)
    print('WER:', wer_value)
Esempio n. 11
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default="QuartzNet15x5Base-En",
        required=True,
        help="Pass: '******'",
    )
    parser.add_argument(
        "--asr_onnx",
        type=str,
        default="./QuartzNet15x5Base-En-max-32.onnx",
        help="Pass: '******'",
    )
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument(
        "--dont_normalize_text",
        default=False,
        action='store_false',
        help="Turn off trasnscript normalization. Recommended for non-English.",
    )
    parser.add_argument(
        "--use_cer",
        default=False,
        action='store_true',
        help="Use Character Error Rate as the evaluation metric")
    parser.add_argument('--qat',
                        action="store_true",
                        help="Use onnx file exported from QAT tools")
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model_cfg = EncDecCTCModel.restore_from(
            restore_path=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModel.restore_from(
            restore_path=args.asr_model, override_config_path=asr_model_cfg)

    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model_cfg = EncDecCTCModel.from_pretrained(
            model_name=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModel.from_pretrained(
            model_name=args.asr_model, override_config_path=asr_model_cfg)
    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.dont_normalize_text,
        })
    asr_model.preprocessor.featurizer.dither = 0.0
    asr_model.preprocessor.featurizer.pad_to = 0
    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()
    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary, use_cer=args.use_cer)
    wer_result = evaluate(asr_model, args.asr_onnx, labels_map, wer, args.qat)
    logging.info(f'Got WER of {wer_result}.')
Esempio n. 12
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--asr_model",
                        type=str,
                        default="QuartzNet15x5Base-En",
                        required=True,
                        help="Pass: '******'")
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--num_batch",
                        type=int,
                        default=50,
                        help="number of batches of the synthetic data")
    parser.add_argument("--batch_size",
                        type=int,
                        default=8,
                        help="batch size of the synthetic data")
    parser.add_argument("--seqlen",
                        type=int,
                        default=500,
                        help="sequence length of the synthetic data")
    parser.add_argument(
        "--train_iter",
        type=int,
        default=200,
        help="training iterations for the synthetic data generation")
    parser.add_argument("--dump_path",
                        type=str,
                        default=None,
                        help="path to dump the synthetic data")
    parser.add_argument(
        "--dump_prefix",
        type=str,
        default='syn',
        help="prefix for the filename of the dumped synthetic data")
    parser.add_argument("--lr",
                        type=float,
                        default=0.01,
                        help="Learning rate for the synthetic data generation")

    args = parser.parse_args()

    torch.set_grad_enabled(False)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        teacher_model = EncDecCTCModel.restore_from(
            restore_path=args.asr_model)
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        teacher_model = EncDecCTCModel.from_pretrained(
            model_name=args.asr_model)

    teacher_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': teacher_model.decoder.vocabulary,
            'batch_size': 8,
            'normalize_transcripts': True,
            'shuffle': True,
        })

    ############################## Distillation #####################################

    teacher_model.set_quant_mode(
        'none')  # distable quantization mode for the teacher model
    torch.set_grad_enabled(True)  # enable backward graph generation

    print("Num batches: %d, Batch size: %d, Training iterations: %d, Learning rate: %.3f " \
            % (args.num_batch, args.batch_size, args.train_iter, args.lr))
    print('Synthesizing...')

    synthetic_data = get_synthetic_data(teacher_model.encoder,
                                        teacher_model.decoder,
                                        batch_size=args.batch_size,
                                        dim=64,
                                        seqlen=args.seqlen,
                                        num_batch=args.num_batch,
                                        train_iter=args.train_iter,
                                        lr=args.lr)

    file_name = '%s_nb%d_iter%d_lr%.3f.pkl' % \
            (args.dump_prefix,  args.num_batch, args.train_iter, args.lr)

    if args.dump_path is not None:
        if not os.path.exists(args.dump_path):
            os.makedirs(args.dump_path)
        file_name = os.path.join(args.dump_path, file_name)

    print('Synthetic data dumped as ', file_name)
    with open(file_name, 'wb') as f:
        pickle.dump([x.cpu() for x in synthetic_data], f)