Exemple #1
0
    def test_export_to_onnx(self):
        model = WaveGlowModel(wcfg)
        model = model.cuda().half()
        typecheck.set_typecheck_enabled(enabled=False)
        with tempfile.TemporaryDirectory() as tmpdir, model.nemo_infer():
            # Generate filename in the temporary directory.
            # TODO: Change `waveglow.ts` to `waveglow.onnx` for > 21.05
            tmp_file_name = os.path.join("waveglow.ts")

            n_mels = 80
            # Test export.
            inp = input_example(n_mels)
            inp1 = taco2wg(*inp)
            inp2 = inp1
            res1 = model.waveglow(*inp1)
            res2 = model.waveglow(*inp2)
            assert torch.allclose(res1, res2, rtol=0.01, atol=0.1)
            WaveGlowModel.forward_for_export = forward_wrapper
            model.export(
                tmp_file_name,
                verbose=True,
                input_example=inp,
                output_example=res1,
                try_script=False,
                check_trace=False,
                do_constant_folding=True,
                dynamic_axes={"spec": [0], "z": [0], "audio": [0]},
            )
Exemple #2
0
    def test_export_to_onnx(self):
        model = WaveGlowModel(wcfg)
        model = model.cuda().half()
        typecheck.set_typecheck_enabled(enabled=False)
        with tempfile.TemporaryDirectory() as tmpdir, model.nemo_infer():
            tmp_file_name = os.path.join(tmpdir, "waveglow.onnx")

            n_mels = 80
            # Test export.
            inp = input_example(n_mels)
            inp1 = taco2wg(*inp)
            inp2 = inp1
            res1 = model.waveglow(*inp1)
            res2 = model.waveglow(*inp2)
            assert torch.allclose(res1, res2, rtol=0.01, atol=0.1)
            WaveGlowModel.forward_for_export = forward_wrapper
            model.export(
                tmp_file_name,
                verbose=False,
                input_example=inp,
                output_example=res1,
                try_script=False,
                check_trace=False,
                do_constant_folding=True,
            )
Exemple #3
0
    def test_get_model(self):
        model = get_pretrained_bert_345m_uncased_model()
        assert isinstance(model, nemo_nlp.modules.MegatronBertEncoder)

        typecheck.set_typecheck_enabled(enabled=False)
        inp = model.input_example()
        out = model.forward(*inp)
        typecheck.set_typecheck_enabled(enabled=True)
Exemple #4
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        super().__init__(cfg=cfg, trainer=trainer)
        typecheck.set_typecheck_enabled(enabled=False)

        cfg = self._cfg
        self.vocab = AudioToCharWithDursF0Dataset.make_vocab(
            **cfg.train_ds.dataset.vocab)
        self.embed = nn.Embedding(len(self.vocab.labels), cfg.d_char)
        self.model = instantiate(cfg.model)
        d_out = cfg.model.jasper[-1].filters
        self.proj = nn.Conv1d(d_out, 1, kernel_size=1)
Exemple #5
0
    def test_export_to_onnx(self):
        model = WaveGlowModel(wcfg)
        # model = WaveGlowModel.restore_from("../WaveGlow-22050Hz-268M.nemo")
        model = model.cuda().half()
        typecheck.set_typecheck_enabled(enabled=False)
        with tempfile.TemporaryDirectory() as tmpdir, model.nemo_infer():
            # Generate filename in the temporary directory.
            tmp_file_name = os.path.join("waveglow.onnx")

            n_mels = 80
            # Test export.
            inp = input_example(n_mels)
            inp1 = taco2wg(**inp)
            inp2 = inp1
            res1 = model.waveglow(*inp1)
            res2 = model.waveglow(*inp2)
            assert torch.allclose(res1, res2, rtol=0.01, atol=0.1)

            model.export(
                tmp_file_name,
                verbose=True,
                input_example=inp,
                output_example=res1,
                try_script=False,
                check_trace=False,
                do_constant_folding=True,
                dynamic_axes={
                    "spec": [0],
                    "z": [0],
                    "audio": [0]
                },
                forward_method=forward_wrapper,
            )

            try:
                test_runtime = True
                import onnxruntime
            except (ImportError, ModuleNotFoundError):
                test_runtime = False
            if test_runtime:
                omodel = onnx.load(tmp_file_name)
                output_names = ['audio']
                sess = onnxruntime.InferenceSession(omodel.SerializeToString())
                output = sess.run(None, {
                    "spec": inp["spec"].cpu().numpy(),
                    "z": inp["z"].cpu().numpy()
                })[0]
                assert torch.allclose(torch.from_numpy(output),
                                      res2.cpu(),
                                      rtol=1,
                                      atol=100)
Exemple #6
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        super().__init__(cfg=cfg, trainer=trainer)
        typecheck.set_typecheck_enabled(enabled=False)

        cfg = self._cfg
        self.vocab = AudioToCharWithDursF0Dataset.make_vocab(
            **cfg.train_ds.dataset.vocab)
        self.blanking = cfg.train_ds.dataset.blanking
        self.preprocessor = instantiate(cfg.preprocessor)
        self.embed = GaussianEmbedding(self.vocab, cfg.d_char)
        self.norm_f0 = MaskedInstanceNorm1d(1)
        self.res_f0 = StyleResidual(cfg.d_char, 1, kernel_size=3)
        self.model = instantiate(cfg.model)
        d_out = cfg.model.jasper[-1].filters
        self.proj = nn.Conv1d(d_out, cfg.n_mels, kernel_size=1)
