Пример #1
0
def get_new_scale(dest_pr,
                  input_pr,
                  input_model,
                  input_pix,
                  input_before,
                  input_ratio=1.,
                  flops=False):
    e_pr = dest_pr
    v_pr = input_pr
    get_ratio = input_ratio

    if abs(v_pr - e_pr) < 0.002:
        get_ratio = 1.0
    else:
        model_pruning = copy.deepcopy(input_model)
        method1 = DPSS(model_pruning, args.lambda21, args.pr)
        method1.adjust_scale_coe(get_ratio)
        method1.channel_prune()
        model_pruning = method1.model
        params_pruning = utils.print_model_param_nums(model_pruning.module)
        flops_pruning = utils.count_model_param_flops(model_pruning.module,
                                                      input_pix)
        if not flops:
            v_pr = 1 - params_pruning / input_before
        else:
            v_pr = 1 - flops_pruning / input_before
        while v_pr > e_pr:
            get_ratio += 0.005
            model_pruning = copy.deepcopy(input_model)
            method1 = DPSS(model_pruning, args.lambda21, args.pr)
            method1.adjust_scale_coe(get_ratio)
            method1.channel_prune()
            model_pruning = method1.model
            params_pruning = utils.print_model_param_nums(model_pruning.module)
            flops_pruning = utils.count_model_param_flops(
                model_pruning.module, input_pix)
            if not flops:
                v_pr = 1 - params_pruning / input_before
            else:
                v_pr = 1 - flops_pruning / input_before
        else:
            while v_pr < e_pr:
                get_ratio -= 0.005
                model_pruning = copy.deepcopy(input_model)
                method1 = DPSS(model_pruning, args.lambda21, args.pr)
                method1.adjust_scale_coe(get_ratio)
                method1.channel_prune()
                model_pruning = method1.model
                params_pruning = utils.print_model_param_nums(
                    model_pruning.module)
                flops_pruning = utils.count_model_param_flops(
                    model_pruning.module, input_pix)
                if not flops:
                    v_pr = 1 - params_pruning / input_before
                else:
                    v_pr = 1 - flops_pruning / input_before

    return get_ratio
