def main(): # get arguments global args args = parser.parse_args() args.block = '' if args.block == 'all' else args.block # student model to quantize student = models.__dict__[args.model](pretrained=True).cuda() student.eval() criterion = nn.CrossEntropyLoss().cuda() cudnn.benchmark = True # layers to quantize (we do not quantize the first 7x7 convolution layer) watcher = ActivationWatcher(student) layers = [layer for layer in watcher.layers[1:] if args.block in layer] # data loading code train_loader, test_loader = load_data(data_path=args.data_path, batch_size=args.batch_size, nb_workers=args.n_workers) # parameters for the centroids optimizer opt_centroids_params_all = [] # book-keeping for compression statistics (in MB) size_uncompressed = compute_size(student) size_index = 0 size_centroids = 0 size_other = size_uncompressed # teacher model teacher = models.__dict__[args.model](pretrained=True).cuda() teacher.eval() # Step 1: iteratively quantize the network layers (quantization + layer-wise centroids distillation) print('Step 1: Quantize network') t = time.time() top_1 = 0 for layer in layers: # gather input activations n_iter_activations = math.ceil(args.n_activations / args.batch_size) watcher = ActivationWatcher(student, layer=layer) in_activations_current = watcher.watch(train_loader, criterion, n_iter_activations) in_activations_current = in_activations_current[layer] # get weight matrix and detach it from the computation graph (.data should be enough, adding .detach() as a safeguard) M = attrgetter(layer + '.weight.data')(student).detach() sizes = M.size() is_conv = len(sizes) == 4 # get padding and stride attributes padding = attrgetter(layer)(student).padding if is_conv else 0 stride = attrgetter(layer)(student).stride if is_conv else 1 groups = attrgetter(layer)(student).groups if is_conv else 1 # block size, distinguish between fully connected and convolutional case if is_conv: out_features, in_features, k, _ = sizes block_size = args.block_size_cv if k > 1 else args.block_size_pw n_centroids = args.n_centroids_cv if k > 1 else args.n_centroids_pw n_blocks = in_features * k * k // block_size else: k = 1 out_features, in_features = sizes block_size = args.block_size_fc n_centroids = args.n_centroids_fc n_blocks = in_features // block_size # clamp number of centroids for stability powers = 2 ** np.arange(0, 16, 1) n_vectors = np.prod(sizes) / block_size idx_power = bisect_left(powers, n_vectors / args.n_centroids_threshold) n_centroids = min(n_centroids, powers[idx_power - 1]) # compression rations bits_per_weight = np.log2(n_centroids) / block_size # number of bits per weight size_index_layer = bits_per_weight * M.numel() / 8 / 1024 / 1024 size_index += size_index_layer # centroids stored in float16 size_centroids_layer = n_centroids * block_size * 2 / 1024 / 1024 size_centroids += size_centroids_layer # size of non-compressed layers, e.g. BatchNorms or first 7x7 convolution size_uncompressed_layer = M.numel() * 4 / 1024 / 1024 size_other -= size_uncompressed_layer # number of samples n_samples = dynamic_sampling(layer) # print layer size print('Quantizing layer: {}, size: {}, n_blocks: {}, block size: {}, ' \ 'centroids: {}, bits/weight: {:.2f}, compressed size: {:.2f} MB'.format( layer, list(sizes), n_blocks, block_size, n_centroids, bits_per_weight, size_index_layer + size_centroids_layer)) # quantizer quantizer = PQ(in_activations_current, M, n_activations=args.n_activations, n_samples=n_samples, eps=args.eps, n_centroids=n_centroids, n_iter=args.n_iter, n_blocks=n_blocks, k=k, stride=stride, padding=padding, groups=groups) if len(args.restart) > 0: # do not quantize already quantized layers try: # load centroids and assignments if already stored quantizer.load(args.restart, layer) centroids = quantizer.centroids assignments = quantizer.assignments # quantize weight matrix M_hat = weight_from_centroids(centroids, assignments, n_blocks, k, is_conv) attrgetter(layer + '.weight')(student).data = M_hat quantizer.save(args.save, layer) # optimizer for global finetuning parameters = [p for (n, p) in student.named_parameters() if layer in n and 'bias' not in n] centroids_params = {'params': parameters, 'assignments': assignments, 'kernel_size': k, 'n_centroids': n_centroids, 'n_blocks': n_blocks} opt_centroids_params_all.append(centroids_params) # proceed to next layer print('Layer already quantized, proceeding to next layer\n') continue # otherwise, quantize layer except FileNotFoundError: print('Quantizing layer') # quantize layer quantizer.encode() # assign quantized weight matrix M_hat = quantizer.decode() attrgetter(layer + '.weight')(student).data = M_hat # top1 top_1 = evaluate(test_loader, student, criterion).item() # book-keeping print('Quantizing time: {:.0f}min, Top1 after quantization: {:.2f}\n'.format((time.time() - t) / 60, top_1)) t = time.time() # Step 2: finetune centroids print('Finetuning centroids') # optimizer for centroids parameters = [p for (n, p) in student.named_parameters() if layer in n and 'bias' not in n] assignments = quantizer.assignments centroids_params = {'params': parameters, 'assignments': assignments, 'kernel_size': k, 'n_centroids': n_centroids, 'n_blocks': n_blocks} # remember centroids parameters to finetuning at the end opt_centroids_params = [centroids_params] opt_centroids_params_all.append(centroids_params) # custom optimizer optimizer_centroids = CentroidSGD(opt_centroids_params, lr=args.lr_centroids, momentum=args.momentum_centroids, weight_decay=args.weight_decay_centroids) # standard training loop n_iter = args.finetune_centroids scheduler = torch.optim.lr_scheduler.StepLR(optimizer_centroids, step_size=1, gamma=0.1) for epoch in range(1): finetune_centroids(train_loader, student, teacher, criterion, optimizer_centroids, n_iter=n_iter) top_1 = evaluate(test_loader, student, criterion) scheduler.step() print('Epoch: {}, Top1: {:.2f}'.format(epoch, top_1)) print('After {} iterations with learning rate {}, Top1: {:.2f}'.format(n_iter, args.lr_centroids, top_1)) # book-keeping print('Finetuning centroids time: {:.0f}min, Top1 after finetuning centroids: {:.2f}\n'.format((time.time() - t) / 60, top_1)) t = time.time() # saving M_hat = attrgetter(layer + '.weight')(student).data centroids = centroids_from_weights(M_hat, assignments, n_centroids, n_blocks) quantizer.centroids = centroids quantizer.save(args.save, layer) # End of compression + finetuning of centroids size_compressed = size_index + size_centroids + size_other print('End of compression, non-compressed teacher model: {:.2f}MB, compressed student model ' \ '(indexing + centroids + other): {:.2f}MB + {:.2f}MB + {:.2f}MB = {:.2f}MB, compression ratio: {:.2f}x\n'.format( size_uncompressed, size_index, size_centroids, size_other, size_compressed, size_uncompressed / size_compressed)) # Step 3: finetune whole network print('Step 3: Finetune whole network') t = time.time() # custom optimizer optimizer_centroids_all = CentroidSGD(opt_centroids_params_all, lr=args.lr_whole, momentum=args.momentum_whole, weight_decay=args.weight_decay_whole) # standard training loop n_iter = args.finetune_whole scheduler = torch.optim.lr_scheduler.StepLR(optimizer_centroids_all, step_size=args.finetune_whole_step_size, gamma=0.1) for epoch in range(args.finetune_whole_epochs): student.train() finetune_centroids(train_loader, student, teacher, criterion, optimizer_centroids_all, n_iter=n_iter) top_1 = evaluate(test_loader, student, criterion) scheduler.step() print('Epoch: {}, Top1: {:.2f}'.format(epoch, top_1)) # state dict pf compressed model state_dict_compressed = {} # save conv1 (not quantized) state_dict_compressed['conv1'] = student.conv1.state_dict() # save biases of the classifier state_dict_compressed['fc_bias'] = {'bias': student.fc.bias} # save batch norms bn_layers = watcher._get_bn_layers() for bn_layer in bn_layers: state_dict_compressed[bn_layer] = attrgetter(bn_layer)(student).state_dict() # save quantized layers for layer in layers: # stats M = attrgetter(layer + '.weight.data')(student).detach() sizes = M.size() is_conv = len(sizes) == 4 # get padding and stride attributes padding = attrgetter(layer)(student).padding if is_conv else 0 stride = attrgetter(layer)(student).stride if is_conv else 1 groups = attrgetter(layer)(student).groups if is_conv else 1 # block size, distinguish between fully connected and convolutional case if is_conv: out_features, in_features, k, _ = sizes block_size = args.block_size_cv if k > 1 else args.block_size_pw n_centroids = args.n_centroids_cv n_blocks = in_features * k * k // block_size else: k = 1 out_features, in_features = sizes block_size = args.block_size_fc n_centroids = args.n_centroids_fc n_blocks = in_features // block_size # clamp number of centroids for stability powers = 2 ** np.arange(0, 16, 1) n_vectors = np.prod(sizes) / block_size idx_power = bisect_left(powers, n_vectors / args.n_centroids_threshold) n_centroids = min(n_centroids, powers[idx_power - 1]) # save quantizer.load(args.save, layer) assignments = quantizer.assignments M_hat = attrgetter(layer + '.weight')(student).data centroids = centroids_from_weights(M_hat, assignments, n_centroids, n_blocks) quantizer.centroids = centroids quantizer.save(args.save, layer) state_dict_layer = { 'centroids': centroids.half(), 'assignments': assignments.short() if 'fc' in layer else assignments.byte(), 'n_blocks': n_blocks, 'is_conv': is_conv, 'k': k } state_dict_compressed[layer] = state_dict_layer # save model torch.save(state_dict_compressed, os.path.join(args.save, 'state_dict_compressed.pth')) # book-keeping print('Finetuning whole network time: {:.0f}min, Top1 after finetuning centroids: {:.2f}\n'.format((time.time() - t) / 60, top_1))
def main(): torch.cuda.empty_cache() student = real_nvp_model( pretrained=True) # resnet.resnet18_1(pretrained=True).cuda() student.eval() cudnn.benchmark = True criterion = real_nvp_loss.RealNVPLoss().cuda() transform_train = transforms.Compose( [transforms.RandomHorizontalFlip(), transforms.ToTensor()]) trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train) trainloader = data.DataLoader(trainset, batch_size=16, shuffle=True, num_workers=0) transform_test = transforms.Compose([transforms.ToTensor()]) testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test) testloader = data.DataLoader(testset, batch_size=16, shuffle=False, num_workers=0) # parameters for the centroids optimizer opt_centroids_params_all = [] # book-keeping for compression statistics (in MB) size_uncompressed = compute_size(student) #44.591949462890625 mb size_index = 0 size_centroids = 0 size_other = size_uncompressed teacher = real_nvp_model(pretrained=True) teacher.eval() watcher = ActivationWatcher(student) layers = [] i = 1 for layer in watcher.layers[1:]: if i % 2 == 0: layers.append(layer) i = i + 1 restart = 1 for layer in layers[0:50]: print(layer) torch.cuda.empty_cache() n_iter_activations = math.ceil(1024 / 32) watcher = ActivationWatcher(student, layer=layer) in_activations_current = watcher.watch(trainloader, criterion, n_iter_activations) in_activations_current = in_activations_current[layer] M = attrgetter(layer + '.weight.data')(student).detach() sizes = M.size() is_conv = len(sizes) == 4 padding = attrgetter(layer)(student).padding if is_conv else 0 stride = attrgetter(layer)(student).stride if is_conv else 1 groups = attrgetter(layer)(student).groups if is_conv else 1 if is_conv: out_features, in_features, k, _ = sizes block_size = 9 if k > 1 else 4 n_centroids = 128 if k > 1 else 128 n_blocks = in_features * k * k // block_size else: k = 1 out_features, in_features = sizes block_size = 4 n_centroids = 256 n_blocks = in_features // block_size powers = 2**np.arange(0, 16, 1) n_vectors = np.prod(sizes) / block_size #4096.0 idx_power = bisect_left(powers, n_vectors / 4) n_centroids = min(n_centroids, powers[idx_power - 1]) #128 # compression rations bits_per_weight = np.log2(n_centroids) / block_size #0.7778 # number of bits per weight size_index_layer = bits_per_weight * M.numel() / 8 / 1024 / 1024 size_index += size_index_layer #0.00341796875 # centroids stored in float16 size_centroids_layer = n_centroids * block_size * 2 / 1024 / 1024 size_centroids += size_centroids_layer # size of non-compressed layers, e.g. BatchNorms or first 7x7 convolution size_uncompressed_layer = M.numel() * 4 / 1024 / 1024 size_other -= size_uncompressed_layer n_samples = 1000 # quantizer quantizer = PQ(in_activations_current, M, n_activations=1024, n_samples=n_samples, eps=1e-8, n_centroids=n_centroids, n_iter=100, n_blocks=n_blocks, k=k, stride=stride, padding=padding, groups=groups) if restart: try: quantizer.load('', layer) centroids = quantizer.centroids assignments = quantizer.assignments # quantize weight matrix M_hat = weight_from_centroids(centroids, assignments, n_blocks, k, is_conv) attrgetter(layer + '.weight')(student).data = M_hat quantizer.save('', layer) # optimizer for global finetuning parameters = [ attrgetter(layer + '.weight.data')(student).detach() ] centroids_params = { 'params': parameters, 'assignments': assignments, 'kernel_size': k, 'n_centroids': n_centroids, 'n_blocks': n_blocks } opt_centroids_params_all.append(centroids_params) # proceed to next layer print('Layer already quantized, proceeding to next layer\n') continue except FileNotFoundError: print('Quantizing layer') # # quantize layer quantizer.encode() M_hat = quantizer.decode() attrgetter(layer + '.weight')(student).data = M_hat parameters = [] parameters = [attrgetter(layer + '.weight.data')(student).detach()] assignments = quantizer.assignments centroids_params = { 'params': parameters, 'assignments': assignments, 'kernel_size': k, 'n_centroids': n_centroids, 'n_blocks': n_blocks } opt_centroids_params_all.append(centroids_params) opt_centroids_params = [centroids_params] optimizer_centroids = CentroidSGD(opt_centroids_params, lr=0.01, momentum=0.9, weight_decay=0.0001) finetune_centroids(trainloader, student.eval(), teacher, criterion, optimizer_centroids, n_iter=100) bpd = evaluate(testloader, student, criterion) print('bits per dim:{:.4f} '.format(bpd)) scheduler = torch.optim.lr_scheduler.StepLR(optimizer_centroids, step_size=1, gamma=0.1) # saving M_hat = attrgetter(layer + '.weight')(student).data centroids = centroids_from_weights(M_hat, assignments, n_centroids, n_blocks) quantizer.centroids = centroids quantizer.save('', layer)