Example #1
0
def quantize(model):
    qconfig = get_default_qconfig("fbgemm")
    qconfig_dict = {"": qconfig}
    return convert_fx(prepare_fx(model, qconfig_dict))
Example #2
0
    def test_input_weight_equalization_graphs(self):
        """ Tests that the modified model for equalization has the same graph
        structure as the model without equalization (before and after
        quantization).
        """

        linear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        linearAdd_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        linear2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        functionalLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        functionalLinearAdd_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        functionalLinear2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        linearRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_method('dequantize')
        ]

        linearReluLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.LinearReLU),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        functionalLinearRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear_relu),
            ns.call_method('dequantize')
        ]

        functionalLinearReluLinear_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.linear_relu),
            ns.call_function(torch.ops.quantized.linear),
            ns.call_method('dequantize')
        ]

        conv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        conv2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Conv2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        functionalConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        functionalConv2_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        convRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.ConvReLU2d),
            ns.call_method('dequantize')
        ]

        convReluConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nniq.ConvReLU2d),
            ns.call_module(nnq.Conv2d),
            ns.call_method('dequantize')
        ]

        functionalConvRelu_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d_relu),
            ns.call_method('dequantize')
        ]

        functionalConvReluConv_node_list = [
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_function(torch.ops.quantized.conv2d_relu),
            ns.call_function(torch.ops.quantized.conv2d),
            ns.call_method('dequantize')
        ]

        tests = [
            (SingleLayerLinearModel, 2, linear_node_list),
            (LinearAddModel, 2, linearAdd_node_list),
            (TwoLayerLinearModel, 2, linear2_node_list),
            (SingleLayerFunctionalLinearModel, 2, functionalLinear_node_list),
            (FunctionalLinearAddModel, 2, functionalLinearAdd_node_list),
            (TwoLayerFunctionalLinearModel, 2, functionalLinear2_node_list),
            (LinearReluModel, 2, linearRelu_node_list),
            (LinearReluLinearModel, 2, linearReluLinear_node_list),
            (FunctionalLinearReluModel, 2, functionalLinearRelu_node_list),
            (FunctionalLinearReluLinearModel, 2,
             functionalLinearReluLinear_node_list),
            (ConvModel, 4, conv_node_list),
            (TwoLayerConvModel, 4, conv2_node_list),
            (SingleLayerFunctionalConvModel, 4, functionalConv_node_list),
            (TwoLayerFunctionalConvModel, 4, functionalConv2_node_list),
            (ConvReluModel, 4, convRelu_node_list),
            (ConvReluConvModel, 4, convReluConv_node_list),
            (FunctionalConvReluModel, 4, functionalConvRelu_node_list),
            (FunctionalConvReluConvModel, 4, functionalConvReluConv_node_list)
        ]

        for (M, ndim, node_list) in tests:
            m = M().eval()

            if ndim == 2:
                x = torch.rand((5, 5))
            elif ndim == 4:
                x = torch.rand((16, 3, 224, 224))

            prepared = prepare_fx(
                m,
                qconfig_dict,
                equalization_qconfig_dict=default_equalization_qconfig_dict)
            prepared(x)
            equalized_quantized_model = convert_fx(prepared)

            # Check the order of nodes in the graph
            self.checkGraphModuleNodes(equalized_quantized_model,
                                       expected_node_list=node_list)
