コード例 #1
0
def quantize(model_raw, bn_bits, param_bits, quant_method='log'):

    if param_bits < 32:
        state_dict = model_raw.state_dict()
        state_dict_quant = OrderedDict()
        sf_dict = OrderedDict()
        for k, v in state_dict.items():
            #print(len(v.size()))
            if 'running' in k:
                if bn_bits >= 32:
                    print("Ignoring {}".format(k))
                    state_dict_quant[k] = v
                    continue
                else:
                    bits = bn_bits
            else:
                bits = param_bits

            if quant_method == 'linear':
                sf = bits - 1. - quant.compute_integral_part(v,
                                                             overflow_rate=0.0)
                v_quant = quant.linear_quantize(v, sf, bits=bits)
            elif quant_method == 'log':
                v_quant = quant.log_minmax_quantize(v, bits=bits)
            elif quant_method == 'minmax':
                v_quant = quant.min_max_quantize(v, bits=bits)
            else:
                v_quant = quant.tanh_quantize(v, bits=bits)
            state_dict_quant[k] = v_quant
            #print(k, bits)
            #print(v_quant)

        model_raw.load_state_dict(state_dict_quant)
    return model_raw
コード例 #2
0
def q_main_layerwise():

    #print(args)
    """ bag of tricks set-ups"""
    criterion = CrossEntropyLossMaybeSmooth(smooth_eps=args.smooth_eps).cuda()
    args.smooth = args.smooth_eps > 0.0
    args.mixup = args.alpha > 0.0
    model_path = "../cifar100_mobilenetv217_retrained_acc_80.510_config_mobile_v2_0.7_threshold.pt"
    #model_path = "/home/qian/l1_share/model_retrained/cifar100_mobilenetv217_retrained_acc_80.420_config_mobile_v2_0.8_uneven_threshold.pt"
    print("\n>_ Loading file... {}".format(model_path))
    #model.load_state_dict(torch.load("./model_reweighted/mnist_reweighted_eps_0.0001_acc_98.79.pt"))
    model.load_state_dict(torch.load(model_path))
    model.type(torch.cuda.FloatTensor)
    model.cuda()
    #print(model.state_dict().items())
    #print(model._modules.items())
    #t_loss, prec1 =test(model, criterion, test_loader)
    #print("\n>_ current pruned model's accuracy{:.3f}% now...\n".format(prec1))
    # quantize parameters
    state_dict = model.state_dict()
    state_dict_quant = OrderedDict()
    sf_dict = OrderedDict()
    LayerSize=getLayerSize(model)
    QBNs=setQBN(LayerSize,1)
    print(QBNs)
    count=-1
    for k, v in state_dict.items():
        sizes=len(list(v.size()))
        if sizes > 1:#removing bias
            #channels=list(v.size())[0]
            count=count+1                
            if QBNs[count]<32:
                bits = QBNs[count]
                if args.quant_method == 'linear':
                    sf = bits - 1. - quant.compute_integral_part(v, overflow_rate=args.overflow_rate)
                    v_quant  = quant.linear_quantize(v, sf, bits=bits)
                elif args.quant_method == 'log':
                    v_quant= quant.log_minmax_quantize(v, bits=bits)
                elif args.quant_method == 'minmax':
                    v_quant= quant.min_max_quantize(v, bits=bits)
                else:
                    v_quant = quant.tanh_quantize(v, bits=bits)
            else:
                v_quant=v
        else:
            v_quant=v
        state_dict_quant[k] = v_quant
    model.load_state_dict(state_dict_quant)

     #quantize forward activation
    if args.fwd_bits < 32:
        model_new = quant.duplicate_model_with_quant(model, bits=args.fwd_bits, overflow_rate=args.overflow_rate,
                                                 counter=args.n_sample, type=args.quant_method)
        model.load_state_dict(model_new.state_dict())#print(model_new)
        #model.load_state_dict(model_new)
    t_loss, prec1 =test(model, criterion, test_loader)
    print("\n>_ current quantized model's accuracy{:.3f}% now...\n".format(prec1))
    torch.save(model.state_dict(), "./model_quantized/cifar100_mobilenetv217_retrained_acc_80.510{}{}_quantized_acc_{:.3f}_{}_{}.pt".format(
                args.arch, args.depth, prec1, args.config_file, args.sparsity_type))
コード例 #3
0
def hard_quant(weight,bits):
    if args.quant_method == 'linear':
        sf = bits - 1. - quant.compute_integral_part(weight, overflow_rate=args.overflow_rate)
        weight_q  = quant.linear_quantize(weight, sf, bits=bits)
    elif args.quant_method == 'log':
        weight_q = quant.log_minmax_quantize(weight, bits=bits)
    elif args.quant_method == 'minmax':
        weight_q = quant.min_max_quantize(weight, bits=bits)
    else:
        weight_q = quant.tanh_quantize(weight, bits=bits)
    return weight_q
