Exemplo n.º 1
0
    output_shape = transformer.log.getOutShapes()
    if args.quantize:
        if args.distill_range:
            targ_layer = [QConv2d, QLinear]
        elif args.trainable:
            targ_layer = [QuantConv2d, QuantLinear]
        else:
            targ_layer = [QuantNConv2d, QuantNLinear]
    else:
        targ_layer = [torch.nn.Conv2d, torch.nn.Linear]

    if args.quantize:
        set_layer_bits(graph, args.bits_weight, args.bits_activation,
                       args.bits_bias, targ_layer)

    net = merge_batchnorm(net, graph, bottoms, targ_layer)

    #create relations
    if args.equalize or args.distill_range:
        res = create_relation(graph,
                              bottoms,
                              targ_layer,
                              delete_single=not args.distill_range)
        if args.equalize:
            cross_layer_equalization(graph,
                                     res,
                                     targ_layer,
                                     visualize_state=False,
                                     converge_thres=2e-7,
                                     s_range=(1 / args.equal_range,
                                              args.equal_range))
Exemplo n.º 2
0
def main():
    args = get_argument()
    assert args.relu or args.relu == args.equalize, 'must replace relu6 to relu while equalization'
    assert args.equalize or args.absorption == args.equalize, 'must use absorption with equalize'
    data = torch.ones((4, 3, 224, 224))  #.cuda()

    model = mobilenet_v2(
        'modeling/classification/mobilenetv2_1.0-f2a8633.pth.tar')
    model.eval()

    transformer = TorchTransformer()
    module_dict = {}
    if args.quantize:
        if args.trainable:
            module_dict[1] = [(nn.Conv2d, QuantConv2d),
                              (nn.Linear, QuantLinear)]
        else:
            module_dict[1] = [(nn.Conv2d, QuantNConv2d),
                              (nn.Linear, QuantNLinear)]

    if args.relu:
        module_dict[0] = [(torch.nn.ReLU6, torch.nn.ReLU)]

    # transformer.summary(model, data)
    # transformer.visualize(model, data, 'graph_cls', graph_size=120)

    model, transformer = switch_layers(model,
                                       transformer,
                                       data,
                                       module_dict,
                                       ignore_layer=[QuantMeasure],
                                       quant_op=args.quantize)

    graph = transformer.log.getGraph()
    bottoms = transformer.log.getBottoms()
    output_shape = transformer.log.getOutShapes()
    if args.quantize:
        if args.trainable:
            targ_layer = [QuantConv2d, QuantLinear]
        else:
            targ_layer = [QuantNConv2d, QuantNLinear]
    else:
        targ_layer = [nn.Conv2d, nn.Linear]

    model = merge_batchnorm(model, graph, bottoms, targ_layer)

    #create relations
    if args.equalize:
        res = create_relation(graph, bottoms, targ_layer)
        cross_layer_equalization(graph,
                                 res,
                                 targ_layer,
                                 visualize_state=False,
                                 converge_thres=2e-7)

    if args.absorption:
        bias_absorption(graph, res, bottoms, 3)

    if args.clip_weight:
        clip_weight(graph, range_clip=[-15, 15], targ_type=targ_layer)

    if args.correction:
        bias_correction(graph, bottoms, targ_layer)

    if args.quantize:
        if not args.trainable:
            graph = quantize_targ_layer(graph, targ_layer)
        set_quant_minmax(graph, bottoms, output_shape)

    model = model.cuda()
    model.eval()

    if args.quantize:
        replace_op()
    acc = inference_all(model)
    print("Acc: {}".format(acc))
    if args.quantize:
        restore_op()
