Example #1
0
    def torchscript_export(self, model, export_path=None, quantize=False):
        cuda.CUDA_ENABLED = False
        model.cpu()
        optimizer = self.trainer.optimizer
        optimizer.pre_export(model)

        model.eval()
        model.prepare_for_onnx_export_()

        unused_raw_batch, batch = next(
            iter(self.data.batches(Stage.TRAIN, load_early=True))
        )
        inputs = model.onnx_trace_input(batch)
        model(*inputs)
        if quantize:
            model.quantize()
        if self.trace_both_encoders:
            trace = jit.trace(model, inputs)
        else:
            trace = jit.trace(model.encoder1, (inputs[0],))
        if hasattr(model, "torchscriptify"):
            trace = model.torchscriptify(
                self.data.tensorizers, trace, self.trace_both_encoders
            )
        trace.apply(lambda s: s._pack() if s._c._has_method("_pack") else None)
        if export_path is not None:
            print(f"Saving torchscript model to: {export_path}")
            trace.save(export_path)
        return trace
Example #2
0
    def translate(self):
        translation_plan = self.plan.copy()
        translation_plan.forward = None

        args_shape = translation_plan.get_args_shape()
        args = PlaceHolder.create_placeholders(args_shape)

        # To avoid storing Plan state tensors in torchscript, they will be send as parameters
        # we trace wrapper func, which accepts state parameters as last arg
        # and sets them into the Plan before executing the Plan
        def wrap_stateful_plan(*args):
            role = translation_plan.role
            state = args[-1]
            if 0 < len(role.state.state_placeholders) == len(
                    state) and isinstance(state, (list, tuple)):
                state_placeholders = tuple(
                    role.placeholders[ph.id.value]
                    for ph in role.state.state_placeholders)
                PlaceHolder.instantiate_placeholders(
                    role.state.state_placeholders, state)
                PlaceHolder.instantiate_placeholders(state_placeholders, state)

            return translation_plan(*args[:-1])

        plan_params = translation_plan.parameters()
        if len(plan_params) > 0:
            torchscript_plan = jit.trace(wrap_stateful_plan,
                                         (*args, plan_params))
        else:
            torchscript_plan = jit.trace(translation_plan, args)

        self.plan.torchscript = torchscript_plan
        return self.plan
Example #3
0
def main_2():

	orig_model = MyModel2()

	# case 1
	input1 = torch.randn(2, 3)
	input1 = input1.pow(2) # always positive
	traced_model1 = jit.trace(orig_model, (input1))

	# case 2
	input2 = torch.randn(2, 3)
	input2 = -input2.pow(2) # always negative
	traced_model2 = jit.trace(orig_model, (input2))

	print(orig_model(input1).size(), orig_model(input2).size())
	print(traced_model1(input1).size(), traced_model1(input2).size())
	print(traced_model2(input1).size(), traced_model2(input2).size())

	print(traced_model1.code)
	print('=======')
	print(traced_model2.code)
	print('=======')

	# scripting
	scripted_model = jit.script(orig_model)

	print(orig_model(input1).size(), orig_model(input2).size())
	print(scripted_model(input1).size(), scripted_model(input2).size())
	print(scripted_model.code)
Example #4
0
    def torchscript_export(self, model, export_path=None, **kwargs):
        # unpack export kwargs
        quantize = kwargs.get("quantize", False)
        accelerate = kwargs.get("accelerate", [])
        padding_control = kwargs.get("padding_control")
        inference_interface = kwargs.get("inference_interface")

        cuda.CUDA_ENABLED = False
        model.cpu()
        optimizer = self.trainer.optimizer
        optimizer.pre_export(model)

        model.eval()
        model.prepare_for_onnx_export_()

        unused_raw_batch, batch = next(
            iter(self.data.batches(Stage.TRAIN, load_early=True))
        )
        inputs = model.onnx_trace_input(batch)
        model(*inputs)
        if quantize:
            model.quantize()
        if "half" in accelerate:
            model.half()
        if self.trace_both_encoders:
            trace = jit.trace(model, inputs)
        else:
            trace = jit.trace(model.encoder1, (inputs[0],))
        if hasattr(model, "torchscriptify"):
            trace = model.torchscriptify(
                self.data.tensorizers, trace, self.trace_both_encoders
            )
        if padding_control is not None:
            if hasattr(trace, "set_padding_control"):
                trace.set_padding_control(padding_control)
            else:
                print(
                    "Padding_control not supported by model. Ignoring padding_control"
                )
        if inference_interface is not None:
            if hasattr(trace, "inference_interface"):
                trace.inference_interface(inference_interface)
            else:
                print(
                    "inference_interface not supported by model. Ignoring inference_interface"
                )
        trace.apply(lambda s: s._pack() if s._c._has_method("_pack") else None)
        if "nnpi" in accelerate:
            trace._c = torch._C._freeze_module(
                trace._c,
                preservedAttrs=["make_prediction", "make_batch", "set_padding_control"],
            )
        if export_path is not None:
            print(f"Saving torchscript model to: {export_path}")
            with PathManager.open(export_path, "wb") as f:
                torch.jit.save(trace, f)
        return trace
