예제 #1
0
def do_masked_retrain(args,model,train_loader,test_loader,sparsity_type,prune_ratios, masks,base_model, masked_path):
    """=============="""
    """masked retrain"""
    """=============="""

    initial_rho = args.rho
    current_rho = initial_rho

    if args.masked_retrain:
        # load admm trained model
        print("Loading: " + base_model)
        model.load_state_dict(torch.load(base_model))
        model.cuda()
            
        if args.optmzr == "adam":
            optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        if args.optmzr == "sgd":
            optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=80, gamma=0.1)
        
        ADMM = admm.ADMM(model, sparsity_type, prune_ratios, rho = initial_rho) # rho doesn't matter here
        best_prec1 = [0]
        mask = admm.hard_prune(ADMM, model)
        masks.append(mask)
        saved_model_path = ''

        for epoch in range(1, args.epochs*2 + 1):
            scheduler.step()
            admm.masked_retrain(args, ADMM, model, device, train_loader, optimizer, epoch,masks)

            prec1 = test(args, model, device, test_loader)
            if prec1 > max(best_prec1):
                print("\n>_ Got better accuracy, saving model with accuracy {:.3f}% now...\n".format(prec1))
                print("Saving Model: "+ masked_path+"/cifar10_vgg{}_retrained_acc_{:.3f}_{}rhos_{}.pt".format(args.depth, prec1, args.rho_num, args.config_file))
                torch.save(model.state_dict(), masked_path+"/cifar10_vgg{}_retrained_acc_{:.3f}_{}rhos_{}.pt".format(args.depth, prec1, args.rho_num, args.config_file))
                saved_model_path = masked_path+"/cifar10_vgg{}_retrained_acc_{:.3f}_{}rhos_{}.pt".format(args.depth, prec1, args.rho_num, args.config_file)
                print("\n>_ Deleting previous model file with accuracy {:.3f}% now...\n".format(max(best_prec1)))
                if len(best_prec1) > 1:
                    os.remove(masked_path+"/cifar10_vgg{}_retrained_acc_{:.3f}_{}rhos_{}.pt".format(args.depth, max(best_prec1), args.rho_num, args.config_file))
                    
                    
            best_prec1.append(prec1)
                    
        admm.test_sparsity(ADMM, model)
                    
        print("Best Acc: {:.4f}".format(max(best_prec1)))
        return saved_model_path,mask     
                    
        """=============="""
        """masked retrain"""
        """=============="""
예제 #2
0
def do_admmtrain(args,model,train_loader,test_loader,sparsity_type,prune_ratios,masks,base_model_path,admm_path):
    """====================="""
    """ multi-rho admm train"""
    """====================="""
    initial_rho = args.rho
    current_rho = initial_rho
    
    if args.admm:
        for i in range(args.rho_num):
            current_rho = initial_rho * 10 ** i
            if i == 0:
                print("Loading" + base_model_path)
                model.load_state_dict(torch.load(base_model_path)) # admm train need basline model
                model.cuda()
            else:
                print("Loading: "+admm_path+"/cifar_vgg{}_{}_{}_{}.pt".format(args.depth, current_rho / 10, args.config_file, args.optmzr))
                model.load_state_dict(torch.load(admm_path+"/cifar_vgg{}_{}_{}_{}.pt".format(args.depth, current_rho / 10, args.config_file, args.optmzr)))
                model.cuda()
                
                
            ADMM = admm.ADMM(model, sparsity_type,prune_ratios, rho = current_rho)
            admm.admm_initialization(args, ADMM=ADMM, model=model)  # intialize Z variable
            
            # admm train
            best_prec1 = 0.
            lr = args.lr / 10
            if args.optmzr == "adam":
                optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=args.weight_decay)
            if args.optmzr == "sgd":
                optimizer = optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
            for epoch in range(1, args.epochs + 1):
                print("current rho: {}".format(current_rho))
                train(args, ADMM, model, device, train_loader, optimizer, epoch, writer,masks)

                prec1 = test(args, model, device, test_loader)
                best_prec1 = max(prec1, best_prec1)
                
            print("Best Acc: {:.4f}".format(best_prec1))
            print("Saving model: " + admm_path+"/cifar_vgg{}_{}_{}_{}.pt".format(args.depth, current_rho, args.config_file, args.optmzr))
            torch.save(model.state_dict(), admm_path+"/cifar_vgg{}_{}_{}_{}.pt".format(args.depth, current_rho, args.config_file, args.optmzr))

    return admm_path+"/cifar_vgg{}_{}_{}_{}.pt".format(args.depth, current_rho, args.config_file, args.optmzr)