Example #3
0
def load(checkpoint_dir, model, **kwargs):
    """Execute the quantize process on the specified model.

    Args:
        checkpoint_dir (dir): The folder of checkpoint.
                              'best_configure.yaml' and 'best_model_weights.pt' are needed
                              in This directory. 'checkpoint' dir is under workspace folder
                              and workspace folder is define in configure yaml file.
        model (object): fp32 model need to do quantization.

    Returns:
        (object): quantized model
    """

    tune_cfg_file = os.path.join(
        os.path.abspath(os.path.expanduser(checkpoint_dir)),
        'best_configure.yaml')
    weights_file = os.path.join(
        os.path.abspath(os.path.expanduser(checkpoint_dir)),
        'best_model_weights.pt')
    assert os.path.exists(
        tune_cfg_file), "tune configure file %s didn't exist" % tune_cfg_file
    assert os.path.exists(
        weights_file), "weight file %s didn't exist" % weights_file

    with open(tune_cfg_file, 'r') as f:
        tune_cfg = yaml.safe_load(f)

    version = get_torch_version()
    if tune_cfg['approach'] != "post_training_dynamic_quant":
        if version < '1.7':
            q_mapping = tq.default_mappings.DEFAULT_MODULE_MAPPING
        elif version < '1.8':
            q_mapping = \
                tq.quantization_mappings.get_static_quant_module_mappings()
        else:
            q_mapping = \
                tq.quantization_mappings.get_default_static_quant_module_mappings()
    else:
        if version < '1.7':
            q_mapping = \
                tq.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING
        elif version < '1.8':
            q_mapping = \
                tq.quantization_mappings.get_dynamic_quant_module_mappings()
        else:
            q_mapping = \
                tq.quantization_mappings.get_default_dynamic_quant_module_mappings()

    if version < '1.7':
        white_list = \
            tq.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING \
            if tune_cfg['approach'] == 'post_training_dynamic_quant' else \
            tq.default_mappings.DEFAULT_QCONFIG_PROPAGATE_WHITE_LIST - \
            {torch.nn.LayerNorm, torch.nn.InstanceNorm3d, torch.nn.Embedding}
    elif version < '1.8':
        white_list = \
            tq.quantization_mappings.get_dynamic_quant_module_mappings() \
            if tune_cfg['approach'] == 'post_training_dynamic_quant' else \
            tq.quantization_mappings.get_qconfig_propagation_list() - \
            {torch.nn.LayerNorm, torch.nn.InstanceNorm3d, torch.nn.Embedding}
    else:
        white_list = \
            tq.quantization_mappings.get_default_dynamic_quant_module_mappings() \
            if tune_cfg['approach'] == 'post_training_dynamic_quant' else \
            tq.quantization_mappings.get_default_qconfig_propagation_list() - \
            {torch.nn.LayerNorm, torch.nn.InstanceNorm3d, torch.nn.Embedding}

    if tune_cfg['approach'] == "post_training_dynamic_quant":
        op_cfgs = _cfg_to_qconfig(tune_cfg, tune_cfg['approach'])
    else:
        op_cfgs = _cfg_to_qconfig(tune_cfg)

    if tune_cfg['framework'] == "pytorch_fx":  # pragma: no cover
        # For torch.fx approach
        assert version >= '1.8', \
                      "Please use PyTroch 1.8 or higher version with pytorch_fx backend"
        from torch.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx
        q_model = copy.deepcopy(model.eval())
        fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, tune_cfg['approach'])
        if tune_cfg['approach'] == "quant_aware_training":
            q_model.train()
            q_model = prepare_qat_fx(
                q_model,
                fx_op_cfgs,
                prepare_custom_config_dict=kwargs if kwargs != {} else None)
        else:
            q_model = prepare_fx(
                q_model,
                fx_op_cfgs,
                prepare_custom_config_dict=kwargs if kwargs != {} else None)
        q_model = convert_fx(q_model)
        weights = torch.load(weights_file)
        q_model.load_state_dict(weights)
        return q_model

    q_model = copy.deepcopy(model.eval())
    _propagate_qconfig(q_model,
                       op_cfgs,
                       white_list=white_list,
                       approach=tune_cfg['approach'])
    # sanity check common API misusage
    if not any(hasattr(m, 'qconfig') and m.qconfig for m in q_model.modules()):
        logger.warn(
            "None of the submodule got qconfig applied. Make sure you "
            "passed correct configuration through `qconfig_dict` or "
            "by assigning the `.qconfig` attribute directly on submodules")
    if tune_cfg['approach'] != "post_training_dynamic_quant":
        add_observer_(q_model)
    q_model = convert(q_model, mapping=q_mapping, inplace=True)
    weights = torch.load(weights_file)
    q_model.load_state_dict(weights)
    return q_model
