Exemplo n.º 1
0
def get_model(dataset, model_class, model_name, pruner="synflow", epoch=0, custom_path="", dense_classifier=False):
    ## Data ##
    print('Loading {} dataset.'.format(dataset))
    input_shape, num_classes = load.dimension(dataset)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(model_class, model_name))
    model = load.model(model_name, model_class)(input_shape, num_classes, dense_classifier, pretrained=False)

    pruned_path = "../Results/pruned/%s/%s/%s/%s_prune.pth" % (model_class, model_name, pruner, epoch,)
    if len(custom_path) != 0:
        pruned_path = custom_path
    print("Loading model from: %s" % (pruned_path))

    model.load_state_dict(torch.load(pruned_path))
    return model
Exemplo n.º 2
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    input_shape, num_classes = load.dimension(args.dataset)
    data_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                  args.workers,
                                  args.prune_dataset_ratio * num_classes)

    ## Model, Loss, Optimizer ##
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()

    ## Compute per Neuron Score ##
    def unit_score_sum(model, scores):
        in_scores = []
        out_scores = []
        for name, module in model.named_modules():
            # # Only plot hidden units between convolutions
            # if isinstance(module, layers.Linear):
            #     W = module.weight
            #     b = module.bias

            #     W_score = scores[id(W)].detach().cpu().numpy()
            #     b_score = scores[id(b)].detach().cpu().numpy()

            #     in_scores.append(W_score.sum(axis=1) + b_score)
            #     out_scores.append(W_score.sum(axis=0))
            if isinstance(module, layers.Conv2d):
                W = module.weight
                W_score = scores[id(W)].detach().cpu().numpy()
                in_score = W_score.sum(axis=(1, 2, 3))
                out_score = W_score.sum(axis=(0, 2, 3))

                if module.bias is not None:
                    b = module.bias
                    b_score = scores[id(b)].detach().cpu().numpy()
                    in_score += b_score

                in_scores.append(in_score)
                out_scores.append(out_score)

        in_scores = np.concatenate(in_scores[:-1])
        out_scores = np.concatenate(out_scores[1:])
        return in_scores, out_scores

    ## Loop through Pruners and Save Data ##
    unit_scores = []
    for i, p in enumerate(args.pruner_list):
        pruner = load.pruner(p)(generator.masked_parameters(
            model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
        sparsity = 10**(-float(args.compression))
        prune_loop(model, loss, pruner, data_loader, device, sparsity,
                   args.compression_schedule, args.mask_scope,
                   args.prune_epochs, args.reinitialize)
        unit_score = unit_score_sum(model, pruner.scores)
        unit_scores.append(unit_score)
        np.save('{}/{}'.format(args.result_dir, p), unit_score)
Exemplo n.º 3
0
def run(args):
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    log_filename = '{}/{}'.format(args.result_dir, 'result.log')
    fout = open(log_filename, 'w')
    fout.write('start!\n')

    if args.compression_list == []:
        args.compression_list.append(args.compression)
    if args.pruner_list == []:
        args.pruner_list.append(args.pruner)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler,
                                 train_loader, test_loader, device,
                                 args.pre_epochs, args.verbose)

    ## Save Original ##
    torch.save(model.state_dict(),
               "{}/pre_train_model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),
               "{}/pre_train_optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),
               "{}/pre_train_scheduler.pt".format(args.result_dir))

    for compression in args.compression_list:
        for p in args.pruner_list:
            # Reset Model, Optimizer, and Scheduler
            print('compression ratio: {} ::: pruner: {}'.format(
                compression, p))
            model.load_state_dict(
                torch.load("{}/pre_train_model.pt".format(args.result_dir),
                           map_location=device))
            optimizer.load_state_dict(
                torch.load("{}/pre_train_optimizer.pt".format(args.result_dir),
                           map_location=device))
            scheduler.load_state_dict(
                torch.load("{}/pre_train_scheduler.pt".format(args.result_dir),
                           map_location=device))

            ## Prune ##
            print('Pruning with {} for {} epochs.'.format(
                p, args.prune_epochs))
            pruner = load.pruner(p)(generator.masked_parameters(
                model, args.prune_bias, args.prune_batchnorm,
                args.prune_residual))
            sparsity = 10**(-float(compression))
            prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                       args.compression_schedule, args.mask_scope,
                       args.prune_epochs, args.reinitialize,
                       args.prune_train_mode, args.shuffle, args.invert)

            ## Post-Train ##
            print('Post-Training for {} epochs.'.format(args.post_epochs))
            post_result = train_eval_loop(model, loss, optimizer, scheduler,
                                          train_loader, test_loader, device,
                                          args.post_epochs, args.verbose)

            ## Display Results ##
            frames = [
                pre_result.head(1),
                pre_result.tail(1),
                post_result.head(1),
                post_result.tail(1)
            ]
            train_result = pd.concat(
                frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
            prune_result = metrics.summary(
                model, pruner.scores, metrics.flop(model, input_shape, device),
                lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                             prune_residual))
            total_params = int(
                (prune_result['sparsity'] * prune_result['size']).sum())
            possible_params = prune_result['size'].sum()
            total_flops = int(
                (prune_result['sparsity'] * prune_result['flops']).sum())
            possible_flops = prune_result['flops'].sum()
            print("Train results:\n", train_result)
            print("Prune results:\n", prune_result)
            print("Parameter Sparsity: {}/{} ({:.4f})".format(
                total_params, possible_params, total_params / possible_params))
            print("FLOP Sparsity: {}/{} ({:.4f})".format(
                total_flops, possible_flops, total_flops / possible_flops))

            ## recording testing time for task 2 ##
            start_time = timeit.default_timer()
            # evaluating the model, including some data gathering overhead
            # eval(model, loss, test_loader, device, args.verbose)
            model.eval()
            with torch.no_grad():
                for data, target in test_loader:
                    data, target = data.to(device), target.to(device)
                    model(data)
                    break
            end_time = timeit.default_timer()
            print("Testing time: {:.4f}".format(end_time - start_time))

            fout.write('compression ratio: {} ::: pruner: {}'.format(
                compression, p))
            fout.write('Train results:\n {}\n'.format(train_result))
            fout.write('Prune results:\n {}\n'.format(prune_result))
            fout.write('Parameter Sparsity: {}/{} ({:.4f})\n'.format(
                total_params, possible_params, total_params / possible_params))
            fout.write("FLOP Sparsity: {}/{} ({:.4f})\n".format(
                total_flops, possible_flops, total_flops / possible_flops))
            fout.write("Testing time: {}\n".format(end_time - start_time))
            fout.write("remaining weights: \n{}\n".format(
                (prune_result['sparsity'] * prune_result['size'])))
            fout.write('flop each layer: {}\n'.format(
                (prune_result['sparsity'] *
                 prune_result['flops']).values.tolist()))
            ## Save Results and Model ##
            if args.save:
                print('Saving results.')
                if not os.path.exists('{}/{}'.format(args.result_dir,
                                                     compression)):
                    os.makedirs('{}/{}'.format(args.result_dir, compression))
                # pre_result.to_pickle("{}/{}/pre-train.pkl".format(args.result_dir, compression))
                # post_result.to_pickle("{}/{}/post-train.pkl".format(args.result_dir, compression))
                # prune_result.to_pickle("{}/{}/compression.pkl".format(args.result_dir, compression))
                # torch.save(model.state_dict(), "{}/{}/model.pt".format(args.result_dir, compression))
                # torch.save(optimizer.state_dict(),
                #         "{}/{}/optimizer.pt".format(args.result_dir, compression))
                # torch.save(scheduler.state_dict(),
                #         "{}/{}/scheduler.pt".format(args.result_dir, compression))

    fout.close()
Exemplo n.º 4
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset) 
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers, args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True, args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False, args.workers)

    ## Model ##
    print('Creating {} model.'.format(args.model))
    model = load.model(args.model, args.model_class)(input_shape, 
                                                     num_classes, 
                                                     args.dense_classifier,
                                                     args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)

    ## Save Original ##
    torch.save(model.state_dict(),"{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),"{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),"{}/scheduler.pt".format(args.result_dir))

    ## Train-Prune Loop ##
    comp_exponents = np.arange(0, 6, 1) # log 10 of inverse sparsity
    train_iterations = [100] # number of epochs between prune periods
    prune_iterations = np.arange(1, 6, 1) # number of prune periods
    for i, comp_exp in enumerate(comp_exponents[::-1]):
        for j, train_iters in enumerate(train_iterations):
            for k, prune_iters in enumerate(prune_iterations):
                print('{} compression ratio, {} train epochs, {} prune iterations'.format(comp_exp, train_iters, prune_iters))
                
                # Reset Model, Optimizer, and Scheduler
                model.load_state_dict(torch.load("{}/model.pt".format(args.result_dir), map_location=device))
                optimizer.load_state_dict(torch.load("{}/optimizer.pt".format(args.result_dir), map_location=device))
                scheduler.load_state_dict(torch.load("{}/scheduler.pt".format(args.result_dir), map_location=device))
                
                for l in range(prune_iters):

                    # Pre Train Model
                    train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                    test_loader, device, train_iters, args.verbose)

                    # Prune Model
                    pruner = load.pruner('mag')(generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
                    sparsity = (10**(-float(comp_exp)))**((l + 1) / prune_iters)
                    prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                               args.linear_compression_schedule, args.mask_scope, 1, args.reinitialize)

                    # Reset Model's Weights
                    original_dict = torch.load("{}/model.pt".format(args.result_dir), map_location=device)
                    original_weights = dict(filter(lambda v: (v[1].requires_grad == True), model.state_dict().items()))
                    model_dict = model.state_dict()
                    model_dict.update(original_weights)
                    model.load_state_dict(model_dict)
                    
                    # Reset Optimizer and Scheduler
                    optimizer.load_state_dict(torch.load("{}/optimizer.pt".format(args.result_dir), map_location=device))
                    scheduler.load_state_dict(torch.load("{}/scheduler.pt".format(args.result_dir), map_location=device))

                # Prune Result
                prune_result = metrics.summary(model, 
                                               pruner.scores,
                                               metrics.flop(model, input_shape, device),
                                               lambda p: generator.prunable(p, args.prune_batchnorm, args.prune_residual))
                # Train Model
                post_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                              test_loader, device, args.post_epochs, args.verbose)
                
                # Save Data
                prune_result.to_pickle("{}/compression-{}-{}-{}.pkl".format(args.result_dir, str(comp_exp), str(train_iters), str(prune_iters)))
                post_result.to_pickle("{}/performance-{}-{}-{}.pkl".format(args.result_dir, str(comp_exp), str(train_iters), str(prune_iters)))
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    input_shape, num_classes = load.dimension(args.dataset)
    train_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)

    ## Model, Loss, Optimizer ##
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)
    loss = nn.CrossEntropyLoss()

    ## Compute Layer Name and Inv Size ##
    def layer_names(model):
        names = []
        inv_size = []
        for name, module in model.named_modules():
            if isinstance(module, (layers.Linear, layers.Conv2d)):
                num_elements = np.prod(module.weight.shape)
                if module.bias is not None:
                    num_elements += np.prod(module.bias.shape)
                names.append(name)
                inv_size.append(1.0 / num_elements)
        return names, inv_size

    ## Compute Average Mag Score ##
    def average_mag_score(model):
        average_scores = []
        for module in model.modules():
            if isinstance(module, (layers.Linear, layers.Conv2d)):
                W = module.weight.detach().cpu().numpy()
                W_score = W**2
                score_sum = W_score.sum()
                num_elements = np.prod(W.shape)

                if module.bias is not None:
                    b = module.bias.detach().cpu().numpy()
                    b_score = b**2
                    score_sum += b_score.sum()
                    num_elements += np.prod(b.shape)

                average_scores.append(np.abs(score_sum / num_elements))
        return average_scores

    ## Train and Save Data ##
    _, inv_size = layer_names(init_model)
    Wscore = []
    for epoch in tqdm(range(args.post_epochs)):
        Wscore.append(average_score(model))
        train(model, loss, optimizer, train_loader, device, epoch,
              args.verbose)
        scheduler.step()

    np.save('{}/{}'.format(args.result_dir, 'inv-size'), inv_size)
    np.save('{}/score'.format(args.result_dir), np.array(Wscore))