Exemple #7
0
    def export(
        self,
        output: str,
        input_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version=None,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = False,
        use_dynamic_axes: bool = True,
        dynamic_axes=None,
        check_tolerance=0.01,
    ):
        my_args = locals()
        my_args.pop('self')

        exportables = []
        for m in self.modules():
            if isinstance(m, Exportable):
                exportables.append(m)

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        format = self.get_format(output)
        output_descr = f"{qual_name} exported to {format}"

        # Pytorch's default for None is too low, can't pass through
        if onnx_opset_version is None:
            onnx_opset_version = 13

        try:
            # Disable typechecks
            typecheck.set_typecheck_enabled(enabled=False)

            # Allow user to completely override forward method to export
            forward_method, old_forward_method = self._wrap_forward_method()

            # Set module to eval mode
            if set_eval:
                self.eval()

            if input_example is None:
                input_example = self._get_input_example()

            # Remove i/o examples from args we propagate to enclosed Exportables
            my_args.pop('output')
            my_args.pop('input_example')

            # Run (posibly overridden) prepare methods before calling forward()
            for ex in exportables:
                ex._prepare_for_export(**my_args)
            self._prepare_for_export(output=output,
                                     input_example=input_example,
                                     **my_args)

            input_list, input_dict = self._setup_input_example(input_example)
            input_names = self._process_input_names()
            output_names = self._process_output_names()
            output_example = tuple(self.forward(*input_list, **input_dict))

            with torch.jit.optimized_execution(True), torch.no_grad():
                jitted_model = None
                if try_script:
                    try:
                        jitted_model = torch.jit.script(self)
                    except Exception as e:
                        logging.error(f"jit.script() failed!\{e}")

                if format == ExportFormat.TORCHSCRIPT:
                    if jitted_model is None:
                        jitted_model = torch.jit.trace_module(
                            self,
                            {
                                "forward":
                                tuple(input_list) + tuple(input_dict.values())
                            },
                            strict=False,
                            check_trace=check_trace,
                            check_tolerance=check_tolerance,
                        )
                    if verbose:
                        logging.info(f"JIT code:\n{jitted_model.code}")
                    jitted_model.save(output)
                    assert os.path.exists(output)
                elif format == ExportFormat.ONNX:
                    if jitted_model is None:
                        jitted_model = self

                    # dynamic axis is a mapping from input/output_name => list of "dynamic" indices
                    if dynamic_axes is None and use_dynamic_axes:
                        dynamic_axes = get_input_dynamic_axes(
                            self.input_module, input_names)
                        dynamic_axes = {
                            **dynamic_axes,
                            **get_output_dynamic_axes(self.output_module, output_names)
                        }

                    torch.onnx.export(
                        jitted_model,
                        input_example,
                        output,
                        input_names=input_names,
                        output_names=output_names,
                        verbose=verbose,
                        export_params=export_params,
                        do_constant_folding=do_constant_folding,
                        keep_initializers_as_inputs=keep_initializers_as_inputs,
                        dynamic_axes=dynamic_axes,
                        opset_version=onnx_opset_version,
                    )

                    # Verify the model can be read, and is valid
                    onnx_model = onnx.load(output)
                    onnx.checker.check_model(onnx_model, full_check=True)

                    if check_trace:
                        self._verify_runtime(
                            onnx_model,
                            input_list,
                            input_dict,
                            input_names,
                            output_names,
                            output_example,
                            output,
                            check_tolerance,
                        )
                else:
                    raise ValueError(
                        f'Encountered unknown export format {format}.')
        finally:
            typecheck.set_typecheck_enabled(enabled=True)
            if forward_method:
                type(self).forward = old_forward_method
            self._export_teardown()
        return ([output], [output_descr])