Example #5
0
 def trace(self):
     self.init_block = trace(self.init_block,
                             torch.rand(1, self.input_size))
     for level in range(self.start_level, self.end_level):
         self.level_blocks[level] = trace(
             self.level_blocks[level],
             torch.rand(1, self.channels[level], 2**level, 2**level))
     for level in range(self.start_level, self.end_level + 1):
         self.toRGBs[level] = trace(
             self.toRGBs[level],
             torch.rand(1, self.channels[level], 2**level, 2**level))
 def __init__(self, model=None, data=None, graph=None):
     """
     We build the network architecture graph according the graph
     in the scriptmodule. However, the original graph from jit.trace
     has lots of detailed information which make the graph complicated
     and hard to understand. So we also store a copy of the network
     architecture in the self.forward_edge. We will simplify the network
     architecure (such as unpack_tuple, etc) stored in self.forward_edge
     to make the graph more clear.
     Parameters
     ----------
     model : torch.nn.Module
         The model to build the network architecture.
     data : torch.Tensor
         The sample input data for the model.
     graph : torch._C.Graph
         Traced graph from jit.trace, if this option is set,
         we donnot need to trace the model again.
     """
     self.model = model
     self.data = data
     if graph is not None:
         self.graph = graph
     elif (model is not None) and (data is not None):
         with torch.onnx.set_training(model, False):
             self.traced_model = jit.trace(model, data)
             self.graph = self.traced_model.graph
             torch._C._jit_pass_inline(self.graph)
     else:
         raise Exception('Input parameters invalid!')
     self.forward_edge = {}
     self.c2py = {}
     self.visited = set()
     self.build_graph()
     self.unpack_tuple()
def export_model(model, path=None, input_shape=(1, 3, 64, 64)):
    """
    Exports the model. If the model is a `ScriptModule`, it is saved as is. If not,
    it is traced (with the given input_shape) and the resulting ScriptModule is saved
    (this requires the `input_shape`, which defaults to the competition default).

    Parameters
    ----------
    model : torch.nn.Module or torch.jit.ScriptModule
        Pytorch Module or a ScriptModule.
    path : str
        Path to the file where the model is saved. Defaults to the value set by the
        `get_model_path` function above.
    input_shape : tuple or list
        Shape of the input to trace the module with. This is only required if model is not a
        torch.jit.ScriptModule.

    Returns
    -------
    str
        Path to where the model is saved.
    """
    path = get_model_path() if path is None else path
    model = deepcopy(model).cpu().eval()
    if not isinstance(model, torch.jit.ScriptModule):
        assert input_shape is not None, "`input_shape` must be provided since model is not a " \
                                        "`ScriptModule`."
        traced_model = trace(model, torch.zeros(*input_shape))
    else:
        traced_model = model
    torch.jit.save(traced_model, path)
    return path
Example #8
0
    def torchscript_export(self, model, export_path=None, quantize=False):
        # Make sure to put the model on CPU and disable CUDA before exporting to
        # ONNX to disable any data_parallel pieces
        cuda.CUDA_ENABLED = False
        model.cpu()
        optimizer = self.trainer.optimizer
        optimizer.pre_export(model)

        # Trace needs eval mode, to disable dropout etc
        model.eval()
        model.prepare_for_onnx_export_()

        unused_raw_batch, batch = next(iter(self.data.batches(Stage.TRAIN)))
        inputs = model.arrange_model_inputs(batch)
        # call model forward to set correct device types
        model(*inputs)
        if quantize:
            model.quantize()
        trace = jit.trace(model, inputs)
        if hasattr(model, "torchscriptify"):
            trace = model.torchscriptify(self.data.tensorizers, trace)
        trace.apply(lambda s: s._pack() if s._c._has_method("_pack") else None)
        if export_path is not None:
            print(f"Saving torchscript model to: {export_path}")
            trace.save(export_path)
        return trace