Exemplo n.º 6
0
def run(args):
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset) 
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers, args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True, args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False, args.workers)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))
    model = load.model(args.model, args.model_class)(input_shape, 
                                                     num_classes, 
                                                     args.dense_classifier, 
                                                     args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)


    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                 test_loader, device, args.pre_epochs, args.verbose)

    ## Prune ##
    print('Pruning with {} for {} epochs.'.format(args.pruner, args.prune_epochs))
    pruner = load.pruner(args.pruner)(generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
    sparsity = 10**(-float(args.compression))
    prune_loop(model, loss, pruner, prune_loader, device, sparsity, 
               args.compression_schedule, args.mask_scope, args.prune_epochs, args.reinitialize, args.prune_train_mode)

    
    ## Post-Train ##
    print('Post-Training for {} epochs.'.format(args.post_epochs))
    post_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                  test_loader, device, args.post_epochs, args.verbose) 

    ## Display Results ##
    frames = [pre_result.head(1), pre_result.tail(1), post_result.head(1), post_result.tail(1)]
    train_result = pd.concat(frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
    prune_result = metrics.summary(model, 
                                   pruner.scores,
                                   metrics.flop(model, input_shape, device),
                                   lambda p: generator.prunable(p, args.prune_batchnorm, args.prune_residual))
    total_params = int((prune_result['sparsity'] * prune_result['size']).sum())
    possible_params = prune_result['size'].sum()
    total_flops = int((prune_result['sparsity'] * prune_result['flops']).sum())
    possible_flops = prune_result['flops'].sum()
    print("Train results:\n", train_result)
    print("Prune results:\n", prune_result)
    print("Parameter Sparsity: {}/{} ({:.4f})".format(total_params, possible_params, total_params / possible_params))
    print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops, total_flops / possible_flops))

    ## Save Results and Model ##
    if args.save:
        print('Saving results.')
        pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))
        post_result.to_pickle("{}/post-train.pkl".format(args.result_dir))
        prune_result.to_pickle("{}/compression.pkl".format(args.result_dir))
        torch.save(model.state_dict(),"{}/model.pt".format(args.result_dir))
        torch.save(optimizer.state_dict(),"{}/optimizer.pt".format(args.result_dir))
        torch.save(scheduler.state_dict(),"{}/scheduler.pt".format(args.result_dir))
Exemplo n.º 7
0
def run(gpu_id, args):

    ## parameters for multi-processing
    print('using gpu', gpu_id)
    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=args.world_size,
                            rank=gpu_id)

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    # device = load.device(args.gpu)
    device = torch.device(gpu_id)

    args.gpu_id = gpu_id

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)

    ## need to change the workers for loading the data
    args.workers = int((args.workers + 4 - 1) / 4)
    print('Adjusted dataloader worker number is ', args.workers)

    # prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers,
    #                 args.prune_dataset_ratio * num_classes, world_size=args.world_size, rank=gpu_id)
    prune_loader, _ = load.dataloader(args.dataset, args.prune_batch_size,
                                      True, args.workers,
                                      args.prune_dataset_ratio * num_classes)

    ## need to divide the training batch size for each GPU
    args.train_batch_size = int(args.train_batch_size / args.gpu_count)
    train_loader, train_sampler = load.dataloader(args.dataset,
                                                  args.train_batch_size,
                                                  True,
                                                  args.workers,
                                                  args=args)
    # args.test_batch_size = int(args.test_batch_size/args.gpu_count)
    test_loader, _ = load.dataloader(args.dataset, args.test_batch_size, False,
                                     args.workers)

    print("data loader batch size (prune::train::test) is {}::{}::{}".format(
        prune_loader.batch_size, train_loader.batch_size,
        test_loader.batch_size))

    log_filename = '{}/{}'.format(args.result_dir, 'result.log')
    fout = open(log_filename, 'w')
    fout.write('start!\n')

    if args.compression_list == []:
        args.compression_list.append(args.compression)
    if args.pruner_list == []:
        args.pruner_list.append(args.pruner)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))

    # load the pre-defined model from the utils
    model = load.model(args.model, args.model_class)(input_shape, num_classes,
                                                     args.dense_classifier,
                                                     args.pretrained)

    ## wrap model with distributed dataparallel module
    torch.cuda.set_device(gpu_id)
    # model = model.to(device)
    model.cuda(gpu_id)
    model = ddp(model, device_ids=[gpu_id])

    ## don't need to move the loss to the GPU as it contains no parameters
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model,
                                 loss,
                                 optimizer,
                                 scheduler,
                                 train_loader,
                                 test_loader,
                                 device,
                                 args.pre_epochs,
                                 args.verbose,
                                 train_sampler=train_sampler)
    print('Pre-Train finished!')

    ## Save Original ##
    torch.save(model.state_dict(),
               "{}/pre_train_model_{}.pt".format(args.result_dir, gpu_id))
    torch.save(optimizer.state_dict(),
               "{}/pre_train_optimizer_{}.pt".format(args.result_dir, gpu_id))
    torch.save(scheduler.state_dict(),
               "{}/pre_train_scheduler_{}.pt".format(args.result_dir, gpu_id))

    if not args.unpruned:
        for compression in args.compression_list:
            for p in args.pruner_list:
                # Reset Model, Optimizer, and Scheduler
                print('compression ratio: {} ::: pruner: {}'.format(
                    compression, p))
                model.load_state_dict(
                    torch.load("{}/pre_train_model_{}.pt".format(
                        args.result_dir, gpu_id),
                               map_location=device))
                optimizer.load_state_dict(
                    torch.load("{}/pre_train_optimizer_{}.pt".format(
                        args.result_dir, gpu_id),
                               map_location=device))
                scheduler.load_state_dict(
                    torch.load("{}/pre_train_scheduler_{}.pt".format(
                        args.result_dir, gpu_id),
                               map_location=device))

                ## Prune ##
                print('Pruning with {} for {} epochs.'.format(
                    p, args.prune_epochs))
                pruner = load.pruner(p)(generator.masked_parameters(
                    model, args.prune_bias, args.prune_batchnorm,
                    args.prune_residual))
                sparsity = 10**(-float(compression))
                prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                           args.compression_schedule, args.mask_scope,
                           args.prune_epochs, args.reinitialize,
                           args.prune_train_mode, args.shuffle, args.invert)

                ## Post-Train ##
                print('Post-Training for {} epochs.'.format(args.post_epochs))
                post_train_start_time = timeit.default_timer()
                post_result = train_eval_loop(model,
                                              loss,
                                              optimizer,
                                              scheduler,
                                              train_loader,
                                              test_loader,
                                              device,
                                              args.post_epochs,
                                              args.verbose,
                                              train_sampler=train_sampler)
                post_train_end_time = timeit.default_timer()
                print("Post Training time: {:.4f}s".format(
                    post_train_end_time - post_train_start_time))

                ## Display Results ##
                frames = [
                    pre_result.head(1),
                    pre_result.tail(1),
                    post_result.head(1),
                    post_result.tail(1)
                ]
                train_result = pd.concat(
                    frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
                prune_result = metrics.summary(
                    model, pruner.scores,
                    metrics.flop(model, input_shape, device),
                    lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                                 prune_residual))
                total_params = int(
                    (prune_result['sparsity'] * prune_result['size']).sum())
                possible_params = prune_result['size'].sum()
                total_flops = int(
                    (prune_result['sparsity'] * prune_result['flops']).sum())
                possible_flops = prune_result['flops'].sum()
                print("Train results:\n", train_result)
                # print("Prune results:\n", prune_result)
                # print("Parameter Sparsity: {}/{} ({:.4f})".format(total_params, possible_params, total_params / possible_params))
                # print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops, total_flops / possible_flops))

                ## recording testing time for task 2 ##
                # evaluating the model, including some data gathering overhead
                # eval(model, loss, test_loader, device, args.verbose)
                model.eval()
                start_time = timeit.default_timer()
                with torch.no_grad():
                    for data, target in test_loader:
                        data, target = data.to(device), target.to(device)
                        temp_eval_out = model(data)
                end_time = timeit.default_timer()
                print("Testing time: {:.4f}s".format(end_time - start_time))

                fout.write('compression ratio: {} ::: pruner: {}'.format(
                    compression, p))
                fout.write('Train results:\n {}\n'.format(train_result))
                fout.write('Prune results:\n {}\n'.format(prune_result))
                fout.write('Parameter Sparsity: {}/{} ({:.4f})\n'.format(
                    total_params, possible_params,
                    total_params / possible_params))
                fout.write("FLOP Sparsity: {}/{} ({:.4f})\n".format(
                    total_flops, possible_flops, total_flops / possible_flops))
                fout.write("Testing time: {}s\n".format(end_time - start_time))
                fout.write("remaining weights: \n{}\n".format(
                    (prune_result['sparsity'] * prune_result['size'])))
                fout.write('flop each layer: {}\n'.format(
                    (prune_result['sparsity'] *
                     prune_result['flops']).values.tolist()))
                ## Save Results and Model ##
                if args.save:
                    print('Saving results.')
                    if not os.path.exists('{}/{}'.format(
                            args.result_dir, compression)):
                        os.makedirs('{}/{}'.format(args.result_dir,
                                                   compression))
                    # pre_result.to_pickle("{}/{}/pre-train.pkl".format(args.result_dir, compression))
                    # post_result.to_pickle("{}/{}/post-train.pkl".format(args.result_dir, compression))
                    # prune_result.to_pickle("{}/{}/compression.pkl".format(args.result_dir, compression))
                    # torch.save(model.state_dict(), "{}/{}/model.pt".format(args.result_dir, compression))
                    # torch.save(optimizer.state_dict(),
                    #         "{}/{}/optimizer.pt".format(args.result_dir, compression))
                    # torch.save(scheduler.state_dict(),
                    #         "{}/{}/scheduler.pt".format(args.result_dir, compression))

    else:
        print('Staring Unpruned NN training')
        print('Training for {} epochs.'.format(args.post_epochs))
        model.load_state_dict(
            torch.load("{}/pre_train_model.pt".format(args.result_dir),
                       map_location=device))
        optimizer.load_state_dict(
            torch.load("{}/pre_train_optimizer.pt".format(args.result_dir),
                       map_location=device))
        scheduler.load_state_dict(
            torch.load("{}/pre_train_scheduler.pt".format(args.result_dir),
                       map_location=device))

        train_start_time = timeit.default_timer()
        result = train_eval_loop(model,
                                 loss,
                                 optimizer,
                                 scheduler,
                                 train_loader,
                                 test_loader,
                                 device,
                                 args.post_epochs,
                                 args.verbose,
                                 train_sampler=train_sampler)
        train_end_time = timeit.default_timer()
        frames = [result.head(1), result.tail(1)]
        train_result = pd.concat(frames, keys=['Init.', 'Final'])
        print('Train results:\n', train_result)

    fout.close()

    dist.destroy_process_group()