Exemple #8
0
    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
        dynamic_axes=None,
        check_tolerance=0.01,
        forward_method=None,
    ):

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output_descr = qual_name + ' exported to ONNX'
        exported = ([output], [output_descr])

        try:
            # Disable typechecks
            typecheck.set_typecheck_enabled(enabled=False)

            # Allow user to completely override forward method to export
            if forward_method is None and hasattr(type(self), "forward_for_export"):
                forward_method = type(self).forward_for_export

            if forward_method:
                old_forward_method = type(self).forward
                type(self).forward = forward_method

            # Set module to eval mode
            if set_eval:
                self.eval()

            format = self.get_format(output)
            self._prepare_for_export()

            with torch.jit.optimized_execution(True):
                jitted_model = None
                if try_script:
                    try:
                        jitted_model = torch.jit.script(self)
                    except Exception as e:
                        print("jit.script() failed!", e)

            if input_example is None:
                input_example = self.input_module.input_example()

            with torch.jit.optimized_execution(True):
                if format == ExportFormat.TORCHSCRIPT:
                    if isinstance(input_example, Dict):
                        input_example = tuple(input_example.values())

                    if jitted_model is None:
                        jitted_model = torch.jit.trace(
                            self,
                            input_example,
                            strict=False,
                            optimize=True,
                            check_trace=check_trace,
                            check_tolerance=check_tolerance,
                        )
                    jitted_model.save(output)
                    assert os.path.exists(output)

                elif format == ExportFormat.ONNX:
                    if jitted_model is None:
                        jitted_model = self
                    if output_example is None:
                        if isinstance(input_example, tuple):
                            output_example = self.forward(*input_example)
                        else:
                            output_example = self.forward(input_example)

                    input_names = self.input_module.get_input_names(input_example)
                    output_names = self.output_module.get_output_names(output_example)

                    # dynamic axis is a mapping from input/output_name => list of "dynamic" indices
                    if dynamic_axes is None and use_dynamic_axes:
                        dynamic_axes = self.input_module.get_input_dynamic_axes(input_names)
                        dynamic_axes = {**dynamic_axes, **self.output_module.get_output_dynamic_axes(output_names)}

                    if isinstance(input_example, Dict):
                        input_example = tuple(input_example.values())

                    torch.onnx.export(
                        jitted_model,
                        input_example,
                        output,
                        input_names=input_names,
                        output_names=output_names,
                        verbose=verbose,
                        export_params=export_params,
                        do_constant_folding=do_constant_folding,
                        keep_initializers_as_inputs=keep_initializers_as_inputs,
                        dynamic_axes=dynamic_axes,
                        opset_version=onnx_opset_version,
                        example_outputs=output_example,
                    )

                    # Verify the model can be read, and is valid
                    onnx_model = onnx.load(output)
                    onnx.checker.check_model(onnx_model, full_check=True)

                    if do_constant_folding:
                        if not ONNX_GRAPHSURGEON_AVAILABLE:
                            logging.info(
                                f"onnx-graphsurgeon module is not instlled."
                                "That may result in suboptimal optimization of exported ONNX graph (including unneeded DOUBLE initializers)."
                                "Please follow the instructions available at:"
                                "https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon"
                                "to install onnx-graphsurgeon from source to improve exported graph."
                            )
                        else:
                            # This pass is to remove/recast certain constants that are generated as 'double'
                            # Those constants break ONNX -> TRT conversion (TRT does not support 'double' as of 7.2)
                            # Can probably be removed once TRT has automatic downcast for double.
                            # However, it may still be useful even then as it seems to always make the graph shorter.
                            graph = gs.import_onnx(onnx_model)
                            onnx_model = gs.export_onnx(graph.fold_constants().cleanup())
                            onnx.checker.check_model(onnx_model, full_check=True)
                            onnx.save(onnx_model, output)
                else:
                    raise ValueError(f'Encountered unknown export format {format}.')
        finally:
            typecheck.set_typecheck_enabled(enabled=True)
            if forward_method:
                type(self).forward = old_forward_method
        return exported
Exemple #9
0
    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        try:
            # Disable typechecks
            typecheck.set_typecheck_enabled(enabled=False)

            # Set module to eval mode
            if set_eval:
                self.eval()

            filename, file_extension = os.path.splitext(output)
            if file_extension not in _EXT_DICT.keys():
                raise ValueError(
                    f"Export file {output} extension does not correspond to any export format!"
                )

            format = _EXT_DICT[file_extension]

            self._prepare_for_export()

            if input_example is not None:
                _in_example = input_example
            else:
                _in_example = self.input_example()

            if output_example is None:
                _out_example = self.forward(*_in_example)

            if not (hasattr(self, 'input_types')
                    and hasattr(self, 'output_types')):
                raise NotImplementedError(
                    'For export to work you must define input and output types'
                )
            input_names = list(self.input_types.keys())
            output_names = list(self.output_types.keys())
            # dynamic axis is a mapping from input/output_name => list of "dynamic" indices
            dynamic_axes = defaultdict(list)

            # extract dynamic axes and remove unnecessary inputs/outputs
            # for input_ports
            for _name, ntype in self.input_types.items():
                if _name in self.disabled_deployment_input_names:
                    input_names.remove(_name)
                    continue
                if use_dynamic_axes:
                    dynamic_axes = {
                        **dynamic_axes,
                        **self._extract_dynamic_axes(_name, ntype)
                    }
            # for output_ports
            for _name, ntype in self.output_types.items():
                if _name in self.disabled_deployment_output_names:
                    output_names.remove(_name)
                    continue
                if use_dynamic_axes:
                    dynamic_axes = {
                        **dynamic_axes,
                        **self._extract_dynamic_axes(_name, ntype)
                    }

            if len(dynamic_axes) == 0:
                dynamic_axes = None

            with torch.jit.optimized_execution(True):
                jitted_model = None
                if try_script:
                    try:
                        jitted_model = torch.jit.script(self)
                    except Exception as e:
                        print("jit.script() failed!", e)
                if _in_example is None:
                    raise ValueError(
                        f'Example input is None, but jit.script() has failed or not tried'
                    )

                if isinstance(_in_example, Dict):
                    _in_example = tuple(_in_example.values())

                if jitted_model is None:
                    jitted_model = torch.jit.trace(self,
                                                   _in_example,
                                                   check_trace=check_trace)

                if format == ExportFormat.TORCHSCRIPT:
                    jitted_model.save(output)
                    assert os.path.exists(output)
                elif format == ExportFormat.ONNX:
                    if _out_example is None:
                        if isinstance(_in_example, tuple):
                            _out_example = self.forward(*_in_example)
                        else:
                            _out_example = self.forward(_in_example)

                    torch.onnx.export(
                        jitted_model,
                        _in_example,
                        output,
                        input_names=input_names,
                        output_names=output_names,
                        verbose=verbose,
                        export_params=export_params,
                        do_constant_folding=do_constant_folding,
                        keep_initializers_as_inputs=keep_initializers_as_inputs,
                        dynamic_axes=dynamic_axes,
                        opset_version=onnx_opset_version,
                        example_outputs=_out_example,
                    )

                    # Verify the model can be read, and is valid
                    onnx_model = onnx.load(output)
                    onnx.checker.check_model(onnx_model, full_check=True)
                    return onnx_model
                else:
                    raise ValueError(
                        f'Encountered unknown export format {format}.')
        finally:
            typecheck.set_typecheck_enabled(enabled=True)
        return [output]  # Subclasses may create more than one file.
