def main():
    start_epoch = 0
    best_prec1, best_prec5 = 0.0, 0.0

    ckpt = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    # Data loading
    print('=> Preparing data..')
    logging.info('=> Preparing data..')

    #loader = import_module('data.' + args.dataset).Data(args)

    # while(1):
    #     a=1

    traindir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet','ILSVRC2012_img_train')
    valdir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet','ILSVRC2012_img_val')
    normalize = transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
    # train_dataset = datasets.ImageFolder(
    #     traindir,
    #     transforms.Compose([
    #         transforms.RandomResizedCrop(224),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         normalize,
    #     ]))

    # train_loader = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=batch_sizes, shuffle=True,
    #     num_workers=8, pin_memory=True, sampler=None)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=256, shuffle=False,
        num_workers=8, pin_memory=True)

    traindir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet/', 'ILSVRC2012_img_train_rec')
    valdir = os.path.join('/mnt/cephfs_new_wj/cv/ImageNet/', 'ILSVRC2012_img_val_rec')


    train_queue = getTrainValDataset(traindir, valdir, batch_size=batch_size, val_batch_size=batch_size,
                                     num_shards=num_gpu, workers=num_workers)
    valid_queue = getTestDataset(valdir, test_batch_size=batch_size, num_shards=num_gpu,
                                 workers=num_workers)

    #loader = cifar100(args)

    # Create model
    print('=> Building model...')
    logging.info('=> Building model...')
    criterion = nn.CrossEntropyLoss()

    # Fine tune from a checkpoint
    refine = args.refine
    assert refine is not None, 'refine is required'
    checkpoint = torch.load(refine, map_location=torch.device(f"cuda:{args.gpus[0]}"))


    if args.pruned:
        mask = checkpoint['mask']
        model = resnet_56_sparse(has_mask = mask).to(args.gpus[0])
        model.load_state_dict(checkpoint['state_dict_s'])
    else:
        model = prune_resnet(args, checkpoint['state_dict_s'])

    # model = torchvision.models.resnet18()

    with torch.cuda.device(0):
        flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
        print('Flops:  ' + flops)
        print('Params: ' + params)
    pruned_dir = args.pruned_dir
    checkpoint_pruned = torch.load(pruned_dir, map_location=torch.device(f"cuda:{args.gpus[0]}"))
    model = torch.nn.DataParallel(model)
    #
    # new_state_dict_pruned = OrderedDict()
    # for k, v in checkpoint_pruned.items():
    #     name = k[7:]
    #     new_state_dict_pruned[name] = v
    # model.load_state_dict(new_state_dict_pruned)

    model.load_state_dict(checkpoint_pruned['state_dict_s'])

    test_prec1, test_prec5 = test(args, valid_queue, model, criterion, writer_test)
    logging.info('Simply test after prune: %e ', test_prec1)
    logging.info('Model size: %e ', get_parameters_size(model)/1e6)

    exit()

    if args.test_only:
        return
    param_s = [param for name, param in model.named_parameters() if 'mask' not in name]
    #optimizer = optim.SGD(model.parameters(), lr=args.lr * 0.00001, momentum=args.momentum,weight_decay=args.weight_decay)
    optimizer = optim.SGD(param_s, lr=1e-5, momentum=args.momentum,weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_decay_step, gamma=0.1)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.num_epochs))


    model_kd = None
    if kd_flag:
        model_kd = ResNet101()
        ckpt_kd = torch.load('resnet101.t7', map_location=torch.device(f"cuda:{args.gpus[0]}"))
        state_dict_kd = ckpt_kd['net']
        new_state_dict_kd = OrderedDict()
        for k, v in state_dict_kd.items():
            name = k[7:]
            new_state_dict_kd[name] = v
    #print(new_state_dict_kd)
        model_kd.load_state_dict(new_state_dict_kd)
        model_kd = model_kd.to(args.gpus[1])

    resume = args.resume
    if resume:
        print('=> Loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume, map_location=torch.device(f"cuda:{args.gpus[0]}"))
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict_s'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print('=> Continue from epoch {}...'.format(start_epoch))
    #print(model.named_parameters())
    #for name, param in model.named_parameters():
        #print(name)
    for epoch in range(start_epoch, 60):
        scheduler.step()#scheduler.step(epoch)
        t1 = time.time()
        train(args, train_queue, model, criterion, optimizer, writer_train, epoch, model_kd)
        test_prec1, test_prec5 = test(args, valid_queue, model, criterion, writer_test, epoch)
        t2 = time.time()
        print(epoch, t2 - t1)
        logging.info('TEST Top1: %e Top5: %e ', test_prec1, test_prec5)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
        logging.info('Best Top1: %e Top5: %e ', best_prec1, best_prec5)

        state = {
            'state_dict_s': model.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1
        }

        ckpt.save_model(state, epoch + 1, is_best)
        train_queue.reset()
        valid_queue.reset()

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
    logging.info('Best Top1: %e Top5: %e ', best_prec1, best_prec5)
示例#2
0
def prune_resnet(args, state_dict):
    thre = args.thre
    num_layers = int(args.student_model.split('_')[1])#层数
    n = (num_layers - 2) // 6
    layers = np.arange(0, 3*n ,n)
 
    mask_block = []
    for name, weight in state_dict.items():
        if 'mask' in name:
            mask_block.append(weight.item())

    pruned_num = sum(m <= thre for m in mask_block)#剪枝数
    pruned_blocks = [int(m) for m in np.argwhere(np.array(mask_block) <= thre)]#返回小于阈值的索引

    old_block = 0
    layer = 'layer1'
    layer_num = int(layer[-1])
    new_block = 0
    new_state_dict = OrderedDict()

    for key, value in state_dict.items():# 将model_best的权值,根据mask去保留或者删除
        if 'layer' in key:
            if key.split('.')[0] != layer:
                layer = key.split('.')[0]
                layer_num = int(layer[-1])
                new_block = 0

            if key.split('.')[1] != old_block:
                old_block = key.split('.')[1]

            if mask_block[layers[layer_num-1] + int(old_block)] == 0:#如果对应的mask码为0
                if layer_num != 1 and old_block == '0' and 'mask' in key:
                    new_block = 1
                continue

            new_key = re.sub(r'\.\d+\.', '.{}.'.format(new_block), key, 1)
            if 'mask' in new_key: 
                new_block += 1

            new_state_dict[new_key] = state_dict[key]#数据传输,保留的block放在new_state_dict中

        else:
            new_state_dict[key] = state_dict[key]

    model = resnet_56_sparse(has_mask=mask_block).to(args.gpus[0])#模型修剪

    print('\n---- After Prune ----\n')
    print(f"Pruned / Total: {pruned_num} / {len(mask_block)}")
    print("Pruned blocks", pruned_blocks)

    save_dir = f'{args.job_dir}/pruned.pt'
    print(f'Saving pruned model to {save_dir}...')
    
    save_state_dict = {}
    save_state_dict['state_dict_s'] = new_state_dict
    save_state_dict['mask'] = mask_block
    torch.save(save_state_dict, save_dir)

    if not args.random:
        model.load_state_dict(new_state_dict)

    return model
示例#3
0
def main():
    start_epoch = 0
    best_prec1, best_prec5 = 0.0, 0.0

    ckpt = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    # Data loading
    print('=> Preparing data..')
    loader = import_module('data.' + args.dataset).Data(args)

    # Create model
    print('=> Building model...')
    criterion = nn.CrossEntropyLoss()

    # Fine tune from a checkpoint
    refine = args.refine
    assert refine is not None, 'refine is required'
    checkpoint = torch.load(refine, map_location=torch.device(f"cuda:{args.gpus[0]}"))
        
    if args.pruned:
        mask = checkpoint['mask']
        pruned = sum([1 for m in mask if mask == 0])
        print(f"Pruned / Total: {pruned} / {len(mask)}")
        model = resnet_56_sparse(has_mask = mask).to(args.gpus[0])
        model.load_state_dict(checkpoint['state_dict_s'])
    else:
        model = prune_resnet(args, checkpoint['state_dict_s'])

    test_prec1, test_prec5 = test(args, loader.loader_test, model, criterion, writer_test)
    print(f"Simply test after prune {test_prec1:.3f}")
    
    if args.test_only:
        return 

    if args.keep_grad:
        for name, weight in model.named_parameters():
            if 'mask' in name:
                weight.requires_grad = False

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume, map_location=torch.device(f"cuda:{args.gpus[0]}"))
        start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print('=> Continue from epoch {}...'.format(start_epoch))

    for epoch in range(start_epoch, args.num_epochs):
        scheduler.step(epoch)

        train(args, loader.loader_train, model, criterion, optimizer, writer_train, epoch)
        test_prec1, test_prec5 = test(args, loader.loader_test, model, criterion, writer_test, epoch)

        is_best_finetune = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        state = {
            'state_dict_s': model.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1
        }

        ckpt.save_model(state, epoch + 1, False, is_best_finetune)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")