Exemplo n.º 8
0
def run(args):
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu, args.seed)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)

    if args.validation:
        trainset = 'train'
        evalset = 'val'
    else:
        trainset = 'trainval'
        evalset = 'test'

    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, trainset, args.workers, corrupt_prob=args.prune_corrupt)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, trainset, args.workers, corrupt_prob=args.train_corrupt)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, evalset, args.workers)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))
    model = load.model(args.model, args.model_class)(input_shape, 
                                                     num_classes, 
                                                     args.dense_classifier, 
                                                     args.pretrained).to(device)

    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)
    torch.save(model.state_dict(),"{}/init-model.pt".format(args.result_dir))

    ## Pre-Train ##
    assert(args.pre_epochs == 0)
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                 test_loader, device, args.pre_epochs, args.verbose)

    torch.save(model.state_dict(),"{}/pre-trained.pt".format(args.result_dir))

    ## Load in the model ##
    # maskref_dict = torch.load(args.mask_file, map_location=device)
    # model_dict = model.state_dict()

    # mask_dict = dict(filter(lambda v: 'mask' in v[0], maskref_dict.items()))
    # print("Keys being loaded\n", '\n'.join(mask_dict.keys()))
    # model_dict.update(mask_dict)

    model.load_state_dict(torch.load(args.model_file, map_location=device))

    # sanity check part 1
    dict_init = {}
    for name, param in model.state_dict().items():
        print(name)
        if name.endswith('weight') and 'shortcut' not in name and 'bn' not in name:
            mask = model.state_dict()[name + '_mask']
            dict_init[name] = (param.sum().item(), mask.sum().item(), (param * mask).sum().item())

    ## This uses a pruner but only for the purpose of shuffling weights ##
    assert(args.prune_epochs == 0 and args.weightshuffle)
    if args.pruner in ["synflow", "altsynflow", "synflowmag", "rsfgrad"]:
        model.double() # to address exploding/vanishing gradients in SynFlow for deep models
    print('Pruning with {} for {} epochs.'.format(args.pruner, args.prune_epochs))
    pruner = load.pruner(args.pruner)(generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
    sparsity = args.sparsity # 10**(-float(args.compression))
    prune_loop(model, loss, pruner, prune_loader, device, sparsity, 
               args.compression_schedule, args.mask_scope, args.prune_epochs, args.reinitialize, args.prune_train_mode, args.shuffle, args.invert, args.weightshuffle)
    if args.pruner in ["synflow", "altsynflow", "synflowmag", "rsfgrad"]:
        model.float() # to address exploding/vanishing gradients in SynFlow for deep models
    torch.save(model.state_dict(),"{}/post-prune-model.pt".format(args.result_dir))

    # sanity check part 2
    for name, param in model.state_dict().items():
        print(name)
        if name.endswith('weight') and 'shortcut' not in name and 'bn' not in name:
            mask = model.state_dict()[name + '_mask']
            print(name, dict_init[name], param.sum().item(), mask.sum().item(), (param * mask).sum().item())

    ## Compute Path Count ##
    print("Number of paths", get_path_count(mdl=model, arch=args.model))
    print("Number of active filters", number_active_filters(mdl=model, arch=args.model))

    ## Post-Train ##
    print('Post-Training for {} epochs.'.format(args.post_epochs))
    post_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                  test_loader, device, args.post_epochs, args.verbose) 

    ## Display Results ##
    frames = [pre_result.head(1), pre_result.tail(1), post_result.head(1), post_result.tail(1)]
    train_result = pd.concat(frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
    prune_result = metrics.summary(model, 
                                   pruner.scores,
                                   metrics.flop(model, input_shape, device),
                                   lambda p: generator.prunable(p, args.prune_batchnorm, args.prune_residual))
    total_params = int((prune_result['sparsity'] * prune_result['size']).sum())
    possible_params = prune_result['size'].sum()
    total_flops = int((prune_result['sparsity'] * prune_result['flops']).sum())
    possible_flops = prune_result['flops'].sum()
    print("Train results:\n", train_result)
    print("Prune results:\n", prune_result)
    print("Parameter Sparsity: {}/{} ({:.4f})".format(total_params, possible_params, total_params / possible_params))
    print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops, total_flops / possible_flops))

    ## Save Results and Model ##
    if args.save:
        print('Saving results.')
        pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))
        post_result.to_pickle("{}/post-train.pkl".format(args.result_dir))
        plot.get_plots(post_result, save_path="{}/plots.png".format(args.result_dir))
        prune_result.to_pickle("{}/compression.pkl".format(args.result_dir))
        torch.save(model.state_dict(),"{}/post-train-model.pt".format(args.result_dir))
        torch.save(optimizer.state_dict(),"{}/optimizer.pt".format(args.result_dir))
        torch.save(scheduler.state_dict(),"{}/scheduler.pt".format(args.result_dir))
Exemplo n.º 9
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print("Loading {} dataset.".format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(
        args.dataset,
        args.prune_batch_size,
        True,
        args.workers,
        args.prune_dataset_ratio * num_classes,
    )
    train_loader = load.dataloader(
        args.dataset, args.train_batch_size, True, args.workers
    )
    test_loader = load.dataloader(
        args.dataset, args.test_batch_size, False, args.workers
    )

    ## Model ##
    print("Creating {} model.".format(args.model))
    model = load.model(args.model, args.model_class)(
        input_shape, num_classes, args.dense_classifier, args.pretrained
    ).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(
        generator.parameters(model),
        lr=args.lr,
        weight_decay=args.weight_decay,
        **opt_kwargs
    )
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate
    )

    ## Save Original ##
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(), "{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(), "{}/scheduler.pt".format(args.result_dir))

    ## Train-Prune Loop ##
    for compression in args.compression_list:
        for level in args.level_list:
            print(
                "{} compression ratio, {} train-prune levels".format(compression, level)
            )

            # Reset Model, Optimizer, and Scheduler
            model.load_state_dict(
                torch.load("{}/model.pt".format(args.result_dir), map_location=device)
            )
            optimizer.load_state_dict(
                torch.load(
                    "{}/optimizer.pt".format(args.result_dir), map_location=device
                )
            )
            scheduler.load_state_dict(
                torch.load(
                    "{}/scheduler.pt".format(args.result_dir), map_location=device
                )
            )

            for l in range(level):

                # Pre Train Model
                train_eval_loop(
                    model,
                    loss,
                    optimizer,
                    scheduler,
                    train_loader,
                    test_loader,
                    device,
                    args.pre_epochs,
                    args.verbose,
                )

                # Prune Model
                pruner = load.pruner(args.pruner)(
                    generator.masked_parameters(
                        model,
                        args.prune_bias,
                        args.prune_batchnorm,
                        args.prune_residual,
                    )
                )
                sparsity = (10 ** (-float(compression))) ** ((l + 1) / level)
                prune_loop(
                    model,
                    loss,
                    pruner,
                    prune_loader,
                    device,
                    sparsity,
                    args.compression_schedule,
                    args.mask_scope,
                    args.prune_epochs,
                    args.reinitialize,
                )

                # Reset Model's Weights
                original_dict = torch.load(
                    "{}/model.pt".format(args.result_dir), map_location=device
                )
                original_weights = dict(
                    filter(
                        lambda v: (v[1].requires_grad == True), original_dict.items()
                    )
                )
                model_dict = model.state_dict()
                model_dict.update(original_weights)
                model.load_state_dict(model_dict)

                # Reset Optimizer and Scheduler
                optimizer.load_state_dict(
                    torch.load(
                        "{}/optimizer.pt".format(args.result_dir), map_location=device
                    )
                )
                scheduler.load_state_dict(
                    torch.load(
                        "{}/scheduler.pt".format(args.result_dir), map_location=device
                    )
                )

            # Prune Result
            prune_result = metrics.summary(
                model,
                pruner.scores,
                metrics.flop(model, input_shape, device),
                lambda p: generator.prunable(
                    p, args.prune_batchnorm, args.prune_residual
                ),
            )
            # Train Model
            post_result = train_eval_loop(
                model,
                loss,
                optimizer,
                scheduler,
                train_loader,
                test_loader,
                device,
                args.post_epochs,
                args.verbose,
            )

            # Save Data
            post_result.to_pickle(
                "{}/post-train-{}-{}-{}.pkl".format(
                    args.result_dir, args.pruner, str(compression), str(level)
                )
            )
            prune_result.to_pickle(
                "{}/compression-{}-{}-{}.pkl".format(
                    args.result_dir, args.pruner, str(compression), str(level)
                )
            )
Exemplo n.º 10
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    ## Model ##
    print('Creating {} model.'.format(args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape,
                                         num_classes,
                                         args.dense_classifier,
                                         args.pretrained,
                                         L=args.num_layers,
                                         N=args.layer_width).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler,
                                 train_loader, test_loader, device,
                                 args.pre_epochs, args.verbose)

    pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),
               "{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),
               "{}/scheduler.pt".format(args.result_dir))

    ## Prune and Fine-Tune##
    pruners = args.pruner_list
    comp_exponents = args.compression_list
    if args.prune_loop == 'rand':
        prune_loop_fun = rand_prune_loop
    elif args.prune_loop == 'approx':
        prune_loop_fun = approx_prune_loop
    else:
        prune_loop_fun = prune_loop

    for j, exp in enumerate(comp_exponents[::-1]):
        for i, p in enumerate(pruners):
            print(p, exp)

            model.load_state_dict(
                torch.load("{}/model.pt".format(args.result_dir),
                           map_location=device))
            optimizer.load_state_dict(
                torch.load("{}/optimizer.pt".format(args.result_dir),
                           map_location=device))
            scheduler.load_state_dict(
                torch.load("{}/scheduler.pt".format(args.result_dir),
                           map_location=device))

            pruner = load.pruner(p)(generator.masked_parameters(
                model, args.prune_bias, args.prune_batchnorm,
                args.prune_residual))
            sparsity = 10**(-float(exp))

            if p == 'rand_weighted':
                prune_loop_fun(model, loss, pruner, prune_loader, device,
                               sparsity, args.linear_compression_schedule,
                               args.mask_scope, args.prune_epochs, args,
                               args.reinitialize)
            elif p == 'sf' or p == 'synflow' or p == 'snip':
                prune_loop_fun(model, loss, pruner, prune_loader, device,
                               sparsity, args.linear_compression_schedule,
                               args.mask_scope, args.prune_epochs,
                               args.reinitialize)
            else:
                prune_loop_fun(model, loss, pruner, prune_loader, device,
                               sparsity, args.linear_compression_schedule,
                               args.mask_scope, 1, args.reinitialize)

            prune_result = metrics.summary(
                model, pruner.scores, metrics.flop(model, input_shape, device),
                lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                             prune_residual))

            post_result = train_eval_loop(model, loss, optimizer, scheduler,
                                          train_loader, test_loader, device,
                                          args.post_epochs, args.verbose)

            post_result.to_pickle("{}/post-train-{}-{}-{}.pkl".format(
                args.result_dir, p, str(exp), args.prune_epochs))
            prune_result.to_pickle("{}/compression-{}-{}-{}.pkl".format(
                args.result_dir, p, str(exp), args.prune_epochs))