예제 #3
0
def main_worker(gpu, ngpus_per_node, config):
    global best_acc1
    config.gpu = gpu

    if config.gpu is not None:
        print("Use GPU: {} for training".format(config.gpu))

    if config.distributed:
        if config.dist_url == "env://" and config.rank == -1:
            config.rank = int(os.environ["RANK"])
        if config.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            config.rank = config.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=config.dist_backend,
                                init_method=config.dist_url,
                                world_size=config.world_size,
                                rank=config.rank)
    # create model
    if config.pretrained:
        print("=> using pre-trained model '{}'".format(config.arch))

        model = models.__dict__[config.arch](pretrained=True)
        print(model)
        param_names = []
        module_names = []
        for name, W in model.named_modules():
            module_names.append(name)
        print(module_names)
        for name, W in model.named_parameters():
            param_names.append(name)
        print(param_names)
    else:
        print("=> creating model '{}'".format(config.arch))
        if config.arch == "alexnet_bn":
            model = AlexNet_BN()
            print(model)
            for i, (name, W) in enumerate(model.named_parameters()):
                print(name)
        else:
            model = models.__dict__[config.arch]()
            print(model)

    if config.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if config.gpu is not None:
            torch.cuda.set_device(config.gpu)
            model.cuda(config.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            config.batch_size = int(config.batch_size / ngpus_per_node)
            config.workers = int(config.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[config.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif config.gpu is not None:
        torch.cuda.set_device(config.gpu)
        model = model.cuda(config.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if config.arch.startswith('alexnet') or config.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:

            model = torch.nn.DataParallel(model).cuda()
    config.model = model
    # define loss function (criterion) and optimizer

    criterion = CrossEntropyLossMaybeSmooth(smooth_eps=config.smooth_eps).cuda(
        config.gpu)

    config.smooth = config.smooth_eps > 0.0
    config.mixup = config.alpha > 0.0

    # note that loading a pretrain model does not inherit optimizer info
    # will use resume to resume admm training
    if config.load_model:
        if os.path.isfile(config.load_model):
            if (config.gpu):
                model.load_state_dict(
                    torch.load(
                        config.load_model,
                        map_location={'cuda:0': 'cuda:{}'.format(config.gpu)}))
            else:
                model.load_state_dict(torch.load(config.load_model))
        else:
            print("=> no checkpoint found at '{}'".format(config.resume))

    config.prepare_pruning()

    nonzero = 0
    zero = 0
    for name, W in model.named_parameters():
        if name in config.conv_names:
            W = W.cpu().detach().numpy()
            zero += np.sum(W == 0)
            nonzero += np.sum(W != 0)
    total = nonzero + zero
    print('compression rate is {}'.format(total * 1.0 / nonzero))
    import sys
    sys.exit()

    # optionally resume from a checkpoint
    if config.resume:
        ## will add logic for loading admm variables
        if os.path.isfile(config.resume):
            print("=> loading checkpoint '{}'".format(config.resume))
            checkpoint = torch.load(config.resume)
            config.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']

            ADMM.ADMM_U = checkpoint['admm']['ADMM_U']
            ADMM.ADMM_Z = checkpoint['admm']['ADMM_Z']

            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                config.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(config.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(config.data, 'train')
    valdir = os.path.join(config.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if config.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=config.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=config.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=config.batch_size,
                                             shuffle=False,
                                             num_workers=config.workers,
                                             pin_memory=True)

    config.warmup = (not config.admm) and config.warmup_epochs > 0
    optimizer_init_lr = config.warmup_lr if config.warmup else config.lr

    optimizer = None
    if (config.optimizer == 'sgd'):
        optimizer = torch.optim.SGD(model.parameters(),
                                    optimizer_init_lr,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay)
    elif (config.optimizer == 'adam'):
        optimizer = torch.optim.Adam(model.parameters(), optimizer_init_lr)

    scheduler = None
    if config.lr_scheduler == 'cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         T_max=config.epochs *
                                                         len(train_loader),
                                                         eta_min=4e-08)
    elif config.lr_scheduler == 'default':
        # sets the learning rate to the initial LR decayed by gamma every 30 epochs"""
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=30 * len(train_loader),
                                              gamma=0.1)
    else:
        raise Exception("unknown lr scheduler")

    if config.warmup:
        scheduler = GradualWarmupScheduler(
            optimizer,
            multiplier=config.lr / config.warmup_lr,
            total_iter=config.warmup_epochs * len(train_loader),
            after_scheduler=scheduler)

    if False:
        validate(val_loader, criterion, config)
        return
    ADMM = None

    if config.verify:
        admm.masking(config)
        admm.test_sparsity(config)
        validate(val_loader, criterion, config)
        import sys
        sys.exit()
    if config.admm:
        ADMM = admm.ADMM(config)

    if config.masked_retrain:
        # make sure small weights are pruned and confirm the acc
        admm.masking(config)
        print("before retrain starts")
        admm.test_sparsity(config)
        validate(val_loader, criterion, config)
    if config.masked_progressive:
        admm.zero_masking(config)
    for epoch in range(config.start_epoch, config.epochs):
        if config.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch

        train(train_loader, config, ADMM, criterion, optimizer, scheduler,
              epoch)

        # evaluate on validation set
        acc1 = validate(val_loader, criterion, config)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if is_best and not config.admm:  # we don't need admm to have best validation acc
            print('saving new best model {}'.format(config.save_model))
            torch.save(model.state_dict(), config.save_model)

        if not config.multiprocessing_distributed or (
                config.multiprocessing_distributed
                and config.rank % ngpus_per_node == 0):
            save_checkpoint(
                config, {
                    'admm': {},
                    'epoch': epoch + 1,
                    'arch': config.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
    # save last model for admm, optimizer detail is not necessary
    if config.save_model and config.admm:
        print('saving model {}'.format(config.save_model))
        torch.save(model.state_dict(), config.save_model)
    if config.masked_retrain:
        print("after masked retrain")
        admm.test_sparsity(config)
예제 #4
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--config_file',
                        type=str,
                        default='',
                        help="config file")
    parser.add_argument('--stage',
                        type=str,
                        default='',
                        help="select the pruning stage")

    args = parser.parse_args()

    config = Config(args)

    use_cuda = True

    torch.manual_seed(1)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=64,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        '../data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=1000,
                                              shuffle=True,
                                              **kwargs)

    model = None
    if config.arch == 'lenet_bn':
        model = LeNet_BN().to(device)
    elif config.arch == 'lenet':
        model = LeNet().to(device)
    elif config.arch == 'lenet_adv':
        model = LeNet_adv(config.width_multiplier).to(device)
    torch.cuda.set_device(config.gpu)
    model.cuda(config.gpu)

    config.model = model

    ADMM = None

    config.prepare_pruning()

    if config.admm:
        ADMM = admm.ADMM(config)

    criterion = CrossEntropyLossMaybeSmooth(smooth_eps=config.smooth_eps).cuda(
        config.gpu)
    config.smooth = config.smooth_eps > 0.0
    config.mixup = config.alpha > 0.0

    config.warmup = (not config.admm) and config.warmup_epochs > 0
    optimizer_init_lr = config.warmup_lr if config.warmup else config.lr

    if (config.optimizer == 'sgd'):
        optimizer = torch.optim.SGD(config.model.parameters(),
                                    optimizer_init_lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif (config.optimizer == 'adam'):
        optimizer = torch.optim.Adam(config.model.parameters(),
                                     optimizer_init_lr)
    else:
        raise Exception("unknown optimizer")

    scheduler = None
    if config.lr_scheduler == 'cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         T_max=config.epochs *
                                                         len(train_loader),
                                                         eta_min=4e-08)
    elif config.lr_scheduler == 'default':
        pass
    else:
        raise Exception("unknown lr scheduler")

    if config.load_model:
        # unlike resume, load model does not care optimizer status or start_epoch
        print('==> Loading from {}'.format(config.load_model))
        config.model.load_state_dict(
            torch.load(config.load_model,
                       map_location={'cuda:0': 'cuda:{}'.format(config.gpu)}))
        test(config, device, test_loader)

    global best_acc
    if config.resume:
        if os.path.isfile(config.resume):
            checkpoint = torch.load(config.resume)
            config.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                config.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(config.resume))

    if config.masked_retrain:
        # make sure small weights are pruned and confirm the acc
        print("<============masking both weights and gradients for retrain")
        admm.masking(config)
        print("<============testing sparsity before retrain")
        admm.test_sparsity(config)
        test(config, device, test_loader)
    if config.masked_progressive:
        admm.zero_masking(config)

    for epoch in range(0, config.epochs + 1):

        train(config, ADMM, device, train_loader, criterion, optimizer,
              scheduler, epoch)
        test(config, device, test_loader)
        save_checkpoint(
            config, {
                'epoch': epoch + 1,
                'arch': config.arch,
                'state_dict': config.model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict()
            })

    print('overall  best_acc is {}'.format(best_acc))

    if (config.save_model and config.admm):
        print('saving model {}'.format(config.save_model))
        torch.save(config.model.state_dict(), config.save_model)
예제 #5
0
def train(hyp):
    # batch_time = AverageMeter()
    # data_time = AverageMeter()
    # losses = AverageMeter()

    cfg = opt.cfg
    data = opt.data
    epochs = opt.epochs  # 500200 batches at bs 64, 117263 images = 273 epochs
    batch_size = opt.batch_size
    accumulate = max(round(64 / batch_size), 1)  # accumulate n times before optimizer update (bs 64)
    weights = opt.weights  # initial training weights
    imgsz_min, imgsz_max, imgsz_test = opt.img_size  # img sizes (min, max, test)

    # Image Sizes
    gs = 32  # (pixels) grid size
    assert math.fmod(imgsz_min, gs) == 0, '--img-size %g must be a %g-multiple' % (imgsz_min, gs)
    opt.multi_scale |= imgsz_min != imgsz_max  # multi if different (min, max)
    if opt.multi_scale:
        if imgsz_min == imgsz_max:
            imgsz_min //= 1.5
            imgsz_max //= 0.667
        grid_min, grid_max = imgsz_min // gs, imgsz_max // gs
        imgsz_min, imgsz_max = int(grid_min * gs), int(grid_max * gs)
    img_size = imgsz_max  # initialize with max size

    # Configure run
    init_seeds()
    data_dict = parse_data_cfg(data)
    train_path = data_dict['train']
    test_path = data_dict['valid']
    nc = 1 if opt.single_cls else int(data_dict['classes'])  # number of classes
    hyp['cls'] *= nc / 80  # update coco-tuned hyp['cls'] to current dataset

    # Remove previous results
    for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
        os.remove(f)

    # Initialize model
    model = Darknet(cfg).to(device)

    # Optimizer

    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
    for k, v in dict(model.named_parameters()).items():
        if '.bias' in k:
            pg2 += [v]  # biases
        elif 'Conv2d.weight' in k:
            pg1 += [v]  # apply weight_decay
        else:
            pg0 += [v]  # all else

    if opt.adam:
        # hyp['lr0'] *= 0.1  # reduce lr (i.e. SGD=5E-3, Adam=5E-4)
        optimizer = optim.Adam(pg0, lr=hyp['lr0'])
        # optimizer = AdaBound(pg0, lr=hyp['lr0'], final_lr=0.1)
    else:
        optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
    optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']})  # add pg1 with weight_decay
    optimizer.add_param_group({'params': pg2})  # add pg2 (biases)
    print('Optimizer groups: %g .bias, %g Conv2d.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
    del pg0, pg1, pg2

    start_epoch = 0
    best_fitness = 0.0
    # attempt_download(weights)

    
    if opt.freeze_layers:                                                                                                                                                            
        output_layer_indices = [idx - 1 for idx, module in enumerate(model.module_list) if isinstance(module, YOLOLayer)]                                                                                                                      
        freeze_layer_indices = [x for x in range(len(model.module_list)) if                                                                                                         
                                (x not in output_layer_indices) and                                                                                                               
                                (x - 1 not in output_layer_indices)]                                                                                                                 
        for idx in freeze_layer_indices:                                                                                                                                             
            for parameter in model.module_list[idx].parameters():                                                                                                                    
                parameter.requires_grad_(False)                                                                                                                                      


    # Mixed precision training https://github.com/NVIDIA/apex
    if mixed_precision:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)

    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.95 + 0.05  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    scheduler.last_epoch = start_epoch - 1  # see link below
    # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822

    # Plot lr schedule
    # y = []
    # for _ in range(epochs):
    #     scheduler.step()
    #     y.append(optimizer.param_groups[0]['lr'])
    # plt.plot(y, '.-', label='LambdaLR')
    # plt.xlabel('epoch')
    # plt.ylabel('LR')
    # plt.tight_layout()
    # plt.savefig('LR.png', dpi=300)

    # Dataset
    dataset = LoadImagesAndLabels(train_path, img_size, batch_size,
                                  augment=True,
                                  hyp=hyp,  # augmentation hyperparameters
                                  rect=opt.rect,  # rectangular training
                                  cache_images=opt.cache_images,
                                  single_cls=opt.single_cls)

    # Dataloader
    batch_size = min(batch_size, len(dataset))
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             num_workers=nw,
                                             shuffle=not opt.rect,
                                             # Shuffle=True unless rectangular training is used
                                             pin_memory=True,
                                             collate_fn=dataset.collate_fn)

    # Testloader
    testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(test_path, imgsz_test, batch_size,
                                                                 hyp=hyp,
                                                                 rect=True,
                                                                 cache_images=opt.cache_images,
                                                                 single_cls=opt.single_cls),
                                             batch_size=batch_size,
                                             num_workers=nw,
                                             pin_memory=True,
                                             collate_fn=dataset.collate_fn)

    initial_rho = opt.rho
    t0 = time.time()
    """====================="""
    """ multi-rho admm train"""
    """====================="""
    if opt.admm:
        opt.notest = True
        # possible weights are '*.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc.
        chkpt = torch.load(weights, map_location=device)

        # load model
        try:
            # chkpt['model'] = {k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel()}
            model.load_state_dict(chkpt['model'], strict=False)
        except Exception as e:
            s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s. " \
                "See https://github.com/ultralytics/yolov3/issues/657" % (opt.weights, opt.cfg, opt.weights)
            print(e)
            raise KeyError(s) from e

        del chkpt

        # Initialize distributed training
        if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
            dist.init_process_group(backend='nccl',  # 'distributed backend'
                                    init_method='tcp://127.0.0.1:9999',  # distributed training init method
                                    world_size=1,  # number of nodes for distributed training
                                    rank=0)  # distributed training node rank
            model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
            model.yolo_layers = model.module.yolo_layers  # move yolo layer indices to top level


        # Model parameters
        model.nc = nc  # attach number of classes to model
        model.hyp = hyp  # attach hyperparameters to model
        model.gr = 1.0  # giou loss ratio (obj_loss = 1.0 or giou)
        model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device)  # attach class weights

        # Model EMA
        ema = torch_utils.ModelEMA(model)

        # Start training
        nb = len(dataloader)  # number of batches
        n_burn = max(int(0.7 * nb), 500)  # burn-in iterations, max(0.7 epochs, 500 iterations)
        maps = np.zeros(nc)  # mAP per class
        # torch.autograd.set_detect_anomaly(True)
        results = (0, 0, 0, 0, 0, 0, 0)  # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'

        print('Image sizes %g - %g train, %g test' % (imgsz_min, imgsz_max, imgsz_test))
        print('Using %g dataloader workers' % nw)
        print('Starting training for %g epochs...' % epochs)



        for i in range(opt.rho_num):
            current_rho = initial_rho * 10 ** i
            ADMM = admm.ADMM(model, file_name="./prune_config/" + opt.config_file + ".yaml", rho=current_rho)
            admm.admm_initialization(opt, ADMM=ADMM, model=model)  # intialize Z variable

            for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------
                print("current rho: {}".format(current_rho))

                model.train()
                masks = {}
                if opt.masked_retrain and not opt.combine_progressive:
                    print("full acc re-train masking")

                    for name, W in (model.module.named_parameters() if type(
                            model) is torch.nn.parallel.DistributedDataParallel else model.named_parameters()):
                        if name not in ADMM.prune_ratios:
                            continue
                        above_threshold, W = admm.weight_pruning(opt, W, ADMM.prune_ratios[name])
                        W.data = W
                        masks[name] = above_threshold
                elif opt.combine_progressive:
                    print("progressive admm-train/re-train masking")
                    for name, W in (model.module.named_parameters() if type(
                            model) is torch.nn.parallel.DistributedDataParallel else model.named_parameters()):
                        weight = W.cpu().detach().numpy()
                        non_zeros = weight != 0
                        non_zeros = non_zeros.astype(np.float32)
                        zero_mask = torch.from_numpy(non_zeros).cuda()
                        W = torch.from_numpy(weight).cuda()
                        W.data = W
                        masks[name] = zero_mask

                # Update image weights (optional)
                if dataset.image_weights:
                    w = model.class_weights.cpu().numpy() * (1 - maps) ** 2  # class weights
                    image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
                    dataset.indices = random.choices(range(dataset.n), weights=image_weights, k=dataset.n)  # rand weighted idx

                mloss = torch.zeros(4).to(device)  # mean losses
                print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
                pbar = tqdm(enumerate(dataloader), total=nb)  # progress bar
                for i, (imgs, targets, paths, _) in pbar:  # batch -------------------------------------------------------------

                    ni = i + nb * epoch  # number integrated batches (since train start)
                    imgs = imgs.to(device).float() / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
                    targets = targets.to(device)

                    # Burn-in
                    if ni <= n_burn:
                        xi = [0, n_burn]  # x interp
                        model.gr = np.interp(ni, xi, [0.0, 1.0])  # giou loss ratio (obj_loss = 1.0 or giou)
                        accumulate = max(1, np.interp(ni, xi, [1, 64 / batch_size]).round())
                        for j, x in enumerate(optimizer.param_groups):
                            # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                            x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
                            x['weight_decay'] = np.interp(ni, xi, [0.0, hyp['weight_decay'] if j == 1 else 0.0])
                            if 'momentum' in x:
                                x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])

                    # Multi-Scale
                    if opt.multi_scale:
                        if ni / accumulate % 1 == 0:  #  adjust img_size (67% - 150%) every 1 batch
                            img_size = random.randrange(grid_min, grid_max + 1) * gs
                        sf = img_size / max(imgs.shape[2:])  # scale factor
                        if sf != 1:
                            ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]]  # new shape (stretched to 32-multiple)
                            imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

                    # Forward
                    pred = model(imgs)

                    # Loss
                    loss, loss_items = compute_loss(pred, targets, model)
                    if not torch.isfinite(loss):
                        print('WARNING: non-finite loss, ending training ', loss_items)
                        return results

                    # Backward
                    loss *= batch_size / 64  # scale loss


                    admm.z_u_update(opt, ADMM, model, device, dataloader, optimizer, epoch, imgs, i,
                                        tb_writer)  # update Z and U variables
                    loss, admm_loss, mixed_loss = admm.append_admm_loss(opt, ADMM, model,
                                                                            loss)  # append admm losss

                    if mixed_precision:
                        with amp.scale_loss(mixed_loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        mixed_loss.backward()

                    if opt.combine_progressive:
                        with torch.no_grad():
                            for name, W in (model.module.named_parameters() if type(
                                    model) is torch.nn.parallel.DistributedDataParallel else model.named_parameters()):
                                if name in masks:
                                    W.grad *= masks[name]

                    # Optimize
                    if ni % accumulate == 0:
                        optimizer.step()
                        optimizer.zero_grad()
                        ema.update(model)

                    # Print
                    mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses
                    mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0)  # (GB)
                    s = ('%10s' * 2 + '%10.3g' * 6) % ('%g/%g' % (epoch, epochs - 1), mem, *mloss, len(targets), img_size)
                    pbar.set_description(s)

                    # Plot
                    # if ni < 1:
                    #     f = 'train_batch%g.jpg' % i  # filename
                        # res = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
                        # if tb_writer:
                        #     tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch)
                        #     # tb_writer.add_graph(model, imgs)  # add model to tensorboard

                    # end batch ------------------------------------------------------------------------------------------------

                # Update scheduler
                if opt.admm:
                    admm.admm_adjust_learning_rate(optimizer, epoch, opt)
                else:
                    scheduler.step()

                # Process epoch results
                ema.update_attr(model)
                final_epoch = epoch + 1 == epochs
                if not opt.notest:  # Calculate mAP  #or final_epoch
                    is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
                    results, maps = test.test(cfg,
                                              data,
                                              batch_size=batch_size,
                                              imgsz=imgsz_test,
                                              model=ema.ema,
                                              save_json=final_epoch and is_coco,
                                              single_cls=opt.single_cls,
                                              dataloader=testloader,
                                              multi_label=ni > n_burn)

                # Write
                with open(results_file, 'a') as f:
                    f.write(s + '%10.3g' * 7 % results + '\n')  # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
                if len(opt.name) and opt.bucket:
                    os.system('gsutil cp results.txt gs://%s/results/results%s.txt' % (opt.bucket, opt.name))

                # Tensorboard
                if tb_writer:
                    tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
                            'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/F1',
                            'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
                    for x, tag in zip(list(mloss[:-1]) + list(results), tags):
                        tb_writer.add_scalar(tag, x, epoch)

                # Update best mAP
                fi = fitness(np.array(results).reshape(1, -1))  # fitness_i = weighted combination of [P, R, mAP, F1]
                if fi > best_fitness:
                    best_fitness = fi

                # end epoch ----------------------------------------------------------------------------------------------------
            # end training

            # admm_adjust_learning_rate ----------------------------------------------------------------------------------------------------
            admm.admm_adjust_learning_rate(optimizer, epoch, opt)
            # end admm_adjust_learning_rate ----------------------------------------------------------------------------------------------------

            print("Saving model.")
            torch.save(
                model.module.state_dict() if type(model) is nn.parallel.DistributedDataParallel else model.state_dict(),
                "./model_pruned/yolov4_{}_{}_{}.pt".format(
                    current_rho, opt.config_file, opt.sparsity_type))

        if not opt.evolve:
            plot_results()  # save as results.png
        print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
        # dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
        # torch.cuda.empty_cache()
        # return results


    """=============="""
    """masked retrain"""
    """=============="""
    if opt.masked_retrain:
        ADMM = admm.ADMM(model, file_name="./prune_config/" + opt.config_file + ".yaml", rho=initial_rho)
        if not opt.resume:
            # possible weights are '*.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc.
            print("\n>_ Loading file: ./model_pruned/yolov4_{}_{}_{}.pt".format(initial_rho * 10 ** (opt.rho_num - 1), opt.config_file, opt.sparsity_type))
            chkpt = torch.load("./model_pruned/yolov4_{}_{}_{}.pt".format(initial_rho * 10 ** (opt.rho_num - 1), opt.config_file, opt.sparsity_type), map_location=device)
            # chkpt = torch.load(weights, map_location=device)
            # load model
            try:
                # chkpt['model'] = {k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel()}
                model.load_state_dict(chkpt, strict=False) #['model']

            except KeyError as e:
                # s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s. " \
                #     "See https://github.com/ultralytics/yolov3/issues/657" % (opt.weights, opt.cfg, opt.weights)
                raise KeyError() from e
            #----------------------------------------------hard prune------------------------------------------------
            admm.hard_prune(opt, ADMM, model)
            #----------------------------------------------hard prune------------------------------------------------
        else:
            try:
                chkpt = torch.load(weights, map_location=device)
                chkpt['model'] = {k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel()}
                model.load_state_dict(chkpt['model'], strict=False)
            except KeyError as e:
                # s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s. " \
                #     "See https://github.com/ultralytics/yolov3/issues/657" % (opt.weights, opt.cfg, opt.weights)
                raise KeyError() from e
            # load optimizer
            if chkpt['optimizer'] is not None:
                optimizer.load_state_dict(chkpt['optimizer'])
                best_fitness = chkpt['best_fitness']

            # load results
            if chkpt.get('training_results') is not None:
                with open(results_file, 'w') as file:
                    file.write(chkpt['training_results'])  # write results.txt

            start_epoch = chkpt['epoch'] + 1
        del chkpt

        # Initialize distributed training
        if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
            dist.init_process_group(backend='nccl',  # 'distributed backend'
                                    init_method='tcp://127.0.0.1:9999',  # distributed training init method
                                    world_size=1,  # number of nodes for distributed training
                                    rank=0)  # distributed training node rank
            model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
            model.yolo_layers = model.module.yolo_layers  # move yolo layer indices to top level

            # Model parameters
        model.nc = nc  # attach number of classes to model
        model.hyp = hyp  # attach hyperparameters to model
        model.gr = 1.0  # giou loss ratio (obj_loss = 1.0 or giou)
        model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device)  # attach class weights

        # Model EMA
        ema = torch_utils.ModelEMA(model)

        # Start training
        nb = len(dataloader)  # number of batches
        n_burn = max(3 * nb, 500)  # burn-in iterations, max(3 epochs, 500 iterations)
        maps = np.zeros(nc)  # mAP per class
        # torch.autograd.set_detect_anomaly(True)
        results = (0, 0, 0, 0, 0, 0, 0)  # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
        print('Image sizes %g - %g train, %g test' % (imgsz_min, imgsz_max, imgsz_test))
        print('Using %g dataloader workers' % nw)
        print('Starting training for %g epochs...' % epochs)
        for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------
            model.train()

            if opt.masked_retrain and not opt.combine_progressive:
                print("full acc re-train masking")
                masks = {}
                for name, W in (model.module.named_parameters() if type(
                        model) is torch.nn.parallel.DistributedDataParallel else model.named_parameters()):
                    if name not in ADMM.prune_ratios:
                        continue
                    above_threshold, W = admm.weight_pruning(opt, W, ADMM.prune_ratios[name])
                    W.data = W
                    masks[name] = above_threshold
            elif opt.combine_progressive:
                print("progressive admm-train/re-train masking")
                masks = {}
                for name, W in (model.module.named_parameters() if type(
                        model) is torch.nn.parallel.DistributedDataParallel else model.named_parameters()):
                    weight = W.cpu().detach().numpy()
                    non_zeros = weight != 0
                    non_zeros = non_zeros.astype(np.float32)
                    zero_mask = torch.from_numpy(non_zeros).cuda()
                    W = torch.from_numpy(weight).cuda()
                    W.data = W
                    masks[name] = zero_mask

            # Update image weights (optional)
            if dataset.image_weights:
                w = model.class_weights.cpu().numpy() * (1 - maps) ** 2  # class weights
                image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
                dataset.indices = random.choices(range(dataset.n), weights=image_weights,
                                                 k=dataset.n)  # rand weighted idx

            mloss = torch.zeros(4).to(device)  # mean losses
            print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
            pbar = tqdm(enumerate(dataloader), total=nb)  # progress bar
            for i, (imgs, targets, paths, _) in pbar:  # batch -------------------------------------------------------------
                ni = i + nb * epoch  # number integrated batches (since train start)
                imgs = imgs.to(device).float() / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
                targets = targets.to(device)

                # Burn-in
                if ni <= n_burn:
                    xi = [0, n_burn]  # x interp
                    model.gr = np.interp(ni, xi, [0.0, 1.0])  # giou loss ratio (obj_loss = 1.0 or giou)
                    accumulate = max(1, np.interp(ni, xi, [1, 64 / batch_size]).round())
                    for j, x in enumerate(optimizer.param_groups):
                        # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                        x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
                        x['weight_decay'] = np.interp(ni, xi, [0.0, hyp['weight_decay'] if j == 1 else 0.0])
                        if 'momentum' in x:
                            x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])

                # Multi-Scale
                if opt.multi_scale:
                    if ni / accumulate % 1 == 0:  # adjust img_size (67% - 150%) every 1 batch
                        img_size = random.randrange(grid_min, grid_max + 1) * gs
                    sf = img_size / max(imgs.shape[2:])  # scale factor
                    if sf != 1:
                        ns = [math.ceil(x * sf / gs) * gs for x in
                              imgs.shape[2:]]  # new shape (stretched to 32-multiple)
                        imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)

                # Forward
                pred = model(imgs)

                # Loss
                loss, loss_items = compute_loss(pred, targets, model)
                if not torch.isfinite(loss):
                    print('WARNING: non-finite loss, ending training ', loss_items)
                    return results

                # Backward
                loss *= batch_size / 64  # scale loss
                if mixed_precision:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if opt.combine_progressive:
                    with torch.no_grad():
                        for name, W in (model.module.named_parameters() if type(
                                model) is torch.nn.parallel.DistributedDataParallel else model.named_parameters()):
                            if name in masks:
                                W.grad *= masks[name]
                if opt.masked_retrain:
                    with torch.no_grad():
                        for name, W in (model.module.named_parameters() if type(
                                model) is torch.nn.parallel.DistributedDataParallel else model.named_parameters()):
                            if name in masks:
                                W.grad *= masks[name]

                # Optimize
                if ni % accumulate == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    ema.update(model)

                # Print
                mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses
                mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0)  # (GB)
                s = ('%10s' * 2 + '%10.3g' * 6) % (
                '%g/%g' % (epoch, epochs - 1), mem, *mloss, len(targets), img_size)
                pbar.set_description(s)

                # Plot
                if ni < 1:
                    f = 'train_batch%g.jpg' % i  # filename
                    res = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
                    if tb_writer:
                        tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch)
                        # tb_writer.add_graph(model, imgs)  # add model to tensorboard

                # end batch ------------------------------------------------------------------------------------------------

            # Update scheduler
            scheduler.step()

            # Process epoch results
            ema.update_attr(model)
            final_epoch = epoch + 1 == epochs
            if not opt.notest or final_epoch:  # Calculate mAP
                is_coco = any(
                    [x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
                results, maps = test.test(cfg,
                                          data,
                                          batch_size=batch_size,
                                          imgsz=imgsz_test,
                                          model=ema.ema,
                                          save_json=final_epoch and is_coco,
                                          single_cls=opt.single_cls,
                                          dataloader=testloader,
                                          multi_label=ni > n_burn)

            # Write
            with open(results_file, 'a') as f:
                f.write(s + '%10.3g' * 7 % results + '\n')  # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
            if len(opt.name) and opt.bucket:
                os.system('gsutil cp results.txt gs://%s/results/results%s.txt' % (opt.bucket, opt.name))

            # Tensorboard
            if tb_writer:
                tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
                        'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/F1',
                        'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
                for x, tag in zip(list(mloss[:-1]) + list(results), tags):
                    tb_writer.add_scalar(tag, x, epoch)

            # Update best mAP
            fi = fitness(np.array(results).reshape(1, -1))  # fitness_i = weighted combination of [P, R, mAP, F1]
            if fi > best_fitness:  #results[2]
                best_fitness = fi  #results[2]
                print("\n>_ Got better accuracy {:.3f}% now...\n".format(results[2]))
                # torch.save(ema.ema.module.state_dict() if hasattr(model, 'module') else ema.ema.state_dict(),
                #            "./model_retrained/yolov4_retrained_acc_{:.3f}_{}rhos_{}_{}.pt".format(results[2], opt.rho_num, opt.config_file, opt.sparsity_type))

            # Save model
            save = (not opt.nosave) or (final_epoch and not opt.evolve)
            if save:
                with open(results_file, 'r') as f:  # create checkpoint
                    chkpt = {'epoch': epoch,
                             'best_fitness': best_fitness,
                             'training_results': f.read(),
                             'model': ema.ema.module.state_dict() if hasattr(model,
                                                                             'module') else ema.ema.state_dict(),
                             'optimizer': None if final_epoch else optimizer.state_dict()}

                # Save last, best and delete
                torch.save(chkpt, last)
                if (best_fitness == fi) and not final_epoch:
                    torch.save(chkpt, best)
                del chkpt

            # end epoch ----------------------------------------------------------------------------------------------------
        # end training

        test_sparsity(model)
        print("Best Acc: {:.4f}".format(results[2]))
        n = opt.name
        if len(n):
            n = '_' + n if not n.isnumeric() else n
            fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
            for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
                if os.path.exists(f1):
                    os.rename(f1, f2)  # rename
                    ispt = f2.endswith('.pt')  # is *.pt
                    strip_optimizer(f2) if ispt else None  # strip optimizer
                    os.system('gsutil cp %s gs://%s/weights' % (
                    f2, opt.bucket)) if opt.bucket and ispt else None  # upload

        if not opt.evolve:
            plot_results()  # save as results.png
        print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
        # dist.destroy_process_group() if torch.cuda.device_count() > 1 else None
        # torch.cuda.empty_cache()
        return results
            config.model = torch.nn.DataParallel(model)
        cudnn.benchmark = True

    if config.load_model:
        # unlike resume, load model does not care optimizer status or start_epoch
        print('==> Loading from {}'.format(config.load_model))

        config.model.load_state_dict(torch.load(
            config.load_model))  # i call 'net' "model"

    config.prepare_pruning()  # take the model and prepare the pruning

    ADMM = None

    if config.admm:
        ADMM = admm.ADMM(config)

    if config.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir(
            'checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/ckpt.t7')
        config.model.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']
        ADMM.ADMM_U = checkpoint['admm']['ADMM_U']
        ADMM.ADMM_Z = checkpoint['admm']['ADMM_Z']

    criterion = CrossEntropyLossMaybeSmooth(smooth_eps=config.smooth_eps).cuda(
        config.gpu)
예제 #7
0
def run_admm(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,cfg_file,processed_first,next_config_file,ADMM,masks,ep,ck):

    # This function processes the current chunk using the information in cfg_file. In parallel, the next chunk is load into the CPU memory

    # Reading chunk-specific cfg file (first argument-mandatory file)
    if not(os.path.exists(cfg_file)):
         sys.stderr.write('ERROR: The config file %s does not exist!\n'%(cfg_file))
         sys.exit(0)
    else:
        config = configparser.ConfigParser()
        config.read(cfg_file)

    # Setting torch seed
    seed=int(config['exp']['seed'])
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)


    # Reading config parameters
    output_folder=config['exp']['out_folder']
    multi_gpu=strtobool(config['exp']['multi_gpu'])

    to_do=config['exp']['to_do']
    info_file=config['exp']['out_info']

    model=config['model']['model'].split('\n')

    forward_outs=config['forward']['forward_out'].split(',')
    forward_normalize_post=list(map(strtobool,config['forward']['normalize_posteriors'].split(',')))
    forward_count_files=config['forward']['normalize_with_counts_from'].split(',')
    require_decodings=list(map(strtobool,config['forward']['require_decoding'].split(',')))

    use_cuda=strtobool(config['exp']['use_cuda'])
    save_gpumem=strtobool(config['exp']['save_gpumem'])
    is_production=strtobool(config['exp']['production'])

    if to_do=='train':
        batch_size=int(config['batches']['batch_size_train'])

    if to_do=='valid':
        batch_size=int(config['batches']['batch_size_valid'])

    if to_do=='forward':
        batch_size=1


    # ***** Reading the Data********
    if processed_first:  # admm初始化的工作,咱们都在这儿做了吧

        # Reading all the features and labels for this chunk
        shared_list=[]

        p=threading.Thread(target=read_lab_fea, args=(cfg_file,is_production,shared_list,output_folder,))
        p.start()
        p.join()

        data_name=shared_list[0]
        data_end_index=shared_list[1]
        fea_dict=shared_list[2]
        lab_dict=shared_list[3]
        arch_dict=shared_list[4]
        data_set=shared_list[5]



        # converting numpy tensors into pytorch tensors and put them on GPUs if specified
        if not(save_gpumem) and use_cuda:
           data_set=torch.from_numpy(data_set).float().cuda()
        else:
           data_set=torch.from_numpy(data_set).float()




    # Reading all the features and labels for the next chunk
    shared_list=[]
    p=threading.Thread(target=read_lab_fea, args=(next_config_file,is_production,shared_list,output_folder,))
    p.start()

    # Reading model and initialize networks
    inp_out_dict=fea_dict

    [nns,costs]=model_init(inp_out_dict,model,config,arch_dict,use_cuda,multi_gpu,to_do)

    if processed_first:
        ADMM = admm.ADMM(config, nns)

    # optimizers initialization
    optimizers=optimizer_init(nns,config,arch_dict)


    # pre-training and multi-gpu init
    for net in nns.keys():
        pt_file_arch=config[arch_dict[net][0]]['arch_pretrain_file']

        if pt_file_arch!='none':
            checkpoint_load = torch.load(pt_file_arch)
            nns[net].load_state_dict(checkpoint_load['model_par'])
            optimizers[net].load_state_dict(checkpoint_load['optimizer_par'])
            optimizers[net].param_groups[0]['lr']=float(config[arch_dict[net][0]]['arch_lr']) # loading lr of the cfg file for pt

        if multi_gpu:
            nns[net] = torch.nn.DataParallel(nns[net])


    if to_do=='forward':

        post_file={}
        for out_id in range(len(forward_outs)):
            if require_decodings[out_id]:
                out_file=info_file.replace('.info','_'+forward_outs[out_id]+'_to_decode.ark')
            else:
                out_file=info_file.replace('.info','_'+forward_outs[out_id]+'.ark')
            post_file[forward_outs[out_id]]=open_or_fd(out_file,output_folder,'wb')


    if strtobool(config['exp']['retrain']) and processed_first and strtobool(config['exp']['masked_progressive']):
        # make sure small weights are pruned and confirm the acc
        print ("<============masking both weights and gradients for retrain")
        masks = admm.masking(config, ADMM, nns)
        print("<============all masking statistics")
        masks = admm.zero_masking(config, nns)
        print ("<============testing sparsity before retrain")
        admm.test_sparsity(config, nns, ADMM)


    if strtobool(config['exp']['masked_progressive']) and processed_first and strtobool(config['exp']['admm']):
        masks = admm.zero_masking(config, nns)


    # check automatically if the model is sequential
    seq_model=is_sequential_dict(config,arch_dict)

    # ***** Minibatch Processing loop********
    if seq_model or to_do=='forward':
        N_snt=len(data_name)
        N_batches=int(N_snt/batch_size)
    else:
        N_ex_tr=data_set.shape[0]
        N_batches=int(N_ex_tr/batch_size)


    beg_batch=0
    end_batch=batch_size

    snt_index=0
    beg_snt=0


    start_time = time.time()

    # array of sentence lengths
    arr_snt_len=shift(shift(data_end_index, -1,0)-data_end_index,1,0)
    arr_snt_len[0]=data_end_index[0]


    loss_sum=0
    err_sum=0

    inp_dim=data_set.shape[1]
    for i in range(N_batches):

        max_len=0

        if seq_model:

         max_len=int(max(arr_snt_len[snt_index:snt_index+batch_size]))
         inp= torch.zeros(max_len,batch_size,inp_dim).contiguous()


         for k in range(batch_size):

                  snt_len=data_end_index[snt_index]-beg_snt
                  N_zeros=max_len-snt_len

                  # Appending a random number of initial zeros, tge others are at the end.
                  N_zeros_left=random.randint(0,N_zeros)

                  # randomizing could have a regularization effect
                  inp[N_zeros_left:N_zeros_left+snt_len,k,:]=data_set[beg_snt:beg_snt+snt_len,:]

                  beg_snt=data_end_index[snt_index]
                  snt_index=snt_index+1

        else:
            # features and labels for batch i
            if to_do!='forward':
                inp= data_set[beg_batch:end_batch,:].contiguous()
            else:
                snt_len=data_end_index[snt_index]-beg_snt
                inp= data_set[beg_snt:beg_snt+snt_len,:].contiguous()
                beg_snt=data_end_index[snt_index]
                snt_index=snt_index+1

        # use cuda
        if use_cuda:
            inp=inp.cuda()

        if to_do=='train':
            # Forward input, with autograd graph active
            outs_dict=forward_model(fea_dict,lab_dict,arch_dict,model,nns,costs,inp,inp_out_dict,max_len,batch_size,to_do,forward_outs)

            if strtobool(config['exp']['admm']):
                batch_idx = i + ck
                admm.admm_update(config,ADMM,nns, ep,batch_idx)   # update Z and U
                outs_dict['loss_final'],admm_loss,mixed_loss = admm.append_admm_loss(config,ADMM,nns,outs_dict['loss_final']) # append admm losss

            for opt in optimizers.keys():
                optimizers[opt].zero_grad()

            if strtobool(config['exp']['admm']):
                mixed_loss.backward()
            else:
                outs_dict['loss_final'].backward()

            if strtobool(config['exp']['masked_progressive']) and not strtobool(config['exp']['retrain']):
                with torch.no_grad():
                    for net in nns.keys():
                        for name, W in nns[net].named_parameters():
                            if name in masks:
                                W.grad *=masks[name]
                        break

            if strtobool(config['exp']['retrain']):
                with torch.no_grad():
                    for net in nns.keys():
                        for name, W in nns[net].named_parameters():
                            if name in masks:
                                W.grad *=masks[name]
                        break

            # Gradient Clipping (th 0.1)
            #for net in nns.keys():
            #    torch.nn.utils.clip_grad_norm_(nns[net].parameters(), 0.1)


            for opt in optimizers.keys():
                if not(strtobool(config[arch_dict[opt][0]]['arch_freeze'])):
                    optimizers[opt].step()
        else:
            with torch.no_grad(): # Forward input without autograd graph (save memory)
                outs_dict=forward_model(fea_dict,lab_dict,arch_dict,model,nns,costs,inp,inp_out_dict,max_len,batch_size,to_do,forward_outs)


        if to_do=='forward':
            for out_id in range(len(forward_outs)):

                out_save=outs_dict[forward_outs[out_id]].data.cpu().numpy()

                if forward_normalize_post[out_id]:
                    # read the config file
                    counts = load_counts(forward_count_files[out_id])
                    out_save=out_save-np.log(counts/np.sum(counts))

                # save the output
                write_mat(output_folder,post_file[forward_outs[out_id]], out_save, data_name[i])
        else:
            loss_sum=loss_sum+outs_dict['loss_final'].detach()
            err_sum=err_sum+outs_dict['err_final'].detach()

        # update it to the next batch
        beg_batch=end_batch
        end_batch=beg_batch+batch_size

        # Progress bar
        if to_do == 'train':
          status_string="Training | (Batch "+str(i+1)+"/"+str(N_batches)+")"+" | L:" +str(round(loss_sum.cpu().item()/(i+1),3))
          if i==N_batches-1:
             status_string="Training | (Batch "+str(i+1)+"/"+str(N_batches)+")"


        if to_do == 'valid':
          status_string="Validating | (Batch "+str(i+1)+"/"+str(N_batches)+")"
        if to_do == 'forward':
          status_string="Forwarding | (Batch "+str(i+1)+"/"+str(N_batches)+")"

        progress(i, N_batches, status=status_string)

    elapsed_time_chunk=time.time() - start_time

    loss_tot=loss_sum/N_batches
    err_tot=err_sum/N_batches

    # clearing memory
    del inp, outs_dict, data_set

    # save the model
    if to_do=='train':


         for net in nns.keys():
             checkpoint={}
             if multi_gpu:
                checkpoint['model_par']=nns[net].module.state_dict()
             else:
                checkpoint['model_par']=nns[net].state_dict()

             checkpoint['optimizer_par']=optimizers[net].state_dict()

             out_file=info_file.replace('.info','_'+arch_dict[net][0]+'.pkl')
             torch.save(checkpoint, out_file)

    if to_do=='forward':
        for out_name in forward_outs:
            post_file[out_name].close()



    # Write info file
    with open(info_file, "w") as text_file:
        text_file.write("[results]\n")
        if to_do!='forward':
            text_file.write("loss=%s\n" % loss_tot.cpu().numpy())
            text_file.write("err=%s\n" % err_tot.cpu().numpy())
        text_file.write("elapsed_time_chunk=%f\n" % elapsed_time_chunk)

    text_file.close()


    # Getting the data for the next chunk (read in parallel)
    p.join()
    data_name=shared_list[0]
    data_end_index=shared_list[1]
    fea_dict=shared_list[2]
    lab_dict=shared_list[3]
    arch_dict=shared_list[4]
    data_set=shared_list[5]


    # converting numpy tensors into pytorch tensors and put them on GPUs if specified
    if not(save_gpumem) and use_cuda:
       data_set=torch.from_numpy(data_set).float().cuda()
    else:
       data_set=torch.from_numpy(data_set).float()


    return [data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,masks,ADMM]
예제 #8
0
    else:
        config.load_model = config.load_model.replace('w', str(config.w))
        prune_alpha = config._prune_ratios['conv1.weight']
        config.load_model = f"{config.load_model.split('.pt')[0]}_{prune_alpha}.pt"
        config.save_model = f"{config.save_model.split('.pt')[0]}_{prune_alpha}.pt"
    print('==> Loading from {}'.format(config.load_model))

    config.model.load_state_dict(torch.load(
        config.load_model))  # i call 'net' "model"

config.prepare_pruning()  # take the model and prepare the pruning

ADMM = None

if config.admm:
    ADMM = admm.ADMM(config, device)

if config.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.t7')
    config.model.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
    ADMM.ADMM_U = checkpoint['admm']['ADMM_U']
    ADMM.ADMM_Z = checkpoint['admm']['ADMM_Z']

criterion = CrossEntropyLossMaybeSmooth(smooth_eps=config.smooth_eps).cuda(
    config.gpu)
config.smooth = config.smooth_eps > 0.0
예제 #9
0
파일: bec.py 프로젝트: rodsveiga/decoders
 def __init__(self, p, _code, **kwargs):
     super().__init__(admm.ADMM(_code.parity_mtx, **kwargs))
예제 #10
0
    model = torch.load(args.pretrained, map_location=device)

criterion = nn.CrossEntropyLoss()

ADMM = None
config = None
if args.admm or args.masked_retrain:                
    config = admm.Config(args, model)
    print(config.prune_ratios)
    for name,_ in model.named_parameters():
        if name in config.prune_ratios:
            print('{} will be pruned'.format(name))
        else:
            print('{} willnot be pruned'.format(name))
if args.admm:
    ADMM = admm.ADMM(model, config)
    admm.admm_initialization(args, ADMM, model)  # intialize Z, U variable   

###############################################################################
# Training code
###############################################################################

def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)


