def quantize_model(model_orig,
                   quant_method,
                   param_bits,
                   batch_norm_bits,
                   layer_output_bits,
                   overflow_rate=0.0,
                   n_sample=10):

    model = copy.deepcopy(model_orig)
    state_dict = model.state_dict()
    state_dict_quant = OrderedDict()

    p_idx = 0
    b_idx = 0

    for idx, (k, v) in enumerate(state_dict.items()):
        if 'running' in k:
            if batch_norm_bits[b_idx] >= 32:
                state_dict_quant[k] = v
                continue
            else:
                bits = batch_norm_bits[b_idx]
                b_idx += 1
        else:
            bits = param_bits[p_idx]
            p_idx += 1

        if quant_method == 'linear':
            sf = bits - 1. - quant.compute_integral_part(
                v, overflow_rate=overflow_rate)
            v_quant = quant.linear_quantize(v, sf, bits=bits)
        elif quant_method == 'log':
            v_quant = quant.log_minmax_quantize(v, bits=bits)
        elif quant_method == 'minmax':
            v_quant = quant.min_max_quantize(v, bits=bits)
        else:
            v_quant = quant.tanh_quantize(v, bits=bits)

        state_dict_quant[k] = v_quant

    model.load_state_dict(state_dict_quant)

    model = quant.quantize_model_per_layer_output(
        model,
        layer_output_bits=layer_output_bits,
        overflow_rate=overflow_rate,
        counter=n_sample,
        type=quant_method)

    # To actually reduce param size by changing data type to float16
    # if layer_output_bits <= 16:
    #     model.half()

    return model
예제 #2
0
                continue
            else:
                bits = args.bn_bits
        else:
            bits = args.param_bits

        if args.quant_method == 'linear':
            sf = bits - 1. - quant.compute_integral_part(
                v, overflow_rate=args.overflow_rate)
            v_quant = quant.linear_quantize(v, sf, bits=bits)
        elif args.quant_method == 'log':
            v_quant = quant.log_minmax_quantize(v, bits=bits)
        elif args.quant_method == 'minmax':
            v_quant = quant.min_max_quantize(v, bits=bits)
        else:
            v_quant = quant.tanh_quantize(v, bits=bits)
        state_dict_quant[k] = v_quant

    model_quant.load_state_dict(state_dict_quant)

# quantize forward activation
if args.fwd_bits < 32:
    model_quant = quant.quantize_model_layer_output(
        model_quant,
        bits=args.fwd_bits,
        overflow_rate=args.overflow_rate,
        counter=args.n_sample,
        type=args.quant_method)
    if args.fwd_bits <= 16:
        model_quant.half()
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # Data
    print('==> Preparing dataset %s' % args.dataset)
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    if args.dataset == 'cifar10':
        dataloader = datasets.CIFAR10
        num_classes = 10
    else:
        dataloader = datasets.CIFAR100
        num_classes = 100

    trainset = dataloader(root=args.data_root,
                          train=True,
                          download=True,
                          transform=transform_train)
    trainloader = data.DataLoader(trainset,
                                  batch_size=args.train_batch,
                                  shuffle=True,
                                  num_workers=args.workers)

    testset = dataloader(root=args.data_root,
                         train=False,
                         download=False,
                         transform=transform_test)
    testloader = data.DataLoader(testset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=args.workers)

    # Model
    print("==> creating model '{}'".format(args.arch))
    if args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
            cardinality=args.cardinality,
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )
    elif args.arch.startswith('densenet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            growthRate=args.growthRate,
            compressionRate=args.compressionRate,
            dropRate=args.drop,
        )
    elif args.arch.startswith('wrn'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
            widen_factor=args.widen_factor,
            dropRate=args.drop,
        )
    elif args.arch.endswith('resnet'):
        model = models.__dict__[args.arch](
            num_classes=num_classes,
            depth=args.depth,
        )
    else:
        model = models.__dict__[args.arch](num_classes=num_classes)

    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    # Resume
    title = 'cifar-10-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.',
            'Valid Acc.'
        ])

    if args.quant_method:
        if args.param_bits < 32:
            state_dict = model.state_dict()
            state_dict_quant = OrderedDict()
            sf_dict = OrderedDict()
            for k, v in state_dict.items():
                if 'running' in k:
                    if args.bn_bits >= 32:
                        print("Ignoring {}".format(k))
                        state_dict_quant[k] = v
                        continue
                    else:
                        bits = args.bn_bits
                else:
                    bits = args.param_bits
                if args.quant_method == 'linear':
                    sf = bits - 1. - quant.compute_integral_part(
                        v, overflow_rate=args.overflow_rate)
                    v_quant = quant.linear_quantize(v, sf, bits=bits)
                elif args.quant_method == 'log':
                    v_quant = quant.log_minmax_quantize(v, bits=bits)
                elif args.quant_method == 'minmax':
                    v_quant = quant.min_max_quantize(v, bits=bits)
                else:
                    v_quant = quant.tanh_quantize(v, bits=bits)
                state_dict_quant[k] = v_quant
                print(k, bits)
            model.load_state_dict(state_dict_quant)
        print('model quanted')
    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(testloader, model, criterion, start_epoch,
                                   use_cuda)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))

        train_loss, train_acc = train(trainloader, model, criterion, optimizer,
                                      epoch, use_cuda)
        test_loss, test_acc = test(testloader, model, criterion, epoch,
                                   use_cuda)

        # append logger file
        logger.append(
            [state['lr'], train_loss, test_loss, train_acc, test_acc])

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            checkpoint=args.checkpoint)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))

    print('Best acc:')
    print(best_acc)