verbose = True
directory = 'Results/lottery_layer_conservation/{}/{}'.format(
    dataset, model_architecture)
try:
    os.makedirs(directory)
except FileExistsError:
    pass

torch.manual_seed(seed=1)
device = load.device(gpu=9)
input_shape, num_classes = load.dimension(dataset=dataset)
train_loader = load.dataloader(dataset=dataset,
                               batch_size=128,
                               train=True,
                               workers=30)
model = load.model(model_architecture=model_architecture,
                   model_class=model_class)
init_model = model(input_shape=input_shape,
                   num_classes=num_classes,
                   dense_classifier=False,
                   pretrained=False).to(device)
train_model = copy.deepcopy(init_model)

opt_class, opt_kwargs = load.optimizer('momentum')
optimizer = opt_class(generator.parameters(train_model),
                      lr=lr,
                      weight_decay=1e-4,
                      **opt_kwargs)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=[30, 60, 80],
                                                 gamma=0.1)
loss = nn.CrossEntropyLoss()
Exemplo n.º 12
0
def run(args):
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print("Loading {} dataset.".format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(
        dataset=args.dataset,
        batch_size=args.prune_batch_size,
        train=True,
        workers=args.workers,
        length=args.prune_dataset_ratio * num_classes,
        datadir=args.data_dir,
    )
    train_loader = load.dataloader(
        dataset=args.dataset,
        batch_size=args.train_batch_size,
        train=True,
        workers=args.workers,
        datadir=args.data_dir,
    )
    test_loader = load.dataloader(
        dataset=args.dataset,
        batch_size=args.test_batch_size,
        train=False,
        workers=args.workers,
        datadir=args.data_dir,
    )

    ## Model, Loss, Optimizer ##
    print("Creating {}-{} model.".format(args.model_class, args.model))
    if args.model in ["fc", "conv"]:
        norm_layer = load.norm_layer(args.norm_layer)
        print(f"Applying {args.norm_layer} normalization: {norm_layer}")
        model = load.model(args.model, args.model_class)(
            input_shape=input_shape,
            num_classes=num_classes,
            dense_classifier=args.dense_classifier,
            pretrained=args.pretrained,
            norm_layer=norm_layer,
        ).to(device)
    else:
        model = load.model(args.model, args.model_class)(
            input_shape=input_shape,
            num_classes=num_classes,
            dense_classifier=args.dense_classifier,
            pretrained=args.pretrained,
        ).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(
        generator.parameters(model),
        lr=args.lr,
        weight_decay=args.weight_decay,
        **opt_kwargs,
    )
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## checkpointing setup ##
    if args.tk_steps_file is not None:
        save_steps = load.save_steps_file(args.tk_steps_file)
        steps_per_epoch = len(train_loader)
        max_epochs = int(save_steps[-1] / steps_per_epoch)
        print(f"Overriding train epochs to last step in file ")
        print(f"    pre_epochs set to 0, post_epochs set to {max_epochs}")
        setattr(args, "pre_epochs", 0)
        setattr(args, "post_epochs", max_epochs)
    else:
        save_steps = None

    ## Pre-Train ##
    print("Pre-Train for {} epochs.".format(args.pre_epochs))
    pre_result = train_eval_loop(
        model,
        loss,
        optimizer,
        scheduler,
        train_loader,
        test_loader,
        device,
        args.pre_epochs,
        args.verbose,
        save_steps=save_steps,
        save_freq=args.save_freq,
        save_path=args.save_path,
    )

    ## Prune ##
    if args.prune_epochs > 0:
        print("Pruning with {} for {} epochs.".format(args.pruner,
                                                      args.prune_epochs))
        pruner = load.pruner(args.pruner)(generator.masked_parameters(
            model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
        sparsity = 10**(-float(args.compression))
        prune_loop(
            model,
            loss,
            pruner,
            prune_loader,
            device,
            sparsity,
            args.compression_schedule,
            args.mask_scope,
            args.prune_epochs,
            args.reinitialize,
        )

    ## Post-Train ##
    print("Post-Training for {} epochs.".format(args.post_epochs))
    post_result = train_eval_loop(
        model,
        loss,
        optimizer,
        scheduler,
        train_loader,
        test_loader,
        device,
        args.post_epochs,
        args.verbose,
        save_steps=save_steps,
        save_freq=args.save_freq,
        save_path=args.save_path,
    )

    ## Display Results ##
    frames = [
        pre_result.head(1),
        pre_result.tail(1),
        post_result.head(1),
        post_result.tail(1),
    ]
    train_result = pd.concat(
        frames, keys=["Init.", "Pre-Prune", "Post-Prune", "Final"])
    print("Train results:\n", train_result)
    if args.prune_epochs > 0:
        prune_result = metrics.summary(
            model,
            pruner.scores,
            metrics.flop(model, input_shape, device),
            lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                         prune_residual),
        )
        total_params = int(
            (prune_result["sparsity"] * prune_result["size"]).sum())
        possible_params = prune_result["size"].sum()
        total_flops = int(
            (prune_result["sparsity"] * prune_result["flops"]).sum())
        possible_flops = prune_result["flops"].sum()

        print("Prune results:\n", prune_result)
        print("Parameter Sparsity: {}/{} ({:.4f})".format(
            total_params, possible_params, total_params / possible_params))
        print("FLOP Sparsity: {}/{} ({:.4f})".format(
            total_flops, possible_flops, total_flops / possible_flops))

    ## Save Results and Model ##
    if args.save:
        print("Saving results.")
        pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))
        post_result.to_pickle("{}/post-train.pkl".format(args.result_dir))
        torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
        torch.save(optimizer.state_dict(),
                   "{}/optimizer.pt".format(args.result_dir))
        if args.prune_epochs > 0:
            prune_result.to_pickle("{}/compression.pkl".format(
                args.result_dir))
            torch.save(pruner.state_dict(),
                       "{}/pruner.pt".format(args.result_dir))
Exemplo n.º 13
0
def run(args):
    print(args)
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset,
                                   args.prune_batch_size,
                                   True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes,
                                   prune_loader=True)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    # print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler,
                                 train_loader, test_loader, device, 0,
                                 args.verbose)

    ## Prune ##
    print('Pruning with {} for {} epochs.'.format(args.pruner,
                                                  args.prune_epochs))
    pruner = load.pruner(args.pruner)(generator.masked_parameters(
        model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
    sparsity = 10**(-float(args.compression))
    print("Sparsity: {}".format(sparsity))
    save_pruned_path = args.save_pruned_path + "/%s/%s/%s" % (
        args.model_class,
        args.model,
        args.pruner,
    )
    if (args.save_pruned):
        print("Saving pruned models to: %s" % (save_pruned_path, ))
        if not os.path.exists(save_pruned_path):
            os.makedirs(save_pruned_path)
    prune_loop(model, loss, pruner, prune_loader, device, sparsity,
               args.compression_schedule, args.mask_scope, args.prune_epochs,
               args.reinitialize, args.save_pruned, save_pruned_path)

    save_batch_output_path = args.save_pruned_path + "/%s/%s/%s/output_%s" % (
        args.model_class, args.model, args.pruner,
        (args.dataset + "_" + str(args.seed) + "_" + str(args.compression)))
    save_init_path_kernel_output_path = args.save_pruned_path + "/init-path-kernel-values.csv"
    row_name = f"{args.model}_{args.dataset}_{args.pruner}_{str(args.seed)}_{str(args.compression)}"

    print(save_init_path_kernel_output_path)
    print(row_name)

    if not os.path.exists(save_batch_output_path):
        os.makedirs(save_batch_output_path)
    ## Post-Train ##
    # print('Post-Training for {} epochs.'.format(args.post_epochs))
    post_result = train_eval_loop(
        model, loss, optimizer, scheduler, train_loader, test_loader, device,
        args.post_epochs, args.verbose, args.compute_path_kernel,
        args.track_weight_movement, save_batch_output_path, True,
        save_init_path_kernel_output_path, row_name, args.compute_init_outputs,
        args.compute_init_grads)

    if (args.save_result):
        save_result_path = args.save_pruned_path + "/%s/%s/%s" % (
            args.model_class,
            args.model,
            args.pruner,
        )
        if not os.path.exists(save_pruned_path):
            os.makedirs(save_result_path)

        print(f"Saving results to {save_result_path}")
        post_result.to_csv(save_result_path + "/%s" %
                           (args.dataset + "_" + str(args.seed) + "_" +
                            str(args.compression) + ".csv"))

    ## Display Results ##
    frames = [pre_result.head(1), post_result.head(1), post_result.tail(1)]
    train_result = pd.concat(frames, keys=['Init.', 'Post-Prune', "Final"])
    prune_result = metrics.summary(
        model, pruner.scores,
        metrics.flop(model, input_shape, device), lambda p: generator.prunable(
            p, args.prune_batchnorm, args.prune_residual))
    total_params = int((prune_result['sparsity'] * prune_result['size']).sum())
    possible_params = prune_result['size'].sum()
    total_flops = int((prune_result['sparsity'] * prune_result['flops']).sum())
    possible_flops = prune_result['flops'].sum()
    print("Train results:\n", train_result)
    print("Prune results:\n", prune_result)
    print("Parameter Sparsity: {}/{} ({:.4f})".format(
        total_params, possible_params, total_params / possible_params))
    print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops,
                                                 total_flops / possible_flops))

    ## Save Results and Model ##
    if args.save:
        print('Saving results.')
        pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))
        post_result.to_pickle("{}/post-train.pkl".format(args.result_dir))
        prune_result.to_pickle("{}/compression.pkl".format(args.result_dir))
        torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
        torch.save(optimizer.state_dict(),
                   "{}/optimizer.pt".format(args.result_dir))
        torch.save(pruner.state_dict(), "{}/pruner.pt".format(args.result_dir))
