示例#1
0
def estimate_stats(model,
                   state_dict,
                   data,
                   num_epoch=10,
                   path_save='modeling/data_dependent_QuantConv2dAdd.pth'):
    import copy

    # model = DeepLab(sync_bn=False)
    model.eval()

    model = model.cuda()

    args = lambda: 0
    args.base_size = 513
    args.crop_size = 513
    voc_val = VOCSegmentation(args, split='train')
    dataloader = DataLoader(voc_val,
                            batch_size=32,
                            shuffle=True,
                            num_workers=0)
    model.train()

    replace_op()
    ss = time.time()
    with torch.no_grad():
        for epoch in range(num_epoch):
            start = time.time()
            for sample in dataloader:
                image, _ = sample['image'].cuda(), sample['label'].cuda()

                _ = model(image)

            end = time.time()
            print("epoch {}: {} sec.".format(epoch, end - start))
    print('total time: {} sec'.format(time.time() - ss))
    restore_op()

    # load 'running_mean' and 'running_var' of batchnorm back from pre-trained parameters
    bn_dict = {}
    for key in state_dict:
        if 'running' in key:
            bn_dict[key] = state_dict[key]

    state = model.state_dict()
    state.update(bn_dict)
    model.load_state_dict(state)

    torch.save(model.state_dict(), path_save)

    return model
示例#2
0
文件: improve_dfq.py 项目: wtsitp/DFQ
def update_quant_range(model, data, graph, bottoms, is_detection=False):
    with torch.no_grad():
        replace_op()
        for batch_data in data:
            batch_data = batch_data.cuda()
            _ = model(batch_data)
        restore_op()
        for idx in graph:
            if bottoms[idx] is None:
                continue
            if bottoms[idx][0] == "Data":
                if not is_detection:
                    graph[idx].quant.running_max.fill_(2.64)
                    graph[idx].quant.running_min.fill_(-2.11790393)
                else:
                    graph[idx].quant.running_max.fill_(1)
                    graph[idx].quant.running_min.fill_(-1)
    return model
示例#3
0
文件: improve_dfq.py 项目: wtsitp/DFQ
def bias_correction_distill(qmodel, model_original, data, targ_type,
                            targ_type_original):
    """!
    do bias correction based on distilled data
    """
    qmodel = qmodel.cuda().eval()
    model_original = model_original.cuda().eval()
    hooks = []
    hooks_original = []
    hook_handles = []

    for name, module in qmodel.named_modules():
        if type(module) in targ_type:
            hook = ModuleHook()
            hooks.append(hook)
            hook_handles.append(module.register_forward_hook(hook.hook))

    for name, module in model_original.named_modules():
        if type(module) in targ_type_original:
            hook = ModuleHook()
            hooks_original.append(hook)
            hook_handles.append(module.register_forward_hook(hook.hook))

    error_list = {}
    assert len(hooks) == len(
        hooks_original), "len of hooks in 2 models must be the same"
    with torch.no_grad():
        for b, batch_data in enumerate(data):

            for hook in hooks:
                hook.clear()

            for hook in hooks_original:
                hook.clear()
            batch_data = batch_data.cuda()
            replace_op()
            out = qmodel(batch_data)
            restore_op()

            out = model_original(batch_data)
            for idx in range(len(hooks)):
                # print("Hook {}, error mean: {}, error sum: {}".format(idx, (hooks_original[idx].outputs.mean(0) - hooks[idx].outputs.mean(0)).cpu().mean(), (hooks_original[idx].outputs.mean(0) - hooks[idx].outputs.mean(0)).cpu().sum()))
                if b == 0:
                    error_list[idx] = [
                        hooks[idx].outputs.mean(0).cpu(),
                        hooks_original[idx].outputs.mean(0).cpu()
                    ]
                else:
                    error_list[idx][0] += (hooks[idx].outputs.mean(0)).cpu()
                    error_list[idx][1] += (
                        hooks_original[idx].outputs.mean(0)).cpu()

                # error_list[idx].append((hooks[idx].outputs - hooks_original[idx].outputs).cpu())

        for idx, hook in enumerate(hooks):
            module = hook.module
            error = (error_list[idx][0] - error_list[idx][1]) / len(data)
            # print("Hook: {}, error_sum: {}, error_mean: {}".format(idx, error.sum(), error.mean()))
            # for idx_error in range(1, len(error_list[idx])):
            # error += error_list[idx][idx_error]
            error = error.view(error.size(0), -1).sum(-1)
            if not hasattr(module, "bias") or getattr(module, "bias") is None:
                module.bias = torch.nn.Parameter(torch.zeros(error.size(0)),
                                                 requires_grad=False)
            module.bias.add_(-error.cuda())

    for handle in hook_handles:
        handle.remove()