Exemple #10
0
    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
    ):
        try:
            # Disable typechecks
            typecheck.set_typecheck_enabled(enabled=False)

            # Set module to eval mode
            if set_eval:
                self.eval()

            format = self.get_format(output)
            self._prepare_for_export()

            if input_example is not None:
                _in_example = input_example
            else:
                _in_example = self.input_example()

            if output_example is None:
                _out_example = self.forward(*_in_example)

            if not (hasattr(self, 'input_types')
                    and hasattr(self, 'output_types')):
                raise NotImplementedError(
                    'For export to work you must define input and output types'
                )
            input_names = list(self.input_types.keys())
            output_names = list(self.output_types.keys())
            # dynamic axis is a mapping from input/output_name => list of "dynamic" indices
            dynamic_axes = defaultdict(list)

            # extract dynamic axes and remove unnecessary inputs/outputs
            # for input_ports
            for _name, ntype in self.input_types.items():
                if _name in self.disabled_deployment_input_names:
                    input_names.remove(_name)
                    continue
                if use_dynamic_axes:
                    dynamic_axes = {
                        **dynamic_axes,
                        **self._extract_dynamic_axes(_name, ntype)
                    }
            # for output_ports
            for _name, ntype in self.output_types.items():
                if _name in self.disabled_deployment_output_names:
                    output_names.remove(_name)
                    continue
                if use_dynamic_axes:
                    dynamic_axes = {
                        **dynamic_axes,
                        **self._extract_dynamic_axes(_name, ntype)
                    }

            if len(dynamic_axes) == 0:
                dynamic_axes = None

            with torch.jit.optimized_execution(True):
                jitted_model = None
                if try_script:
                    try:
                        jitted_model = torch.jit.script(self)
                    except Exception as e:
                        print("jit.script() failed!", e)
                if _in_example is None:
                    raise ValueError(
                        f'Example input is None, but jit.script() has failed or not tried'
                    )

                if isinstance(_in_example, Dict):
                    _in_example = tuple(_in_example.values())

                if jitted_model is None:
                    jitted_model = torch.jit.trace(self,
                                                   _in_example,
                                                   check_trace=check_trace)

                if format == ExportFormat.TORCHSCRIPT:
                    jitted_model.save(output)
                    assert os.path.exists(output)
                elif format == ExportFormat.ONNX:
                    if _out_example is None:
                        if isinstance(_in_example, tuple):
                            _out_example = self.forward(*_in_example)
                        else:
                            _out_example = self.forward(_in_example)

                    torch.onnx.export(
                        jitted_model,
                        _in_example,
                        output,
                        input_names=input_names,
                        output_names=output_names,
                        verbose=verbose,
                        export_params=export_params,
                        do_constant_folding=do_constant_folding,
                        keep_initializers_as_inputs=keep_initializers_as_inputs,
                        dynamic_axes=dynamic_axes,
                        opset_version=onnx_opset_version,
                        example_outputs=_out_example,
                    )

                    # Verify the model can be read, and is valid
                    onnx_model = onnx.load(output)
                    onnx.checker.check_model(onnx_model, full_check=True)

                    if do_constant_folding:
                        if not ONNX_GRAPHSURGEON_AVAILABLE:
                            logging.info(
                                f"onnx-graphsurgeon module is not instlled."
                                "That may result in suboptimal optimization of exported ONNX graph (including unneeded DOUBLE initializers)."
                                "Please follow the instructions available at:"
                                "https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon"
                                "to install onnx-graphsurgeon from source to improve exported graph."
                            )
                        else:
                            # This pass is to remove/recast certain constants that are generated as 'double'
                            # Those constants break ONNX -> TRT conversion (TRT does not support 'double' as of 7.2)
                            # Can probably be removed once TRT has automatic downcast for double.
                            # However, it may still be useful even then as it seems to always make the graph shorter.
                            graph = gs.import_onnx(onnx_model)
                            onnx_model = gs.export_onnx(
                                graph.fold_constants().cleanup())
                            onnx.checker.check_model(onnx_model,
                                                     full_check=True)
                            onnx.save(onnx_model, output)

                    return onnx_model
                else:
                    raise ValueError(
                        f'Encountered unknown export format {format}.')
        finally:
            typecheck.set_typecheck_enabled(enabled=True)
        return [output]  # Subclasses may create more than one file.