Exemplo n.º 3
0
def main():
    args = get_argument()
    assert args.relu or args.relu == args.equalize, 'must replace relu6 to relu while equalization'
    assert args.equalize or args.absorption == args.equalize, 'must use absorption with equalize'

    data = torch.ones((4, 3, 224, 224))  #.cuda()

    if args.resnet:
        import torchvision.models as models
        model = models.resnet18(pretrained=True)
    else:
        model = mobilenet_v2(
            'modeling/classification/mobilenetv2_1.0-f2a8633.pth.tar')
    model.eval()

    if args.distill_range:
        import copy
        # define FP32 model
        model_original = copy.deepcopy(model)
        model_original.eval()
        transformer = TorchTransformer()
        transformer._build_graph(model_original, data, [QuantMeasure])
        graph = transformer.log.getGraph()
        bottoms = transformer.log.getBottoms()

        if not args.true_data:
            data_distill = getDistilData(model_original, 'imagenet', args.dis_batch_size, bn_merged=False,\
                num_batch=args.dis_num_batch, gpu=True, value_range=[-2.11790393, 2.64], size=[224, 224], early_break_factor=1.2 if args.resnet else 0.5)
        else:
            imagenet_dataset = datasets.ImageFolder(
                '/home/jakc4103/windows/Toshiba/workspace/dataset/ILSVRC/Data/CLS-LOC/train',
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
                ]))
            data_distill = []
            dataloader = DataLoader(imagenet_dataset,
                                    batch_size=args.dis_batch_size,
                                    shuffle=True,
                                    num_workers=4,
                                    pin_memory=True)
            for idx, sample in enumerate(dataloader):
                if idx >= args.dis_num_batch:
                    break
                image = sample[0]
                data_distill.append(image)
            del dataloader, imagenet_dataset

    transformer = TorchTransformer()
    module_dict = {}
    if args.quantize:
        if args.distill_range:
            module_dict[1] = [(nn.Conv2d, QConv2d), (nn.Linear, QLinear)]
        elif args.trainable:
            module_dict[1] = [(nn.Conv2d, QuantConv2d),
                              (nn.Linear, QuantLinear)]
        else:
            module_dict[1] = [(nn.Conv2d, QuantNConv2d),
                              (nn.Linear, QuantNLinear)]

    if args.relu:
        module_dict[0] = [(torch.nn.ReLU6, torch.nn.ReLU)]

    # transformer.summary(model, data)
    # transformer.visualize(model, data, 'graph_cls', graph_size=120)

    model, transformer = switch_layers(model,
                                       transformer,
                                       data,
                                       module_dict,
                                       ignore_layer=[QuantMeasure],
                                       quant_op=args.quantize)

    graph = transformer.log.getGraph()
    bottoms = transformer.log.getBottoms()
    if args.quantize:
        if args.distill_range:
            targ_layer = [QConv2d, QLinear]
        elif args.trainable:
            targ_layer = [QuantConv2d, QuantLinear]
        else:
            targ_layer = [QuantNConv2d, QuantNLinear]
    else:
        targ_layer = [nn.Conv2d, nn.Linear]

    if args.quantize:
        set_layer_bits(graph, args.bits_weight, args.bits_activation,
                       args.bits_bias, targ_layer)

    model = merge_batchnorm(model, graph, bottoms, targ_layer)

    #create relations
    if args.equalize or args.distill_range:
        res = create_relation(graph, bottoms, targ_layer, delete_single=False)
        if args.equalize:
            cross_layer_equalization(graph,
                                     res,
                                     targ_layer,
                                     visualize_state=False,
                                     converge_thres=2e-7)

        # if args.distill:
        #     set_scale(res, graph, bottoms, targ_layer)

    if args.absorption:
        bias_absorption(graph, res, bottoms, 3)

    if args.clip_weight:
        clip_weight(graph, range_clip=[-15, 15], targ_type=targ_layer)

    if args.correction:
        # if args.distill:
        #     model_original = copy.deepcopy(model.cpu())
        #     model_original.eval()
        #     transformer = TorchTransformer()
        #     transformer.register(targ_layer[0], nn.Conv2d)
        #     transformer.register(targ_layer[1], nn.Linear)
        #     model_original = transformer.trans_layers(model_original, update=True)

        #     bias_correction_distill(model, model_original, data_distill, targ_layer, [nn.Conv2d, nn.Linear])
        # else:
        bias_correction(graph,
                        bottoms,
                        targ_layer,
                        bits_weight=args.bits_weight)

    if args.quantize:
        if not args.trainable and not args.distill_range:
            graph = quantize_targ_layer(graph, args.bits_weight,
                                        args.bits_bias, targ_layer)

        if args.distill_range:
            set_update_stat(model, [QuantMeasure], True)
            model = update_quant_range(model.cuda(), data_distill, graph,
                                       bottoms)
            set_update_stat(model, [QuantMeasure], False)
        else:
            set_quant_minmax(graph, bottoms)

        torch.cuda.empty_cache()

    # if args.distill:
    #     model = update_scale(model, model_original, data_distill, graph, bottoms, res, targ_layer, num_epoch=1000)
    #     set_quant_minmax(graph, bottoms)

    model = model.cuda()
    model.eval()

    if args.quantize:
        replace_op()
    acc = inference_all(model)
    print("Acc: {}".format(acc))
    if args.quantize:
        restore_op()
    if args.log:
        with open("cls_result.txt", 'a+') as ww:
            ww.write(
                "resnet: {}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill_range: {}\n"
                .format(args.resnet, args.quantize, args.relu, args.equalize,
                        args.absorption, args.correction, args.clip_weight,
                        args.distill_range))
            ww.write("Acc: {}\n\n".format(acc))