示例#4
0
文件: improve_dfq.py 项目: wtsitp/DFQ
def update_scale(qmodel,
                 model,
                 data_distill,
                 graph,
                 bottoms,
                 res,
                 targ_layer,
                 num_epoch=1000):
    """
    this function use data_distill to find optimized scale for DFQ
    """
    print("Start updating scale")
    writer = SummaryWriter("./tensorboard/exp_{}/".format(round(time.time())))
    qmodel = qmodel.eval().cuda()
    model = model.eval().cuda()
    for idx in range(len(data_distill)):
        data_distill[idx].requires_grad = False

    graph_original = copy.deepcopy(graph)

    optimizer = torch.optim.Adam(
        [p for n, p in qmodel.named_parameters() if 'scale' in n], lr=0.001)
    terminate = False

    # hook params
    hooks = []
    hook_handle = []
    for name, module in qmodel.named_modules():
        if type(module) in targ_layer and hasattr(module, 'scale'):
            # print("Add hook to scale of {} module".format(type(module)))
            grad_hook = GradHook(
                module.weight,
                module.scale if hasattr(module, 'scale') else None,
                module.scale_prev if hasattr(module, 'scale_prev') else None,
                module.merge_scale if hasattr(module, 'scale') else None,
                module.merge_scale_prev
                if hasattr(module, 'scale_prev') else None)
            hooks.append(grad_hook)
            # hook_handle.append(module.weight.register_hook(grad_hook.hook_mask_grad_tensor))
            hook_handle.append(
                module.scale.register_hook(grad_hook.hook_mask_grad_tensor))
    try:
        """
        TODO: check if graph and model contains same module parameters!!!
        """
        for epoch in range(num_epoch):
            for it in range(len(data_distill)):
                data = data_distill[it].cuda()
                with torch.no_grad():
                    logit = model(data)
                replace_op()
                qlogit = qmodel(data)
                restore_op()
                klloss = kl_categorical(
                    qlogit, logit)  #+ kl_categorical(logit, qlogit)
                normloss = 0
                for idx, hook in enumerate(hooks):
                    normloss += norm2(hook.get_weight_scaled(), idx, writer,
                                      epoch * len(data_distill) + it + 1)
                loss = klloss
                writer.add_scalar("loss", loss.data,
                                  epoch * len(data_distill) + it + 1)
                writer.add_scalar("norm", normloss.data,
                                  epoch * len(data_distill) + it + 1)
                writer.add_scalar("kldiv", klloss.data,
                                  epoch * len(data_distill) + it + 1)
                print("loss: {}, klloss: {}, norm: {}, iter: {}, epoch: {}".
                      format(loss.data, klloss.data, normloss.data, it + 1,
                             epoch + 1))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                for rr in res:
                    layer_first, _, bn_idx = rr.get_idxs()
                    # scale = torch.clamp(graph[layer_first].scale.detach().data.view(-1), max=1)
                    scale = graph[layer_first].scale.detach().data.view(-1)
                    graph[bn_idx].fake_weight.copy_(
                        graph_original[bn_idx].fake_weight * scale)
                    graph[bn_idx].fake_bias.copy_(
                        graph_original[bn_idx].fake_bias * scale)

                set_quant_minmax(graph, bottoms, verbose=False)

                # for hook in hooks:
                #     hook.update_mask()
                #     print("iter: {}, epoch: {}, mean: {}".format(it, epoch, hook.weight.mean()))
                # print("="*150)

                if loss.data < 0.02:
                    terminate = True
                    break

            if terminate:
                break

    except KeyboardInterrupt:
        for rr in res:
            layer_first, _, bn_idx = rr.get_idxs()
            scale = graph[layer_first].scale.detach().data.view(-1)
            graph[bn_idx].fake_weight.copy_(
                graph_original[bn_idx].fake_weight * scale)
            graph[bn_idx].fake_bias.copy_(graph_original[bn_idx].fake_bias *
                                          scale)

    for handle in hook_handle:
        handle.remove()

    return qmodel