Exemplo n.º 14
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    args.sparse_lvl = 0.8**args.sparse_iter
    print(args.sparse_lvl)

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    torch.manual_seed(args.seed)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.dataset == 'cifar10':
        # num_classes = 10
        # print('Loading {} dataset.'.format(args.dataset))
        # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                                  std=[0.229, 0.224, 0.225])
        # train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
        #         transforms.RandomHorizontalFlip(),
        #         transforms.RandomCrop(32, 4),
        #         transforms.ToTensor(),
        #         normalize,
        #     ]), download=True)

        # train_loader = torch.utils.data.DataLoader(
        #     train_dataset,
        #     batch_size=args.batch_size, shuffle=True,
        #     num_workers=args.workers, pin_memory=True)

        # val_loader = torch.utils.data.DataLoader(
        #     datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
        #         transforms.ToTensor(),
        #         normalize,
        #     ])),
        #     batch_size=128, shuffle=False,
        #     num_workers=args.workers, pin_memory=True)
        print('Loading {} dataset.'.format(args.dataset))
        input_shape, num_classes = load.dimension(args.dataset)
        train_dataset, train_loader = load.dataloader(args.dataset,
                                                      args.batch_size, True,
                                                      args.workers)
        _, val_loader = load.dataloader(args.dataset, 128, False, args.workers)

    elif args.dataset == 'tiny-imagenet':
        args.batch_size = 256
        args.lr = 0.2
        args.epochs = 200
        print('Loading {} dataset.'.format(args.dataset))
        input_shape, num_classes = load.dimension(args.dataset)
        train_dataset, train_loader = load.dataloader(args.dataset,
                                                      args.batch_size, True,
                                                      args.workers)
        _, val_loader = load.dataloader(args.dataset, 128, False, args.workers)

    elif args.dataset == 'cifar100':
        args.batch_size = 128
        # args.lr = 0.01
        args.epochs = 160
        # args.weight_decay = 5e-4
        input_shape, num_classes = load.dimension(args.dataset)
        train_dataset, train_loader = load.dataloader(args.dataset,
                                                      args.batch_size, True,
                                                      args.workers)
        _, val_loader = load.dataloader(args.dataset, 128, False, args.workers)

    if args.arch == 'resnet20':
        print('Creating {} model.'.format(args.arch))
        model = torch.nn.DataParallel(resnet.__dict__[args.arch](
            ONI=args.ONI, T_iter=args.T_iter))
        model.cuda()
    elif args.arch == 'resnet18':
        print('Creating {} model.'.format(args.arch))
        # Using resnet18 from Synflow
        # model = load.model(args.arch, 'tinyimagenet')(input_shape,
        #                                              num_classes,
        #                                              dense_classifier = True).cuda()
        # Using resnet18 from torchvision
        model = models.resnet18()
        model.fc = nn.Linear(512, num_classes)
        model.cuda()
        utils.kaiming_initialize(model)
    elif args.arch == 'resnet110' or args.arch == 'resnet110full':
        # Using resnet110 from Apollo
        # model = apolo_resnet.ResNet(110, num_classes=num_classes)
        model = load.model(args.arch, 'lottery')(input_shape,
                                                 num_classes,
                                                 dense_classifier=True).cuda()
    elif args.arch in [
            'vgg16full', 'vgg16full-bn', 'vgg11full', 'vgg11full-bn'
    ]:
        if args.dataset == 'tiny-imagenet':
            modeltype = 'tinyimagenet'
        else:
            modeltype = 'lottery'
        # Using resnet110 from Apollo
        # model = apolo_resnet.ResNet(110, num_classes=num_classes)
        model = load.model(args.arch, modeltype)(input_shape,
                                                 num_classes,
                                                 dense_classifier=True).cuda()

    # for layer in model.modules():
    #     if isinstance(layer, nn.Linear):
    #         init.orthogonal_(layer.weight.data)
    #     elif isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d)):
    #         special_init.DeltaOrthogonal_init(layer.weight.data)

    print('Number of parameters of model: {}.'.format(count_parameters(model)))
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.compute_sv:
        print('[*] Will compute singular values throught training.')
        size_hook = utils.get_hook(model,
                                   (nn.Linear, nn.Conv2d, nn.ConvTranspose2d))
        utils.run_once(train_loader, model)
        utils.detach_hook([size_hook])
        training_sv = []
        # training_sv_avg = []
        # training_sv_std = []
        training_svmax = []
        training_sv20 = []  # 50% singular value
        training_sv50 = []  # 50% singular value
        training_sv80 = []  # 80% singular value
        # training_kclip12 = [] # singular values larger than 1e-12
        training_sv50p = []  # 50% non-zero singular value
        training_sv80p = []  # 80% non-zero singular value
        # training_kavg = [] # max condition number/average condition number
        sv, svmax, sv20, sv50, sv80, sv50p, sv80p = utils.get_sv(
            model, size_hook)
        training_sv.append(sv)
        # training_sv_avg.append(sv_avg)
        # training_sv_std.append(sv_std)
        training_svmax.append(svmax)
        training_sv20.append(sv20)
        training_sv50.append(sv50)
        training_sv80.append(sv80)
        # training_kclip12.append(kclip12)
        training_sv50p.append(sv50p)
        training_sv80p.append(sv80p)
        # training_kavg.append(kavg)

    if args.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=args.weight_decay)
    if args.dataset == 'tiny-imagenet':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1)
        # milestones=[30, 60, 80], last_epoch=args.start_epoch - 1)
    elif args.dataset == 'cifar100':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[60, 120],
            gamma=0.2,
            last_epoch=args.start_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[80, 120], last_epoch=args.start_epoch - 1)

    # if args.arch in ['resnet1202', 'resnet110']:
    #     # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
    #     # then switch back. In this setup it will correspond for first epoch.
    #     for param_group in optimizer.param_groups:
    #         param_group['lr'] = args.lr*0.1
    for epoch in range(args.pre_epochs):

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(train_loader, model, criterion, optimizer, epoch)
        lr_scheduler.step()

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename=os.path.join(args.save_dir, 'checkpoint.th'))

        # save_checkpoint({
        #     'state_dict': model.state_dict(),
        #     'best_prec1': best_prec1,
        # }, is_best, filename=os.path.join(args.save_dir, 'model.th'))

        if args.compute_sv and epoch % args.save_every == 0:
            sv, svmax, sv20, sv50, sv80, sv50p, sv80p = utils.get_sv(
                model, size_hook)
            training_sv.append(sv)
            # training_sv_avg.append(sv_avg)
            # training_sv_std.append(sv_std)
            training_svmax.append(svmax)
            training_sv20.append(sv20)
            training_sv50.append(sv50)
            training_sv80.append(sv80)
            # training_kclip12.append(kclip12)
            training_sv50p.append(sv50p)
            training_sv80p.append(sv80p)
            # training_kavg.append(kavg)
            np.save(os.path.join(args.save_dir, 'sv.npy'), training_sv)
            # np.save(os.path.join(args.save_dir, 'sv_avg.npy'), training_sv_avg)
            # np.save(os.path.join(args.save_dir, 'sv_std.npy'), training_sv_std)
            np.save(os.path.join(args.save_dir, 'sv_svmax.npy'),
                    training_svmax)
            np.save(os.path.join(args.save_dir, 'sv_sv20.npy'), training_sv20)
            np.save(os.path.join(args.save_dir, 'sv_sv50.npy'), training_sv50)
            np.save(os.path.join(args.save_dir, 'sv_sv80.npy'), training_sv80)
            # np.save(os.path.join(args.save_dir, 'sv_kclip12.npy'), training_kclip12)
            np.save(os.path.join(args.save_dir, 'sv_sv50p.npy'),
                    training_sv50p)
            np.save(os.path.join(args.save_dir, 'sv_sv80p.npy'),
                    training_sv80p)
            # np.save(os.path.join(args.save_dir, 'sv_kavg.npy'), training_kavg)

    print('[*] {} pre-training epochs done'.format(args.pre_epochs))

    # checkpoint = torch.load('preprune.th')
    # model.load_state_dict(checkpoint['state_dict'])

    if args.prune_method != 'NONE':
        nets = [model]
        # snip_loader = torch.utils.data.DataLoader(
        # train_dataset,
        # batch_size=num_classes, shuffle=False,
        # num_workers=args.workers, pin_memory=True, sampler=sampler.BalancedBatchSampler(train_dataset))

        if args.prune_method == 'SNIP':
            # for layer in model.modules():
            #     snip.add_mask_ones(layer)
            # # svfp.svip_reinit(model)
            # # if args.compute_sv:
            # #     sv, sv_avg, sv_std = utils.get_sv(model, size_hook)
            # #     training_sv.append(sv)
            # #     training_sv_avg.append(sv_avg)
            # #     training_sv_std.append(sv_std)
            # utils.save_sparsity(model, args.save_dir)
            # save_checkpoint({
            #     'state_dict': model.state_dict(),
            #     'best_prec1': 0,
            # }, 0, filename=os.path.join(args.save_dir, 'preprune.th'))
            if args.adv:
                snip.apply_advsnip(args,
                                   nets,
                                   train_loader,
                                   criterion,
                                   num_classes=num_classes)
            else:
                snip.apply_snip(args,
                                nets,
                                train_loader,
                                criterion,
                                num_classes=num_classes)
            # snip.apply_snip(args, nets, snip_loader, criterion)
        elif args.prune_method == 'SNIPRES':
            snipres.apply_snipres(args,
                                  nets,
                                  train_loader,
                                  criterion,
                                  input_shape,
                                  num_classes=num_classes)
        elif args.prune_method == 'TEST':
            # checkpoint = torch.load('preprune.th')
            # model.load_state_dict(checkpoint['state_dict'])
            # svip.apply_svip(args, nets)
            # svfp.apply_svip(args, nets)
            # given_sparsity = np.load('saved_sparsity.npy')
            # svfp.apply_svip_givensparsity(args, nets, given_sparsity)
            num_apply_layer = utils.get_apply_layer(model)
            given_sparsity = np.ones(num_apply_layer, )
            # given_sparsity[-4] = args.s_value
            # given_sparsity[-5] = args.s_value
            for layer in args.layer:
                given_sparsity[layer] = args.s_value
            print(given_sparsity)
            # given_sparsity = np.load(args.s_name+'.npy')
            # snip.apply_rand_prune_givensparsity(nets, given_sparsity)
            ftprune.apply_specprune(nets, given_sparsity)
        elif args.prune_method == 'RAND':
            # utils.save_sparsity(model, args.save_dir)
            # checkpoint = torch.load('preprune.th')
            # model.load_state_dict(checkpoint['state_dict'])
            # snip.apply_rand_prune(nets, args.sparse_lvl)
            # given_sparsity = np.load(args.save_dir+'/saved_sparsity.npy')
            num_apply_layer = utils.get_apply_layer(model)
            given_sparsity = np.ones(num_apply_layer, )
            # given_sparsity[-4] = args.s_value
            # given_sparsity[-5] = args.s_value
            for layer in args.layer:
                given_sparsity[layer] = args.s_value
            print(given_sparsity)
            # given_sparsity = np.load(args.s_name+'.npy')
            snip.apply_rand_prune_givensparsity(nets, given_sparsity)
            # svfp.apply_rand_prune_givensparsity_var(nets, given_sparsity, args.reduce_ratio, args.structured, args)
            # svfp.svip_reinit_givenlayer(nets, args.layer)
        elif args.prune_method == 'Zen':
            zenprune.apply_zenprune(args, nets, train_loader)
            # zenprune.apply_nsprune(args, nets, train_loader, num_classes=num_classes)
            # zenprune.apply_SAP(args, nets, train_loader, criterion, num_classes=num_classes)
        elif args.prune_method == 'OrthoSNIP':
            snip.apply_orthosnip(args,
                                 nets,
                                 train_loader,
                                 criterion,
                                 num_classes=num_classes)
        elif args.prune_method == 'Delta':
            snip.apply_prune_active(nets)
        elif args.prune_method == 'FTP':
            ftprune.apply_fpt(args,
                              nets,
                              train_loader,
                              num_classes=num_classes)
        elif args.prune_method == 'NTK':
            _, ntk_loader = load.dataloader(args.dataset, 10, True,
                                            args.workers)
            if args.adv:
                ntkprune.ntk_prune_adv(args,
                                       nets,
                                       ntk_loader,
                                       num_classes=num_classes)
            else:
                # ntkprune.ntk_prune(args, nets, ntk_loader, num_classes=num_classes)
                ntkprune.ntk_ep_prune(args,
                                      nets,
                                      ntk_loader,
                                      num_classes=num_classes)

        # utils.save_sparsity(model, args.save_dir)
        if args.compute_sv:
            if args.rescale:
                ###################################
                _, svmax, _, _, _, _, _ = utils.get_sv(model, size_hook)
                applied_layer = 0
                for net in nets:
                    for layer in net.modules():
                        if isinstance(
                                layer,
                            (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
                            if given_sparsity[applied_layer] != 1:
                                # add_mask_rand_basedonchannel(layer, sparsity[applied_layer], ratio, True, structured)
                                with torch.no_grad():
                                    layer.weight /= svmax[applied_layer]
                            applied_layer += 1
                ###################################

            sv, svmax, sv20, sv50, sv80, sv50p, sv80p = utils.get_sv(
                model, size_hook)
            training_sv.append(sv)
            # training_sv_avg.append(sv_avg)
            # training_sv_std.append(sv_std)
            training_svmax.append(svmax)
            training_sv20.append(sv20)
            training_sv50.append(sv50)
            training_sv80.append(sv80)
            # training_kclip12.append(kclip12)
            training_sv50p.append(sv50p)
            training_sv80p.append(sv80p)
            # training_kavg.append(kavg)
            np.save(os.path.join(args.save_dir, 'sv.npy'), training_sv)
            # np.save(os.path.join(args.save_dir, 'sv_avg.npy'), training_sv_avg)
            # np.save(os.path.join(args.save_dir, 'sv_std.npy'), training_sv_std)
            np.save(os.path.join(args.save_dir, 'sv_svmax.npy'),
                    training_svmax)
            np.save(os.path.join(args.save_dir, 'sv_sv20.npy'), training_sv20)
            np.save(os.path.join(args.save_dir, 'sv_sv50.npy'), training_sv50)
            np.save(os.path.join(args.save_dir, 'sv_sv80.npy'), training_sv80)
            # np.save(os.path.join(args.save_dir, 'sv_kclip12.npy'), training_kclip12)
            np.save(os.path.join(args.save_dir, 'sv_sv50p.npy'),
                    training_sv50p)
            np.save(os.path.join(args.save_dir, 'sv_sv80p.npy'),
                    training_sv80p)
            # np.save(os.path.join(args.save_dir, 'sv_kavg.npy'), training_kavg)

        print('[*] Sparsity after pruning: ', utils.check_sparsity(model))

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    # reinitialize
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=args.weight_decay)
    if args.dataset == 'tiny-imagenet':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1)
        # milestones=[30, 60, 80], last_epoch=args.start_epoch - 1)
    elif args.dataset == 'cifar100':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[60, 120],
            gamma=0.2,
            last_epoch=args.start_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[80, 120], last_epoch=args.start_epoch - 1)
    for epoch in range(args.start_epoch, args.epochs):
        # for epoch in range(args.pre_epochs, args.epochs):

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              ortho=args.ortho)
        lr_scheduler.step()

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename=os.path.join(args.save_dir, 'checkpoint.th'))
            # if args.prune_method !='NONE':
            #     print(utils.check_layer_sparsity(model))

        # save_checkpoint({
        #     'state_dict': model.state_dict(),
        #     'best_prec1': best_prec1,
        # }, is_best, filename=os.path.join(args.save_dir, 'model.th'))

        if args.compute_sv and epoch % args.save_every == 0:
            sv, svmax, sv20, sv50, sv80, sv50p, sv80p = utils.get_sv(
                model, size_hook)
            training_sv.append(sv)
            # training_sv_avg.append(sv_avg)
            # training_sv_std.append(sv_std)
            training_svmax.append(svmax)
            training_sv20.append(sv20)
            training_sv50.append(sv50)
            training_sv80.append(sv80)
            # training_kclip12.append(kclip12)
            training_sv50p.append(sv50p)
            training_sv80p.append(sv80p)
            # training_kavg.append(kavg)
            np.save(os.path.join(args.save_dir, 'sv.npy'), training_sv)
            # np.save(os.path.join(args.save_dir, 'sv_avg.npy'), training_sv_avg)
            # np.save(os.path.join(args.save_dir, 'sv_std.npy'), training_sv_std)
            np.save(os.path.join(args.save_dir, 'sv_svmax.npy'),
                    training_svmax)
            np.save(os.path.join(args.save_dir, 'sv_sv20.npy'), training_sv20)
            np.save(os.path.join(args.save_dir, 'sv_sv50.npy'), training_sv50)
            np.save(os.path.join(args.save_dir, 'sv_sv80.npy'), training_sv80)
            # np.save(os.path.join(args.save_dir, 'sv_kclip12.npy'), training_kclip12)
            np.save(os.path.join(args.save_dir, 'sv_sv50p.npy'),
                    training_sv50p)
            np.save(os.path.join(args.save_dir, 'sv_sv80p.npy'),
                    training_sv80p)
