def test_export_to_onnx(self): model = WaveGlowModel(wcfg) model = model.cuda().half() typecheck.set_typecheck_enabled(enabled=False) with tempfile.TemporaryDirectory() as tmpdir, model.nemo_infer(): # Generate filename in the temporary directory. # TODO: Change `waveglow.ts` to `waveglow.onnx` for > 21.05 tmp_file_name = os.path.join("waveglow.ts") n_mels = 80 # Test export. inp = input_example(n_mels) inp1 = taco2wg(*inp) inp2 = inp1 res1 = model.waveglow(*inp1) res2 = model.waveglow(*inp2) assert torch.allclose(res1, res2, rtol=0.01, atol=0.1) WaveGlowModel.forward_for_export = forward_wrapper model.export( tmp_file_name, verbose=True, input_example=inp, output_example=res1, try_script=False, check_trace=False, do_constant_folding=True, dynamic_axes={"spec": [0], "z": [0], "audio": [0]}, )
def test_export_to_onnx(self): model = WaveGlowModel(wcfg) model = model.cuda().half() typecheck.set_typecheck_enabled(enabled=False) with tempfile.TemporaryDirectory() as tmpdir, model.nemo_infer(): tmp_file_name = os.path.join(tmpdir, "waveglow.onnx") n_mels = 80 # Test export. inp = input_example(n_mels) inp1 = taco2wg(*inp) inp2 = inp1 res1 = model.waveglow(*inp1) res2 = model.waveglow(*inp2) assert torch.allclose(res1, res2, rtol=0.01, atol=0.1) WaveGlowModel.forward_for_export = forward_wrapper model.export( tmp_file_name, verbose=False, input_example=inp, output_example=res1, try_script=False, check_trace=False, do_constant_folding=True, )
def test_get_model(self): model = get_pretrained_bert_345m_uncased_model() assert isinstance(model, nemo_nlp.modules.MegatronBertEncoder) typecheck.set_typecheck_enabled(enabled=False) inp = model.input_example() out = model.forward(*inp) typecheck.set_typecheck_enabled(enabled=True)
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg=cfg, trainer=trainer) typecheck.set_typecheck_enabled(enabled=False) cfg = self._cfg self.vocab = AudioToCharWithDursF0Dataset.make_vocab( **cfg.train_ds.dataset.vocab) self.embed = nn.Embedding(len(self.vocab.labels), cfg.d_char) self.model = instantiate(cfg.model) d_out = cfg.model.jasper[-1].filters self.proj = nn.Conv1d(d_out, 1, kernel_size=1)
def test_export_to_onnx(self): model = WaveGlowModel(wcfg) # model = WaveGlowModel.restore_from("../WaveGlow-22050Hz-268M.nemo") model = model.cuda().half() typecheck.set_typecheck_enabled(enabled=False) with tempfile.TemporaryDirectory() as tmpdir, model.nemo_infer(): # Generate filename in the temporary directory. tmp_file_name = os.path.join("waveglow.onnx") n_mels = 80 # Test export. inp = input_example(n_mels) inp1 = taco2wg(**inp) inp2 = inp1 res1 = model.waveglow(*inp1) res2 = model.waveglow(*inp2) assert torch.allclose(res1, res2, rtol=0.01, atol=0.1) model.export( tmp_file_name, verbose=True, input_example=inp, output_example=res1, try_script=False, check_trace=False, do_constant_folding=True, dynamic_axes={ "spec": [0], "z": [0], "audio": [0] }, forward_method=forward_wrapper, ) try: test_runtime = True import onnxruntime except (ImportError, ModuleNotFoundError): test_runtime = False if test_runtime: omodel = onnx.load(tmp_file_name) output_names = ['audio'] sess = onnxruntime.InferenceSession(omodel.SerializeToString()) output = sess.run(None, { "spec": inp["spec"].cpu().numpy(), "z": inp["z"].cpu().numpy() })[0] assert torch.allclose(torch.from_numpy(output), res2.cpu(), rtol=1, atol=100)
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): super().__init__(cfg=cfg, trainer=trainer) typecheck.set_typecheck_enabled(enabled=False) cfg = self._cfg self.vocab = AudioToCharWithDursF0Dataset.make_vocab( **cfg.train_ds.dataset.vocab) self.blanking = cfg.train_ds.dataset.blanking self.preprocessor = instantiate(cfg.preprocessor) self.embed = GaussianEmbedding(self.vocab, cfg.d_char) self.norm_f0 = MaskedInstanceNorm1d(1) self.res_f0 = StyleResidual(cfg.d_char, 1, kernel_size=3) self.model = instantiate(cfg.model) d_out = cfg.model.jasper[-1].filters self.proj = nn.Conv1d(d_out, cfg.n_mels, kernel_size=1)
def export( self, output: str, input_example=None, verbose=False, export_params=True, do_constant_folding=True, keep_initializers_as_inputs=False, onnx_opset_version=None, try_script: bool = False, set_eval: bool = True, check_trace: bool = False, use_dynamic_axes: bool = True, dynamic_axes=None, check_tolerance=0.01, ): my_args = locals() my_args.pop('self') exportables = [] for m in self.modules(): if isinstance(m, Exportable): exportables.append(m) qual_name = self.__module__ + '.' + self.__class__.__qualname__ format = self.get_format(output) output_descr = f"{qual_name} exported to {format}" # Pytorch's default for None is too low, can't pass through if onnx_opset_version is None: onnx_opset_version = 13 try: # Disable typechecks typecheck.set_typecheck_enabled(enabled=False) # Allow user to completely override forward method to export forward_method, old_forward_method = self._wrap_forward_method() # Set module to eval mode if set_eval: self.eval() if input_example is None: input_example = self._get_input_example() # Remove i/o examples from args we propagate to enclosed Exportables my_args.pop('output') my_args.pop('input_example') # Run (posibly overridden) prepare methods before calling forward() for ex in exportables: ex._prepare_for_export(**my_args) self._prepare_for_export(output=output, input_example=input_example, **my_args) input_list, input_dict = self._setup_input_example(input_example) input_names = self._process_input_names() output_names = self._process_output_names() output_example = tuple(self.forward(*input_list, **input_dict)) with torch.jit.optimized_execution(True), torch.no_grad(): jitted_model = None if try_script: try: jitted_model = torch.jit.script(self) except Exception as e: logging.error(f"jit.script() failed!\{e}") if format == ExportFormat.TORCHSCRIPT: if jitted_model is None: jitted_model = torch.jit.trace_module( self, { "forward": tuple(input_list) + tuple(input_dict.values()) }, strict=False, check_trace=check_trace, check_tolerance=check_tolerance, ) if verbose: logging.info(f"JIT code:\n{jitted_model.code}") jitted_model.save(output) assert os.path.exists(output) elif format == ExportFormat.ONNX: if jitted_model is None: jitted_model = self # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None and use_dynamic_axes: dynamic_axes = get_input_dynamic_axes( self.input_module, input_names) dynamic_axes = { **dynamic_axes, **get_output_dynamic_axes(self.output_module, output_names) } torch.onnx.export( jitted_model, input_example, output, input_names=input_names, output_names=output_names, verbose=verbose, export_params=export_params, do_constant_folding=do_constant_folding, keep_initializers_as_inputs=keep_initializers_as_inputs, dynamic_axes=dynamic_axes, opset_version=onnx_opset_version, ) # Verify the model can be read, and is valid onnx_model = onnx.load(output) onnx.checker.check_model(onnx_model, full_check=True) if check_trace: self._verify_runtime( onnx_model, input_list, input_dict, input_names, output_names, output_example, output, check_tolerance, ) else: raise ValueError( f'Encountered unknown export format {format}.') finally: typecheck.set_typecheck_enabled(enabled=True) if forward_method: type(self).forward = old_forward_method self._export_teardown() return ([output], [output_descr])
def export( self, output: str, input_example=None, output_example=None, verbose=False, export_params=True, do_constant_folding=True, keep_initializers_as_inputs=False, onnx_opset_version: int = 12, try_script: bool = False, set_eval: bool = True, check_trace: bool = True, use_dynamic_axes: bool = True, dynamic_axes=None, check_tolerance=0.01, forward_method=None, ): qual_name = self.__module__ + '.' + self.__class__.__qualname__ output_descr = qual_name + ' exported to ONNX' exported = ([output], [output_descr]) try: # Disable typechecks typecheck.set_typecheck_enabled(enabled=False) # Allow user to completely override forward method to export if forward_method is None and hasattr(type(self), "forward_for_export"): forward_method = type(self).forward_for_export if forward_method: old_forward_method = type(self).forward type(self).forward = forward_method # Set module to eval mode if set_eval: self.eval() format = self.get_format(output) self._prepare_for_export() with torch.jit.optimized_execution(True): jitted_model = None if try_script: try: jitted_model = torch.jit.script(self) except Exception as e: print("jit.script() failed!", e) if input_example is None: input_example = self.input_module.input_example() with torch.jit.optimized_execution(True): if format == ExportFormat.TORCHSCRIPT: if isinstance(input_example, Dict): input_example = tuple(input_example.values()) if jitted_model is None: jitted_model = torch.jit.trace( self, input_example, strict=False, optimize=True, check_trace=check_trace, check_tolerance=check_tolerance, ) jitted_model.save(output) assert os.path.exists(output) elif format == ExportFormat.ONNX: if jitted_model is None: jitted_model = self if output_example is None: if isinstance(input_example, tuple): output_example = self.forward(*input_example) else: output_example = self.forward(input_example) input_names = self.input_module.get_input_names(input_example) output_names = self.output_module.get_output_names(output_example) # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None and use_dynamic_axes: dynamic_axes = self.input_module.get_input_dynamic_axes(input_names) dynamic_axes = {**dynamic_axes, **self.output_module.get_output_dynamic_axes(output_names)} if isinstance(input_example, Dict): input_example = tuple(input_example.values()) torch.onnx.export( jitted_model, input_example, output, input_names=input_names, output_names=output_names, verbose=verbose, export_params=export_params, do_constant_folding=do_constant_folding, keep_initializers_as_inputs=keep_initializers_as_inputs, dynamic_axes=dynamic_axes, opset_version=onnx_opset_version, example_outputs=output_example, ) # Verify the model can be read, and is valid onnx_model = onnx.load(output) onnx.checker.check_model(onnx_model, full_check=True) if do_constant_folding: if not ONNX_GRAPHSURGEON_AVAILABLE: logging.info( f"onnx-graphsurgeon module is not instlled." "That may result in suboptimal optimization of exported ONNX graph (including unneeded DOUBLE initializers)." "Please follow the instructions available at:" "https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon" "to install onnx-graphsurgeon from source to improve exported graph." ) else: # This pass is to remove/recast certain constants that are generated as 'double' # Those constants break ONNX -> TRT conversion (TRT does not support 'double' as of 7.2) # Can probably be removed once TRT has automatic downcast for double. # However, it may still be useful even then as it seems to always make the graph shorter. graph = gs.import_onnx(onnx_model) onnx_model = gs.export_onnx(graph.fold_constants().cleanup()) onnx.checker.check_model(onnx_model, full_check=True) onnx.save(onnx_model, output) else: raise ValueError(f'Encountered unknown export format {format}.') finally: typecheck.set_typecheck_enabled(enabled=True) if forward_method: type(self).forward = old_forward_method return exported
def export( self, output: str, input_example=None, output_example=None, verbose=False, export_params=True, do_constant_folding=True, keep_initializers_as_inputs=False, onnx_opset_version: int = 12, try_script: bool = False, set_eval: bool = True, check_trace: bool = True, use_dynamic_axes: bool = True, ): try: # Disable typechecks typecheck.set_typecheck_enabled(enabled=False) # Set module to eval mode if set_eval: self.eval() filename, file_extension = os.path.splitext(output) if file_extension not in _EXT_DICT.keys(): raise ValueError( f"Export file {output} extension does not correspond to any export format!" ) format = _EXT_DICT[file_extension] self._prepare_for_export() if input_example is not None: _in_example = input_example else: _in_example = self.input_example() if output_example is None: _out_example = self.forward(*_in_example) if not (hasattr(self, 'input_types') and hasattr(self, 'output_types')): raise NotImplementedError( 'For export to work you must define input and output types' ) input_names = list(self.input_types.keys()) output_names = list(self.output_types.keys()) # dynamic axis is a mapping from input/output_name => list of "dynamic" indices dynamic_axes = defaultdict(list) # extract dynamic axes and remove unnecessary inputs/outputs # for input_ports for _name, ntype in self.input_types.items(): if _name in self.disabled_deployment_input_names: input_names.remove(_name) continue if use_dynamic_axes: dynamic_axes = { **dynamic_axes, **self._extract_dynamic_axes(_name, ntype) } # for output_ports for _name, ntype in self.output_types.items(): if _name in self.disabled_deployment_output_names: output_names.remove(_name) continue if use_dynamic_axes: dynamic_axes = { **dynamic_axes, **self._extract_dynamic_axes(_name, ntype) } if len(dynamic_axes) == 0: dynamic_axes = None with torch.jit.optimized_execution(True): jitted_model = None if try_script: try: jitted_model = torch.jit.script(self) except Exception as e: print("jit.script() failed!", e) if _in_example is None: raise ValueError( f'Example input is None, but jit.script() has failed or not tried' ) if isinstance(_in_example, Dict): _in_example = tuple(_in_example.values()) if jitted_model is None: jitted_model = torch.jit.trace(self, _in_example, check_trace=check_trace) if format == ExportFormat.TORCHSCRIPT: jitted_model.save(output) assert os.path.exists(output) elif format == ExportFormat.ONNX: if _out_example is None: if isinstance(_in_example, tuple): _out_example = self.forward(*_in_example) else: _out_example = self.forward(_in_example) torch.onnx.export( jitted_model, _in_example, output, input_names=input_names, output_names=output_names, verbose=verbose, export_params=export_params, do_constant_folding=do_constant_folding, keep_initializers_as_inputs=keep_initializers_as_inputs, dynamic_axes=dynamic_axes, opset_version=onnx_opset_version, example_outputs=_out_example, ) # Verify the model can be read, and is valid onnx_model = onnx.load(output) onnx.checker.check_model(onnx_model, full_check=True) return onnx_model else: raise ValueError( f'Encountered unknown export format {format}.') finally: typecheck.set_typecheck_enabled(enabled=True) return [output] # Subclasses may create more than one file.
def export( self, output: str, input_example=None, output_example=None, verbose=False, export_params=True, do_constant_folding=True, keep_initializers_as_inputs=False, onnx_opset_version: int = 12, try_script: bool = False, set_eval: bool = True, check_trace: bool = True, use_dynamic_axes: bool = True, ): try: # Disable typechecks typecheck.set_typecheck_enabled(enabled=False) # Set module to eval mode if set_eval: self.eval() format = self.get_format(output) self._prepare_for_export() if input_example is not None: _in_example = input_example else: _in_example = self.input_example() if output_example is None: _out_example = self.forward(*_in_example) if not (hasattr(self, 'input_types') and hasattr(self, 'output_types')): raise NotImplementedError( 'For export to work you must define input and output types' ) input_names = list(self.input_types.keys()) output_names = list(self.output_types.keys()) # dynamic axis is a mapping from input/output_name => list of "dynamic" indices dynamic_axes = defaultdict(list) # extract dynamic axes and remove unnecessary inputs/outputs # for input_ports for _name, ntype in self.input_types.items(): if _name in self.disabled_deployment_input_names: input_names.remove(_name) continue if use_dynamic_axes: dynamic_axes = { **dynamic_axes, **self._extract_dynamic_axes(_name, ntype) } # for output_ports for _name, ntype in self.output_types.items(): if _name in self.disabled_deployment_output_names: output_names.remove(_name) continue if use_dynamic_axes: dynamic_axes = { **dynamic_axes, **self._extract_dynamic_axes(_name, ntype) } if len(dynamic_axes) == 0: dynamic_axes = None with torch.jit.optimized_execution(True): jitted_model = None if try_script: try: jitted_model = torch.jit.script(self) except Exception as e: print("jit.script() failed!", e) if _in_example is None: raise ValueError( f'Example input is None, but jit.script() has failed or not tried' ) if isinstance(_in_example, Dict): _in_example = tuple(_in_example.values()) if jitted_model is None: jitted_model = torch.jit.trace(self, _in_example, check_trace=check_trace) if format == ExportFormat.TORCHSCRIPT: jitted_model.save(output) assert os.path.exists(output) elif format == ExportFormat.ONNX: if _out_example is None: if isinstance(_in_example, tuple): _out_example = self.forward(*_in_example) else: _out_example = self.forward(_in_example) torch.onnx.export( jitted_model, _in_example, output, input_names=input_names, output_names=output_names, verbose=verbose, export_params=export_params, do_constant_folding=do_constant_folding, keep_initializers_as_inputs=keep_initializers_as_inputs, dynamic_axes=dynamic_axes, opset_version=onnx_opset_version, example_outputs=_out_example, ) # Verify the model can be read, and is valid onnx_model = onnx.load(output) onnx.checker.check_model(onnx_model, full_check=True) if do_constant_folding: if not ONNX_GRAPHSURGEON_AVAILABLE: logging.info( f"onnx-graphsurgeon module is not instlled." "That may result in suboptimal optimization of exported ONNX graph (including unneeded DOUBLE initializers)." "Please follow the instructions available at:" "https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon" "to install onnx-graphsurgeon from source to improve exported graph." ) else: # This pass is to remove/recast certain constants that are generated as 'double' # Those constants break ONNX -> TRT conversion (TRT does not support 'double' as of 7.2) # Can probably be removed once TRT has automatic downcast for double. # However, it may still be useful even then as it seems to always make the graph shorter. graph = gs.import_onnx(onnx_model) onnx_model = gs.export_onnx( graph.fold_constants().cleanup()) onnx.checker.check_model(onnx_model, full_check=True) onnx.save(onnx_model, output) return onnx_model else: raise ValueError( f'Encountered unknown export format {format}.') finally: typecheck.set_typecheck_enabled(enabled=True) return [output] # Subclasses may create more than one file.
def export( self, output: str, input_example=None, output_example=None, verbose=False, export_params=True, do_constant_folding=True, keep_initializers_as_inputs=False, onnx_opset_version: int = 12, try_script: bool = False, set_eval: bool = True, check_trace: bool = True, use_dynamic_axes: bool = True, dynamic_axes=None, check_tolerance=0.01, forward_method=None, ): my_args = locals() del my_args['self'] qual_name = self.__module__ + '.' + self.__class__.__qualname__ output_descr = qual_name + ' exported to ONNX' try: # Disable typechecks typecheck.set_typecheck_enabled(enabled=False) # Allow user to completely override forward method to export if forward_method is None and hasattr(type(self), "forward_for_export"): forward_method = type(self).forward_for_export if forward_method: old_forward_method = type(self).forward type(self).forward = forward_method # Set module to eval mode if set_eval: self.eval() format = self.get_format(output) if input_example is None: input_example = self.input_module.input_example() if isinstance(input_example, Dict): input_example = tuple(input_example.values()) my_args['input_example'] = input_example self._prepare_for_export(**my_args) if output_example is None: if isinstance(input_example, tuple): output_example = self.forward(*input_example) else: output_example = self.forward(input_example) input_names = self.input_module.get_input_names(input_example) output_names = self.output_module.get_output_names(output_example) with torch.jit.optimized_execution(True), torch.no_grad(): jitted_model = None if try_script: try: jitted_model = torch.jit.script(self) except Exception as e: print("jit.script() failed!", e) with torch.jit.optimized_execution(True), torch.no_grad(): if format == ExportFormat.TORCHSCRIPT: if jitted_model is None: jitted_model = torch.jit.trace( self, input_example, strict=False, optimize=True, check_trace=check_trace, check_tolerance=check_tolerance, ) if verbose: print(jitted_model.code) jitted_model.save(output) assert os.path.exists(output) elif format == ExportFormat.ONNX: if jitted_model is None: jitted_model = self # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None and use_dynamic_axes: dynamic_axes = self.input_module.get_input_dynamic_axes( input_names) dynamic_axes = { **dynamic_axes, **self.output_module.get_output_dynamic_axes(output_names) } torch.onnx.export( jitted_model, input_example, output, input_names=input_names, output_names=output_names, verbose=verbose, export_params=export_params, do_constant_folding=do_constant_folding, keep_initializers_as_inputs=keep_initializers_as_inputs, dynamic_axes=dynamic_axes, opset_version=onnx_opset_version, example_outputs=output_example, ) # Verify the model can be read, and is valid onnx_model = onnx.load(output) onnx.checker.check_model(onnx_model, full_check=True) else: raise ValueError( f'Encountered unknown export format {format}.') finally: typecheck.set_typecheck_enabled(enabled=True) if forward_method: type(self).forward = old_forward_method return ([output], [output_descr])
def nemo_export(argv): args = get_args(argv) loglevel = logging.INFO # assuming loglevel is bound to the string value obtained from the # command line argument. Convert to upper case to allow the user to # specify --log=DEBUG or --log=debug if args.verbose is not None: numeric_level = getattr(logging, args.verbose.upper(), None) if not isinstance(numeric_level, int): raise ValueError('Invalid log level: %s' % numeric_level) loglevel = numeric_level logger = logging.getLogger(__name__) if logger.handlers: for handler in logger.handlers: logger.removeHandler(handler) logging.basicConfig(level=loglevel, format='%(asctime)s [%(levelname)s] %(message)s') logging.info("Logging level set to {}".format(loglevel)) """Convert a .nemo saved model into .riva Riva input format.""" nemo_in = args.source out = args.out # Create a PL trainer object which is required for restoring Megatron models cfg_trainer = TrainerConfig( gpus=1, accelerator="ddp", num_nodes=1, # Need to set the following two to False as ExpManager will take care of them differently. logger=False, checkpoint_callback=False, ) trainer = Trainer(cfg_trainer) logging.info("Restoring NeMo model from '{}'".format(nemo_in)) try: with torch.inference_mode(): # Restore instance from .nemo file using generic model restore_from model = ModelPT.restore_from(restore_path=nemo_in, trainer=trainer) except Exception as e: logging.error( "Failed to restore model from NeMo file : {}. Please make sure you have the latest NeMo package installed with [all] dependencies.".format( nemo_in ) ) raise e logging.info("Model {} restored from '{}'".format(model.cfg.target, nemo_in)) if not isinstance(model, Exportable): logging.error("Your NeMo model class ({}) is not Exportable.".format(model.cfg.target)) sys.exit(1) typecheck.set_typecheck_enabled(enabled=False) try: # # Add custom export parameters here # in_args = {} if args.max_batch is not None: in_args["max_batch"] = args.max_batch if args.max_dim is not None: in_args["max_dim"] = args.max_dim autocast = nullcontext model = model.to(device=args.device) model.eval() with torch.inference_mode(): input_example = model.input_module.input_example(**in_args) if args.autocast: autocast = torch.cuda.amp.autocast with autocast(), torch.inference_mode(): logging.info(f"Getting output example") input_list, input_dict = parse_input_example(input_example) output_example = forward_method(model)(*input_list, **input_dict) logging.info(f"Exporting model with autocast={args.autocast}") input_names = model.input_names output_names = model.output_names _, descriptions = model.export( out, check_trace=False, input_example=input_example, onnx_opset_version=args.onnx_opset, verbose=args.verbose, ) except Exception as e: logging.error( "Export failed. Please make sure your NeMo model class ({}) has working export() and that you have the latest NeMo package installed with [all] dependencies.".format( model.cfg.target ) ) raise e logging.info("Successfully exported to {}".format(out)) del model if args.runtime_check: verify_runtime(out, input_list, input_dict, input_names, output_names, output_example)
def export( self, output: str, input_example=None, output_example=None, verbose=False, export_params=True, do_constant_folding=True, keep_initializers_as_inputs=False, onnx_opset_version: int = 13, try_script: bool = False, set_eval: bool = True, check_trace: bool = False, use_dynamic_axes: bool = True, dynamic_axes=None, check_tolerance=0.01, ): my_args = locals() del my_args['self'] qual_name = self.__module__ + '.' + self.__class__.__qualname__ output_descr = qual_name + ' exported to ONNX' try: # Disable typechecks typecheck.set_typecheck_enabled(enabled=False) # Allow user to completely override forward method to export forward_method, old_forward_method = self._wrap_forward_method() # Set module to eval mode self._set_eval(set_eval) format = self.get_format(output) if input_example is None: input_example = self._get_input_example() my_args['input_example'] = input_example # Run (posibly overridden) prepare method before calling forward() self._prepare_for_export(**my_args) input_list, input_dict = self._setup_input_example(input_example) input_names = self._process_input_names() output_names = self._process_output_names() output_example = self.forward(*input_list, **input_dict) with torch.jit.optimized_execution(True), torch.no_grad(): jitted_model = self._try_jit_compile_model(self, try_script) if format == ExportFormat.TORCHSCRIPT: self._export_torchscript(jitted_model, output, input_dict, input_list, check_trace, check_tolerance, verbose) elif format == ExportFormat.ONNX: self._export_onnx( jitted_model, input_example, output_example, input_names, output_names, use_dynamic_axes, do_constant_folding, dynamic_axes, output, export_params, keep_initializers_as_inputs, onnx_opset_version, verbose, ) # Verify the model can be read, and is valid self._verify_onnx_export(output, output_example, input_list, input_dict, input_names, check_tolerance, check_trace) else: raise ValueError( f'Encountered unknown export format {format}.') finally: typecheck.set_typecheck_enabled(enabled=True) if forward_method: type(self).forward = old_forward_method return ([output], [output_descr])
def export( self, output: str, input_example=None, output_example=None, verbose=False, export_params=True, do_constant_folding=True, keep_initializers_as_inputs=False, onnx_opset_version: int = 13, try_script: bool = False, set_eval: bool = True, check_trace: bool = False, use_dynamic_axes: bool = True, dynamic_axes=None, check_tolerance=0.01, ): my_args = locals() del my_args['self'] qual_name = self.__module__ + '.' + self.__class__.__qualname__ output_descr = qual_name + ' exported to ONNX' try: # Disable typechecks typecheck.set_typecheck_enabled(enabled=False) # Set module to eval mode self._set_eval(set_eval) format = self.get_format(output) # Assign special flag for RNNT export of encoder if not hasattr(self.input_module, '_rnnt_export'): raise ValueError( f"{self.input_module.__class__.__name__} must have a bool attribute `_rnnt_export`, " f"which is necessary for RNNT export.") if not hasattr(self.output_module, '_rnnt_export'): raise ValueError( f"{self.output_module.__class__.__name__} must have a bool attribute `_rnnt_export`, " f"which is necessary for RNNT export.") if not hasattr(self.joint_module, '_rnnt_export'): raise ValueError( f"{self.joint_module.__class__.__name__} must have a bool attribute `_rnnt_export`, " f"which is necessary for RNNT export.") self.input_module._rnnt_export = True self.output_module._rnnt_export = True self.joint_module._rnnt_export = True if input_example is None: encoder_examples, decoder_examples = self._get_input_example() input_example = [encoder_examples, decoder_examples] else: assert type(input_example) in ( list, tuple) and len(input_example) == 2, ( "input_example must " "be a list of two tensors," "for encoder and decoder input") encoder_examples, decoder_examples = input_example if output_example is not None: assert type(output_example) in ( list, tuple) and len(output_example) == 2, ( "output_example must " "be a list of two tensors," "for encoder and decoder+joint" " output") encoder_output_example, decoder_joint_output_example = output_example else: encoder_output_example = None decoder_joint_output_example = None my_args['input_example'] = input_example # Run (posibly overridden) prepare method before calling forward() self._prepare_for_export(**my_args) encoder_input_list, encoder_input_dict = self._setup_input_example( encoder_examples) decoder_input_list, decoder_input_dict = self._setup_input_example( decoder_examples) encoder_input_names, decoder_input_names = self._process_input_names( ) encoder_output_names, decoder_output_names, joint_output_names = self._process_output_names( ) # process decoder states; by convension states must be the last in the list and must be wrapped in a tuple ( decoder_input_list, decoder_input_names, input_state_names, num_states, output_state_names, state_names, ) = self._process_states_names(decoder_input_list, decoder_input_names) with torch.jit.optimized_execution(True), torch.no_grad(): # Encoder export encoder_jitted_model = self._try_jit_compile_model( self.input_module, try_script) if format == exportable.ExportFormat.TORCHSCRIPT: raise NotImplementedError() # # Allow user to completely override forward method to export # forward_method, original_forward_method = self._wrap_forward_method('encoder') # encoder_output_example = self.forward(*encoder_input_list, **encoder_input_dict) # # self._export_torchscript( # encoder_jitted_model, # self._augment_output_filename(output, "Encoder"), # encoder_input_dict, # encoder_input_list, # check_trace, # check_tolerance, # verbose, # ) # # self._export_flag_module = 'decoder_joint' # # # Extract just the encoder logits and remove the encoder lengths # if type(encoder_output_example) in (list, tuple): # encoder_output_example = encoder_output_example[0] # # encoder_decoder_input_list = [encoder_output_example] + list(decoder_input_list) # encoder_decoder_input_dict = decoder_input_dict # # encoder_decoder_input_list = tuple(encoder_decoder_input_list) # # # Allow user to completely override forward method to export # forward_method, _ = self._wrap_forward_method('decoder_joint') # decoder_joint_output_example = self.forward(*encoder_decoder_input_list, **encoder_decoder_input_dict) # decoder_joint_output_example = tuple(decoder_joint_output_example) # # # Resolve output states # if num_states > 0: # if type(decoder_joint_output_example[-1]) == tuple: # raise TypeError("Since input states are available, forward must emit flattened states") # # # remove the name of the states # logging.info( # f"Replacing output state name {decoder_output_names[-1]} with {str(output_state_names)}" # ) # decoder_output_names = decoder_output_names[:-1] # # self._export_torchscript( # None, # self._augment_output_filename(output, "Decoder-Joint"), # encoder_decoder_input_dict, # encoder_decoder_input_list, # check_trace, # check_tolerance, # verbose, # ) elif format == exportable.ExportFormat.ONNX: # Allow user to completely override forward method to export forward_method, original_forward_method = self._wrap_forward_method( 'encoder') encoder_output_example = self.forward( *encoder_input_list, **encoder_input_dict) self._export_flag_module = 'encoder' self._export_onnx( encoder_jitted_model, encoder_examples, encoder_output_example, encoder_input_names, encoder_output_names, use_dynamic_axes, do_constant_folding, dynamic_axes, self._augment_output_filename(output, "Encoder"), export_params, keep_initializers_as_inputs, onnx_opset_version, verbose, ) # Verify the model can be read, and is valid self._verify_onnx_export( self._augment_output_filename(output, "Encoder"), encoder_output_example, encoder_input_list, encoder_input_dict, encoder_input_names, encoder_output_names, check_tolerance, check_trace, ) self._export_flag_module = 'decoder_joint' # Extract just the encoder logits and remove the encoder lengths if type(encoder_output_example) in (list, tuple): encoder_output_example = encoder_output_example[0] encoder_decoder_input_list = [encoder_output_example ] + list(decoder_input_list) encoder_decoder_input_dict = decoder_input_dict encoder_decoder_input_list = tuple( encoder_decoder_input_list) # Allow user to completely override forward method to export forward_method, _ = self._wrap_forward_method( 'decoder_joint') decoder_joint_output_example = self.forward( *encoder_decoder_input_list, **encoder_decoder_input_dict) decoder_joint_output_example = tuple( decoder_joint_output_example) # Resolve output states if num_states > 0: if type(decoder_joint_output_example[-1]) == tuple: raise TypeError( "Since input states are available, forward must emit flattened states" ) # remove the name of the states logging.info( f"Replacing output state name {decoder_output_names[-1]} with {str(output_state_names)}" ) decoder_output_names = decoder_output_names[:-1] self._export_onnx( None, encoder_decoder_input_list, decoder_joint_output_example, self._join_input_output_names(["encoder_outputs"], decoder_input_names, input_state_names), self._join_input_output_names(joint_output_names, decoder_output_names, output_state_names), use_dynamic_axes, do_constant_folding, dynamic_axes, self._augment_output_filename(output, "Decoder-Joint"), export_params, keep_initializers_as_inputs, onnx_opset_version, verbose, ) # Verify the model can be read, and is valid self._verify_onnx_export( self._augment_output_filename(output, "Decoder-Joint"), decoder_joint_output_example, encoder_decoder_input_list, encoder_decoder_input_dict, self._join_input_output_names(["encoder_outputs"], decoder_input_names, input_state_names), self._join_input_output_names(joint_output_names, decoder_output_names, output_state_names), check_tolerance, check_trace, ) else: raise ValueError( f'Encountered unknown export format {format}.') except Exception as e: raise e finally: typecheck.set_typecheck_enabled(enabled=True) logging.warning( "PyTorch Model has been significantly modified. In order to utilize model, delete this " "instance and create a new model.") # replace forward method with original forward method type(self).forward = original_forward_method # Reset special flag for RNNT export of encoder self.input_module._rnnt_export = False self.output_module._rnnt_export = False self.joint_module._rnnt_export = False return ([output], [output_descr])
def export( self, output: str, input_example=None, output_example=None, verbose=False, export_params=True, do_constant_folding=True, keep_initializers_as_inputs=False, onnx_opset_version: int = 13, try_script: bool = False, set_eval: bool = True, check_trace: bool = False, use_dynamic_axes: bool = True, dynamic_axes=None, check_tolerance=0.01, ): my_args = locals() del my_args['self'] qual_name = self.__module__ + '.' + self.__class__.__qualname__ output_descr = qual_name + ' exported to ONNX' try: # Disable typechecks typecheck.set_typecheck_enabled(enabled=False) # Allow user to completely override forward method to export if hasattr(type(self), "forward_for_export"): forward_method = type(self).forward_for_export old_forward_method = type(self).forward type(self).forward = forward_method else: forward_method = None # Set module to eval mode if set_eval: self.freeze() self.input_module.freeze() self.output_module.freeze() format = self.get_format(output) if input_example is None: input_example = self.input_module.input_example() my_args['input_example'] = input_example # Run (posibly overridden) prepare method before calling forward() self._prepare_for_export(**my_args) input_list = list(input_example) input_dict = {} # process possible kwargs if isinstance(input_list[-1], dict): input_dict = input_list[-1] input_list = input_list[:-1] input_names = get_input_names(self.input_module) # remove unnecessary inputs for input_ports for name in self.disabled_deployment_input_names: input_names.remove(name) output_names = get_output_names(self.output_module) # remove unnecessary inputs for input_ports for name in self.disabled_deployment_output_names: output_names.remove(name) output_example = self.forward(*input_list, **input_dict) with torch.jit.optimized_execution(True), torch.no_grad(): jitted_model = None if try_script: try: jitted_model = torch.jit.script(self) except Exception as e: print("jit.script() failed!", e) if format == ExportFormat.TORCHSCRIPT: if jitted_model is None: jitted_model = torch.jit.trace_module( self, { "forward": tuple(input_list) + tuple(input_dict.values()) }, strict=False, optimize=True, check_trace=check_trace, check_tolerance=check_tolerance, ) if verbose: print(jitted_model.code) jitted_model.save(output) assert os.path.exists(output) elif format == ExportFormat.ONNX: if jitted_model is None: jitted_model = self # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None and use_dynamic_axes: dynamic_axes = get_input_dynamic_axes( self.input_module, input_names) dynamic_axes = { **dynamic_axes, **get_output_dynamic_axes(self.output_module, output_names) } torch.onnx.export( jitted_model, input_example, output, input_names=input_names, output_names=output_names, verbose=verbose, export_params=export_params, do_constant_folding=do_constant_folding, keep_initializers_as_inputs=keep_initializers_as_inputs, dynamic_axes=dynamic_axes, opset_version=onnx_opset_version, example_outputs=output_example, ) # Verify the model can be read, and is valid onnx_model = onnx.load(output) onnx.checker.check_model(onnx_model, full_check=True) test_runtime = check_trace if test_runtime: try: import onnxruntime except (ImportError, ModuleNotFoundError): test_runtime = False logging.warning( f"ONNX generated at {output}, not verified - please install onnxruntime.\n" ) if test_runtime: sess = onnxruntime.InferenceSession( onnx_model.SerializeToString()) ort_out = sess.run( None, to_onnxrt_input(input_names, input_list, input_dict)) all_good = True for out_name, out in enumerate(ort_out): expected = output_example[out_name].cpu() if not torch.allclose(torch.from_numpy(out), expected, rtol=check_tolerance, atol=100 * check_tolerance): all_good = False logging.info( f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{out}" ) status = "SUCCESS" if all_good else "FAIL" logging.info( f"ONNX generated at {output} verified with onnxruntime : " + status) else: raise ValueError( f'Encountered unknown export format {format}.') finally: typecheck.set_typecheck_enabled(enabled=True) if forward_method: type(self).forward = old_forward_method return ([output], [output_descr])