示例#5
0
def main():
    args = get_argument()
    assert args.relu or args.relu == args.equalize, 'must replace relu6 to relu while equalization'
    assert args.equalize or args.absorption == args.equalize, 'must use absorption with equalize'
    data = torch.ones((4, 3, 224, 224))  #.cuda()

    model = mobilenet_v2(
        'modeling/classification/mobilenetv2_1.0-f2a8633.pth.tar')
    model.eval()

    transformer = TorchTransformer()
    module_dict = {}
    if args.quantize:
        if args.trainable:
            module_dict[1] = [(nn.Conv2d, QuantConv2d),
                              (nn.Linear, QuantLinear)]
        else:
            module_dict[1] = [(nn.Conv2d, QuantNConv2d),
                              (nn.Linear, QuantNLinear)]

    if args.relu:
        module_dict[0] = [(torch.nn.ReLU6, torch.nn.ReLU)]

    # transformer.summary(model, data)
    # transformer.visualize(model, data, 'graph_cls', graph_size=120)

    model, transformer = switch_layers(model,
                                       transformer,
                                       data,
                                       module_dict,
                                       ignore_layer=[QuantMeasure],
                                       quant_op=args.quantize)

    graph = transformer.log.getGraph()
    bottoms = transformer.log.getBottoms()
    output_shape = transformer.log.getOutShapes()
    if args.quantize:
        if args.trainable:
            targ_layer = [QuantConv2d, QuantLinear]
        else:
            targ_layer = [QuantNConv2d, QuantNLinear]
    else:
        targ_layer = [nn.Conv2d, nn.Linear]

    model = merge_batchnorm(model, graph, bottoms, targ_layer)

    #create relations
    if args.equalize:
        res = create_relation(graph, bottoms, targ_layer)
        cross_layer_equalization(graph,
                                 res,
                                 targ_layer,
                                 visualize_state=False,
                                 converge_thres=2e-7)

    if args.absorption:
        bias_absorption(graph, res, bottoms, 3)

    if args.clip_weight:
        clip_weight(graph, range_clip=[-15, 15], targ_type=targ_layer)

    if args.correction:
        bias_correction(graph, bottoms, targ_layer)

    if args.quantize:
        if not args.trainable:
            graph = quantize_targ_layer(graph, targ_layer)
        set_quant_minmax(graph, bottoms, output_shape)

    model = model.cuda()
    model.eval()

    if args.quantize:
        replace_op()
    acc = inference_all(model)
    print("Acc: {}".format(acc))
    if args.quantize:
        restore_op()