Exemplo n.º 15
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    ## Model ##
    print('Creating {} model.'.format(args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler,
                                 train_loader, test_loader, device,
                                 args.pre_epochs, args.verbose)
    pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))

    ## Save Original ##
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),
               "{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),
               "{}/scheduler.pt".format(args.result_dir))

    ## Prune and Fine-Tune##
    for compression in args.compression_list:
        for p, p_epochs in zip(args.pruner_list, args.prune_epoch_list):
            print('{} compression ratio, {} pruners'.format(compression, p))

            # Reset Model, Optimizer, and Scheduler
            model.load_state_dict(
                torch.load("{}/model.pt".format(args.result_dir),
                           map_location=device))
            optimizer.load_state_dict(
                torch.load("{}/optimizer.pt".format(args.result_dir),
                           map_location=device))
            scheduler.load_state_dict(
                torch.load("{}/scheduler.pt".format(args.result_dir),
                           map_location=device))

            # Prune Model
            pruner = load.pruner(p)(generator.masked_parameters(
                model, args.prune_bias, args.prune_batchnorm,
                args.prune_residual))
            sparsity = 10**(-float(compression))
            prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                       args.compression_schedule, args.mask_scope, p_epochs,
                       args.reinitialize)

            # Prune Result
            prune_result = metrics.summary(
                model, pruner.scores, metrics.flop(model, input_shape, device),
                lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                             prune_residual))

            # Train Model
            post_result = train_eval_loop(model, loss, optimizer, scheduler,
                                          train_loader, test_loader, device,
                                          args.post_epochs, args.verbose)

            # Save Data
            post_result.to_pickle("{}/post-train-{}-{}-{}.pkl".format(
                args.result_dir, p, str(compression), p_epochs))
            prune_result.to_pickle("{}/compression-{}-{}-{}.pkl".format(
                args.result_dir, p, str(compression), p_epochs))
Exemplo n.º 16
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    loss = nn.CrossEntropyLoss()

    for compression in args.compression_list:
        for level in args.level_list:

            ## Models ##
            #models = []
            rows = []
            for i in range(args.models_num):
                print('Creating {} model ({}/{}).'.format(
                    args.model, str(i + 1), str(args.models_num)))
                model = load.model(args.model, args.model_class)(
                    input_shape, num_classes, args.dense_classifier,
                    args.pretrained).to(device)

                opt_class, opt_kwargs = load.optimizer(args.optimizer)
                optimizer = opt_class(generator.parameters(model),
                                      lr=args.lr,
                                      weight_decay=args.weight_decay,
                                      **opt_kwargs)
                scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer,
                    milestones=args.lr_drops,
                    gamma=args.lr_drop_rate)

                ## Save Original ##
                torch.save(model.state_dict(),
                           "{}/model.pt".format(args.result_dir))
                torch.save(optimizer.state_dict(),
                           "{}/optimizer.pt".format(args.result_dir))
                torch.save(scheduler.state_dict(),
                           "{}/scheduler.pt".format(args.result_dir))

                ## Train-Prune Loop ##
                print('{} compression ratio, {} train-prune levels'.format(
                    compression, level))

                for l in range(level):
                    # Pre Train Model
                    train_eval_loop(model, loss, optimizer, scheduler,
                                    train_loader, test_loader, device,
                                    args.pre_epochs, args.verbose)

                    # Prune Model
                    pruner = load.pruner(args.pruner)(
                        generator.masked_parameters(model, args.prune_bias,
                                                    args.prune_batchnorm,
                                                    args.prune_residual))
                    if args.pruner == 'synflown':
                        pruner.set_input(train_loader, num_classes, device)
                    sparsity = (10**(-float(compression)))**((l + 1) / level)
                    prune_loop(model, loss, pruner, prune_loader, device,
                               sparsity, args.compression_schedule,
                               args.mask_scope, args.prune_epochs, False,
                               args.prune_train_mode)

                    # Reset Model's Weights
                    original_dict = torch.load("{}/model.pt".format(
                        args.result_dir),
                                               map_location=device)
                    original_weights = dict(
                        filter(lambda v: (v[1].requires_grad == True),
                               original_dict.items()))
                    model_dict = model.state_dict()
                    model_dict.update(original_weights)
                    model.load_state_dict(model_dict)

                    # Reset Optimizer and Scheduler
                    optimizer.load_state_dict(
                        torch.load("{}/optimizer.pt".format(args.result_dir),
                                   map_location=device))
                    scheduler.load_state_dict(
                        torch.load("{}/scheduler.pt".format(args.result_dir),
                                   map_location=device))

                # Prune Result
                prune_result = metrics.summary(
                    model, pruner.scores,
                    metrics.flop(model, input_shape, device),
                    lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                                 prune_residual))
                # Train Model
                post_result = train_eval_loop(model, loss, optimizer,
                                              scheduler, train_loader,
                                              test_loader, device,
                                              args.post_epochs, args.verbose)

                # Save Data
                post_result.to_pickle("{}/post-train-{}-{}-{}-{}.pkl".format(
                    args.result_dir, args.pruner, str(compression), str(level),
                    str(i + 1)))
                prune_result.to_pickle("{}/compression-{}-{}-{}-{}.pkl".format(
                    args.result_dir, args.pruner, str(compression), str(level),
                    str(i + 1)))
                #torch.save(model.state_dict(),"{}/model{}.pt".format(args.result_dir, str(i+1)))
                #models.append(model)
                min_vals = post_result.min().values
                max_vals = post_result.max().values
                row = [min_vals[0], min_vals[1], max_vals[2], max_vals[3]]
                rows.append(row)

            # Evaluate the combined model
            columns = [
                'train_loss', 'test_loss', 'top1_accuracy', 'top5_accuracy'
            ]
            df = pd.DataFrame(rows, columns=columns)
            df.to_pickle("{}/stats-{}-{}-{}-{}.pkl".format(
                args.result_dir, args.pruner, str(compression), str(level),
                str(args.models_num)))

            print('Mean values:')
            print(df.mean())

            print('Standard deviations:')
            print(df.std())

    print('Done!')