# get_batch subdivides the source data into chunks of length args.bptt.
예제 #11
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--config_file', type=str, default='', help ="config file")
    parser.add_argument('--stage', type=str, default='', help ="select the pruning stage")

    
    args = parser.parse_args()

    config = Config(args)
    
    use_cuda = True


    init = Init_Func(config.init_func)
    

    torch.manual_seed(config.random_seed)
    
    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor()
                           #transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=64, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor()
                           #transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=1000, shuffle=True, **kwargs)


    

    model = None
    if config.arch == 'lenet_bn':
        model = LeNet_BN().to(device)
    elif config.arch == 'lenet':
        model = LeNet().to(device)
    elif config.arch == 'lenet_adv':
        model = LeNet_adv(w=config.width_multiplier).to(device)
    if config.arch not in model_names:
        raise Exception("unknown model architecture")

    ### for initialization experiments
    
    for name,W in model.named_parameters():
        if 'conv' in name and 'bias' not in name:
            print ('initialization uniform')        
            #W.data = torch.nn.init.uniform_(W.data)
            W.data = init.init(W.data)
    model = AttackPGD(model,config)
    #### loading initialization
    '''
    ### for lottery tickets experiments
    read_dict = np.load('lenet_adv_retrained_w16_1_cut.pt_init.npy').item()
    for name,W in model.named_parameters():
        if name not in read_dict:
            continue
        print (name)

        #print ('{} has shape {}'.format(name,read_dict[name].shape))
        print (read_dict[name].shape)
        W.data = torch.from_numpy(read_dict[name])
    '''
    config.model = model



    
    if config.load_model:
        # unlike resume, load model does not care optimizer status or start_epoch
        print('==> Loading from {}'.format(config.load_model))
        config.model.load_state_dict(torch.load(config.load_model, map_location=lambda storage, loc: storage))
        #config.model.load_state_dict(torch.load(config.load_model,map_location = {'cuda:0':'cuda:{}'.format(config.gpu)}))
                

    torch.cuda.set_device(config.gpu)
    config.model.cuda(config.gpu)
    test(config,  device, test_loader)    
    ADMM = None

    config.prepare_pruning()
    
    if config.admm:
        ADMM = admm.ADMM(config)

    optimizer = None
    if (config.optimizer == 'sgd'):
        optimizer = torch.optim.SGD(config.model.parameters(), config.lr,
                                momentum=0.9,
                                    weight_decay=1e-6)

    elif (config.optimizer =='adam'):
        optimizer = torch.optim.Adam(config.model.parameters(),config.lr)    

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs*len(train_loader),eta_min=4e-08)

        
        
                      
    if config.resume:
        if os.path.isfile(config.resume):
            checkpoint = torch.load(config.resume)
            config.start_epoch = checkpoint['epoch']
            best_adv_acc = checkpoint['best_adv_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(config.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(config.resume))            

        
    if config.masked_retrain:
        # make sure small weights are pruned and confirm the acc
        print ("<============masking both weights and gradients for retrain")    
        admm.masking(config)

        print ("<============testing sparsity before retrain")
        admm.test_sparsity(config)
        test(config,  device, test_loader)        
    if config.masked_progressive:
        admm.zero_masking(config)

        
    for epoch in range(0, config.epochs+1):

        if config.admm:
            admm.admm_adjust_learning_rate(optimizer, epoch, config)
        else:
            if config.lr_scheduler == 'cosine':
                scheduler.step()
            elif config.lr_scheduler == 'sgd':
                if epoch == 20:
                    config.lr/=10
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = config.lr
            else:
                pass # it uses adam
            
        train(config,ADMM,device, train_loader, optimizer, epoch)
        test(config, device, test_loader)
        

    admm.test_sparsity(config)
    test(config,  device, test_loader)    
    if config.save_model and config.admm:
        print ('saving model {}'.format(config.save_model))
        torch.save(config.model.state_dict(),config.save_model)
예제 #12
0
def masked_retrain(args, pre_mask, task, train_loader):
    """ 
    bag of tricks set-ups
    """
    initial_rho = args.rho
    criterion = CrossEntropyLossMaybeSmooth(smooth_eps=args.smooth_eps).cuda()
    args.smooth = args.smooth_eps > 0.0
    args.mixup = args.alpha > 0.0

    optimizer_init_lr = args.warmup_lr if args.warmup else args.lr
    optimizer = None
    if args.optmzr == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    optimizer_init_lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.optmzr == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), optimizer_init_lr)
    '''
    Set learning rate
    '''
    scheduler = None
    if args.lr_scheduler == 'cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=args.epochs_mask_retrain * len(train_loader),
            eta_min=4e-08)
    elif args.lr_scheduler == 'default':
        # my learning rate scheduler for cifar, following https://github.com/kuangliu/pytorch-cifar
        epoch_milestones = [65, 100, 130, 190, 220, 250, 280]
        """
        Set the learning rate of each parameter task to the initial lr decayed 
        by gamma once the number of epoch reaches one of the milestones
        """
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[i * len(train_loader) for i in epoch_milestones],
            gamma=0.5)
    else:
        raise Exception("unknown lr scheduler")

    if args.warmup:
        scheduler = GradualWarmupScheduler(optimizer,
                                           multiplier=args.lr / args.warmup_lr,
                                           total_iter=args.warmup_epochs *
                                           len(train_loader),
                                           after_scheduler=scheduler)
    '''
    load admm trained model
    '''
    save_path = os.path.join(args.save_path_exp, 'task' + str(task))
    print("Loading file: " + save_path + "/prunned_{}{}_{}_{}_{}_{}.pt".format(
        args.arch, args.depth, initial_rho * 10**
        (args.rho_num - 1), args.config_file, args.optmzr, args.sparsity_type))
    model.load_state_dict(
        torch.load(save_path + "/prunned_{}{}_{}_{}_{}_{}.pt".format(
            args.arch, args.depth, initial_rho * 10**(args.rho_num - 1),
            args.config_file, args.optmzr, args.sparsity_type)))

    if args.config_file:
        config = "./profile/" + args.config_file + ".yaml"
    elif args.config_setting:
        config = args.prune_ratios
    else:
        raise Exception("must provide a config setting.")
    ADMM = admm.ADMM(args, model, config=config, rho=initial_rho)
    best_prec1 = [0]
    best_mask = ''
    '''
    Deal with masks
    '''
    if args.heritage_weight or args.adaptive_mask:
        model_backup = copy.deepcopy(model.state_dict())

    if pre_mask:
        pre_mask = mask_reverse(args, pre_mask)
        #test_column_sparsity_mask(pre_mask)
        set_model_mask(model, pre_mask)

    # Trigger for experiment [leave space for future learning]
    if task != args.tasks - 1:
        admm.hard_prune(args, ADMM, model)  # prune weights

    if args.adaptive_mask and args.mask:
        admm.hard_prune_mask(args, ADMM, model)  #set submasks

    current_trainable_mask = get_model_mask(model=model)
    current_mask = copy.deepcopy(current_trainable_mask)
    submask = {}

    # if heritage, copy weights back to model
    if args.heritage_weight and args.mask:
        with torch.no_grad():
            for name, W in (model.named_parameters()):
                if name in args.pruned_layer:
                    W.data += model_backup[name].data * args.mask[name].cuda()

    # if adaptive learning, copy selected weights back to model
    if args.adaptive_mask and args.mask:
        with torch.no_grad():

            # mask layer: previous tasks part {0,1}; remaining {0}
            for name, M in (model.named_parameters()):
                if 'mask' in name:
                    weight_name = name.replace('w_mask', 'weight')
                    submask[weight_name] = M.cpu().detach()

            # copy selected weights back to model
            for name, W in (model.named_parameters()):
                if name in args.pruned_layer:
                    '''
                    Reason why use args.mask instead of submask
                    1. easy to cumulate model weights, if use submask, then need to backup weights belong to args.mask-submask
                    2. weights 'selective' already achieved by mask layer (fixed during mask retrain)
                    '''
                    W.data += model_backup[name].data * args.mask[name].cuda()

            # combine submask and current trainable mask
            for name in submask:
                current_mask[name] += submask[name]

            # mask layer: previous tasks part {0,1}; remaining {1}
            for name, M in (model.named_parameters()):
                if 'mask' in name:
                    M.data = current_mask[name.replace('w_mask',
                                                       'weight')].cuda()

        set_adaptive_mask(model, requires_grad=False)

    epoch_loss_dict = {}
    testAcc = []
    '''
    Start prunning
    '''
    for epoch in range(1, args.epochs_mask_retrain + 1):
        prune_train(args, current_trainable_mask, ADMM, train_loader,
                    criterion, optimizer, scheduler, epoch)
        prec1 = pipeline.test_model(args, model)

        if prec1 > max(best_prec1):
            #print("\n>_ Got better accuracy, saving model with accuracy {:.3f}% now...\n".format(prec1))
            torch.save(model.state_dict(), save_path + "/retrained.pt")

        testAcc.append(prec1)

        best_prec1.append(prec1)
        #print("current best acc is: {:.4f}".format(max(best_prec1)))

    print("Best Acc: {:.4f}%".format(max(best_prec1)))
    print('Pruned Mask sparsity')
    test_sparsity_mask(args, current_trainable_mask)

    return current_mask