Exemplo n.º 4
0
def main():
    args = get_argument()
    assert args.relu or args.relu == args.equalize, 'must replace relu6 to relu while equalization'
    assert args.equalize or args.absorption == args.equalize, 'must use absorption with equalize'
    data = torch.ones((4, 3, 513, 513))#.cuda()

    model = DeepLab(sync_bn=False)
    state_dict = torch.load('modeling/segmentation/deeplab-mobilenet.pth.tar')['state_dict']
    model.load_state_dict(state_dict)
    model.eval()
    if args.distill_range:
        import copy
        # define FP32 model 
        model_original = copy.deepcopy(model)
        model_original.eval()
        transformer = TorchTransformer()
        transformer._build_graph(model_original, data, [QuantMeasure])
        graph = transformer.log.getGraph()
        bottoms = transformer.log.getBottoms()
    
        data_distill = getDistilData(model_original, 'imagenet', 32, bn_merged=False,\
            num_batch=8, gpu=True, value_range=[-2.11790393, 2.64], size=[513, 513], early_break_factor=0.2)

    transformer = TorchTransformer()

    module_dict = {}
    if args.quantize:
        if args.distill_range:
            module_dict[1] = [(nn.Conv2d, QConv2d)]
        elif args.trainable:
            module_dict[1] = [(nn.Conv2d, QuantConv2d)]
        else:
            module_dict[1] = [(nn.Conv2d, QuantNConv2d)]
    
    if args.relu:
        module_dict[0] = [(torch.nn.ReLU6, torch.nn.ReLU)]

    # transformer.summary(model, data)
    # transformer.visualize(model, data, 'graph_deeplab', graph_size=120)

    model, transformer = switch_layers(model, transformer, data, module_dict, ignore_layer=[QuantMeasure], quant_op=args.quantize)
    graph = transformer.log.getGraph()
    bottoms = transformer.log.getBottoms()

    if args.quantize:
        if args.distill_range:
            targ_layer = [QConv2d]
        elif args.trainable:
            targ_layer = [QuantConv2d]
        else:
            targ_layer = [QuantNConv2d]
    else:
        targ_layer = [nn.Conv2d]
    if args.quantize:
        set_layer_bits(graph, args.bits_weight, args.bits_activation, args.bits_bias, targ_layer)
    model = merge_batchnorm(model, graph, bottoms, targ_layer)

    #create relations
    if args.equalize or args.distill_range:
        res = create_relation(graph, bottoms, targ_layer)
        if args.equalize:
            cross_layer_equalization(graph, res, targ_layer, visualize_state=False)

        # if args.distill:
        #     set_scale(res, graph, bottoms, targ_layer)

    if args.absorption:
        bias_absorption(graph, res, bottoms, 3)
    
    if args.clip_weight:
        clip_weight(graph, range_clip=[-15, 15], targ_type=targ_layer)

    if args.correction:
        bias_correction(graph, bottoms, targ_layer)

    if args.quantize:
        if not args.trainable and not args.distill_range:
            graph = quantize_targ_layer(graph, args.bits_weight, args.bits_bias, targ_layer)
        
        if args.distill_range:
            set_update_stat(model, [QuantMeasure], True)
            model = update_quant_range(model.cuda(), data_distill, graph, bottoms)
            set_update_stat(model, [QuantMeasure], False)
        else:
            set_quant_minmax(graph, bottoms)

        torch.cuda.empty_cache()
    
    model = model.cuda()
    model.eval()

    if args.quantize:
        replace_op()
    inference_all(model, args.dataset, args if args.log else None)
    if args.quantize:
        restore_op()
