Exemplo n.º 1
0
def _export(model,
            args,
            strip_large_tensor_data=False,
            large_tensor_threshold=LARGE_TENSOR_DATA_THRESHOLD,
            **kwargs):
    model.zero_grad()
    bytesio = io.BytesIO()
    opset_ver = kwargs.get('opset_version', None)
    if opset_ver is None:
        opset_ver = _default_onnx_opset_version
    strip_doc_string = kwargs.pop('strip_doc_string', True)
    with init_annotate(model, opset_ver) as ann, \
            as_output.trace(model) as (model, outputs):
        outs = _export_util(model,
                            args,
                            bytesio,
                            strip_doc_string=False,
                            **kwargs)
        onnx_graph = onnx.load(io.BytesIO(bytesio.getvalue()))
        onnx_graph = ann.set_annotate(onnx_graph)
        onnx_graph = ann.reorg_anchor(onnx_graph)
        outputs.add_outputs_to_model(onnx_graph)

    if strip_doc_string:
        for node in onnx_graph.graph.node:
            node.doc_string = b''
    if strip_large_tensor_data:
        _strip_large_initializer_raw_data(onnx_graph, large_tensor_threshold)

    return onnx_graph, outs
Exemplo n.º 2
0
        def new_forward(self, *args):
            if hasattr(self, '_traced_mod'):
                out = self._traced_mod(*args)
                inter_size = len(self._names)
                if inter_size == 0:
                    return out
                if not isinstance(out, (tuple, list)):
                    out = [out]
                return dict(
                    **{str(i): x
                       for i, x in enumerate(out[:-inter_size])},
                    **{
                        name: x
                        for name, x in zip(self._names, out[-inter_size:])
                    },
                )

            new_forward = self.forward
            self.forward = self._forward_with_init

            with trace(self) as (new_module, outputs):
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    self._traced_mod = torch.jit.trace_module(
                        new_module, {"forward": args})
                self._names = [out.name for out in outputs.values]

            self.forward = new_forward
            return self.forward(*args)
Exemplo n.º 3
0
def _export(
        model: torch.nn.Module,
        args: Sequence[Any],
        strip_large_tensor_data: bool = False,
        large_tensor_threshold: int = LARGE_TENSOR_DATA_THRESHOLD,
        use_pfto: bool = False,
        **kwargs: Any,
) -> Tuple[onnx.ModelProto, Any]:
    model.zero_grad()
    bytesio = io.BytesIO()
    opset_ver = kwargs.get('opset_version', None)
    if opset_ver is None:
        opset_ver = _default_onnx_opset_version
        kwargs['opset_version'] = opset_ver
    if use_pfto or not pytorch_pfn_extras.requires('1.10.0'):
        strip_doc_string = kwargs.get('strip_doc_string', True)
        kwargs['strip_doc_string'] = False
    else:
        strip_doc_string = kwargs.pop('strip_doc_string', True)
        kwargs['verbose'] = True
    with init_annotate(model, opset_ver) as ann, \
            as_output.trace(model) as (model, outputs), \
            grad.init_grad_state():
        if use_pfto:
            outs = pfto_export(
                model, args, bytesio, **kwargs)
        else:
            outs = _export_util(
                model, args, bytesio, **kwargs)
        onnx_graph = onnx.load(io.BytesIO(bytesio.getvalue()))
        onnx_graph = ann.set_annotate(onnx_graph)
        onnx_graph = ann.reorg_anchor(onnx_graph)
        outputs.add_outputs_to_model(onnx_graph)
        if strip_doc_string:
            for node in onnx_graph.graph.node:
                node.doc_string = b''

    if strip_large_tensor_data:
        _strip_large_initializer_raw_data(onnx_graph, large_tensor_threshold)

    return onnx_graph, outs