Example #4
0
 def prep_qat_eval(self):
     self.model = quantize_fx.convert_fx(self.model)
     if self.jit:
         self.model = torch.jit.script(self.model)
     self.model.eval()
Example #5
0
    def _test_auto_tracing(
        self,
        m,
        qconfig,
        example_args,
        fuse_modules=True,
        do_fx_comparison=True,
        do_torchscript_checks=True,
    ):
        m_copy = copy.deepcopy(m)

        qconfig_dict = {'': qconfig}

        mp = _quantize_dbr.prepare(m,
                                   qconfig_dict,
                                   example_args,
                                   fuse_modules=fuse_modules)
        out_p = mp(*example_args)
        # print(mp)
        mq = _quantize_dbr.convert(mp)
        # print(mq)
        # verify it runs
        out_q = mq(*example_args)
        # print(out_q)

        # compare it against FX
        if do_fx_comparison:
            m_copy_p = prepare_fx(m_copy, {'': qconfig})
            out_m_copy_p = m_copy_p(*example_args)
            # print(m_copy_p)
            m_copy_q = convert_fx(m_copy_p)
            # print(m_copy_q)
            # print(m_copy_q.graph)
            out_q_fx = m_copy_q(*example_args)
            # print(out_q)
            # print(out_q_fx)
            self.assertTrue(_allclose(out_p, out_m_copy_p))
            # print(out_q)
            # print(out_q_fx)
            self.assertTrue(_allclose(out_q, out_q_fx))

        if do_torchscript_checks:
            # verify torch.jit.trace works
            mq_jit_traced = torch.jit.trace(mq,
                                            example_args,
                                            check_trace=False)
            # print(mq_jit_traced.graph)
            traced_out = mq_jit_traced(*example_args)
            self.assertTrue(_allclose(traced_out, out_q))

            # verify torch.jit.script works
            rewritten = mq.rewrite_for_scripting()
            rewritten_out = rewritten(*example_args)
            # print(rewritten)
            self.assertTrue(_allclose(rewritten_out, out_q))

            scripted_rewritten = torch.jit.script(rewritten)
            # print(scripted_rewritten.graph)
            scripted_rewritten_out = scripted_rewritten(*example_args)
            # print('scripted_rewritten_out', scripted_rewritten_out)
            self.assertTrue(_allclose(scripted_rewritten_out, out_q))

            traced_rewritten = torch.jit.trace(rewritten,
                                               example_args,
                                               check_trace=False)
            traced_rewritten_out = traced_rewritten(*example_args)
            self.assertTrue(_allclose(traced_rewritten_out, out_q))
Example #6
0
    def test_selective_equalization(self):
        """ Tests that we are able to run numeric suite on the equalized model
        and construct a valid equalization_qconfig_dict equalizing only the top
        4 layers with the highest quantization errors.
        """

        torch.manual_seed(1)

        class M(nn.Module):
            def __init__(self):
                super().__init__()
                self.bot = torch.nn.Sequential(torch.nn.Linear(5, 5))
                self.top = torch.nn.Sequential(torch.nn.Linear(5, 5))

            def forward(self, x):
                x = self.bot(x)
                x = torch.add(x, 5)
                x = self.top(x)
                return x

        float_model = M().eval()
        # Hard coded so that the top layer has a higher quantization error
        x = torch.tensor([[0.0642, 0.7824, 0.4255, 0.7106, 0.5957],
                          [0.8373, 0.8851, 0.8229, 0.0212, 0.8987],
                          [0.9077, 0.7538, 0.4530, 0.5772, 0.1376],
                          [0.0690, 0.9002, 0.7998, 0.2768, 0.8985],
                          [0.0282, 0.5068, 0.6725, 0.1829, 0.5480]])

        # Quantize the float model
        prepared_model = prepare_fx(copy.deepcopy(float_model),
                                    specific_qconfig_dict)
        prepared_model(x)
        quantized_model = convert_fx(copy.deepcopy(prepared_model))

        # Get the SQNR between the float and quantized model
        layer_to_sqnr_dict = get_layer_sqnr_dict(copy.deepcopy(prepared_model),
                                                 quantized_model, x)

        # Construct the equalization_qconfig_dict equalizing layers with the highest
        # quantization errors
        selective_equalization_qconfig_dict = get_equalization_qconfig_dict(
            layer_to_sqnr_dict, 1)

        # Create the selectively equalized model
        prepared_model = prepare_fx(
            copy.deepcopy(float_model),
            specific_qconfig_dict,
            equalization_qconfig_dict=selective_equalization_qconfig_dict,
        )
        prepared_model(x)
        equalized_model = convert_fx(prepared_model)

        node_list = [
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize'),
            ns.call_function(torch.add),
            ns.call_function(torch.mul),
            ns.call_function(torch.quantize_per_tensor),
            ns.call_module(nnq.Linear),
            ns.call_method('dequantize')
        ]

        # Check the order of nodes in the graph
        self.checkGraphModuleNodes(equalized_model,
                                   expected_node_list=node_list)