Exemple #11
0
    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 12,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = True,
        use_dynamic_axes: bool = True,
        dynamic_axes=None,
        check_tolerance=0.01,
        forward_method=None,
    ):
        my_args = locals()
        del my_args['self']

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output_descr = qual_name + ' exported to ONNX'

        try:
            # Disable typechecks
            typecheck.set_typecheck_enabled(enabled=False)

            # Allow user to completely override forward method to export
            if forward_method is None and hasattr(type(self),
                                                  "forward_for_export"):
                forward_method = type(self).forward_for_export

            if forward_method:
                old_forward_method = type(self).forward
                type(self).forward = forward_method

            # Set module to eval mode
            if set_eval:
                self.eval()

            format = self.get_format(output)

            if input_example is None:
                input_example = self.input_module.input_example()

            if isinstance(input_example, Dict):
                input_example = tuple(input_example.values())

            my_args['input_example'] = input_example
            self._prepare_for_export(**my_args)

            if output_example is None:
                if isinstance(input_example, tuple):
                    output_example = self.forward(*input_example)
                else:
                    output_example = self.forward(input_example)

            input_names = self.input_module.get_input_names(input_example)
            output_names = self.output_module.get_output_names(output_example)

            with torch.jit.optimized_execution(True), torch.no_grad():
                jitted_model = None
                if try_script:
                    try:
                        jitted_model = torch.jit.script(self)
                    except Exception as e:
                        print("jit.script() failed!", e)

            with torch.jit.optimized_execution(True), torch.no_grad():
                if format == ExportFormat.TORCHSCRIPT:
                    if jitted_model is None:
                        jitted_model = torch.jit.trace(
                            self,
                            input_example,
                            strict=False,
                            optimize=True,
                            check_trace=check_trace,
                            check_tolerance=check_tolerance,
                        )
                    if verbose:
                        print(jitted_model.code)
                    jitted_model.save(output)
                    assert os.path.exists(output)

                elif format == ExportFormat.ONNX:
                    if jitted_model is None:
                        jitted_model = self

                    # dynamic axis is a mapping from input/output_name => list of "dynamic" indices
                    if dynamic_axes is None and use_dynamic_axes:
                        dynamic_axes = self.input_module.get_input_dynamic_axes(
                            input_names)
                        dynamic_axes = {
                            **dynamic_axes,
                            **self.output_module.get_output_dynamic_axes(output_names)
                        }

                    torch.onnx.export(
                        jitted_model,
                        input_example,
                        output,
                        input_names=input_names,
                        output_names=output_names,
                        verbose=verbose,
                        export_params=export_params,
                        do_constant_folding=do_constant_folding,
                        keep_initializers_as_inputs=keep_initializers_as_inputs,
                        dynamic_axes=dynamic_axes,
                        opset_version=onnx_opset_version,
                        example_outputs=output_example,
                    )

                    # Verify the model can be read, and is valid
                    onnx_model = onnx.load(output)
                    onnx.checker.check_model(onnx_model, full_check=True)

                else:
                    raise ValueError(
                        f'Encountered unknown export format {format}.')
        finally:
            typecheck.set_typecheck_enabled(enabled=True)
            if forward_method:
                type(self).forward = old_forward_method
        return ([output], [output_descr])
Exemple #12
0
def nemo_export(argv):
    args = get_args(argv)
    loglevel = logging.INFO
    # assuming loglevel is bound to the string value obtained from the
    # command line argument. Convert to upper case to allow the user to
    # specify --log=DEBUG or --log=debug
    if args.verbose is not None:
        numeric_level = getattr(logging, args.verbose.upper(), None)
        if not isinstance(numeric_level, int):
            raise ValueError('Invalid log level: %s' % numeric_level)
        loglevel = numeric_level

    logger = logging.getLogger(__name__)
    if logger.handlers:
        for handler in logger.handlers:
            logger.removeHandler(handler)
    logging.basicConfig(level=loglevel, format='%(asctime)s [%(levelname)s] %(message)s')
    logging.info("Logging level set to {}".format(loglevel))

    """Convert a .nemo saved model into .riva Riva input format."""
    nemo_in = args.source
    out = args.out

    # Create a PL trainer object which is required for restoring Megatron models
    cfg_trainer = TrainerConfig(
        gpus=1,
        accelerator="ddp",
        num_nodes=1,
        # Need to set the following two to False as ExpManager will take care of them differently.
        logger=False,
        checkpoint_callback=False,
    )
    trainer = Trainer(cfg_trainer)

    logging.info("Restoring NeMo model from '{}'".format(nemo_in))
    try:
        with torch.inference_mode():
            # Restore instance from .nemo file using generic model restore_from
            model = ModelPT.restore_from(restore_path=nemo_in, trainer=trainer)
    except Exception as e:
        logging.error(
            "Failed to restore model from NeMo file : {}. Please make sure you have the latest NeMo package installed with [all] dependencies.".format(
                nemo_in
            )
        )
        raise e

    logging.info("Model {} restored from '{}'".format(model.cfg.target, nemo_in))

    if not isinstance(model, Exportable):
        logging.error("Your NeMo model class ({}) is not Exportable.".format(model.cfg.target))
        sys.exit(1)
    typecheck.set_typecheck_enabled(enabled=False)

    try:
        #
        #  Add custom export parameters here
        #
        in_args = {}
        if args.max_batch is not None:
            in_args["max_batch"] = args.max_batch
        if args.max_dim is not None:
            in_args["max_dim"] = args.max_dim

        autocast = nullcontext
        model = model.to(device=args.device)
        model.eval()
        with torch.inference_mode():
            input_example = model.input_module.input_example(**in_args)
        if args.autocast:
            autocast = torch.cuda.amp.autocast
        with autocast(), torch.inference_mode():
            logging.info(f"Getting output example")
            input_list, input_dict = parse_input_example(input_example)
            output_example = forward_method(model)(*input_list, **input_dict)
            logging.info(f"Exporting model with autocast={args.autocast}")
            input_names = model.input_names
            output_names = model.output_names

            _, descriptions = model.export(
                out,
                check_trace=False,
                input_example=input_example,
                onnx_opset_version=args.onnx_opset,
                verbose=args.verbose,
            )

    except Exception as e:
        logging.error(
            "Export failed. Please make sure your NeMo model class ({}) has working export() and that you have the latest NeMo package installed with [all] dependencies.".format(
                model.cfg.target
            )
        )
        raise e

    logging.info("Successfully exported to {}".format(out))

    del model

    if args.runtime_check:
        verify_runtime(out, input_list, input_dict, input_names, output_names, output_example)
