Ejemplo n.º 1
0
def main():
    # get arguments
    global args
    args = parser.parse_args()
    args.block = '' if args.block == 'all' else args.block

    # student model to quantize
    student = models.__dict__[args.model](pretrained=True).cuda()
    student.eval()
    criterion = nn.CrossEntropyLoss().cuda()
    cudnn.benchmark = True

    # layers to quantize (we do not quantize the first 7x7 convolution layer)
    watcher = ActivationWatcher(student)
    layers = [layer for layer in watcher.layers[1:] if args.block in layer]

    # data loading code
    train_loader, test_loader = load_data(data_path=args.data_path, batch_size=args.batch_size, nb_workers=args.n_workers)

    # parameters for the centroids optimizer
    opt_centroids_params_all = []

    # book-keeping for compression statistics (in MB)
    size_uncompressed = compute_size(student)
    size_index = 0
    size_centroids = 0
    size_other = size_uncompressed

    # teacher model
    teacher = models.__dict__[args.model](pretrained=True).cuda()
    teacher.eval()

    # Step 1: iteratively quantize the network layers (quantization + layer-wise centroids distillation)
    print('Step 1: Quantize network')
    t = time.time()
    top_1 = 0

    for layer in layers:
        #  gather input activations
        n_iter_activations = math.ceil(args.n_activations / args.batch_size)
        watcher = ActivationWatcher(student, layer=layer)
        in_activations_current = watcher.watch(train_loader, criterion, n_iter_activations)
        in_activations_current = in_activations_current[layer]

        # get weight matrix and detach it from the computation graph (.data should be enough, adding .detach() as a safeguard)
        M = attrgetter(layer + '.weight.data')(student).detach()
        sizes = M.size()
        is_conv = len(sizes) == 4

        # get padding and stride attributes
        padding = attrgetter(layer)(student).padding if is_conv else 0
        stride = attrgetter(layer)(student).stride if is_conv else 1
        groups = attrgetter(layer)(student).groups if is_conv else 1

        # block size, distinguish between fully connected and convolutional case
        if is_conv:
            out_features, in_features, k, _ = sizes
            block_size = args.block_size_cv if k > 1 else args.block_size_pw
            n_centroids = args.n_centroids_cv if k > 1 else args.n_centroids_pw
            n_blocks = in_features * k * k // block_size
        else:
            k = 1
            out_features, in_features = sizes
            block_size = args.block_size_fc
            n_centroids = args.n_centroids_fc
            n_blocks = in_features // block_size

        # clamp number of centroids for stability
        powers = 2 ** np.arange(0, 16, 1)
        n_vectors = np.prod(sizes) / block_size
        idx_power = bisect_left(powers, n_vectors / args.n_centroids_threshold)
        n_centroids = min(n_centroids, powers[idx_power - 1])

        # compression rations
        bits_per_weight = np.log2(n_centroids) / block_size

        # number of bits per weight
        size_index_layer = bits_per_weight * M.numel() / 8 / 1024 / 1024
        size_index += size_index_layer

        # centroids stored in float16
        size_centroids_layer = n_centroids * block_size * 2 / 1024 / 1024
        size_centroids += size_centroids_layer

        # size of non-compressed layers, e.g. BatchNorms or first 7x7 convolution
        size_uncompressed_layer = M.numel() * 4 / 1024 / 1024
        size_other -= size_uncompressed_layer

        # number of samples
        n_samples = dynamic_sampling(layer)

        # print layer size
        print('Quantizing layer: {}, size: {}, n_blocks: {}, block size: {}, ' \
              'centroids: {}, bits/weight: {:.2f}, compressed size: {:.2f} MB'.format(
               layer, list(sizes), n_blocks, block_size, n_centroids,
               bits_per_weight, size_index_layer + size_centroids_layer))

        # quantizer
        quantizer = PQ(in_activations_current, M, n_activations=args.n_activations,
                       n_samples=n_samples, eps=args.eps, n_centroids=n_centroids,
                       n_iter=args.n_iter, n_blocks=n_blocks, k=k,
                       stride=stride, padding=padding, groups=groups)

        if len(args.restart) > 0:
            # do not quantize already quantized layers
            try:
                # load centroids and assignments if already stored
                quantizer.load(args.restart, layer)
                centroids = quantizer.centroids
                assignments = quantizer.assignments

                # quantize weight matrix
                M_hat = weight_from_centroids(centroids, assignments, n_blocks, k, is_conv)
                attrgetter(layer + '.weight')(student).data = M_hat
                quantizer.save(args.save, layer)

                # optimizer for global finetuning
                parameters = [p for (n, p) in student.named_parameters() if layer in n and 'bias' not in n]
                centroids_params = {'params': parameters,
                                    'assignments': assignments,
                                    'kernel_size': k,
                                    'n_centroids': n_centroids,
                                    'n_blocks': n_blocks}
                opt_centroids_params_all.append(centroids_params)

                # proceed to next layer
                print('Layer already quantized, proceeding to next layer\n')
                continue

            # otherwise, quantize layer
            except FileNotFoundError:
                print('Quantizing layer')

        # quantize layer
        quantizer.encode()

        # assign quantized weight matrix
        M_hat = quantizer.decode()
        attrgetter(layer + '.weight')(student).data = M_hat

        # top1
        top_1 = evaluate(test_loader, student, criterion).item()

        # book-keeping
        print('Quantizing time: {:.0f}min, Top1 after quantization: {:.2f}\n'.format((time.time() - t) / 60, top_1))
        t = time.time()

        # Step 2: finetune centroids
        print('Finetuning centroids')

        # optimizer for centroids
        parameters = [p for (n, p) in student.named_parameters() if layer in n and 'bias' not in n]
        assignments = quantizer.assignments
        centroids_params = {'params': parameters,
                            'assignments': assignments,
                            'kernel_size': k,
                            'n_centroids': n_centroids,
                            'n_blocks': n_blocks}

        # remember centroids parameters to finetuning at the end
        opt_centroids_params = [centroids_params]
        opt_centroids_params_all.append(centroids_params)

        # custom optimizer
        optimizer_centroids = CentroidSGD(opt_centroids_params, lr=args.lr_centroids,
                                          momentum=args.momentum_centroids,
                                          weight_decay=args.weight_decay_centroids)

        # standard training loop
        n_iter = args.finetune_centroids
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer_centroids, step_size=1, gamma=0.1)

        for epoch in range(1):
            finetune_centroids(train_loader, student, teacher, criterion, optimizer_centroids, n_iter=n_iter)
            top_1 = evaluate(test_loader, student, criterion)
            scheduler.step()
            print('Epoch: {}, Top1: {:.2f}'.format(epoch, top_1))

        print('After {} iterations with learning rate {}, Top1: {:.2f}'.format(n_iter, args.lr_centroids, top_1))

        # book-keeping
        print('Finetuning centroids time: {:.0f}min, Top1 after finetuning centroids: {:.2f}\n'.format((time.time() - t) / 60, top_1))
        t = time.time()

        # saving
        M_hat = attrgetter(layer + '.weight')(student).data
        centroids = centroids_from_weights(M_hat, assignments, n_centroids, n_blocks)
        quantizer.centroids = centroids
        quantizer.save(args.save, layer)

    # End of compression + finetuning of centroids
    size_compressed = size_index + size_centroids + size_other
    print('End of compression, non-compressed teacher model: {:.2f}MB, compressed student model ' \
          '(indexing + centroids + other): {:.2f}MB + {:.2f}MB + {:.2f}MB = {:.2f}MB, compression ratio: {:.2f}x\n'.format(
          size_uncompressed, size_index, size_centroids, size_other, size_compressed, size_uncompressed / size_compressed))

    # Step 3: finetune whole network
    print('Step 3: Finetune whole network')
    t = time.time()

    # custom optimizer
    optimizer_centroids_all = CentroidSGD(opt_centroids_params_all, lr=args.lr_whole,
                                      momentum=args.momentum_whole,
                                      weight_decay=args.weight_decay_whole)

    # standard training loop
    n_iter = args.finetune_whole
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer_centroids_all, step_size=args.finetune_whole_step_size, gamma=0.1)

    for epoch in range(args.finetune_whole_epochs):
        student.train()
        finetune_centroids(train_loader, student, teacher, criterion, optimizer_centroids_all, n_iter=n_iter)
        top_1 = evaluate(test_loader, student, criterion)
        scheduler.step()
        print('Epoch: {}, Top1: {:.2f}'.format(epoch, top_1))

    # state dict pf compressed model
    state_dict_compressed = {}

    # save conv1 (not quantized)
    state_dict_compressed['conv1'] = student.conv1.state_dict()

    # save biases of the classifier
    state_dict_compressed['fc_bias'] = {'bias': student.fc.bias}

    # save batch norms
    bn_layers = watcher._get_bn_layers()

    for bn_layer in bn_layers:
        state_dict_compressed[bn_layer] = attrgetter(bn_layer)(student).state_dict()

    # save quantized layers
    for layer in layers:

        # stats
        M = attrgetter(layer + '.weight.data')(student).detach()
        sizes = M.size()
        is_conv = len(sizes) == 4

        # get padding and stride attributes
        padding = attrgetter(layer)(student).padding if is_conv else 0
        stride = attrgetter(layer)(student).stride if is_conv else 1
        groups = attrgetter(layer)(student).groups if is_conv else 1

        # block size, distinguish between fully connected and convolutional case
        if is_conv:
            out_features, in_features, k, _ = sizes
            block_size = args.block_size_cv if k > 1 else args.block_size_pw
            n_centroids = args.n_centroids_cv
            n_blocks = in_features * k * k // block_size
        else:
            k = 1
            out_features, in_features = sizes
            block_size = args.block_size_fc
            n_centroids = args.n_centroids_fc
            n_blocks = in_features // block_size

        # clamp number of centroids for stability
        powers = 2 ** np.arange(0, 16, 1)
        n_vectors = np.prod(sizes) / block_size
        idx_power = bisect_left(powers, n_vectors / args.n_centroids_threshold)
        n_centroids = min(n_centroids, powers[idx_power - 1])

        # save
        quantizer.load(args.save, layer)
        assignments = quantizer.assignments
        M_hat = attrgetter(layer + '.weight')(student).data
        centroids = centroids_from_weights(M_hat, assignments, n_centroids, n_blocks)
        quantizer.centroids = centroids
        quantizer.save(args.save, layer)
        state_dict_layer = {
            'centroids': centroids.half(),
            'assignments': assignments.short() if 'fc' in layer else assignments.byte(),
            'n_blocks': n_blocks,
            'is_conv': is_conv,
            'k': k
        }
        state_dict_compressed[layer] = state_dict_layer

    # save model
    torch.save(state_dict_compressed, os.path.join(args.save, 'state_dict_compressed.pth'))

    # book-keeping
    print('Finetuning whole network time: {:.0f}min, Top1 after finetuning centroids: {:.2f}\n'.format((time.time() - t) / 60, top_1))
