def test_EncDecCTCModel_adapted_export_to_onnx(self): model_config = DictConfig({ 'preprocessor': DictConfig(self.preprocessor), 'encoder': DictConfig(self.encoder_dict), 'decoder': DictConfig(self.decoder_dict), }) # support adapter in encoder model_config.encoder.cls = model_config.encoder.cls + 'Adapter' # ConvASREncoderAdapter # load model model = EncDecCTCModel(cfg=model_config) # add adapter adapter_cfg = OmegaConf.structured( LinearAdapterConfig( in_features=model_config.encoder.params.jasper[0].filters, dim=32)) model.add_adapter('temp', cfg=adapter_cfg) model = model.cuda() with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'qn.onnx') model.export( output=filename, check_trace=True, ) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed assert onnx_model.graph.input[0].name == 'audio_signal' assert onnx_model.graph.output[0].name == 'logprobs'
def generate_ref_hyps(asr_model: EncDecCTCModel, search: str, arpa: str): if can_gpu: asr_model = asr_model.cuda() print("USING GPU!") asr_model.eval() vocabulary = asr_model.decoder.vocabulary labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))]) wer = WER(vocabulary=vocabulary) if search == "kenlm" or search == "beamsearch": arpa_file = prepare_arpa_file(arpa) lm_path = arpa_file if search == "kenlm" else None beamsearcher = nemo_asr.modules.BeamSearchDecoderWithLM( vocab=list(vocabulary), beam_width=16, alpha=2, beta=1.5, lm_path=lm_path, num_cpus=max(os.cpu_count(), 1), input_tensor=True, ) for batch in asr_model.test_dataloader(): # TODO(tilo): test_loader should return dict or some typed object not tuple of tensors!! if can_gpu: batch = [x.cuda() for x in batch] input_signal, inpsig_len, transcript, transc_len = batch with autocast(): log_probs, encoded_len, greedy_predictions = asr_model( input_signal=input_signal, input_signal_length=inpsig_len) if search == "greedy": decoded = wer.ctc_decoder_predictions_tensor(greedy_predictions) else: decoded = beamsearch_forward(beamsearcher, log_probs=log_probs, log_probs_length=encoded_len) for i, hyp in enumerate(decoded): reference = "".join([ labels_map[c] for c in transcript[i].cpu().detach().numpy()[:transc_len[i]] ]) yield reference, hyp