Exemple #13
0
    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 13,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = False,
        use_dynamic_axes: bool = True,
        dynamic_axes=None,
        check_tolerance=0.01,
    ):
        my_args = locals()
        del my_args['self']

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output_descr = qual_name + ' exported to ONNX'

        try:
            # Disable typechecks
            typecheck.set_typecheck_enabled(enabled=False)

            # Allow user to completely override forward method to export
            forward_method, old_forward_method = self._wrap_forward_method()

            # Set module to eval mode
            self._set_eval(set_eval)

            format = self.get_format(output)

            if input_example is None:
                input_example = self._get_input_example()

            my_args['input_example'] = input_example

            # Run (posibly overridden) prepare method before calling forward()
            self._prepare_for_export(**my_args)

            input_list, input_dict = self._setup_input_example(input_example)

            input_names = self._process_input_names()
            output_names = self._process_output_names()

            output_example = self.forward(*input_list, **input_dict)

            with torch.jit.optimized_execution(True), torch.no_grad():
                jitted_model = self._try_jit_compile_model(self, try_script)

                if format == ExportFormat.TORCHSCRIPT:
                    self._export_torchscript(jitted_model, output, input_dict,
                                             input_list, check_trace,
                                             check_tolerance, verbose)

                elif format == ExportFormat.ONNX:
                    self._export_onnx(
                        jitted_model,
                        input_example,
                        output_example,
                        input_names,
                        output_names,
                        use_dynamic_axes,
                        do_constant_folding,
                        dynamic_axes,
                        output,
                        export_params,
                        keep_initializers_as_inputs,
                        onnx_opset_version,
                        verbose,
                    )

                    # Verify the model can be read, and is valid
                    self._verify_onnx_export(output, output_example,
                                             input_list, input_dict,
                                             input_names, check_tolerance,
                                             check_trace)
                else:
                    raise ValueError(
                        f'Encountered unknown export format {format}.')
        finally:
            typecheck.set_typecheck_enabled(enabled=True)
            if forward_method:
                type(self).forward = old_forward_method
        return ([output], [output_descr])