예제 #4
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch SVHN Example')
    parser.add_argument('--type', default='cifar10', help='|'.join(selector.known_models))
    parser.add_argument('--quant_method', default='linear', help='linear|minmax|log|tanh')
    parser.add_argument('--batch_size', type=int, default=100, help='input batch size for training (default: 64)')
    parser.add_argument('--gpu', default=None, help='index of gpus to use')
    parser.add_argument('--ngpu', type=int, default=8, help='number of gpus to use')
    parser.add_argument('--seed', type=int, default=117, help='random seed (default: 1)')
    parser.add_argument('--model_root', default='~/.torch/models/', help='folder to save the model')
    parser.add_argument('--data_root', default='/data/public_dataset/pytorch/', help='folder to save the model')
    parser.add_argument('--logdir', default='log/default', help='folder to save to the log')

    parser.add_argument('--input_size', type=int, default=224, help='input size of image')
    parser.add_argument('--n_sample', type=int, default=20, help='number of samples to infer the scaling factor')
    parser.add_argument('--param_bits', type=int, default=8, help='bit-width for parameters')
    parser.add_argument('--bn_bits', type=int, default=32, help='bit-width for running mean and std')
    parser.add_argument('--fwd_bits', type=int, default=8, help='bit-width for layer output')
    parser.add_argument('--overflow_rate', type=float, default=0.0, help='overflow rate')
    args = parser.parse_args()

    args.gpu = misc.auto_select_gpu(utility_bound=0, num_gpu=args.ngpu, selected_gpus=args.gpu)
    args.ngpu = len(args.gpu)
    misc.ensure_dir(args.logdir)
    args.model_root = misc.expand_user(args.model_root)
    args.data_root = misc.expand_user(args.data_root)
    args.input_size = 299 if 'inception' in args.type else args.input_size
    assert args.quant_method in ['linear', 'minmax', 'log', 'tanh']
    print("=================FLAGS==================")
    for k, v in args.__dict__.items():
        print('{}: {}'.format(k, v))
    print("========================================")

    assert torch.cuda.is_available(), 'no cuda'
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # load model and dataset fetcher
    model_raw, ds_fetcher, is_imagenet = selector.select(args.type, model_root=args.model_root)
    args.ngpu = args.ngpu if is_imagenet else 1

    # quantize parameters
    if args.param_bits < 32:
        state_dict = model_raw.state_dict()
        state_dict_quant = OrderedDict()
        sf_dict = OrderedDict()
        for k, v in state_dict.items():
            if 'running' in k:
                if args.bn_bits >=32:
                    print("Ignoring {}".format(k))
                    state_dict_quant[k] = v
                    continue
                else:
                    bits = args.bn_bits
            else:
                bits = args.param_bits

            if args.quant_method == 'linear':
                sf = bits - 1. - quant.compute_integral_part(v, overflow_rate=args.overflow_rate)
                v_quant  = quant.linear_quantize(v, sf, bits=bits)
            elif args.quant_method == 'log':
                v_quant = quant.log_minmax_quantize(v, bits=bits)
            elif args.quant_method == 'minmax':
                v_quant = quant.min_max_quantize(v, bits=bits)
            else:
                v_quant = quant.tanh_quantize(v, bits=bits)
            state_dict_quant[k] = v_quant
            print(k, bits)
        model_raw.load_state_dict(state_dict_quant)

    # quantize forward activation
    if args.fwd_bits < 32:
        model_raw = quant.duplicate_model_with_quant(model_raw, bits=args.fwd_bits, overflow_rate=args.overflow_rate,
                                                     counter=args.n_sample, type=args.quant_method)
        print(model_raw)
        val_ds_tmp = ds_fetcher(10, data_root=args.data_root, train=False, input_size=args.input_size)
        misc.eval_model(model_raw, val_ds_tmp, ngpu=1, n_sample=args.n_sample, is_imagenet=is_imagenet)

    # eval model
    val_ds = ds_fetcher(args.batch_size, data_root=args.data_root, train=False, input_size=args.input_size)
    acc1, acc5 = misc.eval_model(model_raw, val_ds, ngpu=args.ngpu, is_imagenet=is_imagenet)

    # print sf
    print(model_raw)
    res_str = "type={}, quant_method={}, param_bits={}, bn_bits={}, fwd_bits={}, overflow_rate={}, acc1={:.4f}, acc5={:.4f}".format(
        args.type, args.quant_method, args.param_bits, args.bn_bits, args.fwd_bits, args.overflow_rate, acc1, acc5)
    print(res_str)
    with open('acc1_acc5.txt', 'a') as f:
        f.write(res_str + '\n')
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)



    # Data
    print('==> Preparing dataset %s' % args.dataset)
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    if args.dataset == 'cifar10':
        dataloader = datasets.CIFAR10
        num_classes = 10
    else:
        dataloader = datasets.CIFAR100
        num_classes = 100


    trainset = dataloader(root=args.data_root, train=True, download=True, transform=transform_train)
    trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers)

    testset = dataloader(root=args.data_root, train=False, download=False, transform=transform_test)
    testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers)

    # Model   
    print("==> creating model '{}'".format(args.arch))
    if args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
                    cardinality=args.cardinality,
                    num_classes=num_classes,
                    depth=args.depth,
                    widen_factor=args.widen_factor,
                    dropRate=args.drop,
                )
    elif args.arch.startswith('densenet'):
        model = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                    growthRate=args.growthRate,
                    compressionRate=args.compressionRate,
                    dropRate=args.drop,
                )        
    elif args.arch.startswith('wrn'):
        model = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                    widen_factor=args.widen_factor,
                    dropRate=args.drop,
                )
    elif args.arch.endswith('resnet'):
        model = models.__dict__[args.arch](
                    num_classes=num_classes,
                    depth=args.depth,
                )
    else:
        model = models.__dict__[args.arch](num_classes=num_classes)

    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # Resume
    title = 'cifar-10-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

    if args.quant_method:
        if args.param_bits < 32:
            state_dict = model.state_dict()
            state_dict_quant = OrderedDict()
            sf_dict = OrderedDict()
            for k, v in state_dict.items():
                if 'running' in k:
                    if args.bn_bits >=32:
                        print("Ignoring {}".format(k))
                        state_dict_quant[k] = v
                        continue
                    else:
                        bits = args.bn_bits
                else:
                    bits = args.param_bits
                if args.quant_method == 'linear':
                    sf = bits - 1. - quant.compute_integral_part(v, overflow_rate=args.overflow_rate)
                    v_quant  = quant.linear_quantize(v, sf, bits=bits)
                elif args.quant_method == 'log':
                    v_quant = quant.log_minmax_quantize(v, bits=bits)
                elif args.quant_method == 'minmax':
                    v_quant = quant.min_max_quantize(v, bits=bits)
                else:
                    v_quant = quant.tanh_quantize(v, bits=bits)
                state_dict_quant[k] = v_quant
                print(k, bits)
            model.load_state_dict(state_dict_quant)
        print('model quanted')
    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(testloader, model, criterion, start_epoch, use_cuda)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr']))

        train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda)
        test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda)

        # append logger file
        logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer' : optimizer.state_dict(),
            }, is_best, checkpoint=args.checkpoint)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))

    print('Best acc:')
    print(best_acc)