Exemplo n.º 17
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    input_shape, num_classes = load.dimension(args.dataset) 
    data_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers, args.prune_dataset_ratio * num_classes)

    ## Model, Loss, Optimizer ##
    model = load.model(args.model, args.model_class)(input_shape, 
                                                     num_classes, 
                                                     args.dense_classifier, 
                                                     args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()

    ## Compute Layer Name and Inv Size ##
    def layer_names(model):
        names = []
        inv_size = []
        for name, module in model.named_modules():
            if isinstance(module, (layers.Linear, layers.Conv2d)):
                num_elements = np.prod(module.weight.shape)
                if module.bias is not None:
                    num_elements += np.prod(module.bias.shape)
                names.append(name)
                inv_size.append(1.0/num_elements)
        return names, inv_size

    ## Compute Average Layer Score ##
    def average_layer_score(model, scores):
        average_scores = []
        for name, module in model.named_modules():
            if isinstance(module, (layers.Linear, layers.Conv2d)):
                W = module.weight
                W_score = scores[id(W)].detach().cpu().numpy()
                score_sum = W_score.sum()
                num_elements = np.prod(W.shape)

                if module.bias is not None:
                    b = module.bias
                    b_score = scores[id(b)].detach().cpu().numpy()
                    score_sum += b_score.sum()
                    num_elements += np.prod(b.shape)

                average_scores.append(np.abs(score_sum / num_elements))
        return average_scores


    ## Loop through Pruners and Save Data ##
    names, inv_size = layer_names(model)
    average_scores = []
    unit_scores = []
    for i, p in enumerate(args.pruner_list):
        pruner = load.pruner(p)(generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
        sparsity = 10**(-float(args.compression))
        prune_loop(model, loss, pruner, prune_loader, device, sparsity, 
                   args.compression_schedule, args.mask_scope, args.prune_epochs, args.reinitialize)
        average_score = average_layer_score(model, pruner.scores)
        average_scores.append(average_score)
        np.save('{}/{}'.format(args.result_dir, p), np.array(average_score))
    np.save('{}/{}'.format(args.result_dir,'inv-size'), inv_size)
Exemplo n.º 18
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu, args.seed)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)

    if args.validation:
        trainset = 'train'
        evalset = 'val'
    else:
        trainset = 'trainval'
        evalset = 'test'

    prune_loader = load.dataloader(args.dataset,
                                   args.train_batch_size,
                                   trainset,
                                   args.workers,
                                   corrupt_prob=args.prune_corrupt,
                                   seed=args.split_seed)
    train_loader = load.dataloader(args.dataset,
                                   args.train_batch_size,
                                   trainset,
                                   args.workers,
                                   corrupt_prob=args.train_corrupt,
                                   seed=args.split_seed)
    test_loader = load.dataloader(args.dataset,
                                  args.test_batch_size,
                                  evalset,
                                  args.workers,
                                  seed=args.split_seed)

    ## Model ##
    print('Creating {} model.'.format(args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Save Original ##
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),
               "{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),
               "{}/scheduler.pt".format(args.result_dir))

    ## Train-Prune Loop ##
    sparsity = args.sparsity
    try:
        level = args.level_list[0]
    except IndexError:
        raise ValueError("'--level-list' must have size >= 1.")

    print('{} compression ratio, {} train-prune levels'.format(
        sparsity, level))

    # Reset Model, Optimizer, and Scheduler
    model.load_state_dict(
        torch.load("{}/model.pt".format(args.result_dir), map_location=device))
    optimizer.load_state_dict(
        torch.load("{}/optimizer.pt".format(args.result_dir),
                   map_location=device))
    scheduler.load_state_dict(
        torch.load("{}/scheduler.pt".format(args.result_dir),
                   map_location=device))

    for l in range(level):

        # Pre Train Model
        if args.rewind_epochs > 0:
            level_train_result = train_eval_loop_midsave(
                model, loss, optimizer, scheduler, prune_loader, test_loader,
                device, args.pre_epochs, args.verbose, args.rewind_epochs)
        else:
            level_train_result = train_eval_loop(model, loss, optimizer,
                                                 scheduler, prune_loader,
                                                 test_loader, device,
                                                 args.pre_epochs, args.verbose)

        level_train_result.to_pickle(
            f'{args.result_dir}/train-level-{l}-metrics.pkl')

        torch.save(model.state_dict(),
                   "{}/train-level-{}.pt".format(args.result_dir, l))

        # Prune Model
        pruner = load.pruner(args.pruner)(generator.masked_parameters(
            model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
        #sparsity = (10**(-float(compression)))**((l + 1) / level)
        prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                   args.compression_schedule, args.mask_scope,
                   args.prune_epochs, args.reinitialize, args.prune_train_mode,
                   args.shuffle, args.invert)

        # Prune Result
        prune_result = metrics.summary(
            model, pruner.scores, metrics.flop(model, input_shape, device),
            lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                         prune_residual))
        prune_result.to_pickle("{}/sparsity-{}-{}-{}.pkl".format(
            args.result_dir, args.pruner, str(sparsity), str(l + 1)))

        # Reset Model's Weights
        if args.rewind_epochs > 0:
            original_dict = torch.load("model_pretrain_midway.pt",
                                       map_location=device)
        else:
            original_dict = torch.load("{}/model.pt".format(args.result_dir),
                                       map_location=device)
        original_weights = dict(
            filter(lambda v: 'mask' not in v[0], original_dict.items()))
        model_dict = model.state_dict()
        model_dict.update(original_weights)
        model.load_state_dict(model_dict)

        # Reset Optimizer and Scheduler
        optimizer.load_state_dict(
            torch.load("{}/optimizer.pt".format(args.result_dir),
                       map_location=device))
        scheduler.load_state_dict(
            torch.load("{}/scheduler.pt".format(args.result_dir),
                       map_location=device))

    torch.save(model.state_dict(),
               "{}/post-prune-model.pt".format(args.result_dir))

    ## Compute Path Count ##
    # print("Number of paths", get_path_count(mdl=model, arch=args.model))
    print("Number of active filters",
          number_active_filters(mdl=model, arch=args.model))

    # Train Model
    post_result = train_eval_loop(model, loss, optimizer, scheduler,
                                  train_loader, test_loader, device,
                                  args.post_epochs, args.verbose)

    # Save Data
    post_result.to_pickle("{}/post-train-{}-{}-{}.pkl".format(
        args.result_dir, args.pruner, str(sparsity), str(level)))

    # Save final model
    torch.save(model.state_dict(), f'{args.result_dir}/post-final-train.pt')
    torch.save(optimizer.state_dict(), f'{args.result_dir}/optimizer.pt')
    torch.save(scheduler.state_dict(), f'{args.result_dir}/scheduler.pt')
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    input_shape, num_classes = load.dimension(args.dataset)
    data_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                  args.workers,
                                  args.prune_dataset_ratio * num_classes)

    ## Model, Loss, Optimizer ##
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))

    def score(parameters, model, loss, dataloader, device):
        @torch.no_grad()
        def linearize(model):
            signs = {}
            for name, param in model.state_dict().items():
                signs[name] = torch.sign(param)
                param.abs_()
            return signs

        @torch.no_grad()
        def nonlinearize(model, signs):
            for name, param in model.state_dict().items():
                param.mul_(signs[name])

        signs = linearize(model)
        (data, _) = next(iter(dataloader))
        input_dim = list(data[0, :].shape)
        input = torch.ones([1] + input_dim).to(device)
        output = model(input)
        maxflow = torch.sum(output)
        maxflow.backward()
        scores = {}
        for _, p in parameters:
            scores[id(p)] = torch.clone(p.grad * p).detach().abs_()
            p.grad.data.zero_()
        nonlinearize(model, signs)

        return scores, maxflow.item()

    def mask(parameters, scores, sparsity):
        global_scores = torch.cat([torch.flatten(v) for v in scores.values()])
        k = int((1.0 - sparsity) * global_scores.numel())
        cutsize = 0
        if not k < 1:
            cutsize = torch.sum(
                torch.topk(global_scores, k, largest=False).values).item()
            threshold, _ = torch.kthvalue(global_scores, k)
            for mask, param in parameters:
                score = scores[id(param)]
                zero = torch.tensor([0.]).to(mask.device)
                one = torch.tensor([1.]).to(mask.device)
                mask.copy_(torch.where(score <= threshold, zero, one))
        return cutsize

    @torch.no_grad()
    def apply_mask(parameters):
        for mask, param in parameters:
            param.mul_(mask)

    results = []
    for style in ['linear', 'exponential']:
        print(style)
        sparsity_ratios = []
        for i, exp in enumerate(args.compression_list):
            max_ratios = []
            for j, epochs in enumerate(args.prune_epoch_list):
                model.load_state_dict(
                    torch.load("{}/model.pt".format(args.result_dir),
                               map_location=device))
                parameters = list(
                    generator.masked_parameters(model, args.prune_bias,
                                                args.prune_batchnorm,
                                                args.prune_residual))
                model.eval()
                ratios = []
                for epoch in tqdm(range(epochs)):
                    apply_mask(parameters)
                    scores, maxflow = score(parameters, model, loss,
                                            data_loader, device)
                    sparsity = 10**(-float(exp))
                    if style == 'linear':
                        sparse = 1.0 - (1.0 - sparsity) * (
                            (epoch + 1) / epochs)
                    if style == 'exponential':
                        sparse = sparsity**((epoch + 1) / epochs)
                    cutsize = mask(parameters, scores, sparse)
                    ratios.append(cutsize / maxflow)
                max_ratios.append(max(ratios))
            sparsity_ratios.append(max_ratios)
        results.append(sparsity_ratios)
    np.save('{}/ratios'.format(args.result_dir), np.array(results))