Example #9
0
    def parse(model, args, omit_useless_nodes=True):
        with onnx.set_training(model, False):
            trace = jit.trace(model, args)
            graph = trace.graph

        n_inputs = args.shape[0]  # not sure...

        model_graph = ModelGraph()
        for i, node in enumerate(graph.inputs()):
            if omit_useless_nodes:
                if len(
                        node.uses()
                ) == 0:  # number of user of the node (= number of outputs/ fanout)
                    continue

            if i < n_inputs:
                model_graph.append(NodePyIO(node, 'input'))
            else:
                model_graph.append(NodePyIO(node))  # parameter

        for node in graph.nodes():
            model_graph.append(NodePyOP(node))

        for node in graph.outputs():  # must place last.
            NodePyIO(node, 'output')
        model_graph.find_common_root()
        model_graph.populate_namespace_from_OP_to_IO()

        model_graph.parse_scopes()

        openvino_graph = model_graph.create_openvino_graph(model)

        return model_graph, openvino_graph
Example #10
0
    def test_trace_bilstm_differ_batch_size(self):
        # BiLSTM torch tracing was using torch.new_zeros for default input hidden
        # states, which doesn't trace properly. torch.jit traces torch.new_zeros as
        # constant and therefore locks the traced model into a static batch size.
        # torch.LSTM now uses zeros, adding test case here to verify behavior.
        # see https://github.com/pytorch/pytorch/issues/16664

        class Model(nn.Module):
            def __init__(self):
                super().__init__()
                self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_SIZE)
                self.bilstm = bilstm.BiLSTM(bilstm.BiLSTM.Config(), EMBEDDING_SIZE)

            def forward(self, tokens, seq_lengths):
                embeddings = self.embedding(tokens)
                return self.bilstm(embeddings, seq_lengths)

        model = Model()
        trace_inputs = (
            torch.LongTensor([[2, 3, 4], [2, 2, 1]]),
            torch.LongTensor([3, 2]),
        )

        trace = jit.trace(model, trace_inputs)

        test_inputs = (torch.LongTensor([[4, 5, 6]]), torch.LongTensor([3]))

        # we are just testing that this doesn't throw an exception
        trace(*test_inputs)
Example #11
0
    def __call__(self, *args, **kwargs):
        method_model = _ForwardOverrideModel(self.model, self.method_name)

        try:
            assert len(args) == 0, "only KV support implemented"

            fn = getattr(self.model, self.method_name)
            method_argnames = _get_input_argnames(fn=fn, exclude=["self"])
            method_input = tuple(kwargs[name] for name in method_argnames)

            self.tracing_result = trace(method_model, method_input)
        except Exception:
            # for backward compatibility
            self.tracing_result = trace(method_model, *args, **kwargs)
        output = self.model.forward(*args, **kwargs)

        return output
Example #12
0
def _maybe_optimize(model):
    try:
        from torch.jit import trace
        model = trace(model, example_inputs=torch.rand(1, 3, 224, 224))
        logger.info('successfully optimized PyTorch model using JIT tracing')
    except ImportError:
        logger.warning('unable to leverage torch.jit.trace optimizations')
        pass
    return model
Example #13
0
    def save(self, states: Tensor, masks: Tensor):
        if self.save_path is None:
            return

        save_dir = os.path.join(self.save_path, str(int(time())))
        os.makedirs(save_dir, exist_ok=False)

        model = jit.trace(self.network, (states, masks))
        jit.save(model, os.path.join(save_dir, "network.pt"))
Example #14
0
def main(model_path='model.pt', name='densenet201', out_name='model.trcd'):
    model = torch.load(model_path)
    model, = list(model.children())
    state = model.state_dict()

    base = get_baseline(name=name)
    base.load_state_dict(state)
    base.eval()

    model = jit.trace(base, example_inputs=(torch.rand(4, 3, 256, 256), ))
    jit.save(model, out_name)