Exemple #14
0
    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 13,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = False,
        use_dynamic_axes: bool = True,
        dynamic_axes=None,
        check_tolerance=0.01,
    ):
        my_args = locals()
        del my_args['self']

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output_descr = qual_name + ' exported to ONNX'

        try:
            # Disable typechecks
            typecheck.set_typecheck_enabled(enabled=False)

            # Set module to eval mode
            self._set_eval(set_eval)

            format = self.get_format(output)

            # Assign special flag for RNNT export of encoder
            if not hasattr(self.input_module, '_rnnt_export'):
                raise ValueError(
                    f"{self.input_module.__class__.__name__} must have a bool attribute `_rnnt_export`, "
                    f"which is necessary for RNNT export.")

            if not hasattr(self.output_module, '_rnnt_export'):
                raise ValueError(
                    f"{self.output_module.__class__.__name__} must have a bool attribute `_rnnt_export`, "
                    f"which is necessary for RNNT export.")

            if not hasattr(self.joint_module, '_rnnt_export'):
                raise ValueError(
                    f"{self.joint_module.__class__.__name__} must have a bool attribute `_rnnt_export`, "
                    f"which is necessary for RNNT export.")

            self.input_module._rnnt_export = True
            self.output_module._rnnt_export = True
            self.joint_module._rnnt_export = True

            if input_example is None:
                encoder_examples, decoder_examples = self._get_input_example()
                input_example = [encoder_examples, decoder_examples]
            else:
                assert type(input_example) in (
                    list, tuple) and len(input_example) == 2, (
                        "input_example must "
                        "be a list of two tensors,"
                        "for encoder and decoder input")
                encoder_examples, decoder_examples = input_example

            if output_example is not None:
                assert type(output_example) in (
                    list, tuple) and len(output_example) == 2, (
                        "output_example must "
                        "be a list of two tensors,"
                        "for encoder and decoder+joint"
                        " output")
                encoder_output_example, decoder_joint_output_example = output_example
            else:
                encoder_output_example = None
                decoder_joint_output_example = None

            my_args['input_example'] = input_example

            # Run (posibly overridden) prepare method before calling forward()
            self._prepare_for_export(**my_args)

            encoder_input_list, encoder_input_dict = self._setup_input_example(
                encoder_examples)
            decoder_input_list, decoder_input_dict = self._setup_input_example(
                decoder_examples)

            encoder_input_names, decoder_input_names = self._process_input_names(
            )
            encoder_output_names, decoder_output_names, joint_output_names = self._process_output_names(
            )

            # process decoder states; by convension states must be the last in the list and must be wrapped in a tuple
            (
                decoder_input_list,
                decoder_input_names,
                input_state_names,
                num_states,
                output_state_names,
                state_names,
            ) = self._process_states_names(decoder_input_list,
                                           decoder_input_names)

            with torch.jit.optimized_execution(True), torch.no_grad():
                # Encoder export
                encoder_jitted_model = self._try_jit_compile_model(
                    self.input_module, try_script)

                if format == exportable.ExportFormat.TORCHSCRIPT:
                    raise NotImplementedError()
                    # # Allow user to completely override forward method to export
                    # forward_method, original_forward_method = self._wrap_forward_method('encoder')
                    # encoder_output_example = self.forward(*encoder_input_list, **encoder_input_dict)
                    #
                    # self._export_torchscript(
                    #     encoder_jitted_model,
                    #     self._augment_output_filename(output, "Encoder"),
                    #     encoder_input_dict,
                    #     encoder_input_list,
                    #     check_trace,
                    #     check_tolerance,
                    #     verbose,
                    # )
                    #
                    # self._export_flag_module = 'decoder_joint'
                    #
                    # # Extract just the encoder logits and remove the encoder lengths
                    # if type(encoder_output_example) in (list, tuple):
                    #     encoder_output_example = encoder_output_example[0]
                    #
                    # encoder_decoder_input_list = [encoder_output_example] + list(decoder_input_list)
                    # encoder_decoder_input_dict = decoder_input_dict
                    #
                    # encoder_decoder_input_list = tuple(encoder_decoder_input_list)
                    #
                    # # Allow user to completely override forward method to export
                    # forward_method, _ = self._wrap_forward_method('decoder_joint')
                    # decoder_joint_output_example = self.forward(*encoder_decoder_input_list, **encoder_decoder_input_dict)
                    # decoder_joint_output_example = tuple(decoder_joint_output_example)
                    #
                    # # Resolve output states
                    # if num_states > 0:
                    #     if type(decoder_joint_output_example[-1]) == tuple:
                    #         raise TypeError("Since input states are available, forward must emit flattened states")
                    #
                    #     # remove the name of the states
                    #     logging.info(
                    #         f"Replacing output state name {decoder_output_names[-1]} with {str(output_state_names)}"
                    #     )
                    #     decoder_output_names = decoder_output_names[:-1]
                    #
                    # self._export_torchscript(
                    #     None,
                    #     self._augment_output_filename(output, "Decoder-Joint"),
                    #     encoder_decoder_input_dict,
                    #     encoder_decoder_input_list,
                    #     check_trace,
                    #     check_tolerance,
                    #     verbose,
                    # )

                elif format == exportable.ExportFormat.ONNX:
                    # Allow user to completely override forward method to export
                    forward_method, original_forward_method = self._wrap_forward_method(
                        'encoder')
                    encoder_output_example = self.forward(
                        *encoder_input_list, **encoder_input_dict)

                    self._export_flag_module = 'encoder'

                    self._export_onnx(
                        encoder_jitted_model,
                        encoder_examples,
                        encoder_output_example,
                        encoder_input_names,
                        encoder_output_names,
                        use_dynamic_axes,
                        do_constant_folding,
                        dynamic_axes,
                        self._augment_output_filename(output, "Encoder"),
                        export_params,
                        keep_initializers_as_inputs,
                        onnx_opset_version,
                        verbose,
                    )

                    # Verify the model can be read, and is valid
                    self._verify_onnx_export(
                        self._augment_output_filename(output, "Encoder"),
                        encoder_output_example,
                        encoder_input_list,
                        encoder_input_dict,
                        encoder_input_names,
                        encoder_output_names,
                        check_tolerance,
                        check_trace,
                    )

                    self._export_flag_module = 'decoder_joint'

                    # Extract just the encoder logits and remove the encoder lengths
                    if type(encoder_output_example) in (list, tuple):
                        encoder_output_example = encoder_output_example[0]

                    encoder_decoder_input_list = [encoder_output_example
                                                  ] + list(decoder_input_list)
                    encoder_decoder_input_dict = decoder_input_dict

                    encoder_decoder_input_list = tuple(
                        encoder_decoder_input_list)

                    # Allow user to completely override forward method to export
                    forward_method, _ = self._wrap_forward_method(
                        'decoder_joint')
                    decoder_joint_output_example = self.forward(
                        *encoder_decoder_input_list,
                        **encoder_decoder_input_dict)
                    decoder_joint_output_example = tuple(
                        decoder_joint_output_example)

                    # Resolve output states
                    if num_states > 0:
                        if type(decoder_joint_output_example[-1]) == tuple:
                            raise TypeError(
                                "Since input states are available, forward must emit flattened states"
                            )

                        # remove the name of the states
                        logging.info(
                            f"Replacing output state name {decoder_output_names[-1]} with {str(output_state_names)}"
                        )
                        decoder_output_names = decoder_output_names[:-1]

                    self._export_onnx(
                        None,
                        encoder_decoder_input_list,
                        decoder_joint_output_example,
                        self._join_input_output_names(["encoder_outputs"],
                                                      decoder_input_names,
                                                      input_state_names),
                        self._join_input_output_names(joint_output_names,
                                                      decoder_output_names,
                                                      output_state_names),
                        use_dynamic_axes,
                        do_constant_folding,
                        dynamic_axes,
                        self._augment_output_filename(output, "Decoder-Joint"),
                        export_params,
                        keep_initializers_as_inputs,
                        onnx_opset_version,
                        verbose,
                    )

                    # Verify the model can be read, and is valid
                    self._verify_onnx_export(
                        self._augment_output_filename(output, "Decoder-Joint"),
                        decoder_joint_output_example,
                        encoder_decoder_input_list,
                        encoder_decoder_input_dict,
                        self._join_input_output_names(["encoder_outputs"],
                                                      decoder_input_names,
                                                      input_state_names),
                        self._join_input_output_names(joint_output_names,
                                                      decoder_output_names,
                                                      output_state_names),
                        check_tolerance,
                        check_trace,
                    )

                else:
                    raise ValueError(
                        f'Encountered unknown export format {format}.')

        except Exception as e:
            raise e
        finally:
            typecheck.set_typecheck_enabled(enabled=True)
            logging.warning(
                "PyTorch Model has been significantly modified. In order to utilize model, delete this "
                "instance and create a new model.")

            # replace forward method with original forward method

            type(self).forward = original_forward_method

            # Reset special flag for RNNT export of encoder
            self.input_module._rnnt_export = False
            self.output_module._rnnt_export = False
            self.joint_module._rnnt_export = False

        return ([output], [output_descr])