Пример #2
0
def main():
    global args
    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()

    # initial
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # Data loading code
    kwargs = {
        'num_workers': args.workers,
        'pin_memory': True
    } if args.cuda else {}
    if args.dataset == 'cifar10':
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.data_dir,
                             train=False,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.4914, 0.4824, 0.4467],
                                                      [0.2471, 0.2435, 0.2616])
                             ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)
        num_classes = 10
        input_pix = 32

    elif args.dataset == 'cifar100':
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(
                args.data_dir,
                train=False,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize([0.5071, 0.4867, 0.4408],
                                         [0.2675, 0.2565, 0.2761])
                ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)
        num_classes = 100
        input_pix = 32
    else:
        train_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(
                os.path.join(args.data_dir, 'train'),
                transforms.Compose([
                    # transforms.Scale(256),
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
                ])),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(
                os.path.join(args.data_dir, 'val'),
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
                ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)
        num_classes = 1000
        input_pix = 224

    # create model
    model = models.__dict__[args.arch]()
    if args.cuda:
        model.cuda()
    pruned_model = torch.load(args.model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # optionally resume from a checkpoint

    cudnn.benchmark = True
    test_accuracy = validate(val_loader, pruned_model, criterion)

    params_before = utils.print_model_param_nums(model)
    flops_before = utils.count_model_param_flops(model, input_pix)
    params_after = utils.print_model_param_nums(pruned_model.module)
    flops_after = utils.count_model_param_flops(pruned_model.module, input_pix)
    print('Before pruning: \n'
          'params: {}\t'
          'flops: {}\n'
          'After pruning: \n'
          'params: {}\t'
          'flops: {}\n'
          'params_ratio: {pratio:.2f}%\t'
          'flops_ratio: {fratio:.2f}%\n'
          'params_rate: {prate:.2f}\t'
          'flops_rate: {frate:.2f}\n'
          'Prec@1: {top1:.4f}\n'
          'Prec@5: {top5:.4f}'.format(
              params_before,
              flops_before,
              params_after,
              flops_after,
              pratio=(params_before - params_after) * 100. / params_before,
              fratio=(flops_before - flops_after) * 100. / flops_before,
              prate=params_before / params_after,
              frate=flops_before / flops_after,
              top1=test_accuracy[0],
              top5=test_accuracy[1]))
Пример #3
0
def main():
    global args, best_prec1, log, log1, log2, log3
    args = parser.parse_args()
    args.save_time = args.arch + '_experiment_' + args.dataset + '_' + '%.4f' % (
        args.lambda21) + '_' + str(args.flops) + '_flops_' + '%.3f' % (
            args.pr) + '_' + '%d' % (args.epochs)
    args.save_dir = os.path.join(
        './snapshot/', args.dataset + '_experiment_' + args.arch + '_' +
        'lambda21' + '_' + '%d' % (args.epochs))
    args.save_time = args.save_time + '_dpss_sigmoid_filter_sum_context_add_scratch_' + time.strftime(
        "%Y%m%d%H%M", time.localtime())
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    log = open(os.path.join(args.save_dir, '{}.log'.format(args.save_time)),
               'w')
    log1 = open(
        os.path.join(args.save_dir,
                     '{}_rate_epoch.txt'.format(args.save_time)), 'w')
    log2 = open(
        os.path.join(args.save_dir,
                     '{}_sparsity_ratio_epoch.txt'.format(args.save_time)),
        'w')
    log3 = open(
        os.path.join(
            args.save_dir,
            '{}_sparsity_allocation_ratio_epoch.txt'.format(args.save_time)),
        'w')
    model_save_name = os.path.join(args.save_dir, args.save_time)
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    print_log(
        'Parameters setting: \n'
        'epochs: {}\t'
        'lr: {}\n'.format(args.epochs, args.lr), log)

    # initial
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # Data loading code
    kwargs = {
        'num_workers': args.workers,
        'pin_memory': True
    } if args.cuda else {}
    if args.dataset == 'cifar10':
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            args.data_dir,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.Pad(4),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.4914, 0.4824, 0.4467],
                                     [0.2471, 0.2435, 0.2616])
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(args.data_dir,
                             train=False,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.4914, 0.4824, 0.4467],
                                                      [0.2471, 0.2435, 0.2616])
                             ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)
        num_classes = 10
        input_pix = 32

    elif args.dataset == 'cifar100':
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR100(
            args.data_dir,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.Pad(4),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.5071, 0.4867, 0.4408],
                                     [0.2675, 0.2565, 0.2761])
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(
                args.data_dir,
                train=False,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize([0.5071, 0.4867, 0.4408],
                                         [0.2675, 0.2565, 0.2761])
                ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)
        num_classes = 100
        input_pix = 32
    else:
        train_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(
                os.path.join(args.data_dir, 'train'),
                transforms.Compose([
                    # transforms.Scale(256),
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
                ])),
            batch_size=args.batch_size,
            shuffle=True,
            **kwargs)
        val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(
                os.path.join(args.data_dir, 'val'),
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
                ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)
        num_classes = 1000
        input_pix = 224

    # create model
    model = models.__dict__[args.arch]()
    if args.cuda:
        model.cuda()

    # multi-gpu
    gpu_num = torch.cuda.device_count()
    print_log('GPU NUM: {:2d}'.format(gpu_num), log)
    model = torch.nn.DataParallel(model, list(range(gpu_num))).cuda()
    # if gpu_num > 1:
    #     model = torch.nn.DataParallel(model, list(range(gpu_num))).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            model.load_state_dict(torch.load(args.resume))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    method = DPSS(model, args.lambda21, args.pr)
    model_pruning = copy.deepcopy(model)
    params_before = utils.print_model_param_nums(model_pruning)
    flops_before = utils.count_model_param_flops(model_pruning, input_pix)
    validate(val_loader, model, criterion)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    next_ratio = 1.
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        model_pruning = copy.deepcopy(model)
        method1 = DPSS(model_pruning, args.lambda21, args.pr)
        # method1.adjust_scale_coe(next_ratio)
        method1.channel_prune()
        print_log(str(method1.layer_sparsity_ratio), log2)
        params_pruning = utils.print_model_param_nums(method1.model.module)
        flops_pruning = utils.count_model_param_flops(method1.model.module,
                                                      input_pix)
        pruned_params = 1 - params_pruning / params_before
        pruned_flops = 1 - flops_pruning / flops_before
        if not args.flops:
            next_ratio = get_new_scale(args.pr, pruned_params,
                                       copy.deepcopy(model), input_pix,
                                       params_before, next_ratio)
        else:
            next_ratio = get_new_scale(args.pr, pruned_flops,
                                       copy.deepcopy(model), input_pix,
                                       flops_before, next_ratio, args.flops)
        method.adjust_scale_coe(next_ratio)
        model_pruning1 = copy.deepcopy(model)
        method2 = DPSS(model_pruning1, args.lambda21, args.pr)
        method2.adjust_scale_coe(next_ratio)
        method2.channel_prune()
        print_log(str(method2.layer_sparsity_allocation_ratio), log3)
        print(next_ratio)
        train_loss = train(train_loader, model, criterion, optimizer, epoch,
                           method, True)
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)
        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(model, model_save_name, is_best)

        print_log(
            '{}\t'
            '{trn_loss:.4f}\t'
            '{tst_acc:.4f}\t'
            '{pratio:.4f}\t'
            '{fratio:.4f}\t'.format(
                epoch,
                trn_loss=train_loss,
                tst_acc=prec1,
                pratio=(params_before - params_pruning) / params_before,
                fratio=(flops_before - flops_pruning) / flops_before), log1)
        print_log(
            '{}\t'
            '{trn_loss:.4f}\t'
            '{tst_acc:.4f}\t'
            '{pratio:.4f}\t'
            '{fratio:.4f}\t'.format(
                epoch,
                trn_loss=train_loss,
                tst_acc=prec1,
                pratio=(params_before - params_pruning) / params_before,
                fratio=(flops_before - flops_pruning) / flops_before), log)

    print_log('Get new model\n', log)
    model = torch.load(model_save_name + '_best.pth.tar')
    if args.cuda:
        model.cuda()
    method2 = DPSS(model, args.lambda21, args.pr)
    method2.channel_prune()
    model = method2.model
    if args.cuda:
        model.cuda()
    validate(val_loader, model, criterion)
    save_checkpoint(model, model_save_name + '_pruned', False)
    params_after = utils.print_model_param_nums(model.module)
    flops_after = utils.count_model_param_flops(model.module, input_pix)
    print_log(
        'Before pruning: \n'
        'params: {}\t'
        'flops: {}\n'
        'After pruning: \n'
        'params: {}\t'
        'flops: {}\n'
        'params_ratio: {pratio:.2f}%\t'
        'flops_ratio: {fratio:.2f}%\n'
        'params_rate: {prate:.2f}\t'
        'flops_rate: {frate:.2f}\n'
        'Prec@1: {top1:.4f}'.format(
            params_before,
            flops_before,
            params_after,
            flops_after,
            pratio=(params_before - params_after) * 100. / params_before,
            fratio=(flops_before - flops_after) * 100. / flops_before,
            prate=params_before / params_after,
            frate=flops_before / flops_after,
            top1=best_prec1), log)
    log.close()
    log1.close()
    log2.close()
    log3.close()
Пример #4
0
        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, n_class),
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, self.classifier[1].in_features)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels / m.groups
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()


if __name__ == '__main__':
    net = mobilenetv2().cuda()
    print(net)
    params_before = utils.print_model_param_nums(net)
    flops_before = utils.count_model_param_flops(net, 224)
    print(params_before)
    print(flops_before)