Exemplo n.º 5
0
def main():
    args = get_argument()
    # An instance of your model
    if args.resnet:
        import torchvision.models as models
        model = models.resnet18(pretrained=True)
        model = ProbModel(model)
    else:
        model = mobilenet_v2(
            'modeling/classification/mobilenetv2_1.0-f2a8633.pth.tar')
        model = ProbModel(model)
    model.eval()

    if args.quantize:
        data = torch.ones((4, 3, 224, 224))  #.cuda()

        if args.distill_range:
            import copy
            # define FP32 model
            model_original = copy.deepcopy(model)
            model_original.eval()
            transformer = TorchTransformer()
            transformer._build_graph(model_original, data, [QuantMeasure])
            graph = transformer.log.getGraph()
            bottoms = transformer.log.getBottoms()

            data_distill = getDistilData(model_original, 'imagenet', args.dis_batch_size, bn_merged=False,\
                num_batch=args.dis_num_batch, gpu=True, value_range=[-2.11790393, 2.64], size=[224, 224], early_break_factor=1.2 if args.resnet else 0.5)

        transformer = TorchTransformer()
        module_dict = {}

        if args.distill_range:
            module_dict[1] = [(torch.nn.Conv2d, QConv2d),
                              (torch.nn.Linear, QLinear)]
        else:
            module_dict[1] = [(torch.nn.Conv2d, QuantNConv2d),
                              (torch.nn.Linear, QuantNLinear)]

        if args.relu or args.equalize:
            module_dict[0] = [(torch.nn.ReLU6, torch.nn.ReLU)]

        # transformer.summary(model, data)
        # transformer.visualize(model, data, 'graph_cls', graph_size=120)

        model, transformer = switch_layers(model,
                                           transformer,
                                           data,
                                           module_dict,
                                           ignore_layer=[QuantMeasure],
                                           quant_op=True)

        graph = transformer.log.getGraph()
        bottoms = transformer.log.getBottoms()
        if args.distill_range:
            targ_layer = [QConv2d, QLinear]
        else:
            targ_layer = [QuantNConv2d, QuantNLinear]

        set_layer_bits(graph, args.bits_weight, args.bits_activation,
                       args.bits_bias, targ_layer)

        model = merge_batchnorm(model, graph, bottoms, targ_layer)

        #create relations
        if args.equalize or args.distill_range:
            res = create_relation(graph,
                                  bottoms,
                                  targ_layer,
                                  delete_single=False)
            if args.equalize:
                cross_layer_equalization(graph,
                                         res,
                                         targ_layer,
                                         visualize_state=False,
                                         converge_thres=2e-7,
                                         signed=True)

        if args.clip_weight:
            clip_weight(graph, range_clip=[-15, 15], targ_type=targ_layer)

        if args.correction:
            bias_correction(graph,
                            bottoms,
                            targ_layer,
                            bits_weight=args.bits_weight,
                            signed=True)

        if args.distill_range:
            set_update_stat(model, [QuantMeasure], True)
            model = update_quant_range(model.cuda(), data_distill, graph,
                                       bottoms)
            set_update_stat(model, [QuantMeasure], False)
        else:
            set_quant_minmax(graph, bottoms)

        torch.cuda.empty_cache()

        # restore custom conv layer to torch.nn.conv2d
        module_dict = {}
        if args.distill_range:
            module_dict[1] = [(QConv2d, torch.nn.Conv2d),
                              (QLinear, torch.nn.Linear)]
        else:
            module_dict[1] = [(QuantNConv2d, torch.nn.Conv2d),
                              (QuantNLinear, torch.nn.Linear)]

        model, transformer = switch_layers(model,
                                           transformer,
                                           data,
                                           module_dict,
                                           ignore_layer=[QuantMeasure],
                                           quant_op=False)
        graph = transformer.log.getGraph()
        bottoms = transformer.log.getBottoms()

    # An example input you would normally provide to your model's forward() method
    x = torch.rand(1, 3, 224, 224)

    # Export the onnx model
    torch_out = torch.onnx._export(model, x, "model.onnx", export_params=True)

    # Simplify model using onnx-simplifier
    os.system("python3 -m onnxsim model.onnx model-sim.onnx")
    os.system("rm model.onnx")

    cur_path = os.path.abspath(os.getcwd())
    os.system("mv model-sim.onnx {}".format(
        os.path.join(args.ncnn_build, 'tools/onnx', 'model-sim.onnx')))
    os.chdir(os.path.join(args.ncnn_build, 'tools/onnx'))

    # Convert onnx to ncnn
    os.system("./onnx2ncnn model-sim.onnx model.param model.bin")

    # Add input image size to .param
    lines = [line.strip() for line in open("model.param", "r")]
    with open("model.param", 'w') as ww:
        for idx, line in enumerate(lines):
            if idx == 2 and 'input' in line.lower():
                line += ' 0=224 1=224 2=3'
            ww.write(line + '\n')

    if not os.path.exists(os.path.join(cur_path, 'modeling/ncnn')):
        os.makedirs(os.path.join(cur_path, 'modeling/ncnn'))

    os.system("rm model-sim.onnx")

    if args.quantize:
        os.system("mv model.param {}".format(
            os.path.join(args.ncnn_build, 'tools/quantize', 'model.param')))
        os.system("mv model.bin {}".format(
            os.path.join(args.ncnn_build, 'tools/quantize', 'model.bin')))
        os.chdir(os.path.join(args.ncnn_build, 'tools/quantize'))

        # Estimate activation range using https://github.com/Tencent/ncnn/tree/master/tools/quantize
        os.system("./ncnn2table --param=model.param --bin=model.bin\
                --images={} --output=model_int8_channel.table\
                --mean={},{},{} --norm={},{},{} --size=224,224 --thread=2".
                  format(args.image_path, 0.485 * 255, 0.456 * 255,
                         0.406 * 255, 1 / (0.229 * 255), 1 / (0.224 * 255),
                         1 / (0.225 * 255)))

        # modify activation min/max range and weight min/max range to values calculated in DFQ
        table_old = [
            line.strip() for line in open("model_int8_channel.table", 'r')
        ]
        table_new = []
        count = 0
        for ii in range(2):
            for idx in graph:
                if type(graph[idx]) in [torch.nn.Conv2d, torch.nn.Linear]:
                    if ii == 0:  #min/max for layer weight
                        mi = float(torch.min(graph[idx].weight))
                        ma = float(torch.max(graph[idx].weight))
                    else:
                        mi = float(torch.min(graph[idx].quant.running_min))
                        ma = float(torch.max(graph[idx].quant.running_max))
                    scale = 128. / (max(abs(ma), abs(mi)))

                    if ii == 0:  #min/max for activation
                        table_new.append(
                            ' '.join(table_old[count].split(' ')[0:1] +
                                     [str(scale)] *
                                     graph[idx].weight.shape[0]))
                    else:
                        table_new.append(
                            ' '.join(table_old[count].split(' ')[0:1] +
                                     [str(scale)]))
                    count += 1

        with open("model_int8_tensor.table", 'w') as ww:
            for line in table_new:
                ww.write(line + '\n')

        # Convert to Int8 model
        os.system(
            "./ncnn2int8 model.param model.bin model_int8.param model_int8.bin model_int8_tensor.table"
        )
        lines = [line.strip() for line in open("model_int8.param", "r")]

        os.system("cp model_int8.param {}".format(
            os.path.join(cur_path, args.param)))
        os.system("cp model_int8.bin {}".format(
            os.path.join(cur_path, args.bin)))
        os.system("cp model_int8_tensor.table {}".format(
            os.path.join(cur_path, args.table)))
    else:
        os.system("mv model.param {}".format(os.path.join(
            cur_path, args.param)))
        os.system("mv model.bin {}".format(os.path.join(cur_path, args.bin)))

    os.chdir(cur_path)
    line = ' '.join([l.strip()
                     for l in open(args.param, 'r')][-1].split()).split(' ')[1]
    print("=" * 100)
    print("Target layer name '{}'".format(line))
    print("=" * 100)