Example #15
0
    def torchscript_export(self, model, export_path):
        # Make sure to put the model on CPU and disable CUDA before exporting to
        # ONNX to disable any data_parallel pieces
        cuda.CUDA_ENABLED = False
        model.cpu()
        precision.deactivate(model)
        # Trace needs eval mode, to disable dropout etc
        model.eval()

        batch = next(iter(self.data.batches(Stage.TEST)))
        inputs = model.arrange_model_inputs(batch)
        trace = jit.trace(model, inputs)
        trace.save(export_path)
Example #16
0
def _jit_run(
    f: FunctionType,
    compilation_cache: dict,
    jit_kw_args: dict,
    *args: Union[Numeric, TorchRandomState],
):
    if "torch" not in compilation_cache:
        # Run once to populate the control flow cache.
        f(*args)
        # Compile.
        compilation_cache["torch"] = trace(f, args, **jit_kw_args)

    return compilation_cache["torch"](*args)
Example #17
0
    def translate(self):
        translation_plan = self.plan.copy()
        translation_plan.forward = None
        # Make sure we're trying to trace Role with pytorch commands
        translation_plan.base_framework = TranslationTarget.PYTORCH.value

        args = translation_plan.create_dummy_args()

        # jit.trace clones input args and can change their type, so we have to skip types check
        # TODO see if type check can be made less strict,
        #  e.g. tensor/custom tensor/nn.Parameter could be considered same type
        translation_plan.validate_input_types = False

        # To avoid storing Plan state tensors in torchscript, they will be sent as parameters
        # we trace wrapper func, which accepts state parameters as last arg
        # and sets them into the Plan before executing the Plan
        def wrap_stateful_plan(*args):
            role = translation_plan.role
            state = args[-1]
            if 0 < len(role.state.state_placeholders) == len(state) and isinstance(
                state, (list, tuple)
            ):
                state_placeholders = tuple(
                    role.placeholders[ph.id.value] for ph in role.state.state_placeholders
                )
                PlaceHolder.instantiate_placeholders(role.state.state_placeholders, state)
                PlaceHolder.instantiate_placeholders(state_placeholders, state)

            return translation_plan(*args[:-1])

        plan_params = translation_plan.parameters()
        if len(plan_params) > 0:
            torchscript_plan = jit.trace(wrap_stateful_plan, (*args, plan_params))
        else:
            torchscript_plan = jit.trace(translation_plan, args)

        self.plan.torchscript = torchscript_plan
        return self.plan
Example #18
0
    def __init__(self, net: Policy):
        """
        Constructor

        :param net: non-recurrent network to wrap, which must not be a script module
        """
        super().__init__()

        # Setup attributes
        self.input_size = net.env_spec.obs_space.flat_dim
        self.output_size = net.env_spec.act_space.flat_dim

        self.net = trace(
            net, (to.from_numpy(net.env_spec.obs_space.sample_uniform()), ))
def trace_model(model, sample_input):
    """Traces the model
    
    Args:
        model (nn.Module): The model to be traced
        sample_input (torch.Tensor): The sample input to the model
    
    Returns:
        The traced model
    """

    if torch.__version__[0] != '1':
        return model
    else:
        return trace(model, sample_input)
Example #20
0
 def __init__(self):
     super(DenseNetJIT, self).__init__()
     self.add_module(
         'densenet',
         jit.trace(
             DenseNet(growth_rate=12,
                      DenseBlock_layer_num=(40, 40, 40),
                      bottle_neck_size=4,
                      dropout_rate=0.2,
                      compression_rate=0.5,
                      num_init_features=16,
                      num_input_features=3,
                      num_classes=10,
                      bias=False,
                      memory_efficient=False), torch.rand(1, 3, 32, 32)))
Example #21
0
 def trace(self, device):
     for level in range(self.start_level, self.end_level + 1):
         self.fromRGBs[level] = trace(
             self.fromRGBs[level],
             torch.rand(1,
                        self.input_channels,
                        2**level,
                        2**level,
                        device=device))
     for level in range(self.start_level, self.end_level):
         self.level_blocks[level] = trace(
             self.level_blocks[level],
             torch.rand(1,
                        self.channels[level + 1],
                        2**(level + 1),
                        2**(level + 1),
                        device=device))
     self.final_block = trace(
         self.final_block,
         torch.rand(1,
                    self.channels[self.start_level],
                    2**self.start_level,
                    2**self.start_level,
                    device=device))
