Пример #1
0
    def set_asr_model(self, asr_model):
        """
        Setup the parameters for the given ASR model
        Currently, the following models are supported:
            stt_en_conformer_ctc_large
            stt_en_conformer_ctc_medium
            stt_en_conformer_ctc_small
            QuartzNet15x5Base-En
        """

        if 'QuartzNet' in asr_model:
            self.run_ASR = self.run_ASR_QuartzNet_CTC
            asr_model = EncDecCTCModel.from_pretrained(model_name=asr_model,
                                                       strict=False)
            self.params['offset'] = -0.18
            self.model_stride_in_secs = 0.02
            self.asr_delay_sec = -1 * self.params['offset']

        elif 'conformer_ctc' in asr_model:
            self.run_ASR = self.run_ASR_BPE_CTC
            asr_model = EncDecCTCModelBPE.from_pretrained(model_name=asr_model,
                                                          strict=False)
            self.model_stride_in_secs = 0.04
            self.asr_delay_sec = 0.0
            self.params['offset'] = 0
            self.chunk_len_in_sec = 1.6
            self.total_buffer_in_secs = 4

        elif 'citrinet' in asr_model:
            self.run_ASR = self.run_ASR_BPE_CTC
            asr_model = EncDecCTCModelBPE.from_pretrained(model_name=asr_model,
                                                          strict=False)
            self.model_stride_in_secs = 0.08
            self.asr_delay_sec = 0.0
            self.params['offset'] = 0
            self.chunk_len_in_sec = 1.6
            self.total_buffer_in_secs = 4

        elif 'conformer_transducer' in asr_model or 'contextnet' in asr_model:
            self.run_ASR = self.run_ASR_BPE_RNNT
            asr_model = EncDecRNNTBPEModel.from_pretrained(
                model_name=asr_model, strict=False)
            self.model_stride_in_secs = 0.04
            self.asr_delay_sec = 0.0
            self.params['offset'] = 0
            self.chunk_len_in_sec = 1.6
            self.total_buffer_in_secs = 4
        else:
            raise ValueError(f"ASR model name not found: {asr_model}")
        self.params['time_stride'] = self.model_stride_in_secs
        self.asr_batch_size = 16
        asr_model.eval()

        self.audio_file_list = [
            value['audio_filepath']
            for _, value in self.AUDIO_RTTM_MAP.items()
        ]

        return asr_model
Пример #2
0
 def test_EncDecCTCModelBPE_v2(self):
     # TODO: Switch to using named configs because here we don't really care about weights
     cn = EncDecCTCModelBPE.from_pretrained(
         model_name="stt_en_conformer_ctc_small")
     self.__test_restore_elsewhere(
         model=cn,
         attr_for_eq_check=set(["decoder._feat_in",
                                "decoder._num_classes"]))
Пример #3
0
 def test_EncDecCTCModelBPE(self):
     # TODO: Switch to using named configs because here we don't really care about weights
     cn = EncDecCTCModelBPE.from_pretrained(
         model_name="ContextNet-192-WPE-1024-8x-Stride")
     self.__test_restore_elsewhere(
         model=cn,
         attr_for_eq_check=set(["decoder._feat_in",
                                "decoder._num_classes"]))