Exemplo n.º 20
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    args.sparse_lvl = 0.8 ** args.sparse_iter
    print(args.sparse_lvl)

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    
    torch.manual_seed(args.seed)

    cudnn.benchmark = True

    if args.dataset =='cifar10':
        print('Loading {} dataset.'.format(args.dataset))
        input_shape, num_classes = load.dimension(args.dataset) 
        train_dataset, train_loader = load.dataloader(args.dataset, args.batch_size, True, args.workers)
        _, val_loader = load.dataloader(args.dataset, 128, False, args.workers)

    elif args.dataset == 'tiny-imagenet':
        args.batch_size = 256
        args.lr = 0.2
        args.epochs = 200
        print('Loading {} dataset.'.format(args.dataset))
        input_shape, num_classes = load.dimension(args.dataset) 
        train_dataset, train_loader = load.dataloader(args.dataset, args.batch_size, True, args.workers)
        _, val_loader = load.dataloader(args.dataset, 128, False, args.workers)

    elif args.dataset == 'cifar100':
        args.batch_size = 128
        # args.lr = 0.01
        args.epochs = 160
        # args.weight_decay = 5e-4
        input_shape, num_classes = load.dimension(args.dataset) 
        train_dataset, train_loader = load.dataloader(args.dataset, args.batch_size, True, args.workers)
        _, val_loader = load.dataloader(args.dataset, 128, False, args.workers)

    if args.arch == 'resnet20':
        print('Creating {} model.'.format(args.arch))
        # model = torch.nn.DataParallel(resnet.__dict__[args.arch](ONI=args.ONI, T_iter=args.T_iter))
        model = resnet.__dict__[args.arch](ONI=args.ONI, T_iter=args.T_iter)
        model.cuda()
    elif args.arch == 'resnet18':
        print('Creating {} model.'.format(args.arch))
        # Using resnet18 from Synflow
        # model = load.model(args.arch, 'tinyimagenet')(input_shape, 
        #                                              num_classes,
        #                                              dense_classifier = True).cuda()
        # Using resnet18 from torchvision
        model = models.resnet18()
        model.fc = nn.Linear(512, num_classes)
        model.cuda()
        utils.kaiming_initialize(model)
    elif args.arch == 'resnet110' or args.arch == 'resnet110full':
        # Using resnet110 from Apollo
        # model = apolo_resnet.ResNet(110, num_classes=num_classes)
        model = load.model(args.arch, 'lottery')(input_shape, 
                                             num_classes,
                                             dense_classifier = True).cuda()
    elif args.arch in ['vgg16full', 'vgg16full-bn', 'vgg11full', 'vgg11full-bn'] :
        if args.dataset == 'tiny-imagenet':
            modeltype = 'tinyimagenet'
        else:
            modeltype = 'lottery'
        # Using resnet110 from Apollo
        # model = apolo_resnet.ResNet(110, num_classes=num_classes)
        model = load.model(args.arch, modeltype)(input_shape, 
                                             num_classes,
                                             dense_classifier = True).cuda()
    
    # for layer in model.modules():
    #     if isinstance(layer, nn.Linear):
    #         init.orthogonal_(layer.weight.data)
    #     elif isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d)):
    #         special_init.DeltaOrthogonal_init(layer.weight.data)

    print('Number of parameters of model: {}.'.format(count_parameters(model)))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.compute_sv:
        print('[*] Will compute singular values throught training.')
        size_hook = utils.get_hook(model, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d))
        utils.run_once(train_loader, model)
        utils.detach_hook([size_hook])
        training_sv = []
        training_svmax = []
        training_sv20 = [] # 50% singular value
        training_sv50 = [] # 50% singular value
        training_sv80 = [] # 80% singular value
        training_kclip = [] # singular values larger than 1e-12
        sv, svmax, sv20, sv50, sv80, kclip = utils.get_sv(model, size_hook)
        training_sv.append(sv)
        training_svmax.append(svmax)
        training_sv20.append(sv20)
        training_sv50.append(sv50)
        training_sv80.append(sv80)
        training_kclip.append(kclip)
    
    if args.compute_ntk:
        training_ntk_eig = []
        if num_classes>=32:
            _, ntk_loader = load.dataloader(args.dataset, 32, True, args.workers)
            grasp_fetch = False
        else:
            ntk_loader = train_loader
            grasp_fetch = True
        training_ntk_eig.append(utils.get_ntk_eig(ntk_loader, [model], train_mode = True, num_batch=1, num_classes=num_classes, samples_per_class=1, grasp_fetch=grasp_fetch))
    
    if args.compute_lrs:
        # training_lrs = []
        # lrc_model = utils.Linear_Region_Collector(train_loader, input_size=(args.batch_size,*input_shape), sample_batch=300)
        # lrc_model.reinit(models=[model])
        # lrs = lrc_model.forward_batch_sample()[0]
        # training_lrs.append(lrs)
        # lrc_model.clear_hooks()
        # print('[*] Current number of linear regions:{}'.format(lrs))
        GAP_zen, output_zen = utils.get_zenscore(model, train_loader, args.arch, num_classes)
        print('[*] Before pruning: GAP_zen:{:e}, output_zen:{:e}'.format(GAP_zen,output_zen))
    
    if args.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                nesterov = True,
                                weight_decay=args.weight_decay)

    if args.dataset ==  'tiny-imagenet':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[100, 150], last_epoch=args.start_epoch - 1)
                                                        # milestones=[30, 60, 80], last_epoch=args.start_epoch - 1)
    elif args.dataset ==  'cifar100':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[60, 120], gamma = 0.2, last_epoch=args.start_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[80, 120], last_epoch=args.start_epoch - 1)

    # This part is for training full NN model to obtain Lottery ticket

    # # First save original network:
    init_path = os.path.join(args.save_dir, 'init_checkpoint.th')
    save_checkpoint({
                'state_dict': model.state_dict()
            }, False, filename=init_path)

    if args.prune_method == 'NONE':
        pre_epochs = args.epochs
    else:
        pre_epochs = 0

    training_loss = []
    for epoch in range(pre_epochs):

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(train_loader, model, criterion, optimizer, epoch, track = training_loss)
        lr_scheduler.step()

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best, filename=os.path.join(args.save_dir, 'densenet_checkpoint.th'))

        if args.compute_sv and epoch % args.save_every == 0:
            sv,  svmax, sv20, sv50, sv80, kclip= utils.get_sv(model, size_hook)
            training_sv.append(sv)
            training_svmax.append(svmax)
            training_sv20.append(sv20)
            training_sv50.append(sv50)
            training_sv80.append(sv80)
            training_kclip.append(kclip)
            np.save(os.path.join(args.save_dir, 'sv.npy'), training_sv)
            np.save(os.path.join(args.save_dir, 'sv_svmax.npy'), training_svmax)
            np.save(os.path.join(args.save_dir, 'sv_sv20.npy'), training_sv20)
            np.save(os.path.join(args.save_dir, 'sv_sv50.npy'), training_sv50)
            np.save(os.path.join(args.save_dir, 'sv_sv80.npy'), training_sv80)
            np.save(os.path.join(args.save_dir, 'sv_kclip.npy'), training_kclip)

        if args.compute_ntk and epoch % args.save_every == 0:
            training_ntk_eig.append(utils.get_ntk_eig(ntk_loader, [model], train_mode = True, num_batch=1, num_classes=num_classes, samples_per_class=1, grasp_fetch=grasp_fetch))
            np.save(os.path.join(args.save_dir, 'ntk_eig.npy'), training_ntk_eig)
        


    print('[*] {} epochs of dense network pre-training done'.format(pre_epochs))
    np.save(os.path.join(args.save_dir, 'trainloss.npy'), training_loss)

    # densenet_checkpoint = torch.load(os.path.join(args.save_dir, 'densenet_checkpoint.th'))
    # model.load_state_dict(densenet_checkpoint['state_dict'])
    # print('Model loaded!')
    # Obtain lottery ticket by magnitude pruning
    if args.prune_method == 'NONE':
        snip.apply_mag_prune(args, model)
        # reinitialize
        init_checkpoint = torch.load(init_path)
        model.load_state_dict(init_checkpoint['state_dict'])
        print('Model reinitialized!')
    elif args.prune_method == 'SNIP':
        init_checkpoint = torch.load(init_path)
        model.load_state_dict(init_checkpoint['state_dict'])
        print('Model reinitialized!')
        snip.apply_snip(args, [model], train_loader, criterion, num_classes=num_classes)
        # attack.shuffle_mask(model)
    elif args.prune_method == 'RAND':
        init_checkpoint = torch.load(init_path)
        model.load_state_dict(init_checkpoint['state_dict'])
        print('Model reinitialized!')
        snip.apply_rand_prune([model], args.sparse_lvl)
    elif args.prune_method == 'GRASP':
        init_checkpoint = torch.load(init_path)
        model.load_state_dict(init_checkpoint['state_dict'])
        print('Model reinitialized!')
        snip.apply_grasp(args, [model], train_loader, criterion, num_classes=num_classes)
    elif args.prune_method == 'Zen':
        zenprune.apply_zenprune(args, [model], train_loader)
        # zenprune.apply_cont_zenprune(args, [model], train_loader)
        # zenprune.apply_zentransfer(args, [model], train_loader)
        # init_checkpoint = torch.load(init_path)
        # model.load_state_dict(init_checkpoint['state_dict'])
        # print('Model reinitialized!')
    elif args.prune_method == 'Mag':
        snip.apply_mag_prune(args, model)
        init_checkpoint = torch.load(init_path)
        model.load_state_dict(init_checkpoint['state_dict'])
        print('Model reinitialized!')
    elif args.prune_method == 'Synflow':
        synflow.apply_synflow(args, model)

    print('{} done, sparsity of the current model: {}.'.format(args.prune_method, utils.check_sparsity(model)))
    
    if args.compute_lrs:
        # training_lrs = []
        # lrc_model = utils.Linear_Region_Collector(train_loader, input_size=(args.batch_size,*input_shape), sample_batch=300)
        # lrc_model.reinit(models=[model])
        # lrs = lrc_model.forward_batch_sample()[0]
        # training_lrs.append(lrs)
        # lrc_model.clear_hooks()
        # print('[*] Current number of linear regions:{}'.format(lrs))
        GAP_zen, output_zen = utils.get_zenscore(model, train_loader, args.arch, num_classes)
        print('[*] After pruning: GAP_zen:{:e}, output_zen:{:e}'.format(GAP_zen,output_zen))

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    # Recreate optimizer and learning scheduler
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                nesterov = True,
                                weight_decay=args.weight_decay)

    if args.dataset ==  'tiny-imagenet':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[100, 150], last_epoch=args.start_epoch - 1)
                                                        # milestones=[30, 60, 80], last_epoch=args.start_epoch - 1)
    elif args.dataset ==  'cifar100':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[60, 120], gamma = 0.2, last_epoch=args.start_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[80, 120], last_epoch=args.start_epoch - 1)

    for epoch in range(args.epochs):
    # for epoch in range(args.pre_epochs, args.epochs):

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(train_loader, model, criterion, optimizer, epoch, track=training_loss, ortho=args.ortho)
        lr_scheduler.step()

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best, filename=os.path.join(args.save_dir, 'sparsenet_checkpoint.th'))
            # if args.prune_method !='NONE':
            #     print(utils.check_layer_sparsity(model))   

        # save_checkpoint({
        #     'state_dict': model.state_dict(),
        #     'best_prec1': best_prec1,
        # }, is_best, filename=os.path.join(args.save_dir, 'model.th'))

        if args.compute_sv and epoch % args.save_every == 0:
            sv,  svmax, sv20, sv50, sv80,  kclip = utils.get_sv(model, size_hook)
            training_sv.append(sv)
            training_svmax.append(svmax)
            training_sv20.append(sv20)
            training_sv50.append(sv50)
            training_sv80.append(sv80)
            training_kclip.append(kclip)
            np.save(os.path.join(args.save_dir, 'sv.npy'), training_sv)
            np.save(os.path.join(args.save_dir, 'sv_svmax.npy'), training_svmax)
            np.save(os.path.join(args.save_dir, 'sv_sv20.npy'), training_sv20)
            np.save(os.path.join(args.save_dir, 'sv_sv50.npy'), training_sv50)
            np.save(os.path.join(args.save_dir, 'sv_sv80.npy'), training_sv80)
            np.save(os.path.join(args.save_dir, 'sv_kclip.npy'), training_kclip)

        if args.compute_ntk and epoch % args.save_every == 0:
            training_ntk_eig.append(utils.get_ntk_eig(ntk_loader, [model], train_mode = True, num_batch=1, num_classes=num_classes, samples_per_class=1, grasp_fetch=grasp_fetch))
            np.save(os.path.join(args.save_dir, 'ntk_eig.npy'), training_ntk_eig)

        # if args.compute_lrs and epoch % args.save_every == 0:
        #     lrc_model.reinit(models=[model])
        #     lrs = lrc_model.forward_batch_sample()[0]
        #     training_lrs.append(lrs)
        #     lrc_model.clear_hooks()
        #     print('[*] Current number of linear regions:{}'.format(lrs))
    np.save(os.path.join(args.save_dir, 'trainloss.npy'), training_loss)