Exemple #15
0
    def export(
        self,
        output: str,
        input_example=None,
        output_example=None,
        verbose=False,
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
        onnx_opset_version: int = 13,
        try_script: bool = False,
        set_eval: bool = True,
        check_trace: bool = False,
        use_dynamic_axes: bool = True,
        dynamic_axes=None,
        check_tolerance=0.01,
    ):
        my_args = locals()
        del my_args['self']

        qual_name = self.__module__ + '.' + self.__class__.__qualname__
        output_descr = qual_name + ' exported to ONNX'

        try:
            # Disable typechecks
            typecheck.set_typecheck_enabled(enabled=False)

            # Allow user to completely override forward method to export
            if hasattr(type(self), "forward_for_export"):
                forward_method = type(self).forward_for_export
                old_forward_method = type(self).forward
                type(self).forward = forward_method
            else:
                forward_method = None

            # Set module to eval mode
            if set_eval:
                self.freeze()
                self.input_module.freeze()
                self.output_module.freeze()

            format = self.get_format(output)

            if input_example is None:
                input_example = self.input_module.input_example()

            my_args['input_example'] = input_example

            # Run (posibly overridden) prepare method before calling forward()
            self._prepare_for_export(**my_args)

            input_list = list(input_example)
            input_dict = {}
            # process possible kwargs
            if isinstance(input_list[-1], dict):
                input_dict = input_list[-1]
                input_list = input_list[:-1]

            input_names = get_input_names(self.input_module)
            # remove unnecessary inputs for input_ports
            for name in self.disabled_deployment_input_names:
                input_names.remove(name)
            output_names = get_output_names(self.output_module)
            # remove unnecessary inputs for input_ports
            for name in self.disabled_deployment_output_names:
                output_names.remove(name)

            output_example = self.forward(*input_list, **input_dict)

            with torch.jit.optimized_execution(True), torch.no_grad():
                jitted_model = None
                if try_script:
                    try:
                        jitted_model = torch.jit.script(self)
                    except Exception as e:
                        print("jit.script() failed!", e)

                if format == ExportFormat.TORCHSCRIPT:
                    if jitted_model is None:
                        jitted_model = torch.jit.trace_module(
                            self,
                            {
                                "forward":
                                tuple(input_list) + tuple(input_dict.values())
                            },
                            strict=False,
                            optimize=True,
                            check_trace=check_trace,
                            check_tolerance=check_tolerance,
                        )
                    if verbose:
                        print(jitted_model.code)
                    jitted_model.save(output)
                    assert os.path.exists(output)

                elif format == ExportFormat.ONNX:
                    if jitted_model is None:
                        jitted_model = self

                    # dynamic axis is a mapping from input/output_name => list of "dynamic" indices
                    if dynamic_axes is None and use_dynamic_axes:
                        dynamic_axes = get_input_dynamic_axes(
                            self.input_module, input_names)
                        dynamic_axes = {
                            **dynamic_axes,
                            **get_output_dynamic_axes(self.output_module, output_names)
                        }

                    torch.onnx.export(
                        jitted_model,
                        input_example,
                        output,
                        input_names=input_names,
                        output_names=output_names,
                        verbose=verbose,
                        export_params=export_params,
                        do_constant_folding=do_constant_folding,
                        keep_initializers_as_inputs=keep_initializers_as_inputs,
                        dynamic_axes=dynamic_axes,
                        opset_version=onnx_opset_version,
                        example_outputs=output_example,
                    )

                    # Verify the model can be read, and is valid
                    onnx_model = onnx.load(output)
                    onnx.checker.check_model(onnx_model, full_check=True)
                    test_runtime = check_trace
                    if test_runtime:
                        try:
                            import onnxruntime
                        except (ImportError, ModuleNotFoundError):
                            test_runtime = False
                            logging.warning(
                                f"ONNX generated at {output}, not verified - please install onnxruntime.\n"
                            )

                    if test_runtime:
                        sess = onnxruntime.InferenceSession(
                            onnx_model.SerializeToString())

                        ort_out = sess.run(
                            None,
                            to_onnxrt_input(input_names, input_list,
                                            input_dict))
                        all_good = True
                        for out_name, out in enumerate(ort_out):
                            expected = output_example[out_name].cpu()
                            if not torch.allclose(torch.from_numpy(out),
                                                  expected,
                                                  rtol=check_tolerance,
                                                  atol=100 * check_tolerance):
                                all_good = False
                                logging.info(
                                    f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{out}"
                                )
                        status = "SUCCESS" if all_good else "FAIL"
                        logging.info(
                            f"ONNX generated at {output} verified with onnxruntime : "
                            + status)
                else:
                    raise ValueError(
                        f'Encountered unknown export format {format}.')
        finally:
            typecheck.set_typecheck_enabled(enabled=True)
            if forward_method:
                type(self).forward = old_forward_method
        return ([output], [output_descr])