def fuse_conv_bn(conv, bn, bits=8): w = conv.weight mean = bn.running_mean var_sqrt = torch.sqrt(bn.running_var + bn.eps) beta = bn.weight gamma = bn.bias if conv.bias is not None: b = conv.bias else: b = mean.new_zeros(mean.shape) w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1]) b = (b - mean) / var_sqrt * beta + gamma fused_conv = nn.Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, bias=True) fused_conv.mask = conv.mask sf = bits - 1. - quant.compute_integral_part(w, 0.0) delta = math.pow(2.0, -sf) w = quant.linear_quantize(w, delta, bits) fused_conv.weight = nn.Parameter(w) sf = bits - 1. - quant.compute_integral_part(b, 0.0) delta = math.pow(2.0, -sf) b = quant.linear_quantize(b, delta, bits) fused_conv.bias = nn.Parameter(b) return fused_conv
def quant_linear(linear, bits=8): w, b = linear.weight.data, linear.bias.data sf = bits - 1. - quant.compute_integral_part(w, 0.0) delta = math.pow(2.0, -sf) linear.weight.data = quant.linear_quantize(w, delta, bits) sf = bits - 1. - quant.compute_integral_part(b, 0.0) delta = math.pow(2.0, -sf) linear.bias.data = quant.linear_quantize(b, delta, bits) return linear
def quantize(model_raw, bn_bits, param_bits, quant_method='log'): if param_bits < 32: state_dict = model_raw.state_dict() state_dict_quant = OrderedDict() sf_dict = OrderedDict() for k, v in state_dict.items(): #print(len(v.size())) if 'running' in k: if bn_bits >= 32: print("Ignoring {}".format(k)) state_dict_quant[k] = v continue else: bits = bn_bits else: bits = param_bits if quant_method == 'linear': sf = bits - 1. - quant.compute_integral_part(v, overflow_rate=0.0) v_quant = quant.linear_quantize(v, sf, bits=bits) elif quant_method == 'log': v_quant = quant.log_minmax_quantize(v, bits=bits) elif quant_method == 'minmax': v_quant = quant.min_max_quantize(v, bits=bits) else: v_quant = quant.tanh_quantize(v, bits=bits) state_dict_quant[k] = v_quant #print(k, bits) #print(v_quant) model_raw.load_state_dict(state_dict_quant) return model_raw
def q_main_layerwise(): #print(args) """ bag of tricks set-ups""" criterion = CrossEntropyLossMaybeSmooth(smooth_eps=args.smooth_eps).cuda() args.smooth = args.smooth_eps > 0.0 args.mixup = args.alpha > 0.0 model_path = "../cifar100_mobilenetv217_retrained_acc_80.510_config_mobile_v2_0.7_threshold.pt" #model_path = "/home/qian/l1_share/model_retrained/cifar100_mobilenetv217_retrained_acc_80.420_config_mobile_v2_0.8_uneven_threshold.pt" print("\n>_ Loading file... {}".format(model_path)) #model.load_state_dict(torch.load("./model_reweighted/mnist_reweighted_eps_0.0001_acc_98.79.pt")) model.load_state_dict(torch.load(model_path)) model.type(torch.cuda.FloatTensor) model.cuda() #print(model.state_dict().items()) #print(model._modules.items()) #t_loss, prec1 =test(model, criterion, test_loader) #print("\n>_ current pruned model's accuracy{:.3f}% now...\n".format(prec1)) # quantize parameters state_dict = model.state_dict() state_dict_quant = OrderedDict() sf_dict = OrderedDict() LayerSize=getLayerSize(model) QBNs=setQBN(LayerSize,1) print(QBNs) count=-1 for k, v in state_dict.items(): sizes=len(list(v.size())) if sizes > 1:#removing bias #channels=list(v.size())[0] count=count+1 if QBNs[count]<32: bits = QBNs[count] if args.quant_method == 'linear': sf = bits - 1. - quant.compute_integral_part(v, overflow_rate=args.overflow_rate) v_quant = quant.linear_quantize(v, sf, bits=bits) elif args.quant_method == 'log': v_quant= quant.log_minmax_quantize(v, bits=bits) elif args.quant_method == 'minmax': v_quant= quant.min_max_quantize(v, bits=bits) else: v_quant = quant.tanh_quantize(v, bits=bits) else: v_quant=v else: v_quant=v state_dict_quant[k] = v_quant model.load_state_dict(state_dict_quant) #quantize forward activation if args.fwd_bits < 32: model_new = quant.duplicate_model_with_quant(model, bits=args.fwd_bits, overflow_rate=args.overflow_rate, counter=args.n_sample, type=args.quant_method) model.load_state_dict(model_new.state_dict())#print(model_new) #model.load_state_dict(model_new) t_loss, prec1 =test(model, criterion, test_loader) print("\n>_ current quantized model's accuracy{:.3f}% now...\n".format(prec1)) torch.save(model.state_dict(), "./model_quantized/cifar100_mobilenetv217_retrained_acc_80.510{}{}_quantized_acc_{:.3f}_{}_{}.pt".format( args.arch, args.depth, prec1, args.config_file, args.sparsity_type))
def hard_quant(weight,bits): if args.quant_method == 'linear': sf = bits - 1. - quant.compute_integral_part(weight, overflow_rate=args.overflow_rate) weight_q = quant.linear_quantize(weight, sf, bits=bits) elif args.quant_method == 'log': weight_q = quant.log_minmax_quantize(weight, bits=bits) elif args.quant_method == 'minmax': weight_q = quant.min_max_quantize(weight, bits=bits) else: weight_q = quant.tanh_quantize(weight, bits=bits) return weight_q
def quantize_CNN(model_raw, quant_method='log'): bn_bits = args.bn_bits param_bits = args.param_bits bn2d = 0 bn1d = 0 conv = 0 linear = 0 if param_bits < 32: state_dict = model_raw.state_dict() state_dict_quant = OrderedDict() sf_dict = OrderedDict() for k, v in state_dict.items(): bn_bits = args.bn_bits param_bits = args.param_bits if (len(v.size()) == 4): ####################### if CNN print('quantize: {}'.format(v.size())) if 'running' in k: if bn_bits >= 32: #print("Ignoring {}".format(k)) state_dict_quant[k] = v continue else: bits = bn_bits else: bits = param_bits if quant_method == 'linear': sf = bits - 1. - quant.compute_integral_part( v, overflow_rate=0.0) v_quant = quant.linear_quantize(v, sf, bits=bits) elif quant_method == 'log': v_quant = quant.log_minmax_quantize(v, bits=bits) elif quant_method == 'minmax': v_quant = quant.min_max_quantize(v, bits=bits) else: v_quant = quant.tanh_quantize(v, bits=bits) state_dict_quant[k] = v_quant else: ################### if not CNN #print('not quantize: {}'.format(v.size())) bn_bits = 32 param_bits = 32 if 'running' in k: if bn_bits >= 32: #print("Ignoring {}".format(k)) state_dict_quant[k] = v continue else: bits = bn_bits else: bits = param_bits if quant_method == 'linear': sf = bits - 1. - quant.compute_integral_part( v, overflow_rate=0.0) v_quant = quant.linear_quantize(v, sf, bits=bits) elif quant_method == 'log': v_quant = quant.log_minmax_quantize(v, bits=bits) elif quant_method == 'minmax': v_quant = quant.min_max_quantize(v, bits=bits) else: v_quant = quant.tanh_quantize(v, bits=bits) state_dict_quant[k] = v_quant #print(k, bits) #print(v_quant) model_raw.load_state_dict(state_dict_quant) return model_raw
# -*- coding: utf-8 -*- """ Spyder Editor This is a temporary script file. """ import torch from torch.autograd import Variable import math import quant input = Variable( torch.FloatTensor([[12.3000], [23.5000], [129.2000], [-293.1000]])) bits = 12 # linear_quantize sf = quant.compute_integral_part(input, 0.1) sf_linear = 12 - 1 - sf linear_q_res = quant.linear_quantize(input, sf_linear, bits) # min_max_quantize minmax_q_res = quant.min_max_quantize(input, bits) #print "minmax_q_res\n",minmax_q_res # log_minmax_quantize #log_q_res = quant.log_minmax_quantize(input,bits) #print log_q_res
def quantized_retrain(criterion, optimizer, scheduler): model_path = "../cifar100_mobilenetv217_retrained_acc_80.510mobilenetv217_quantized_acc_80.180_config_vgg16_threshold.pt" print("\n>_ Loading file... {}".format(model_path)) # model.load_state_dict(torch.load("./model_reweighted/mnist_reweighted_eps_0.0001_acc_98.79.pt")) model.load_state_dict(torch.load(model_path)) model.type(torch.cuda.FloatTensor) model.cuda() best_prec1 = [0] #prune_util.hard_prune(args, prune_ratios, model) epoch_loss_dict = {} testAcc = [] #quantize parameters if args.param_bits < 32: state_dict = model.state_dict() #state_dict = model state_dict_quant = OrderedDict() sf_dict = OrderedDict() for k, v in state_dict.items(): if 'running' in k: if args.bn_bits >= 32: #print("Ignoring {}".format(k)) state_dict_quant[k] = v continue else: bits = args.bn_bits else: bits = args.param_bits if args.quant_method == 'linear': sf = bits - 1. - quant.compute_integral_part( v, overflow_rate=args.overflow_rate) v_quant = quant.linear_quantize(v, sf, bits=bits) elif args.quant_method == 'log': v_quant = quant.log_minmax_quantize(v, bits=bits) elif args.quant_method == 'minmax': v_quant = quant.min_max_quantize(v, bits=bits) else: v_quant = quant.tanh_quantize(v, bits=bits) state_dict_quant[k] = v_quant #print(k, bits) model.load_state_dict(state_dict_quant) new_model = OrderedDict() if args.fwd_bits < 32: new_model = quant.duplicate_model_with_quant( model, bits=args.fwd_bits, overflow_rate=args.overflow_rate, counter=args.n_sample, type=args.quant_method) model.load_state_dict(new_model.state_dict()) #retrain for epoch in range(1, args.epochs + 1): idx_loss_dict = qtrain(train_loader, criterion, optimizer, scheduler, epoch, model, args, layers=None, rew_layers=None, eps=None) t_loss, prec1 = test(model, criterion, test_loader) if prec1 > max(best_prec1): print( "\n>_ Got better accuracy, saving model with accuracy {:.3f}% now...\n" .format(prec1)) torch.save( model.state_dict(), "./model_retrained/quantized_cifar100_{}{}_retrained_acc_{:.3f}_{}_{}.pt" .format(args.arch, args.depth, prec1, args.config_file, args.sparsity_type)) print( "\n>_ Deleting previous model file with accuracy {:.3f}% now...\n" .format(max(best_prec1))) if len(best_prec1) > 1: os.remove( "./model_retrained/quantized_cifar100_{}{}_retrained_acc_{:.3f}_{}_{}.pt" .format(args.arch, args.depth, max(best_prec1), args.config_file, args.sparsity_type)) epoch_loss_dict[epoch] = idx_loss_dict testAcc.append(prec1) best_prec1.append(prec1) print("current best acc is: {:.4f}".format(max(best_prec1))) test_column_sparsity(model) test_filter_sparsity(model) print("Best Acc: {:.4f}%".format(max(best_prec1))) np.save( strftime("./plotable/%m-%d-%Y-%H:%M_plotable_{}.npy".format( args.sparsity_type)), epoch_loss_dict) np.save( strftime("./plotable/%m-%d-%Y-%H:%M_testAcc_{}.npy".format( args.sparsity_type)), testAcc)