示例#4
0
文件: main.py 项目: yunhengzi/GAL
def main():
    checkpoint = utils.checkpoint(args)
    writer_train = SummaryWriter(args.job_dir + '/run/train')
    writer_test = SummaryWriter(args.job_dir + '/run/test')

    start_epoch = 0
    best_prec1 = 0.0
    best_prec5 = 0.0

    # Data loading
    print('=> Preparing data..')
    loader = cifar10(args)

    # Create model
    print('=> Building model...')
    model_t = resnet_56().to(args.gpus[0])

    # Load teacher model
    ckpt_t = torch.load(args.teacher_dir,
                        map_location=torch.device(f"cuda:{args.gpus[0]}"))
    state_dict_t = ckpt_t['state_dict']
    model_t.load_state_dict(state_dict_t)
    model_t = model_t.to(args.gpus[0])

    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = resnet_56_sparse().to(args.gpus[0])

    model_dict_s = model_s.state_dict()
    model_dict_s.update(state_dict_t)
    model_s.load_state_dict(model_dict_s)

    if len(args.gpus) != 1:
        model_s = nn.DataParallel(model_s, device_ids=args.gpus)

    model_d = Discriminator().to(args.gpus[0])

    models = [model_t, model_s, model_d]

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

    param_s = [
        param for name, param in model_s.named_parameters()
        if 'mask' not in name
    ]
    param_m = [
        param for name, param in model_s.named_parameters() if 'mask' in name
    ]

    optimizer_s = optim.SGD(param_s,
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(param_m, lr=args.lr, gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    resume = args.resume
    if resume:
        print('=> Resuming from ckpt {}'.format(resume))
        ckpt = torch.load(resume,
                          map_location=torch.device(f"cuda:{args.gpus[0]}"))
        best_prec1 = ckpt['best_prec1']
        start_epoch = ckpt['epoch']
        model_s.load_state_dict(ckpt['state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        print('=> Continue from epoch {}...'.format(start_epoch))

    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    if args.test_only:
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)
        print('=> Test Prec@1: {:.2f}'.format(test_prec1))
        return

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            s.step(epoch)

        train(args, loader.loader_train, models, optimizers, epoch,
              writer_train)
        test_prec1, test_prec5 = test(args, loader.loader_test, model_s)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1)
        best_prec5 = max(test_prec5, best_prec5)

        model_state_dict = model_s.module.state_dict() if len(
            args.gpus) > 1 else model_s.state_dict()

        state = {
            'state_dict_s': model_state_dict,
            'state_dict_d': model_d.state_dict(),
            'best_prec1': best_prec1,
            'best_prec5': best_prec5,
            'optimizer_d': optimizer_d.state_dict(),
            'optimizer_s': optimizer_s.state_dict(),
            'optimizer_m': optimizer_m.state_dict(),
            'scheduler_d': scheduler_d.state_dict(),
            'scheduler_s': scheduler_s.state_dict(),
            'scheduler_m': scheduler_m.state_dict(),
            'epoch': epoch + 1
        }
        checkpoint.save_model(state, epoch + 1, is_best)

    print(f"=> Best @prec1: {best_prec1:.3f} @prec5: {best_prec5:.3f}")

    best_model = torch.load(f'{args.job_dir}/checkpoint/model_best.pt',
                            map_location=torch.device(f"cuda:{args.gpus[0]}"))

    model = prune_resnet(args, best_model['state_dict_s'])