Example #22
0
def script_model(model: nn.Module, sizes: list) -> ScriptModule:
    """
    Generates converts model to the cript model
    Args:
        model: model to convert
        sizes: sizes of input

    Returns:
        graph_model: converted model
    """
    xs = tuple(torch.randn(1, 3, s, s, requires_grad=False) for s in sizes)
    graph_model = trace(model.eval(), xs)
    graph_model.eval()

    return graph_model
Example #23
0
    def __init__(self, module: Policy):
        """
        Constructor

        :param module: non-recurrent network to wrap, which must not be a script module
        """
        super().__init__()

        # Setup attributes
        self.input_size = module.env_spec.obs_space.flat_dim
        self.output_size = module.env_spec.act_space.flat_dim

        samples = to.from_numpy(module.env_spec.obs_space.sample_uniform()
                                )  # .to(dtype=to.get_default_dtype())
        self.module = trace(module, (samples, ))
Example #24
0
    def translate(self):
        plan = self.plan

        args_shape = plan.get_args_shape()
        args = PlaceHolder.create_placeholders(args_shape)

        # Temporarily remove reference to original function
        tmp_forward = plan.forward
        plan.forward = None

        # To avoid storing Plan state tensors inside the torchscript,
        # we trace wrapper func, which accepts state parameters as last arg
        # and sets them into the Plan before executing the Plan
        def wrap_stateful_plan(*args):
            role = plan.role
            state = args[-1]
            if 0 < len(role.state.state_placeholders) == len(state) and isinstance(
                state, (list, tuple)
            ):
                state_placeholders = tuple(
                    role.placeholders[ph.id.value] for ph in role.state.state_placeholders
                )
                PlaceHolder.instantiate_placeholders(role.state.state_placeholders, state)
                PlaceHolder.instantiate_placeholders(state_placeholders, state)

            return plan(*args[:-1])

        plan_params = plan.parameters()
        if len(plan_params) > 0:
            torchscript_plan = jit.trace(wrap_stateful_plan, (*args, plan_params))
        else:
            torchscript_plan = jit.trace(plan, args)
        plan.torchscript = torchscript_plan
        plan.forward = tmp_forward

        return plan
Example #25
0
    def __init__(self, model, dummy_input):
        with torch.onnx.set_training(model, False):
            trace, _ = jit.trace(model, dummy_input)

            # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
            # composing a GEMM operation; etc.
            torch.onnx._optimize_trace(trace, False)
            graph = trace.graph()
            self.ops = []
            self.params = {}
            self.edges = []
            self.temp = {}

            in_out = list(graph.inputs()) + list(graph.outputs())
            for param in in_out:
                self.__add_param(param)

            for node in graph.nodes():
                op = {}
                op['name'] = node.scopeName()
                op['orig-name'] = node.scopeName()
                op['type'] = node.kind()
                op['inputs'] = []
                op['outputs'] = []
                op['params'] = []

                # in-place operators create very confusing graphs (Resnet, for example),
                # so we "unroll" them
                same = [
                    layer for layer in self.ops
                    if layer['orig-name'] == op['orig-name']
                ]
                if len(same) > 0:
                    op['name'] += "." + str(len(same))
                self.ops.append(op)

                for input_ in node.inputs():
                    self.__add_input(op, input_)
                    self.edges.append((input_.uniqueName(), op['name']))

                for output in node.outputs():
                    self.__add_output(op, output)
                    self.edges.append((op['name'], output.uniqueName()))

                op['attrs'] = {
                    attr_name: node[attr_name]
                    for attr_name in node.attributeNames()
                }
Example #26
0
    def __init__(self, window_size=3):
        super(SSIM, self).__init__()

        gaussian_img_kernel = {
            'weight': create_gaussian_window(window_size, 3).float(),
            'bias': torch.zeros(3)
        }
        gaussian_blur = nn.Conv2d(3,
                                  3,
                                  window_size,
                                  padding=window_size // 2,
                                  groups=3).to(device)
        gaussian_blur.load_state_dict(gaussian_img_kernel)
        self.gaussian_blur = trace(
            gaussian_blur,
            torch.rand(3, 3, 16, 16, dtype=torch.float32, device=device))