Ejemplo n.º 2
0
def main():
    torch.cuda.empty_cache()

    student = real_nvp_model(
        pretrained=True)  # resnet.resnet18_1(pretrained=True).cuda()
    student.eval()
    cudnn.benchmark = True

    criterion = real_nvp_loss.RealNVPLoss().cuda()
    transform_train = transforms.Compose(
        [transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])
    trainset = torchvision.datasets.CIFAR10(root='data',
                                            train=True,
                                            download=True,
                                            transform=transform_train)
    trainloader = data.DataLoader(trainset,
                                  batch_size=16,
                                  shuffle=True,
                                  num_workers=0)

    transform_test = transforms.Compose([transforms.ToTensor()])
    testset = torchvision.datasets.CIFAR10(root='data',
                                           train=False,
                                           download=True,
                                           transform=transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=16,
                                 shuffle=False,
                                 num_workers=0)

    # parameters for the centroids optimizer
    opt_centroids_params_all = []

    # book-keeping for compression statistics (in MB)
    size_uncompressed = compute_size(student)  #44.591949462890625 mb
    size_index = 0
    size_centroids = 0
    size_other = size_uncompressed

    teacher = real_nvp_model(pretrained=True)
    teacher.eval()

    watcher = ActivationWatcher(student)
    layers = []
    i = 1

    for layer in watcher.layers[1:]:
        if i % 2 == 0:
            layers.append(layer)
        i = i + 1

    restart = 1
    for layer in layers[0:50]:
        print(layer)
        torch.cuda.empty_cache()
        n_iter_activations = math.ceil(1024 / 32)
        watcher = ActivationWatcher(student, layer=layer)
        in_activations_current = watcher.watch(trainloader, criterion,
                                               n_iter_activations)
        in_activations_current = in_activations_current[layer]

        M = attrgetter(layer + '.weight.data')(student).detach()
        sizes = M.size()
        is_conv = len(sizes) == 4

        padding = attrgetter(layer)(student).padding if is_conv else 0
        stride = attrgetter(layer)(student).stride if is_conv else 1
        groups = attrgetter(layer)(student).groups if is_conv else 1

        if is_conv:

            out_features, in_features, k, _ = sizes
            block_size = 9 if k > 1 else 4
            n_centroids = 128 if k > 1 else 128
            n_blocks = in_features * k * k // block_size

        else:
            k = 1
            out_features, in_features = sizes
            block_size = 4
            n_centroids = 256
            n_blocks = in_features // block_size

        powers = 2**np.arange(0, 16, 1)
        n_vectors = np.prod(sizes) / block_size  #4096.0
        idx_power = bisect_left(powers, n_vectors / 4)
        n_centroids = min(n_centroids, powers[idx_power - 1])  #128

        # compression rations
        bits_per_weight = np.log2(n_centroids) / block_size  #0.7778
        # number of bits per weight
        size_index_layer = bits_per_weight * M.numel() / 8 / 1024 / 1024
        size_index += size_index_layer  #0.00341796875
        # centroids stored in float16
        size_centroids_layer = n_centroids * block_size * 2 / 1024 / 1024
        size_centroids += size_centroids_layer
        # size of non-compressed layers, e.g. BatchNorms or first 7x7 convolution
        size_uncompressed_layer = M.numel() * 4 / 1024 / 1024
        size_other -= size_uncompressed_layer
        n_samples = 1000

        # quantizer
        quantizer = PQ(in_activations_current,
                       M,
                       n_activations=1024,
                       n_samples=n_samples,
                       eps=1e-8,
                       n_centroids=n_centroids,
                       n_iter=100,
                       n_blocks=n_blocks,
                       k=k,
                       stride=stride,
                       padding=padding,
                       groups=groups)

        if restart:

            try:
                quantizer.load('', layer)
                centroids = quantizer.centroids
                assignments = quantizer.assignments
                # quantize weight matrix
                M_hat = weight_from_centroids(centroids, assignments, n_blocks,
                                              k, is_conv)
                attrgetter(layer + '.weight')(student).data = M_hat
                quantizer.save('', layer)
                # optimizer for global finetuning
                parameters = [
                    attrgetter(layer + '.weight.data')(student).detach()
                ]

                centroids_params = {
                    'params': parameters,
                    'assignments': assignments,
                    'kernel_size': k,
                    'n_centroids': n_centroids,
                    'n_blocks': n_blocks
                }
                opt_centroids_params_all.append(centroids_params)
                # proceed to next layer
                print('Layer already quantized, proceeding to next layer\n')
                continue
            except FileNotFoundError:
                print('Quantizing layer')

        #
        # quantize layer
        quantizer.encode()
        M_hat = quantizer.decode()
        attrgetter(layer + '.weight')(student).data = M_hat

        parameters = []
        parameters = [attrgetter(layer + '.weight.data')(student).detach()]
        assignments = quantizer.assignments
        centroids_params = {
            'params': parameters,
            'assignments': assignments,
            'kernel_size': k,
            'n_centroids': n_centroids,
            'n_blocks': n_blocks
        }
        opt_centroids_params_all.append(centroids_params)
        opt_centroids_params = [centroids_params]
        optimizer_centroids = CentroidSGD(opt_centroids_params,
                                          lr=0.01,
                                          momentum=0.9,
                                          weight_decay=0.0001)
        finetune_centroids(trainloader,
                           student.eval(),
                           teacher,
                           criterion,
                           optimizer_centroids,
                           n_iter=100)

        bpd = evaluate(testloader, student, criterion)
        print('bits per dim:{:.4f} '.format(bpd))
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer_centroids,
                                                    step_size=1,
                                                    gamma=0.1)

        # saving
        M_hat = attrgetter(layer + '.weight')(student).data
        centroids = centroids_from_weights(M_hat, assignments, n_centroids,
                                           n_blocks)
        quantizer.centroids = centroids
        quantizer.save('', layer)