コード例 #1
0
ファイル: test_asr_exportables.py プロジェクト: NVIDIA/NeMo
    def test_EncDecCTCModel_adapted_export_to_onnx(self):
        model_config = DictConfig({
            'preprocessor': DictConfig(self.preprocessor),
            'encoder': DictConfig(self.encoder_dict),
            'decoder': DictConfig(self.decoder_dict),
        })

        # support adapter in encoder
        model_config.encoder.cls = model_config.encoder.cls + 'Adapter'  # ConvASREncoderAdapter

        # load model
        model = EncDecCTCModel(cfg=model_config)

        # add adapter
        adapter_cfg = OmegaConf.structured(
            LinearAdapterConfig(
                in_features=model_config.encoder.params.jasper[0].filters,
                dim=32))
        model.add_adapter('temp', cfg=adapter_cfg)

        model = model.cuda()

        with tempfile.TemporaryDirectory() as tmpdir:
            filename = os.path.join(tmpdir, 'qn.onnx')
            model.export(
                output=filename,
                check_trace=True,
            )
            onnx_model = onnx.load(filename)
            onnx.checker.check_model(onnx_model,
                                     full_check=True)  # throws when failed
            assert onnx_model.graph.input[0].name == 'audio_signal'
            assert onnx_model.graph.output[0].name == 'logprobs'
コード例 #2
0
 def test_EncDecCTCModel_export_to_onnx(self):
     model_config = DictConfig({
         'preprocessor': DictConfig(self.preprocessor),
         'encoder': DictConfig(self.encoder_dict),
         'decoder': DictConfig(self.decoder_dict),
     })
     model = EncDecCTCModel(cfg=model_config)
     with tempfile.TemporaryDirectory() as tmpdir:
         filename = os.path.join(tmpdir, 'qn.onnx')
         model.export(output=filename)
         onnx_model = onnx.load(filename)
         onnx.checker.check_model(onnx_model,
                                  full_check=True)  # throws when failed
         assert onnx_model.graph.input[0].name == 'audio_signal'
         assert onnx_model.graph.output[0].name == 'logprobs'