Example #27
0
    def torchscript_export(
        self,
        model,
        export_path=None,
        quantize=False,
        sort_input=False,
        sort_key=1,
        inference_interface=None,
        accelerate=None,
    ):
        # Make sure to put the model on CPU and disable CUDA before exporting to
        # ONNX to disable any data_parallel pieces
        cuda.CUDA_ENABLED = False
        model.cpu()
        optimizer = self.trainer.optimizer
        optimizer.pre_export(model)

        # Trace needs eval mode, to disable dropout etc
        model.eval()
        model.prepare_for_onnx_export_()

        unused_raw_batch, batch = next(
            iter(self.data.batches(Stage.TRAIN, load_early=True)))
        inputs = model.onnx_trace_input(batch)
        # call model forward to set correct device types
        if sort_input:
            _, sorted_indices = sort(inputs[sort_key], descending=True)
            inputs = [i.index_select(0, sorted_indices) for i in inputs]
        model(*inputs)
        if quantize:
            model.quantize()
        if accelerate is not None:
            if "half" in accelerate:
                model.half()
        if inference_interface is not None:
            model.inference_interface(inference_interface)
        trace = jit.trace(model, inputs)
        if accelerate is not None:
            if "nnpi" in accelerate:
                trace._c = torch._C._freeze_module(trace._c)
        if hasattr(model, "torchscriptify"):
            trace = model.torchscriptify(self.data.tensorizers, trace)
        trace.apply(lambda s: s._pack() if s._c._has_method("_pack") else None)
        if export_path is not None:
            print(f"Saving torchscript model to: {export_path}")
            trace.save(export_path)
        return trace
Example #28
0
    def torchscript_export(self, model, export_path):
        # Make sure to put the model on CPU and disable CUDA before exporting to
        # ONNX to disable any data_parallel pieces
        cuda.CUDA_ENABLED = False
        model.cpu()
        precision.deactivate(model)
        # Trace needs eval mode, to disable dropout etc
        model.eval()
        model.prepare_for_onnx_export_()

        batch = next(iter(self.data.batches(Stage.TEST)))
        inputs = model.arrange_model_inputs(batch)
        trace = jit.trace(model, inputs)
        if hasattr(model, "torchscriptify"):
            trace = model.torchscriptify(self.data.tensorizers, trace)
        print(f"Saving torchscript model to: {export_path}")
        trace.save(export_path)
Example #29
0
def trace_model(
    model: Model,
    batch: Union[Tuple[torch.Tensor], torch.Tensor],
    method_name: str = "forward",
) -> jit.ScriptModule:
    """Traces model using runner and batch.

    Args:
        model: Model to trace
        batch: Batch to trace the model
        method_name: Model's method name that will be
            used as entrypoint during tracing

    Example:
        .. code-block:: python

           import torch

           from catalyst.utils import trace_model

           class LinModel(torch.nn.Module):
               def __init__(self):
                   super().__init__()
                   self.lin1 = torch.nn.Linear(10, 10)
                   self.lin2 = torch.nn.Linear(2, 10)

               def forward(self, inp_1, inp_2):
                   return self.lin1(inp_1), self.lin2(inp_2)

               def first_only(self, inp_1):
                   return self.lin1(inp_1)

           lin_model = LinModel()
           traced_model = trace_model(
               lin_model, batch=torch.randn(1, 10), method_name="first_only"
           )

    Returns:
        jit.ScriptModule: Traced model
    """
    nn_model = get_nn_from_ddp_module(model)
    wrapped_model = ModelForwardWrapper(model=nn_model,
                                        method_name=method_name)
    traced = jit.trace(wrapped_model, example_inputs=batch)
    return traced
Example #30
0
    def translate(self):
        plan = self.plan

        args_shape = plan.get_args_shape()
        args = PlaceHolder.create_placeholders(args_shape)

        # Temporarily remove reference to original function
        tmp_forward = plan.forward
        plan.forward = None

        # To avoid storing Plan state tensors in torchscript, they will be send as parameters
        plan_params = plan.parameters()
        if len(plan_params) > 0:
            args = (*args, plan_params)
        torchscript_plan = jit.trace(plan, args)
        plan.torchscript = torchscript_plan
        plan.forward = tmp_forward

        return plan