float_model.eval()

# deepcopy the model since we need to keep the original model around
model_to_quantize = copy.deepcopy(float_model)

model_to_quantize.eval()

qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}

prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
print(prepared_model.graph)

calibrate(prepared_model, data_loader_train)

quantized_model = convert_fx(prepared_model)
print(quantized_model)

script_module = torch.jit.trace(quantized_model, torch.randn(1, 3, 224, 224)).eval()
print(script_module)

torch._C._jit_pass_inline(script_module.graph)

print("Size of model before quantization")
print_size_of_model(float_model)
print("Size of model after quantization")
print_size_of_model(script_module)
top1, top5 = evaluate(script_module, criterion, data_loader_test)
print("[before serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg))

top1, top5 = evaluate(float_model, criterion, data_loader_test)
Example #8
0
        super(Net, self).__init__()
        repo = 'alantess/vigilant-driving:main/1.0.75'
        self.model = torch.hub.load(repo, 'segnet', pretrained=True)

    def forward(self, x):
        x = self.model(x).squeeze(0).argmax(0)
        return x.mul(100).clamp(0, 255)


model_fp = Net()
model_fp.eval()

model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

# Fusion
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)

# Save model
scipted_model = torch.jit.script(model_fused)
scripted_optimized_moodel = optimize_for_mobile(scipted_model)
torch.jit.save(scripted_optimized_moodel, "models/segnet_fx_mobile.pt")
Example #9
0
def main(args):
    # data
    train_transform = tv.transforms.Compose([])
    if args.data_augmentation:
        train_transform.transforms.append(
            tv.transforms.RandomCrop(32, padding=4))
        train_transform.transforms.append(tv.transforms.RandomHorizontalFlip())
    train_transform.transforms.append(tv.transforms.ToTensor())
    normalize = tv.transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
    train_transform.transforms.append(normalize)

    test_transform = tv.transforms.Compose(
        [tv.transforms.ToTensor(), normalize])

    train_dataset = tv.datasets.CIFAR10(root='data/',
                                        train=True,
                                        transform=train_transform,
                                        download=True)

    test_dataset = tv.datasets.CIFAR10(root='data/',
                                       train=False,
                                       transform=test_transform,
                                       download=True)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.bs,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=4)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=args.bs,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=4)

    # net
    net = tv.models.mobilenet_v2(num_classes=10)
    net.load_state_dict(torch.load('mobilenet_v2.pth', map_location='cpu'))
    net.dropout = torch.nn.Sequential()

    # quantization
    model_to_quantize = copy.deepcopy(net).to(device)
    qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('fbgemm')}
    model_to_quantize.train()
    model_prepared = prepare_qat_fx(model_to_quantize, qconfig_dict)
    # optimizer and loss
    optimizer = torch.optim.SGD(model_prepared.parameters(),
                                lr=0.01,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 16], 0.1)
    criterion = torch.nn.CrossEntropyLoss().to(device)
    # train
    model_prepared.train()
    for i_epoch in range(20):
        time_s = time.time()
        for i_iter, data in enumerate(train_loader):
            img, label = data
            img, label = img.to(device), label.to(device)

            optimizer.zero_grad()
            feat = model_prepared(img)
            loss = criterion(feat, label)
            loss.backward()
            optimizer.step()
            time_e = time.time()

            print(
                'Epoch:{:3}/20 || Iter: {:4}/{} || '
                'Loss: {:2.4f} '
                'ETA: {:.2f}min'.format(i_epoch + 1, i_iter + 1,
                                        len(train_loader), loss.item(),
                                        (time_e - time_s) * (20 - i_epoch) *
                                        len(train_loader) / (i_iter + 1) / 60))
        scheduler.step()

    # to int8
    model_int8 = convert_fx(model_prepared)
    torch.jit.save(torch.jit.script(model_int8), 'int8-qat.pth')

    # valid
    loaded_quantized_model = torch.jit.load('int8-qat.pth')
    correct = 0.
    total = 0.
    with torch.no_grad():
        loaded_quantized_model.eval()
        for images, labels in tqdm(test_loader):
            images = images
            labels = labels

            pred = loaded_quantized_model(images)

            pred = torch.max(pred.data, 1)[1]
            total += labels.size(0)
            correct += (pred == labels).sum().item()

        val_acc = correct / total
        print(val_acc)
 def convert(self, model, submodules, attrs):
     model.another_layer = convert_fx(model.another_layer)
     return model