Exemplo n.º 6
0
def main():
    args = get_argument()

    data = torch.ones((4, 3, 224, 224))  #.cuda()

    if args.model == 'resnet50':
        model = models.resnet50(pretrained=True)
    elif args.model == 'inceptionv3':
        model = models.inception_v3(pretrained=True)
    elif args.model == 'mobilenetv2':
        from modeling.classification import MobileNetV2
        model = MobileNetV2.mobilenet_v2(pretrained=True)
    else:
        assert False, 'Model type not supported'

    model = QuantModel(model, args.bits_activation)

    model.eval()

    transformer = TorchTransformer()
    module_dict = {}
    if args.quantize:
        module_dict[1] = [(nn.Conv2d, QuantNConv2d),\
                            (nn.Linear, QuantNLinear),\
                            (nn.AdaptiveAvgPool2d, QuantAdaptiveAvgPool2d),\
                            (nn.MaxPool2d, QuantMaxPool2d)]

    # transformer.summary(model, data)
    # transformer.visualize(model, data, 'graph_cls', graph_size=120)

    model, transformer = switch_layers(model,
                                       transformer,
                                       data,
                                       module_dict,
                                       ignore_layer=[QuantMeasure],
                                       quant_op=args.quantize)

    graph = transformer.log.getGraph()
    bottoms = transformer.log.getBottoms()
    if args.quantize:
        targ_layer = [QuantNConv2d, QuantNLinear]
    else:
        targ_layer = [nn.Conv2d, nn.Linear]

    model = merge_batchnorm(model, graph, bottoms, targ_layer)

    if args.quantize:
        set_layer_bits(graph,
                       args.bits_weight,
                       args.bits_activation,
                       targ_type=targ_layer)

    if args.quantize:
        print("preparing data for computing activation min/max range")
        trans = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])
        if not os.path.exists("_512_train.txt"):
            print("Creating _512_train.txt, this will take some time...")
            from utils import default_loader
            imagenet_dataset = datasets.ImageFolder(os.path.join(
                args.imagenet_path, 'train'),
                                                    trans,
                                                    loader=default_loader)

            np.random.seed(1000)
            perm_idx = np.random.permutation(len(imagenet_dataset))
            images = []
            for i in range(512):
                images.append(imagenet_dataset[perm_idx[i]][0].unsqueeze(0))

            del imagenet_dataset
        else:
            from PIL import Image
            images = []
            for line in open("_512_train.txt", 'r'):
                line = line.strip()
                with open(line, 'rb') as f:
                    img = Image.open(f)
                    img = img.convert('RGB')

                images.append(trans(img).unsqueeze(0))

        set_update_stat(model, True)
        model = set_quant_minmax_data(model, images, [QuantMeasure])
        set_update_stat(model, False)

        graph = quantize_targ_layer(graph,
                                    args.bits_weight,
                                    targ_type=targ_layer,
                                    quant_type=args.qtype)

    model = model.cuda()
    model.eval()

    acc = inference_all(model, os.path.join(args.imagenet_path, 'val'))
    print("Acc: {}".format(acc))

    if args.log:
        with open("cls_result.txt", 'a+') as ww:
            ww.write(
                "model: {}, quant: {}, qtype: {}, bits_weight: {}, correction: {}\n"
                .format(args.model, args.quantize, args.qtype,
                        args.bits_weight, args.correction))
            ww.write("Acc: {}\n\n".format(acc))