示例#6
0
文件: main_ssd.py 项目: wtsitp/DFQ
            net, nms_method=args.nms_method, device=DEVICE)
    elif args.net == 'sq-ssd-lite':
        predictor = create_squeezenet_ssd_lite_predictor(
            net, nms_method=args.nms_method, device=DEVICE)
    elif args.net == 'mb2-ssd-lite':
        predictor = create_mobilenetv2_ssd_lite_predictor(
            net, nms_method=args.nms_method, device=DEVICE)
    else:
        logging.fatal(
            "The net type is wrong. It should be one of vgg16-ssd, mb1-ssd and mb1-ssd-lite."
        )
        parser.print_help(sys.stderr)
        sys.exit(1)

    if args.quantize:
        replace_op()

    results = []
    print("Start Inference")
    for i in range(len(dataset)):
        # print("process image", i)
        # timer.start("Load Image")
        image = dataset.get_image(i)
        # print("Load Image: {:4f} seconds.".format(timer.end("Load Image")))
        # timer.start("Predict")
        boxes, labels, probs = predictor.predict(image)
        # print("Prediction: {:4f} seconds.".format(timer.end("Predict")))
        indexes = torch.ones(labels.size(0), 1, dtype=torch.float32) * i
        results.append(
            torch.cat(
                [
示例#7
0
def main():
    args = get_argument()
    assert args.relu or args.relu == args.equalize, 'must replace relu6 to relu while equalization'
    assert args.equalize or args.absorption == args.equalize, 'must use absorption with equalize'

    data = torch.ones((4, 3, 224, 224))  #.cuda()

    if args.resnet:
        import torchvision.models as models
        model = models.resnet18(pretrained=True)
    else:
        model = mobilenet_v2(
            'modeling/classification/mobilenetv2_1.0-f2a8633.pth.tar')
    model.eval()

    if args.distill_range:
        import copy
        # define FP32 model
        model_original = copy.deepcopy(model)
        model_original.eval()
        transformer = TorchTransformer()
        transformer._build_graph(model_original, data, [QuantMeasure])
        graph = transformer.log.getGraph()
        bottoms = transformer.log.getBottoms()

        if not args.true_data:
            data_distill = getDistilData(model_original, 'imagenet', args.dis_batch_size, bn_merged=False,\
                num_batch=args.dis_num_batch, gpu=True, value_range=[-2.11790393, 2.64], size=[224, 224], early_break_factor=1.2 if args.resnet else 0.5)
        else:
            imagenet_dataset = datasets.ImageFolder(
                '/home/jakc4103/windows/Toshiba/workspace/dataset/ILSVRC/Data/CLS-LOC/train',
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
                ]))
            data_distill = []
            dataloader = DataLoader(imagenet_dataset,
                                    batch_size=args.dis_batch_size,
                                    shuffle=True,
                                    num_workers=4,
                                    pin_memory=True)
            for idx, sample in enumerate(dataloader):
                if idx >= args.dis_num_batch:
                    break
                image = sample[0]
                data_distill.append(image)
            del dataloader, imagenet_dataset

    transformer = TorchTransformer()
    module_dict = {}
    if args.quantize:
        if args.distill_range:
            module_dict[1] = [(nn.Conv2d, QConv2d), (nn.Linear, QLinear)]
        elif args.trainable:
            module_dict[1] = [(nn.Conv2d, QuantConv2d),
                              (nn.Linear, QuantLinear)]
        else:
            module_dict[1] = [(nn.Conv2d, QuantNConv2d),
                              (nn.Linear, QuantNLinear)]

    if args.relu:
        module_dict[0] = [(torch.nn.ReLU6, torch.nn.ReLU)]

    # transformer.summary(model, data)
    # transformer.visualize(model, data, 'graph_cls', graph_size=120)

    model, transformer = switch_layers(model,
                                       transformer,
                                       data,
                                       module_dict,
                                       ignore_layer=[QuantMeasure],
                                       quant_op=args.quantize)

    graph = transformer.log.getGraph()
    bottoms = transformer.log.getBottoms()
    if args.quantize:
        if args.distill_range:
            targ_layer = [QConv2d, QLinear]
        elif args.trainable:
            targ_layer = [QuantConv2d, QuantLinear]
        else:
            targ_layer = [QuantNConv2d, QuantNLinear]
    else:
        targ_layer = [nn.Conv2d, nn.Linear]

    if args.quantize:
        set_layer_bits(graph, args.bits_weight, args.bits_activation,
                       args.bits_bias, targ_layer)

    model = merge_batchnorm(model, graph, bottoms, targ_layer)

    #create relations
    if args.equalize or args.distill_range:
        res = create_relation(graph, bottoms, targ_layer, delete_single=False)
        if args.equalize:
            cross_layer_equalization(graph,
                                     res,
                                     targ_layer,
                                     visualize_state=False,
                                     converge_thres=2e-7)

        # if args.distill:
        #     set_scale(res, graph, bottoms, targ_layer)

    if args.absorption:
        bias_absorption(graph, res, bottoms, 3)

    if args.clip_weight:
        clip_weight(graph, range_clip=[-15, 15], targ_type=targ_layer)

    if args.correction:
        # if args.distill:
        #     model_original = copy.deepcopy(model.cpu())
        #     model_original.eval()
        #     transformer = TorchTransformer()
        #     transformer.register(targ_layer[0], nn.Conv2d)
        #     transformer.register(targ_layer[1], nn.Linear)
        #     model_original = transformer.trans_layers(model_original, update=True)

        #     bias_correction_distill(model, model_original, data_distill, targ_layer, [nn.Conv2d, nn.Linear])
        # else:
        bias_correction(graph,
                        bottoms,
                        targ_layer,
                        bits_weight=args.bits_weight)

    if args.quantize:
        if not args.trainable and not args.distill_range:
            graph = quantize_targ_layer(graph, args.bits_weight,
                                        args.bits_bias, targ_layer)

        if args.distill_range:
            set_update_stat(model, [QuantMeasure], True)
            model = update_quant_range(model.cuda(), data_distill, graph,
                                       bottoms)
            set_update_stat(model, [QuantMeasure], False)
        else:
            set_quant_minmax(graph, bottoms)

        torch.cuda.empty_cache()

    # if args.distill:
    #     model = update_scale(model, model_original, data_distill, graph, bottoms, res, targ_layer, num_epoch=1000)
    #     set_quant_minmax(graph, bottoms)

    model = model.cuda()
    model.eval()

    if args.quantize:
        replace_op()
    acc = inference_all(model)
    print("Acc: {}".format(acc))
    if args.quantize:
        restore_op()
    if args.log:
        with open("cls_result.txt", 'a+') as ww:
            ww.write(
                "resnet: {}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill_range: {}\n"
                .format(args.resnet, args.quantize, args.relu, args.equalize,
                        args.absorption, args.correction, args.clip_weight,
                        args.distill_range))
            ww.write("Acc: {}\n\n".format(acc))
示例#8
0
def main():
    args = get_argument()
    assert args.relu or args.relu == args.equalize, 'must replace relu6 to relu while equalization'
    assert args.equalize or args.absorption == args.equalize, 'must use absorption with equalize'
    data = torch.ones((4, 3, 513, 513))#.cuda()

    model = DeepLab(sync_bn=False)
    state_dict = torch.load('modeling/segmentation/deeplab-mobilenet.pth.tar')['state_dict']
    model.load_state_dict(state_dict)
    model.eval()
    if args.distill_range:
        import copy
        # define FP32 model 
        model_original = copy.deepcopy(model)
        model_original.eval()
        transformer = TorchTransformer()
        transformer._build_graph(model_original, data, [QuantMeasure])
        graph = transformer.log.getGraph()
        bottoms = transformer.log.getBottoms()
    
        data_distill = getDistilData(model_original, 'imagenet', 32, bn_merged=False,\
            num_batch=8, gpu=True, value_range=[-2.11790393, 2.64], size=[513, 513], early_break_factor=0.2)

    transformer = TorchTransformer()

    module_dict = {}
    if args.quantize:
        if args.distill_range:
            module_dict[1] = [(nn.Conv2d, QConv2d)]
        elif args.trainable:
            module_dict[1] = [(nn.Conv2d, QuantConv2d)]
        else:
            module_dict[1] = [(nn.Conv2d, QuantNConv2d)]
    
    if args.relu:
        module_dict[0] = [(torch.nn.ReLU6, torch.nn.ReLU)]

    # transformer.summary(model, data)
    # transformer.visualize(model, data, 'graph_deeplab', graph_size=120)

    model, transformer = switch_layers(model, transformer, data, module_dict, ignore_layer=[QuantMeasure], quant_op=args.quantize)
    graph = transformer.log.getGraph()
    bottoms = transformer.log.getBottoms()

    if args.quantize:
        if args.distill_range:
            targ_layer = [QConv2d]
        elif args.trainable:
            targ_layer = [QuantConv2d]
        else:
            targ_layer = [QuantNConv2d]
    else:
        targ_layer = [nn.Conv2d]
    if args.quantize:
        set_layer_bits(graph, args.bits_weight, args.bits_activation, args.bits_bias, targ_layer)
    model = merge_batchnorm(model, graph, bottoms, targ_layer)

    #create relations
    if args.equalize or args.distill_range:
        res = create_relation(graph, bottoms, targ_layer)
        if args.equalize:
            cross_layer_equalization(graph, res, targ_layer, visualize_state=False)

        # if args.distill:
        #     set_scale(res, graph, bottoms, targ_layer)

    if args.absorption:
        bias_absorption(graph, res, bottoms, 3)
    
    if args.clip_weight:
        clip_weight(graph, range_clip=[-15, 15], targ_type=targ_layer)

    if args.correction:
        bias_correction(graph, bottoms, targ_layer)

    if args.quantize:
        if not args.trainable and not args.distill_range:
            graph = quantize_targ_layer(graph, args.bits_weight, args.bits_bias, targ_layer)
        
        if args.distill_range:
            set_update_stat(model, [QuantMeasure], True)
            model = update_quant_range(model.cuda(), data_distill, graph, bottoms)
            set_update_stat(model, [QuantMeasure], False)
        else:
            set_quant_minmax(graph, bottoms)

        torch.cuda.empty_cache()
    
    model = model.cuda()
    model.eval()

    if args.quantize:
        replace_op()
    inference_all(model, args.dataset, args if args.log else None)
    if args.quantize:
        restore_op()
示例#9
0
def main():
    args = get_argument()
    assert args.relu or args.relu == args.equalize, 'must replace relu6 to relu while equalization'
    assert args.equalize or args.absorption == args.equalize, 'must use absorption with equalize'
    data = torch.ones((4, 3, 513, 513))  #.cuda()

    model = DeepLab(sync_bn=False)
    state_dict = torch.load(
        'modeling/segmentation/deeplab-mobilenet.pth.tar')['state_dict']
    model.load_state_dict(state_dict)
    model.eval()

    transformer = TorchTransformer()

    module_dict = {}
    if args.quantize:
        if args.trainable:
            module_dict[1] = [(nn.Conv2d, QuantConv2d)]
        else:
            module_dict[1] = [(nn.Conv2d, QuantNConv2d)]

    if args.relu:
        module_dict[0] = [(torch.nn.ReLU6, torch.nn.ReLU)]

    # transformer.summary(model, data)
    # transformer.visualize(model, data, 'graph_deeplab', graph_size=120)

    model, transformer = switch_layers(model,
                                       transformer,
                                       data,
                                       module_dict,
                                       ignore_layer=[QuantMeasure],
                                       quant_op=args.quantize)
    graph = transformer.log.getGraph()
    bottoms = transformer.log.getBottoms()
    output_shape = transformer.log.getOutShapes()

    if args.quantize:
        if args.trainable:
            targ_layer = [QuantConv2d]
        else:
            targ_layer = [QuantNConv2d]
    else:
        targ_layer = [nn.Conv2d]
    model = merge_batchnorm(model, graph, bottoms, targ_layer)

    #create relations
    if args.equalize:
        res = create_relation(graph, bottoms, targ_layer)
        cross_layer_equalization(graph, res, targ_layer, visualize_state=False)

    if args.absorption:
        bias_absorption(graph, res, bottoms, 3)

    if args.clip_weight:
        clip_weight(graph, range_clip=[-15, 15], targ_type=targ_layer)

    if args.correction:
        bias_correction(graph, bottoms, targ_layer)

    if args.quantize:
        if not args.trainable:
            graph = quantize_targ_layer(graph, targ_layer)
        set_quant_minmax(graph, bottoms, output_shape)

    model = model.cuda()
    model.eval()

    if args.quantize:
        replace_op()
    inference_all(model, args.dataset)
    if args.quantize:
        restore_op()