Пример #4
0
def main(cfg):
    if cfg.n_gpus > 0:
        cfg.model.train_ds.batch_size //= cfg.n_gpus

    logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg, resolve=True)}')

    pl.utilities.seed.seed_everything(cfg.seed)

    trainer = pl.Trainer(**cfg.trainer)
    exp_manager(trainer, cfg.get("exp_manager", None))
    if "tokenizer" in cfg.model:
        asr_model = EncDecCTCModelBPE(cfg=cfg.model, trainer=trainer)
    else:
        asr_model = EncDecCTCModel(cfg=cfg.model, trainer=trainer)

    # Initialize the weights of the model from another model, if provided via config
    asr_model.maybe_init_from_pretrained_checkpoint(cfg)

    trainer.fit(asr_model)

    if hasattr(cfg.model,
               'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
        gpu = 1 if cfg.trainer.gpus != 0 else 0
        test_trainer = pl.Trainer(
            gpus=gpu,
            precision=trainer.precision,
            amp_level=trainer.accelerator_connector.amp_level,
            amp_backend=cfg.trainer.get("amp_backend", "native"),
        )
        if asr_model.prepare_test(test_trainer):
            test_trainer.test(asr_model)
Пример #5
0
    def set_asr_model(self, ASR_model_name):
        if 'QuartzNet' in ASR_model_name:
            self.run_ASR = self.run_ASR_QuartzNet_CTC
            asr_model = EncDecCTCModel.from_pretrained(
                model_name=ASR_model_name, strict=False)
        elif 'conformer' in ASR_model_name:
            self.run_ASR = self.run_ASR_Conformer_CTC
            _ = EncDecCTCModelBPE.from_pretrained(model_name=ASR_model_name,
                                                  strict=False)
            raise NotImplementedError
            # This option has not been implemented yet.
        elif 'citrinet' in ASR_model_name:
            raise NotImplementedError
        else:
            raise ValueError(
                f"ASR model name not found: {self.params['ASR_model_name']}")

        return asr_model
Пример #6
0
def main(cfg):
    name_prefix, checkpoint_paths, save_ckpt_only = process_config(cfg)

    if not save_ckpt_only:
        trainer = pl.Trainer(**cfg.trainer)

        # <<< Change model class here ! >>>
        # Model architecture which will contain the averaged checkpoints
        # Change the model constructor to the one you would like (if needed)
        model = EncDecCTCModelBPE(cfg=cfg.model, trainer=trainer)
    """ < Checkpoint Averaging Logic > """
    # load state dicts
    n = len(checkpoint_paths)
    avg_state = None

    logging.info(f"Averaging {n} checkpoints ...")

    for ix, path in enumerate(checkpoint_paths):
        checkpoint = torch.load(path, map_location='cpu')

        if 'state_dict' in checkpoint:
            checkpoint = checkpoint['state_dict']

        if ix == 0:
            # Initial state
            avg_state = checkpoint

            logging.info(
                f"Initialized average state dict with checkpoint : {path}")
        else:
            # Accumulated state
            for k in avg_state:
                avg_state[k] = avg_state[k] + checkpoint[k]

            logging.info(
                f"Updated average state dict with state from checkpoint : {path}"
            )

    for k in avg_state:
        if str(avg_state[k].dtype).startswith("torch.int"):
            # For int type, not averaged, but only accumulated.
            # e.g. BatchNorm.num_batches_tracked
            pass
        else:
            avg_state[k] = avg_state[k] / n

    # Save model
    if save_ckpt_only:
        ckpt_name = name_prefix + '-averaged.ckpt'
        torch.save(avg_state, ckpt_name)

        logging.info(f"Averaged pytorch checkpoint saved as : {ckpt_name}")
    else:
        # Set model state
        logging.info("Loading averaged state dict in provided model")
        model.load_state_dict(avg_state, strict=True)

        ckpt_name = name_prefix + '-averaged.nemo'
        model.save_to(ckpt_name)

        logging.info(f"Averaged model saved as : {ckpt_name}")
Пример #7
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default="QuartzNet15x5Base-En",
        required=True,
        help="Pass: '******'",
    )
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument(
        "--dont_normalize_text",
        default=False,
        action='store_false',
        help="Turn off trasnscript normalization. Recommended for non-English.",
    )
    parser.add_argument('--num_calib_batch',
                        default=1,
                        type=int,
                        help="Number of batches for calibration.")
    parser.add_argument('--calibrator',
                        type=str,
                        choices=["max", "histogram"],
                        default="max")
    parser.add_argument('--percentile',
                        nargs='+',
                        type=float,
                        default=[99.9, 99.99, 99.999, 99.9999])
    parser.add_argument("--amp",
                        action="store_true",
                        help="Use AMP in calibration.")
    parser.set_defaults(amp=False)

    args = parser.parse_args()
    torch.set_grad_enabled(False)

    # Initialize quantization
    quant_desc_input = QuantDescriptor(calib_method=args.calibrator)
    quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
    quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(
        quant_desc_input)
    quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model_cfg = EncDecCTCModelBPE.restore_from(
            restore_path=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModelBPE.restore_from(
            restore_path=args.asr_model, override_config_path=asr_model_cfg)

    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model_cfg = EncDecCTCModelBPE.from_pretrained(
            model_name=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModelBPE.from_pretrained(
            model_name=args.asr_model, override_config_path=asr_model_cfg)

    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.dont_normalize_text,
            'shuffle': True,
        })
    asr_model.preprocessor.featurizer.dither = 0.0
    asr_model.preprocessor.featurizer.pad_to = 0
    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()

    # Enable calibrators
    for name, module in asr_model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    for i, test_batch in enumerate(asr_model.test_dataloader()):
        if can_gpu:
            test_batch = [x.cuda() for x in test_batch]
        if args.amp:
            with autocast():
                _ = asr_model(input_signal=test_batch[0],
                              input_signal_length=test_batch[1])
        else:
            _ = asr_model(input_signal=test_batch[0],
                          input_signal_length=test_batch[1])
        if i >= args.num_calib_batch:
            break

    # Save calibrated model(s)
    model_name = args.asr_model.replace(
        ".nemo", "") if args.asr_model.endswith(".nemo") else args.asr_model
    if not args.calibrator == "histogram":
        compute_amax(asr_model, method="max")
        asr_model.save_to(
            F"{model_name}-max-{args.num_calib_batch*args.batch_size}.nemo")
    else:
        for percentile in args.percentile:
            print(F"{percentile} percentile calibration")
            compute_amax(asr_model, method="percentile")
            asr_model.save_to(
                F"{model_name}-percentile-{percentile}-{args.num_calib_batch*args.batch_size}.nemo"
            )

        for method in ["mse", "entropy"]:
            print(F"{method} calibration")
            compute_amax(asr_model, method=method)
            asr_model.save_to(
                F"{model_name}-{method}-{args.num_calib_batch*args.batch_size}.nemo"
            )