예제 #6
0
                state_dict_quant[k] = v
                continue
            else:
                bits = args.bn_bits
        else:
            bits = args.param_bits

        if args.quant_method == 'linear':
            sf = bits - 1. - quant.compute_integral_part(v, overflow_rate=args.overflow_rate)
            v_quant  = quant.linear_quantize(v, sf, bits=bits)
        elif args.quant_method == 'log':
            v_quant = quant.log_minmax_quantize(v, bits=bits)
        elif args.quant_method == 'minmax':
            v_quant = quant.min_max_quantize(v, bits=bits)
        else:
            v_quant = quant.tanh_quantize(v, bits=bits)
        state_dict_quant[k] = v_quant
        print(k, bits)
    model_raw.load_state_dict(state_dict_quant)

# quantize forward activation
if args.fwd_bits < 32:
    model_raw = quant.duplicate_model_with_quant(model_raw, bits=args.fwd_bits, overflow_rate=args.overflow_rate,
                                                 counter=args.n_sample, type=args.quant_method)
    print(model_raw)
    val_ds_tmp = ds_fetcher(10, data_root=args.data_root, train=False, input_size=args.input_size)
    misc.eval_model(model_raw, val_ds_tmp, ngpu=1, n_sample=args.n_sample, is_imagenet=is_imagenet)

# eval model
val_ds = ds_fetcher(args.batch_size, data_root=args.data_root, train=False, input_size=args.input_size)
acc1, acc5 = misc.eval_model(model_raw, val_ds, ngpu=args.ngpu, is_imagenet=is_imagenet)