Exemplo n.º 7
0
def main():
    args = get_argument()
    assert args.relu or args.relu == args.equalize, 'must replace relu6 to relu while equalization'
    assert args.equalize or args.absorption == args.equalize, 'must use absorption with equalize'
    data = torch.ones((4, 3, 513, 513))  #.cuda()

    model = DeepLab(sync_bn=False)
    state_dict = torch.load(
        'modeling/segmentation/deeplab-mobilenet.pth.tar')['state_dict']
    model.load_state_dict(state_dict)
    model.eval()

    transformer = TorchTransformer()

    module_dict = {}
    if args.quantize:
        if args.trainable:
            module_dict[1] = [(nn.Conv2d, QuantConv2d)]
        else:
            module_dict[1] = [(nn.Conv2d, QuantNConv2d)]

    if args.relu:
        module_dict[0] = [(torch.nn.ReLU6, torch.nn.ReLU)]

    # transformer.summary(model, data)
    # transformer.visualize(model, data, 'graph_deeplab', graph_size=120)

    model, transformer = switch_layers(model,
                                       transformer,
                                       data,
                                       module_dict,
                                       ignore_layer=[QuantMeasure],
                                       quant_op=args.quantize)
    graph = transformer.log.getGraph()
    bottoms = transformer.log.getBottoms()
    output_shape = transformer.log.getOutShapes()

    if args.quantize:
        if args.trainable:
            targ_layer = [QuantConv2d]
        else:
            targ_layer = [QuantNConv2d]
    else:
        targ_layer = [nn.Conv2d]
    model = merge_batchnorm(model, graph, bottoms, targ_layer)

    #create relations
    if args.equalize:
        res = create_relation(graph, bottoms, targ_layer)
        cross_layer_equalization(graph, res, targ_layer, visualize_state=False)

    if args.absorption:
        bias_absorption(graph, res, bottoms, 3)

    if args.clip_weight:
        clip_weight(graph, range_clip=[-15, 15], targ_type=targ_layer)

    if args.correction:
        bias_correction(graph, bottoms, targ_layer)

    if args.quantize:
        if not args.trainable:
            graph = quantize_targ_layer(graph, targ_layer)
        set_quant_minmax(graph, bottoms, output_shape)

    model = model.cuda()
    model.eval()

    if args.quantize:
        replace_op()
    inference_all(model, args.dataset)
    if args.quantize:
        restore_op()