Пример #8
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default="QuartzNet15x5Base-En",
        required=True,
        help="Pass: '******'",
    )
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--wer_target",
                        type=float,
                        default=None,
                        help="used by test")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--wer_tolerance",
                        type=float,
                        default=1.0,
                        help="used by test")
    parser.add_argument(
        "--dont_normalize_text",
        default=False,
        action='store_false',
        help="Turn off trasnscript normalization. Recommended for non-English.",
    )
    parser.add_argument(
        "--use_cer",
        default=False,
        action='store_true',
        help="Use Character Error Rate as the evaluation metric")
    parser.add_argument('--sensitivity',
                        action="store_true",
                        help="Perform sensitivity analysis")
    parser.add_argument('--onnx', action="store_true", help="Export to ONNX")
    parser.add_argument('--quant-disable-keyword',
                        type=str,
                        nargs='+',
                        help='disable quantizers by keyword')
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    quant_modules.initialize()

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model_cfg = EncDecCTCModelBPE.restore_from(
            restore_path=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModelBPE.restore_from(
            restore_path=args.asr_model, override_config_path=asr_model_cfg)

    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model_cfg = EncDecCTCModelBPE.from_pretrained(
            model_name=args.asr_model, return_config=True)
        with open_dict(asr_model_cfg):
            asr_model_cfg.encoder.quantize = True
        asr_model = EncDecCTCModelBPE.from_pretrained(
            model_name=args.asr_model, override_config_path=asr_model_cfg)
    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.dont_normalize_text,
        })
    asr_model.preprocessor.featurizer.dither = 0.0
    asr_model.preprocessor.featurizer.pad_to = 0
    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()

    if args.quant_disable_keyword:
        for name, module in asr_model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                for keyword in args.quant_disable_keyword:
                    if keyword in name:
                        logging.warning(F"Disable {name}")
                        module.disable()

    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary, use_cer=args.use_cer)
    wer_quant = evaluate(asr_model, labels_map, wer)
    logging.info(f'Got WER of {wer_quant}. Tolerance was {args.wer_tolerance}')

    if args.sensitivity:
        if wer_quant < args.wer_tolerance:
            logging.info(
                "Tolerance is already met. Skip sensitivity analyasis.")
            return
        quant_layer_names = []
        for name, module in asr_model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                module.disable()
                layer_name = name.replace("._input_quantizer",
                                          "").replace("._weight_quantizer", "")
                if layer_name not in quant_layer_names:
                    quant_layer_names.append(layer_name)
        logging.info(F"{len(quant_layer_names)} quantized layers found.")

        # Build sensitivity profile
        quant_layer_sensitivity = {}
        for i, quant_layer in enumerate(quant_layer_names):
            logging.info(F"Enable {quant_layer}")
            for name, module in asr_model.named_modules():
                if isinstance(
                        module,
                        quant_nn.TensorQuantizer) and quant_layer in name:
                    module.enable()
                    logging.info(F"{name:40}: {module}")

            # Eval the model
            wer_value = evaluate(asr_model, labels_map, wer)
            logging.info(F"WER: {wer_value}")
            quant_layer_sensitivity[
                quant_layer] = args.wer_tolerance - wer_value

            for name, module in asr_model.named_modules():
                if isinstance(
                        module,
                        quant_nn.TensorQuantizer) and quant_layer in name:
                    module.disable()
                    logging.info(F"{name:40}: {module}")

        # Skip most sensitive layers until WER target is met
        for name, module in asr_model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                module.enable()
        quant_layer_sensitivity = collections.OrderedDict(
            sorted(quant_layer_sensitivity.items(), key=lambda x: x[1]))
        pprint(quant_layer_sensitivity)
        skipped_layers = []
        for quant_layer, _ in quant_layer_sensitivity.items():
            for name, module in asr_model.named_modules():
                if isinstance(module, quant_nn.TensorQuantizer):
                    if quant_layer in name:
                        logging.info(F"Disable {name}")
                        if not quant_layer in skipped_layers:
                            skipped_layers.append(quant_layer)
                        module.disable()
            wer_value = evaluate(asr_model, labels_map, wer)
            if wer_value <= args.wer_tolerance:
                logging.info(
                    F"WER tolerance {args.wer_tolerance} is met by skipping {len(skipped_layers)} sensitive layers."
                )
                print(skipped_layers)
                export_onnx(args, asr_model)
                return
        raise ValueError(
            f"WER tolerance {args.wer_tolerance} can not be met with any layer quantized!"
        )

    export_onnx(args, asr_model)