コード例 #4
0
def quantize_CNN(model_raw, quant_method='log'):
    bn_bits = args.bn_bits
    param_bits = args.param_bits

    bn2d = 0
    bn1d = 0
    conv = 0
    linear = 0

    if param_bits < 32:
        state_dict = model_raw.state_dict()
        state_dict_quant = OrderedDict()
        sf_dict = OrderedDict()
        for k, v in state_dict.items():
            bn_bits = args.bn_bits
            param_bits = args.param_bits
            if (len(v.size()) == 4):  ####################### if CNN
                print('quantize: {}'.format(v.size()))
                if 'running' in k:
                    if bn_bits >= 32:
                        #print("Ignoring {}".format(k))
                        state_dict_quant[k] = v
                        continue
                    else:
                        bits = bn_bits
                else:
                    bits = param_bits

                if quant_method == 'linear':
                    sf = bits - 1. - quant.compute_integral_part(
                        v, overflow_rate=0.0)
                    v_quant = quant.linear_quantize(v, sf, bits=bits)
                elif quant_method == 'log':
                    v_quant = quant.log_minmax_quantize(v, bits=bits)
                elif quant_method == 'minmax':
                    v_quant = quant.min_max_quantize(v, bits=bits)
                else:
                    v_quant = quant.tanh_quantize(v, bits=bits)
                state_dict_quant[k] = v_quant

            else:  ################### if not CNN
                #print('not quantize: {}'.format(v.size()))
                bn_bits = 32
                param_bits = 32
                if 'running' in k:
                    if bn_bits >= 32:
                        #print("Ignoring {}".format(k))
                        state_dict_quant[k] = v
                        continue
                    else:
                        bits = bn_bits
                else:
                    bits = param_bits

                if quant_method == 'linear':
                    sf = bits - 1. - quant.compute_integral_part(
                        v, overflow_rate=0.0)
                    v_quant = quant.linear_quantize(v, sf, bits=bits)
                elif quant_method == 'log':
                    v_quant = quant.log_minmax_quantize(v, bits=bits)
                elif quant_method == 'minmax':
                    v_quant = quant.min_max_quantize(v, bits=bits)
                else:
                    v_quant = quant.tanh_quantize(v, bits=bits)
                state_dict_quant[k] = v_quant
                #print(k, bits)
                #print(v_quant)

        model_raw.load_state_dict(state_dict_quant)
    return model_raw
コード例 #5
0
def quantized_retrain(criterion, optimizer, scheduler):
    model_path = "../cifar100_mobilenetv217_retrained_acc_80.510mobilenetv217_quantized_acc_80.180_config_vgg16_threshold.pt"
    print("\n>_ Loading file... {}".format(model_path))
    # model.load_state_dict(torch.load("./model_reweighted/mnist_reweighted_eps_0.0001_acc_98.79.pt"))
    model.load_state_dict(torch.load(model_path))
    model.type(torch.cuda.FloatTensor)
    model.cuda()
    best_prec1 = [0]

    #prune_util.hard_prune(args, prune_ratios, model)
    epoch_loss_dict = {}
    testAcc = []
    #quantize parameters
    if args.param_bits < 32:
        state_dict = model.state_dict()
        #state_dict = model
        state_dict_quant = OrderedDict()
        sf_dict = OrderedDict()
        for k, v in state_dict.items():
            if 'running' in k:
                if args.bn_bits >= 32:
                    #print("Ignoring {}".format(k))
                    state_dict_quant[k] = v
                    continue
                else:
                    bits = args.bn_bits
            else:
                bits = args.param_bits

            if args.quant_method == 'linear':
                sf = bits - 1. - quant.compute_integral_part(
                    v, overflow_rate=args.overflow_rate)
                v_quant = quant.linear_quantize(v, sf, bits=bits)
            elif args.quant_method == 'log':
                v_quant = quant.log_minmax_quantize(v, bits=bits)
            elif args.quant_method == 'minmax':
                v_quant = quant.min_max_quantize(v, bits=bits)
            else:
                v_quant = quant.tanh_quantize(v, bits=bits)
            state_dict_quant[k] = v_quant
            #print(k, bits)
        model.load_state_dict(state_dict_quant)
    new_model = OrderedDict()
    if args.fwd_bits < 32:
        new_model = quant.duplicate_model_with_quant(
            model,
            bits=args.fwd_bits,
            overflow_rate=args.overflow_rate,
            counter=args.n_sample,
            type=args.quant_method)
        model.load_state_dict(new_model.state_dict())

    #retrain
    for epoch in range(1, args.epochs + 1):
        idx_loss_dict = qtrain(train_loader,
                               criterion,
                               optimizer,
                               scheduler,
                               epoch,
                               model,
                               args,
                               layers=None,
                               rew_layers=None,
                               eps=None)
        t_loss, prec1 = test(model, criterion, test_loader)

        if prec1 > max(best_prec1):
            print(
                "\n>_ Got better accuracy, saving model with accuracy {:.3f}% now...\n"
                .format(prec1))
            torch.save(
                model.state_dict(),
                "./model_retrained/quantized_cifar100_{}{}_retrained_acc_{:.3f}_{}_{}.pt"
                .format(args.arch, args.depth, prec1, args.config_file,
                        args.sparsity_type))
            print(
                "\n>_ Deleting previous model file with accuracy {:.3f}% now...\n"
                .format(max(best_prec1)))
            if len(best_prec1) > 1:
                os.remove(
                    "./model_retrained/quantized_cifar100_{}{}_retrained_acc_{:.3f}_{}_{}.pt"
                    .format(args.arch, args.depth, max(best_prec1),
                            args.config_file, args.sparsity_type))

        epoch_loss_dict[epoch] = idx_loss_dict
        testAcc.append(prec1)

        best_prec1.append(prec1)
        print("current best acc is: {:.4f}".format(max(best_prec1)))

    test_column_sparsity(model)
    test_filter_sparsity(model)

    print("Best Acc: {:.4f}%".format(max(best_prec1)))
    np.save(
        strftime("./plotable/%m-%d-%Y-%H:%M_plotable_{}.npy".format(
            args.sparsity_type)), epoch_loss_dict)
    np.save(
        strftime("./plotable/%m-%d-%Y-%H:%M_testAcc_{}.npy".format(
            args.sparsity_type)), testAcc)