예제 #13
0
def admm_prune(args, pre_mask, task, train_loader):
    """ 
    bag of tricks set-ups
    """
    initial_rho = args.rho
    criterion = CrossEntropyLossMaybeSmooth(smooth_eps=args.smooth_eps).cuda()
    args.smooth = args.smooth_eps > 0.0
    args.mixup = args.alpha > 0.0

    optimizer_init_lr = args.warmup_lr if args.warmup else args.lr
    optimizer = None
    if args.optmzr == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    optimizer_init_lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.optmzr == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), optimizer_init_lr)
    '''
    Set learning rate
    '''
    scheduler = None
    if args.lr_scheduler == 'cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=args.epochs_prune * len(train_loader),
            eta_min=4e-08)
    elif args.lr_scheduler == 'default':
        # my learning rate scheduler for cifar, following https://github.com/kuangliu/pytorch-cifar
        epoch_milestones = [65, 100, 130, 190, 220, 250, 280]
        """
        Set the learning rate of each parameter task to the initial lr decayed 
        by gamma once the number of epoch reaches one of the milestones
        """
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[i * len(train_loader) for i in epoch_milestones],
            gamma=0.5)
    else:
        raise Exception("unknown lr scheduler")

    if args.warmup:
        scheduler = GradualWarmupScheduler(optimizer,
                                           multiplier=args.lr / args.warmup_lr,
                                           total_iter=args.warmup_epochs *
                                           len(train_loader),
                                           after_scheduler=scheduler)

    # backup model weights
    if args.heritage_weight or args.adaptive_mask:
        model_backup = copy.deepcopy(model.state_dict())

    # get mask for training & set pre-trained (for previous tasks) weights to be zero
    if pre_mask:
        pre_mask = mask_reverse(args, pre_mask)
        set_model_mask(model, pre_mask)
    '''
    if heritage or adaptive, copy weights back to model
    not for first task
    '''
    if args.heritage_weight or args.adaptive_mask:
        if args.mask:
            with torch.no_grad():
                for name, W in (model.named_parameters()):
                    if name in args.pruned_layer:
                        W.data += model_backup[name].data * args.mask[
                            name].cuda()
    '''
    Start Pruning...
    '''
    for i in range(args.rho_num):
        current_rho = initial_rho * 10**i

        if args.config_file:
            config = "./profile/" + args.config_file + ".yaml"
        elif args.config_setting:
            config = args.prune_ratios
        else:
            raise Exception("must provide a config setting.")
        ADMM = admm.ADMM(args, model, config=config, rho=current_rho)
        admm.admm_initialization(args, ADMM=ADMM,
                                 model=model)  # intialize Z variable

        # admm train
        best_prec1 = 0.

        for epoch in range(1, args.epochs_prune + 1):
            print("current rho: {}".format(current_rho))
            prune_train(args, pre_mask, ADMM, train_loader, criterion,
                        optimizer, scheduler, epoch)

            prec1 = pipeline.test_model(args, model)
            best_prec1 = max(prec1, best_prec1)

        print("Best Acc: {:.4f}%".format(best_prec1))
        save_path = os.path.join(args.save_path_exp, 'task' + str(task))
        torch.save(
            model.state_dict(),
            save_path + "/prunned_{}{}_{}_{}_{}_{}.pt".format(
                args.arch, args.depth, current_rho, args.config_file,
                args.optmzr, args.sparsity_type))