Example #11
0
 def prepare_for_quant_convert(self, cfg):
     self.avgpool = convert_fx(
         self.avgpool,
         convert_custom_config_dict=self.custom_config_dict)
     return self
Example #12
0
def quantize_fx(model, inputs, data_loader, dynamic=True, selective=False):

    if hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder):

        static = not dynamic

        if dynamic:
            qconfig = per_channel_dynamic_qconfig
        else:
            qconfig = QConfig(
                activation=HistogramObserver.with_args(reduce_range=False),
                weight=default_weight_observer,
            )

        # Only linear layers
        qconfig_dict = {"": None}
        if static and selective:
            qconfig_dict["module_name"] = []
            layers = model.encoder.encoder.transformer.layers.layers.layers
            layers_str = "layers"
            # skip first layer
            for layer_idx in range(1, len(layers)):
                qconfig_dict["module_name"].append((
                    layers_str +
                    ".{}.attention.input_projection".format(layer_idx),
                    qconfig,
                ))
                qconfig_dict["module_name"].append((
                    layers_str +
                    ".{}.attention.output_projection".format(layer_idx),
                    qconfig,
                ))
                for mlp_idx, m in enumerate(
                        layers[layer_idx].residual_mlp.mlp):
                    # Only quantize first linear otherwise there are accuarcy issues with static quantization
                    if type(m) == torch.nn.Linear and mlp_idx < 1:
                        qconfig_dict["module_name"].append((
                            layers_str + ".{}.residual_mlp.mlp.{}".format(
                                layer_idx, mlp_idx),
                            qconfig,
                        ))
        else:
            qconfig_dict["object_type"] = [(torch.nn.Linear, qconfig)]

        def calibrate(model, loader, max_samples=-1):
            model.eval()
            with torch.no_grad():
                for (idx, d) in enumerate(loader):
                    print("Running sample input #" + str(idx))
                    model(d[1]["tokens"])
                    if idx == max_samples:
                        break

        prepared_model = prepare_fx(
            model.encoder.encoder.transformer.layers.layers,
            qconfig_dict)  # fuse modules and insert observers

        model.encoder.encoder.transformer.layers.layers = prepared_model
        if static:
            calibrate(model, data_loader)  # run calibration on sample data
        model.encoder.encoder.transformer.layers.layers = convert_fx(
            prepared_model)

        # Trace the submodule in order to fix the interface
        if static:
            input1 = torch.randn([2, 1, 1024], dtype=torch.float)
            input2 = torch.randn([1, 2]).bool()
            traced = torch.jit.trace(
                model.encoder.encoder.transformer.layers.layers,
                (input1, input2))
            model.encoder.encoder.transformer.layers.layers = traced

        # Trace the overall module
        trace = model.trace(inputs)

        return trace