def extract_save_quant_state_dict(model, all_names, filename='int_state_dict.pth.tar'): state_dict = model.state_dict() for key in state_dict.keys(): #import pdb; pdb.set_trace() val = state_dict[key] if 'weight' in key: num_bits = 4 if key[:-7] in all_names else 8 if num_bits == 4: import pdb pdb.set_trace() weight_qparams = calculate_qparams(val, num_bits=num_bits, flatten_dims=(1, -1), reduce_dim=None) val_q = quantize(val, qparams=weight_qparams, dequantize=False) zero_point = (-weight_qparams[1] / weight_qparams[0] * (2**weight_qparams[2] - 1)).round() val_q = val_q - zero_point print(val_q.eq(0).sum().float().div(val_q.numel())) if 'bias' in key: val_q = quantize(val, num_bits=num_bits * 2, flatten_dims=(0, -1)) state_dict[key] = val_q torch.save(state_dict, filename) return state_dict
def sketch_layer(model, index, compre_ratio=0.8): layer_list = get_layer_list(model) wei = layer_list[index].weight float_step = (float(wei.max()) - float(wei.min())) / (2.**8 - 1.) #print(float_step) quan_wei = quantize(wei, num_bits=8, min_value=float(wei.min()), max_value=float(wei.max())) * layer_list[index].mask [a, b, c, d] = quan_wei.shape #print(quan_wei.data.cpu().numpy().reshape((a,b*c*d)).shape) #print(type(quan_wei.data.cpu().numpy().reshape((a,b*c*d)))) #np.save('to_yang',quan_wei.data.cpu().numpy().reshape((a,b*c*d))) logging.info('Magnitude of weights: ' + str(quan_wei.data.abs().mean().item())) wei.data = torch.Tensor( sketch_transform(quan_wei.data.cpu().numpy().reshape((a, b * c * d)), 8, compre_ratio, float_step).reshape((a, b, c, d))) logging.info('Magnitude of distance: ' + str(torch.abs(wei.data - quan_wei.data.cpu()).mean().item())) logging.info('STD of distance: ' + str(torch.abs(wei.data - quan_wei.data.cpu()).std().item())) wei.cuda()
def sketch_layer(model, index): layer_list = get_layer_list(model) wei = layer_list[index].weight float_step = (float(wei.max()) - float(wei.min())) / (2.** 8 - 1.) #print(float_step) quan_wei = quantize(wei, num_bits=8, min_value=float(wei.min()), max_value=float(wei.max())) [a,b,c,d] = quan_wei.shape #print(quan_wei.data.cpu().numpy().reshape((a,b*c*d)).shape) #print(type(quan_wei.data.cpu().numpy().reshape((a,b*c*d)))) np.save('to_yang',quan_wei.data.cpu().numpy().reshape((a,b*c*d))) wei.data=torch.Tensor(sketch_transform(quan_wei.data.cpu().numpy().reshape((a,b*c*d)), 8, 0.6, float_step).reshape((a,b,c,d))) wei.cuda()