Ejemplo n.º 1
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    input_shape, num_classes = load.dimension(args.dataset) 
    data_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers, args.prune_dataset_ratio * num_classes)

    ## Model, Loss, Optimizer ##
    model = load.model(args.model, args.model_class)(input_shape, 
                                                     num_classes, 
                                                     args.dense_classifier, 
                                                     args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()

    ## Compute Layer Name and Inv Size ##
    def layer_names(model):
        names = []
        inv_size = []
        for name, module in model.named_modules():
            if isinstance(module, (layers.Linear, layers.Conv2d)):
                num_elements = np.prod(module.weight.shape)
                if module.bias is not None:
                    num_elements += np.prod(module.bias.shape)
                names.append(name)
                inv_size.append(1.0/num_elements)
        return names, inv_size

    ## Compute Average Layer Score ##
    def average_layer_score(model, scores):
        average_scores = []
        for name, module in model.named_modules():
            if isinstance(module, (layers.Linear, layers.Conv2d)):
                W = module.weight
                W_score = scores[id(W)].detach().cpu().numpy()
                score_sum = W_score.sum()
                num_elements = np.prod(W.shape)

                if module.bias is not None:
                    b = module.bias
                    b_score = scores[id(b)].detach().cpu().numpy()
                    score_sum += b_score.sum()
                    num_elements += np.prod(b.shape)

                average_scores.append(np.abs(score_sum / num_elements))
        return average_scores


    ## Loop through Pruners and Save Data ##
    names, inv_size = layer_names(model)
    average_scores = []
    unit_scores = []
    for i, p in enumerate(args.pruner_list):
        pruner = load.pruner(p)(generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
        sparsity = 10**(-float(args.compression))
        prune_loop(model, loss, pruner, prune_loader, device, sparsity, 
                   args.compression_schedule, args.mask_scope, args.prune_epochs, args.reinitialize)
        average_score = average_layer_score(model, pruner.scores)
        average_scores.append(average_score)
        np.save('{}/{}'.format(args.result_dir, p), np.array(average_score))
    np.save('{}/{}'.format(args.result_dir,'inv-size'), inv_size)
Ejemplo n.º 2
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    input_shape, num_classes = load.dimension(args.dataset)
    data_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                  args.workers,
                                  args.prune_dataset_ratio * num_classes)

    ## Model, Loss, Optimizer ##
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()

    ## Compute per Neuron Score ##
    def unit_score_sum(model, scores):
        in_scores = []
        out_scores = []
        for name, module in model.named_modules():
            # # Only plot hidden units between convolutions
            # if isinstance(module, layers.Linear):
            #     W = module.weight
            #     b = module.bias

            #     W_score = scores[id(W)].detach().cpu().numpy()
            #     b_score = scores[id(b)].detach().cpu().numpy()

            #     in_scores.append(W_score.sum(axis=1) + b_score)
            #     out_scores.append(W_score.sum(axis=0))
            if isinstance(module, layers.Conv2d):
                W = module.weight
                W_score = scores[id(W)].detach().cpu().numpy()
                in_score = W_score.sum(axis=(1, 2, 3))
                out_score = W_score.sum(axis=(0, 2, 3))

                if module.bias is not None:
                    b = module.bias
                    b_score = scores[id(b)].detach().cpu().numpy()
                    in_score += b_score

                in_scores.append(in_score)
                out_scores.append(out_score)

        in_scores = np.concatenate(in_scores[:-1])
        out_scores = np.concatenate(out_scores[1:])
        return in_scores, out_scores

    ## Loop through Pruners and Save Data ##
    unit_scores = []
    for i, p in enumerate(args.pruner_list):
        pruner = load.pruner(p)(generator.masked_parameters(
            model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
        sparsity = 10**(-float(args.compression))
        prune_loop(model, loss, pruner, data_loader, device, sparsity,
                   args.compression_schedule, args.mask_scope,
                   args.prune_epochs, args.reinitialize)
        unit_score = unit_score_sum(model, pruner.scores)
        unit_scores.append(unit_score)
        np.save('{}/{}'.format(args.result_dir, p), unit_score)
Ejemplo n.º 3
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    ## Model ##
    print('Creating {} model.'.format(args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler,
                                 train_loader, test_loader, device,
                                 args.pre_epochs, args.verbose)
    pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))

    ## Save Original ##
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),
               "{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),
               "{}/scheduler.pt".format(args.result_dir))

    ## Prune and Fine-Tune##
    for compression in args.compression_list:
        for p, p_epochs in zip(args.pruner_list, args.prune_epoch_list):
            print('{} compression ratio, {} pruners'.format(compression, p))

            # Reset Model, Optimizer, and Scheduler
            model.load_state_dict(
                torch.load("{}/model.pt".format(args.result_dir),
                           map_location=device))
            optimizer.load_state_dict(
                torch.load("{}/optimizer.pt".format(args.result_dir),
                           map_location=device))
            scheduler.load_state_dict(
                torch.load("{}/scheduler.pt".format(args.result_dir),
                           map_location=device))

            # Prune Model
            pruner = load.pruner(p)(generator.masked_parameters(
                model, args.prune_bias, args.prune_batchnorm,
                args.prune_residual))
            sparsity = 10**(-float(compression))
            prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                       args.compression_schedule, args.mask_scope, p_epochs,
                       args.reinitialize)

            # Prune Result
            prune_result = metrics.summary(
                model, pruner.scores, metrics.flop(model, input_shape, device),
                lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                             prune_residual))

            # Train Model
            post_result = train_eval_loop(model, loss, optimizer, scheduler,
                                          train_loader, test_loader, device,
                                          args.post_epochs, args.verbose)

            # Save Data
            post_result.to_pickle("{}/post-train-{}-{}-{}.pkl".format(
                args.result_dir, p, str(compression), p_epochs))
            prune_result.to_pickle("{}/compression-{}-{}-{}.pkl".format(
                args.result_dir, p, str(compression), p_epochs))
Ejemplo n.º 4
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset) 
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers, args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True, args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False, args.workers)

    ## Model ##
    print('Creating {} model.'.format(args.model))
    model = load.model(args.model, args.model_class)(input_shape, 
                                                     num_classes, 
                                                     args.dense_classifier,
                                                     args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)

    ## Save Original ##
    torch.save(model.state_dict(),"{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),"{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),"{}/scheduler.pt".format(args.result_dir))

    ## Train-Prune Loop ##
    comp_exponents = np.arange(0, 6, 1) # log 10 of inverse sparsity
    train_iterations = [100] # number of epochs between prune periods
    prune_iterations = np.arange(1, 6, 1) # number of prune periods
    for i, comp_exp in enumerate(comp_exponents[::-1]):
        for j, train_iters in enumerate(train_iterations):
            for k, prune_iters in enumerate(prune_iterations):
                print('{} compression ratio, {} train epochs, {} prune iterations'.format(comp_exp, train_iters, prune_iters))
                
                # Reset Model, Optimizer, and Scheduler
                model.load_state_dict(torch.load("{}/model.pt".format(args.result_dir), map_location=device))
                optimizer.load_state_dict(torch.load("{}/optimizer.pt".format(args.result_dir), map_location=device))
                scheduler.load_state_dict(torch.load("{}/scheduler.pt".format(args.result_dir), map_location=device))
                
                for l in range(prune_iters):

                    # Pre Train Model
                    train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                    test_loader, device, train_iters, args.verbose)

                    # Prune Model
                    pruner = load.pruner('mag')(generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
                    sparsity = (10**(-float(comp_exp)))**((l + 1) / prune_iters)
                    prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                               args.linear_compression_schedule, args.mask_scope, 1, args.reinitialize)

                    # Reset Model's Weights
                    original_dict = torch.load("{}/model.pt".format(args.result_dir), map_location=device)
                    original_weights = dict(filter(lambda v: (v[1].requires_grad == True), model.state_dict().items()))
                    model_dict = model.state_dict()
                    model_dict.update(original_weights)
                    model.load_state_dict(model_dict)
                    
                    # Reset Optimizer and Scheduler
                    optimizer.load_state_dict(torch.load("{}/optimizer.pt".format(args.result_dir), map_location=device))
                    scheduler.load_state_dict(torch.load("{}/scheduler.pt".format(args.result_dir), map_location=device))

                # Prune Result
                prune_result = metrics.summary(model, 
                                               pruner.scores,
                                               metrics.flop(model, input_shape, device),
                                               lambda p: generator.prunable(p, args.prune_batchnorm, args.prune_residual))
                # Train Model
                post_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                              test_loader, device, args.post_epochs, args.verbose)
                
                # Save Data
                prune_result.to_pickle("{}/compression-{}-{}-{}.pkl".format(args.result_dir, str(comp_exp), str(train_iters), str(prune_iters)))
                post_result.to_pickle("{}/performance-{}-{}-{}.pkl".format(args.result_dir, str(comp_exp), str(train_iters), str(prune_iters)))
Ejemplo n.º 5
0
def run(args):
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    log_filename = '{}/{}'.format(args.result_dir, 'result.log')
    fout = open(log_filename, 'w')
    fout.write('start!\n')

    if args.compression_list == []:
        args.compression_list.append(args.compression)
    if args.pruner_list == []:
        args.pruner_list.append(args.pruner)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler,
                                 train_loader, test_loader, device,
                                 args.pre_epochs, args.verbose)

    ## Save Original ##
    torch.save(model.state_dict(),
               "{}/pre_train_model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),
               "{}/pre_train_optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),
               "{}/pre_train_scheduler.pt".format(args.result_dir))

    for compression in args.compression_list:
        for p in args.pruner_list:
            # Reset Model, Optimizer, and Scheduler
            print('compression ratio: {} ::: pruner: {}'.format(
                compression, p))
            model.load_state_dict(
                torch.load("{}/pre_train_model.pt".format(args.result_dir),
                           map_location=device))
            optimizer.load_state_dict(
                torch.load("{}/pre_train_optimizer.pt".format(args.result_dir),
                           map_location=device))
            scheduler.load_state_dict(
                torch.load("{}/pre_train_scheduler.pt".format(args.result_dir),
                           map_location=device))

            ## Prune ##
            print('Pruning with {} for {} epochs.'.format(
                p, args.prune_epochs))
            pruner = load.pruner(p)(generator.masked_parameters(
                model, args.prune_bias, args.prune_batchnorm,
                args.prune_residual))
            sparsity = 10**(-float(compression))
            prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                       args.compression_schedule, args.mask_scope,
                       args.prune_epochs, args.reinitialize,
                       args.prune_train_mode, args.shuffle, args.invert)

            ## Post-Train ##
            print('Post-Training for {} epochs.'.format(args.post_epochs))
            post_result = train_eval_loop(model, loss, optimizer, scheduler,
                                          train_loader, test_loader, device,
                                          args.post_epochs, args.verbose)

            ## Display Results ##
            frames = [
                pre_result.head(1),
                pre_result.tail(1),
                post_result.head(1),
                post_result.tail(1)
            ]
            train_result = pd.concat(
                frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
            prune_result = metrics.summary(
                model, pruner.scores, metrics.flop(model, input_shape, device),
                lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                             prune_residual))
            total_params = int(
                (prune_result['sparsity'] * prune_result['size']).sum())
            possible_params = prune_result['size'].sum()
            total_flops = int(
                (prune_result['sparsity'] * prune_result['flops']).sum())
            possible_flops = prune_result['flops'].sum()
            print("Train results:\n", train_result)
            print("Prune results:\n", prune_result)
            print("Parameter Sparsity: {}/{} ({:.4f})".format(
                total_params, possible_params, total_params / possible_params))
            print("FLOP Sparsity: {}/{} ({:.4f})".format(
                total_flops, possible_flops, total_flops / possible_flops))

            ## recording testing time for task 2 ##
            start_time = timeit.default_timer()
            # evaluating the model, including some data gathering overhead
            # eval(model, loss, test_loader, device, args.verbose)
            model.eval()
            with torch.no_grad():
                for data, target in test_loader:
                    data, target = data.to(device), target.to(device)
                    model(data)
                    break
            end_time = timeit.default_timer()
            print("Testing time: {:.4f}".format(end_time - start_time))

            fout.write('compression ratio: {} ::: pruner: {}'.format(
                compression, p))
            fout.write('Train results:\n {}\n'.format(train_result))
            fout.write('Prune results:\n {}\n'.format(prune_result))
            fout.write('Parameter Sparsity: {}/{} ({:.4f})\n'.format(
                total_params, possible_params, total_params / possible_params))
            fout.write("FLOP Sparsity: {}/{} ({:.4f})\n".format(
                total_flops, possible_flops, total_flops / possible_flops))
            fout.write("Testing time: {}\n".format(end_time - start_time))
            fout.write("remaining weights: \n{}\n".format(
                (prune_result['sparsity'] * prune_result['size'])))
            fout.write('flop each layer: {}\n'.format(
                (prune_result['sparsity'] *
                 prune_result['flops']).values.tolist()))
            ## Save Results and Model ##
            if args.save:
                print('Saving results.')
                if not os.path.exists('{}/{}'.format(args.result_dir,
                                                     compression)):
                    os.makedirs('{}/{}'.format(args.result_dir, compression))
                # pre_result.to_pickle("{}/{}/pre-train.pkl".format(args.result_dir, compression))
                # post_result.to_pickle("{}/{}/post-train.pkl".format(args.result_dir, compression))
                # prune_result.to_pickle("{}/{}/compression.pkl".format(args.result_dir, compression))
                # torch.save(model.state_dict(), "{}/{}/model.pt".format(args.result_dir, compression))
                # torch.save(optimizer.state_dict(),
                #         "{}/{}/optimizer.pt".format(args.result_dir, compression))
                # torch.save(scheduler.state_dict(),
                #         "{}/{}/scheduler.pt".format(args.result_dir, compression))

    fout.close()
Ejemplo n.º 6
0
def run(args):
    print(args)
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset,
                                   args.prune_batch_size,
                                   True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes,
                                   prune_loader=True)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    # print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler,
                                 train_loader, test_loader, device, 0,
                                 args.verbose)

    ## Prune ##
    print('Pruning with {} for {} epochs.'.format(args.pruner,
                                                  args.prune_epochs))
    pruner = load.pruner(args.pruner)(generator.masked_parameters(
        model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
    sparsity = 10**(-float(args.compression))
    print("Sparsity: {}".format(sparsity))
    save_pruned_path = args.save_pruned_path + "/%s/%s/%s" % (
        args.model_class,
        args.model,
        args.pruner,
    )
    if (args.save_pruned):
        print("Saving pruned models to: %s" % (save_pruned_path, ))
        if not os.path.exists(save_pruned_path):
            os.makedirs(save_pruned_path)
    prune_loop(model, loss, pruner, prune_loader, device, sparsity,
               args.compression_schedule, args.mask_scope, args.prune_epochs,
               args.reinitialize, args.save_pruned, save_pruned_path)

    save_batch_output_path = args.save_pruned_path + "/%s/%s/%s/output_%s" % (
        args.model_class, args.model, args.pruner,
        (args.dataset + "_" + str(args.seed) + "_" + str(args.compression)))
    save_init_path_kernel_output_path = args.save_pruned_path + "/init-path-kernel-values.csv"
    row_name = f"{args.model}_{args.dataset}_{args.pruner}_{str(args.seed)}_{str(args.compression)}"

    print(save_init_path_kernel_output_path)
    print(row_name)

    if not os.path.exists(save_batch_output_path):
        os.makedirs(save_batch_output_path)
    ## Post-Train ##
    # print('Post-Training for {} epochs.'.format(args.post_epochs))
    post_result = train_eval_loop(
        model, loss, optimizer, scheduler, train_loader, test_loader, device,
        args.post_epochs, args.verbose, args.compute_path_kernel,
        args.track_weight_movement, save_batch_output_path, True,
        save_init_path_kernel_output_path, row_name, args.compute_init_outputs,
        args.compute_init_grads)

    if (args.save_result):
        save_result_path = args.save_pruned_path + "/%s/%s/%s" % (
            args.model_class,
            args.model,
            args.pruner,
        )
        if not os.path.exists(save_pruned_path):
            os.makedirs(save_result_path)

        print(f"Saving results to {save_result_path}")
        post_result.to_csv(save_result_path + "/%s" %
                           (args.dataset + "_" + str(args.seed) + "_" +
                            str(args.compression) + ".csv"))

    ## Display Results ##
    frames = [pre_result.head(1), post_result.head(1), post_result.tail(1)]
    train_result = pd.concat(frames, keys=['Init.', 'Post-Prune', "Final"])
    prune_result = metrics.summary(
        model, pruner.scores,
        metrics.flop(model, input_shape, device), lambda p: generator.prunable(
            p, args.prune_batchnorm, args.prune_residual))
    total_params = int((prune_result['sparsity'] * prune_result['size']).sum())
    possible_params = prune_result['size'].sum()
    total_flops = int((prune_result['sparsity'] * prune_result['flops']).sum())
    possible_flops = prune_result['flops'].sum()
    print("Train results:\n", train_result)
    print("Prune results:\n", prune_result)
    print("Parameter Sparsity: {}/{} ({:.4f})".format(
        total_params, possible_params, total_params / possible_params))
    print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops,
                                                 total_flops / possible_flops))

    ## Save Results and Model ##
    if args.save:
        print('Saving results.')
        pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))
        post_result.to_pickle("{}/post-train.pkl".format(args.result_dir))
        prune_result.to_pickle("{}/compression.pkl".format(args.result_dir))
        torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
        torch.save(optimizer.state_dict(),
                   "{}/optimizer.pt".format(args.result_dir))
        torch.save(pruner.state_dict(), "{}/pruner.pt".format(args.result_dir))
Ejemplo n.º 7
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu, args.seed)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)

    if args.validation:
        trainset = 'train'
        evalset = 'val'
    else:
        trainset = 'trainval'
        evalset = 'test'

    prune_loader = load.dataloader(args.dataset,
                                   args.train_batch_size,
                                   trainset,
                                   args.workers,
                                   corrupt_prob=args.prune_corrupt,
                                   seed=args.split_seed)
    train_loader = load.dataloader(args.dataset,
                                   args.train_batch_size,
                                   trainset,
                                   args.workers,
                                   corrupt_prob=args.train_corrupt,
                                   seed=args.split_seed)
    test_loader = load.dataloader(args.dataset,
                                  args.test_batch_size,
                                  evalset,
                                  args.workers,
                                  seed=args.split_seed)

    ## Model ##
    print('Creating {} model.'.format(args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Save Original ##
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),
               "{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),
               "{}/scheduler.pt".format(args.result_dir))

    ## Train-Prune Loop ##
    sparsity = args.sparsity
    try:
        level = args.level_list[0]
    except IndexError:
        raise ValueError("'--level-list' must have size >= 1.")

    print('{} compression ratio, {} train-prune levels'.format(
        sparsity, level))

    # Reset Model, Optimizer, and Scheduler
    model.load_state_dict(
        torch.load("{}/model.pt".format(args.result_dir), map_location=device))
    optimizer.load_state_dict(
        torch.load("{}/optimizer.pt".format(args.result_dir),
                   map_location=device))
    scheduler.load_state_dict(
        torch.load("{}/scheduler.pt".format(args.result_dir),
                   map_location=device))

    for l in range(level):

        # Pre Train Model
        if args.rewind_epochs > 0:
            level_train_result = train_eval_loop_midsave(
                model, loss, optimizer, scheduler, prune_loader, test_loader,
                device, args.pre_epochs, args.verbose, args.rewind_epochs)
        else:
            level_train_result = train_eval_loop(model, loss, optimizer,
                                                 scheduler, prune_loader,
                                                 test_loader, device,
                                                 args.pre_epochs, args.verbose)

        level_train_result.to_pickle(
            f'{args.result_dir}/train-level-{l}-metrics.pkl')

        torch.save(model.state_dict(),
                   "{}/train-level-{}.pt".format(args.result_dir, l))

        # Prune Model
        pruner = load.pruner(args.pruner)(generator.masked_parameters(
            model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
        #sparsity = (10**(-float(compression)))**((l + 1) / level)
        prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                   args.compression_schedule, args.mask_scope,
                   args.prune_epochs, args.reinitialize, args.prune_train_mode,
                   args.shuffle, args.invert)

        # Prune Result
        prune_result = metrics.summary(
            model, pruner.scores, metrics.flop(model, input_shape, device),
            lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                         prune_residual))
        prune_result.to_pickle("{}/sparsity-{}-{}-{}.pkl".format(
            args.result_dir, args.pruner, str(sparsity), str(l + 1)))

        # Reset Model's Weights
        if args.rewind_epochs > 0:
            original_dict = torch.load("model_pretrain_midway.pt",
                                       map_location=device)
        else:
            original_dict = torch.load("{}/model.pt".format(args.result_dir),
                                       map_location=device)
        original_weights = dict(
            filter(lambda v: 'mask' not in v[0], original_dict.items()))
        model_dict = model.state_dict()
        model_dict.update(original_weights)
        model.load_state_dict(model_dict)

        # Reset Optimizer and Scheduler
        optimizer.load_state_dict(
            torch.load("{}/optimizer.pt".format(args.result_dir),
                       map_location=device))
        scheduler.load_state_dict(
            torch.load("{}/scheduler.pt".format(args.result_dir),
                       map_location=device))

    torch.save(model.state_dict(),
               "{}/post-prune-model.pt".format(args.result_dir))

    ## Compute Path Count ##
    # print("Number of paths", get_path_count(mdl=model, arch=args.model))
    print("Number of active filters",
          number_active_filters(mdl=model, arch=args.model))

    # Train Model
    post_result = train_eval_loop(model, loss, optimizer, scheduler,
                                  train_loader, test_loader, device,
                                  args.post_epochs, args.verbose)

    # Save Data
    post_result.to_pickle("{}/post-train-{}-{}-{}.pkl".format(
        args.result_dir, args.pruner, str(sparsity), str(level)))

    # Save final model
    torch.save(model.state_dict(), f'{args.result_dir}/post-final-train.pt')
    torch.save(optimizer.state_dict(), f'{args.result_dir}/optimizer.pt')
    torch.save(scheduler.state_dict(), f'{args.result_dir}/scheduler.pt')
Ejemplo n.º 8
0
def run(args):
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu, args.seed)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)

    if args.validation:
        trainset = 'train'
        evalset = 'val'
    else:
        trainset = 'trainval'
        evalset = 'test'

    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, trainset, args.workers, corrupt_prob=args.prune_corrupt)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, trainset, args.workers, corrupt_prob=args.train_corrupt)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, evalset, args.workers)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))
    model = load.model(args.model, args.model_class)(input_shape, 
                                                     num_classes, 
                                                     args.dense_classifier, 
                                                     args.pretrained).to(device)

    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)
    torch.save(model.state_dict(),"{}/init-model.pt".format(args.result_dir))

    ## Pre-Train ##
    assert(args.pre_epochs == 0)
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                 test_loader, device, args.pre_epochs, args.verbose)

    torch.save(model.state_dict(),"{}/pre-trained.pt".format(args.result_dir))

    ## Load in the model ##
    # maskref_dict = torch.load(args.mask_file, map_location=device)
    # model_dict = model.state_dict()

    # mask_dict = dict(filter(lambda v: 'mask' in v[0], maskref_dict.items()))
    # print("Keys being loaded\n", '\n'.join(mask_dict.keys()))
    # model_dict.update(mask_dict)

    model.load_state_dict(torch.load(args.model_file, map_location=device))

    # sanity check part 1
    dict_init = {}
    for name, param in model.state_dict().items():
        print(name)
        if name.endswith('weight') and 'shortcut' not in name and 'bn' not in name:
            mask = model.state_dict()[name + '_mask']
            dict_init[name] = (param.sum().item(), mask.sum().item(), (param * mask).sum().item())

    ## This uses a pruner but only for the purpose of shuffling weights ##
    assert(args.prune_epochs == 0 and args.weightshuffle)
    if args.pruner in ["synflow", "altsynflow", "synflowmag", "rsfgrad"]:
        model.double() # to address exploding/vanishing gradients in SynFlow for deep models
    print('Pruning with {} for {} epochs.'.format(args.pruner, args.prune_epochs))
    pruner = load.pruner(args.pruner)(generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
    sparsity = args.sparsity # 10**(-float(args.compression))
    prune_loop(model, loss, pruner, prune_loader, device, sparsity, 
               args.compression_schedule, args.mask_scope, args.prune_epochs, args.reinitialize, args.prune_train_mode, args.shuffle, args.invert, args.weightshuffle)
    if args.pruner in ["synflow", "altsynflow", "synflowmag", "rsfgrad"]:
        model.float() # to address exploding/vanishing gradients in SynFlow for deep models
    torch.save(model.state_dict(),"{}/post-prune-model.pt".format(args.result_dir))

    # sanity check part 2
    for name, param in model.state_dict().items():
        print(name)
        if name.endswith('weight') and 'shortcut' not in name and 'bn' not in name:
            mask = model.state_dict()[name + '_mask']
            print(name, dict_init[name], param.sum().item(), mask.sum().item(), (param * mask).sum().item())

    ## Compute Path Count ##
    print("Number of paths", get_path_count(mdl=model, arch=args.model))
    print("Number of active filters", number_active_filters(mdl=model, arch=args.model))

    ## Post-Train ##
    print('Post-Training for {} epochs.'.format(args.post_epochs))
    post_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                  test_loader, device, args.post_epochs, args.verbose) 

    ## Display Results ##
    frames = [pre_result.head(1), pre_result.tail(1), post_result.head(1), post_result.tail(1)]
    train_result = pd.concat(frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
    prune_result = metrics.summary(model, 
                                   pruner.scores,
                                   metrics.flop(model, input_shape, device),
                                   lambda p: generator.prunable(p, args.prune_batchnorm, args.prune_residual))
    total_params = int((prune_result['sparsity'] * prune_result['size']).sum())
    possible_params = prune_result['size'].sum()
    total_flops = int((prune_result['sparsity'] * prune_result['flops']).sum())
    possible_flops = prune_result['flops'].sum()
    print("Train results:\n", train_result)
    print("Prune results:\n", prune_result)
    print("Parameter Sparsity: {}/{} ({:.4f})".format(total_params, possible_params, total_params / possible_params))
    print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops, total_flops / possible_flops))

    ## Save Results and Model ##
    if args.save:
        print('Saving results.')
        pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))
        post_result.to_pickle("{}/post-train.pkl".format(args.result_dir))
        plot.get_plots(post_result, save_path="{}/plots.png".format(args.result_dir))
        prune_result.to_pickle("{}/compression.pkl".format(args.result_dir))
        torch.save(model.state_dict(),"{}/post-train-model.pt".format(args.result_dir))
        torch.save(optimizer.state_dict(),"{}/optimizer.pt".format(args.result_dir))
        torch.save(scheduler.state_dict(),"{}/scheduler.pt".format(args.result_dir))
Ejemplo n.º 9
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print("Loading {} dataset.".format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(
        args.dataset,
        args.prune_batch_size,
        True,
        args.workers,
        args.prune_dataset_ratio * num_classes,
    )
    train_loader = load.dataloader(
        args.dataset, args.train_batch_size, True, args.workers
    )
    test_loader = load.dataloader(
        args.dataset, args.test_batch_size, False, args.workers
    )

    ## Model ##
    print("Creating {} model.".format(args.model))
    model = load.model(args.model, args.model_class)(
        input_shape, num_classes, args.dense_classifier, args.pretrained
    ).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(
        generator.parameters(model),
        lr=args.lr,
        weight_decay=args.weight_decay,
        **opt_kwargs
    )
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate
    )

    ## Save Original ##
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(), "{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(), "{}/scheduler.pt".format(args.result_dir))

    ## Train-Prune Loop ##
    for compression in args.compression_list:
        for level in args.level_list:
            print(
                "{} compression ratio, {} train-prune levels".format(compression, level)
            )

            # Reset Model, Optimizer, and Scheduler
            model.load_state_dict(
                torch.load("{}/model.pt".format(args.result_dir), map_location=device)
            )
            optimizer.load_state_dict(
                torch.load(
                    "{}/optimizer.pt".format(args.result_dir), map_location=device
                )
            )
            scheduler.load_state_dict(
                torch.load(
                    "{}/scheduler.pt".format(args.result_dir), map_location=device
                )
            )

            for l in range(level):

                # Pre Train Model
                train_eval_loop(
                    model,
                    loss,
                    optimizer,
                    scheduler,
                    train_loader,
                    test_loader,
                    device,
                    args.pre_epochs,
                    args.verbose,
                )

                # Prune Model
                pruner = load.pruner(args.pruner)(
                    generator.masked_parameters(
                        model,
                        args.prune_bias,
                        args.prune_batchnorm,
                        args.prune_residual,
                    )
                )
                sparsity = (10 ** (-float(compression))) ** ((l + 1) / level)
                prune_loop(
                    model,
                    loss,
                    pruner,
                    prune_loader,
                    device,
                    sparsity,
                    args.compression_schedule,
                    args.mask_scope,
                    args.prune_epochs,
                    args.reinitialize,
                )

                # Reset Model's Weights
                original_dict = torch.load(
                    "{}/model.pt".format(args.result_dir), map_location=device
                )
                original_weights = dict(
                    filter(
                        lambda v: (v[1].requires_grad == True), original_dict.items()
                    )
                )
                model_dict = model.state_dict()
                model_dict.update(original_weights)
                model.load_state_dict(model_dict)

                # Reset Optimizer and Scheduler
                optimizer.load_state_dict(
                    torch.load(
                        "{}/optimizer.pt".format(args.result_dir), map_location=device
                    )
                )
                scheduler.load_state_dict(
                    torch.load(
                        "{}/scheduler.pt".format(args.result_dir), map_location=device
                    )
                )

            # Prune Result
            prune_result = metrics.summary(
                model,
                pruner.scores,
                metrics.flop(model, input_shape, device),
                lambda p: generator.prunable(
                    p, args.prune_batchnorm, args.prune_residual
                ),
            )
            # Train Model
            post_result = train_eval_loop(
                model,
                loss,
                optimizer,
                scheduler,
                train_loader,
                test_loader,
                device,
                args.post_epochs,
                args.verbose,
            )

            # Save Data
            post_result.to_pickle(
                "{}/post-train-{}-{}-{}.pkl".format(
                    args.result_dir, args.pruner, str(compression), str(level)
                )
            )
            prune_result.to_pickle(
                "{}/compression-{}-{}-{}.pkl".format(
                    args.result_dir, args.pruner, str(compression), str(level)
                )
            )
Ejemplo n.º 10
0
    return cutsize

@torch.no_grad()
def apply_mask(parameters):
    for mask, param in parameters:
        param.mul_(mask)

results = []
for style in ['linear', 'exponential']:
    print(style)
    sparsity_ratios = []
    for i, exp in enumerate(exponents):
        max_ratios = []
        for j, epochs in enumerate(iterations):
            model.load_state_dict(torch.load("{}/model.pt".format(directory), map_location=device))
            parameters = list(generator.masked_parameters(model, False, False, False))
            model.eval()
            ratios = []
            for epoch in tqdm(range(epochs)):
                apply_mask(parameters)
                scores, maxflow = score(parameters, model, loss, data_loader, device)
                sparsity = 10**(-float(exp))
                if style == 'linear':
                    sparse = 1.0 - (1.0 - sparsity)*((epoch + 1) / epochs)
                if style == 'exponential':
                    sparse = sparsity**((epoch + 1) / epochs)
                cutsize = mask(parameters, scores, sparse)
                ratios.append(cutsize / maxflow)
            max_ratios.append(max(ratios))
        sparsity_ratios.append(max_ratios)
    results.append(sparsity_ratios)
Ejemplo n.º 11
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    ## Model ##
    print('Creating {} model.'.format(args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape,
                                         num_classes,
                                         args.dense_classifier,
                                         args.pretrained,
                                         L=args.num_layers,
                                         N=args.layer_width).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler,
                                 train_loader, test_loader, device,
                                 args.pre_epochs, args.verbose)

    pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),
               "{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),
               "{}/scheduler.pt".format(args.result_dir))

    ## Prune and Fine-Tune##
    pruners = args.pruner_list
    comp_exponents = args.compression_list
    if args.prune_loop == 'rand':
        prune_loop_fun = rand_prune_loop
    elif args.prune_loop == 'approx':
        prune_loop_fun = approx_prune_loop
    else:
        prune_loop_fun = prune_loop

    for j, exp in enumerate(comp_exponents[::-1]):
        for i, p in enumerate(pruners):
            print(p, exp)

            model.load_state_dict(
                torch.load("{}/model.pt".format(args.result_dir),
                           map_location=device))
            optimizer.load_state_dict(
                torch.load("{}/optimizer.pt".format(args.result_dir),
                           map_location=device))
            scheduler.load_state_dict(
                torch.load("{}/scheduler.pt".format(args.result_dir),
                           map_location=device))

            pruner = load.pruner(p)(generator.masked_parameters(
                model, args.prune_bias, args.prune_batchnorm,
                args.prune_residual))
            sparsity = 10**(-float(exp))

            if p == 'rand_weighted':
                prune_loop_fun(model, loss, pruner, prune_loader, device,
                               sparsity, args.linear_compression_schedule,
                               args.mask_scope, args.prune_epochs, args,
                               args.reinitialize)
            elif p == 'sf' or p == 'synflow' or p == 'snip':
                prune_loop_fun(model, loss, pruner, prune_loader, device,
                               sparsity, args.linear_compression_schedule,
                               args.mask_scope, args.prune_epochs,
                               args.reinitialize)
            else:
                prune_loop_fun(model, loss, pruner, prune_loader, device,
                               sparsity, args.linear_compression_schedule,
                               args.mask_scope, 1, args.reinitialize)

            prune_result = metrics.summary(
                model, pruner.scores, metrics.flop(model, input_shape, device),
                lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                             prune_residual))

            post_result = train_eval_loop(model, loss, optimizer, scheduler,
                                          train_loader, test_loader, device,
                                          args.post_epochs, args.verbose)

            post_result.to_pickle("{}/post-train-{}-{}-{}.pkl".format(
                args.result_dir, p, str(exp), args.prune_epochs))
            prune_result.to_pickle("{}/compression-{}-{}-{}.pkl".format(
                args.result_dir, p, str(exp), args.prune_epochs))
Ejemplo n.º 12
0
            W = module.weight
            W_score = scores[id(W)].detach().cpu().numpy()
            score_sum = W_score.sum()
            num_elements = np.prod(W.shape)

            if module.bias is not None:
                b = module.bias
                b_score = scores[id(b)].detach().cpu().numpy()
                score_sum += b_score.sum()
                num_elements += np.prod(b.shape)

            average_scores.append(np.abs(score_sum / num_elements))
    return average_scores


names, inv_size = layer_names(model)
average_scores = []
unit_scores = []
for i, p in enumerate(pruners):
    pruner = load.pruner(p)(generator.masked_parameters(
        model, True, False, False))
    prune_loop(model, loss, pruner, data_loader, device, 1.0, False, 'global',
               1)
    average_score = average_layer_score(model, pruner.scores)
    average_scores.append(average_score)
    np.save(
        '{}/{}-{}'.format(directory, p,
                          'pretrained' if pretrained else 'initialization'),
        np.array(average_score))
np.save('{}/{}'.format(directory, 'inv-size'), inv_size)
Ejemplo n.º 13
0
def rand_prune_loop(unpruned_model,
                    loss,
                    main_pruner,
                    dataloader,
                    device,
                    sparsity,
                    linear_schedule,
                    scope,
                    epochs,
                    args,
                    reinitialize=False,
                    sample_number=None,
                    epsilon=None,
                    jitter=None):
    r"""Applies score mask loop iteratively to a final sparsity level.
    """
    unpruned_model.eval()
    if sample_number is None:
        sample_number = args.max_samples
    if epsilon is None:
        epsilon = args.epsilon
    if jitter is None:
        jitter = args.jitter

    main_pruner.apply_mask()
    # zero = torch.tensor([0.]).cuda()
    # one = torch.tensor([1.]).cuda()

    last_loss = eval(unpruned_model, loss, dataloader, device, 1,
                     early_stop=5)[0]
    sparsity_graph = [1]
    loss_graph = [last_loss]
    n, N = main_pruner.stats()
    k = ticket_size = sparsity * N

    epoch = -1
    while True:
        epoch += 1
        if linear_schedule:
            # assume final sparsity is ticket size
            if n == k:
                break
            n, _ = main_pruner.stats()
            prune_num = np.log(2 / sample_number) / (np.log(n - k) - np.log(n))
            sparse = (n - prune_num) / N
            sparsity_graph += [sparse]
            if round(sparse * N) == n:
                #parse = (n-1)/N
                break
            n = round(sparse * N)
            #sparse = 1.0 - (1.0 - sparsity)*((epoch + 1) / epochs) # Linear
        else:
            if epoch == epochs:
                break
            sparse = sparsity**((epoch + 1) / epochs)  # Exponential

        num_samples = 0
        best_so_far = []
        best_loss_so_far = float('Inf')
        for _ in (range(sample_number)):
            num_samples += 1
            model = copy.deepcopy(unpruned_model)
            pruner = load.pruner(main_pruner.name)(generator.masked_parameters(
                model, args.prune_bias, args.prune_batchnorm,
                args.prune_residual))
            pruner.apply_mask()
            pruner.jitter = jitter
            pruner.score(model, loss, dataloader, device)
            pruner.mask(sparse, scope)
            remaining_params, total_params = pruner.stats()
            if remaining_params < total_params * sparse - 5:
                continue
            pruner.apply_mask()

            eval_loss = eval(model, loss, dataloader, device, 0,
                             early_stop=5)[0]
            if (eval_loss / last_loss - 1) < epsilon / epochs:
                last_loss = eval_loss
                loss_graph += [last_loss]
                for i, (mask, p) in enumerate(pruner.masked_parameters):
                    main_pruner.masked_parameters[i][0].copy_(mask)
                main_pruner.apply_mask()
                # param_sampled_count[i] = pruner.scores[id(p)]
                break
            if num_samples == sample_number:
                last_loss = best_loss_so_far
                loss_graph += [last_loss]
                for i, mask in enumerate(best_so_far):
                    main_pruner.masked_parameters[i][0].copy_(mask)
                main_pruner.apply_mask()
                break

            if eval_loss < best_loss_so_far:
                best_loss_so_far = eval_loss
                best_so_far = [mask for mask, p in pruner.masked_parameters]

        # mse = 0
        # for i, (mask, p) in enumerate(pruner.masked_parameters):
        #     mse += ((param_sampled_count[i]/(sample_iteration+1) - pruner.scores[id(p)])**2).mean()
        #     # param_sampled_count[i] += pruner.scores[id(p)]

        # total_mse = (sample_iteration/(sample_iteration+1)) * total_mse + 1/(sample_iteration+1)*mse
        remaining_params, total_params = main_pruner.stats()
        print('sparsity={}, E[n]={}, n={}, num_samples={}, loss={}'.format(
            round(sparse, 3), sparse * total_params, remaining_params,
            num_samples, round(last_loss, 7)))

    # for i, (m, p) in enumerate(main_pruner.masked_parameters):
    #     main_pruner.scores[id(p)] = param_sampled_count[i]

    # main_pruner.mask(sparsity, scope)
    # main_pruner.apply_mask()

    if reinitialize:
        model._initialize_weights()
    # Confirm sparsity level
    remaining_params, total_params = main_pruner.stats()
    if np.abs(remaining_params - total_params * sparsity) >= 1:
        print("ERROR: {} prunable parameters remaining, expected {}".format(
            remaining_params, total_params * sparsity))

    plt.plot(loss_graph)
    return loss_graph
Ejemplo n.º 14
0
def run(args):
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print("Loading {} dataset.".format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(
        dataset=args.dataset,
        batch_size=args.prune_batch_size,
        train=True,
        workers=args.workers,
        length=args.prune_dataset_ratio * num_classes,
        datadir=args.data_dir,
    )
    train_loader = load.dataloader(
        dataset=args.dataset,
        batch_size=args.train_batch_size,
        train=True,
        workers=args.workers,
        datadir=args.data_dir,
    )
    test_loader = load.dataloader(
        dataset=args.dataset,
        batch_size=args.test_batch_size,
        train=False,
        workers=args.workers,
        datadir=args.data_dir,
    )

    ## Model, Loss, Optimizer ##
    print("Creating {}-{} model.".format(args.model_class, args.model))
    if args.model in ["fc", "conv"]:
        norm_layer = load.norm_layer(args.norm_layer)
        print(f"Applying {args.norm_layer} normalization: {norm_layer}")
        model = load.model(args.model, args.model_class)(
            input_shape=input_shape,
            num_classes=num_classes,
            dense_classifier=args.dense_classifier,
            pretrained=args.pretrained,
            norm_layer=norm_layer,
        ).to(device)
    else:
        model = load.model(args.model, args.model_class)(
            input_shape=input_shape,
            num_classes=num_classes,
            dense_classifier=args.dense_classifier,
            pretrained=args.pretrained,
        ).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(
        generator.parameters(model),
        lr=args.lr,
        weight_decay=args.weight_decay,
        **opt_kwargs,
    )
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## checkpointing setup ##
    if args.tk_steps_file is not None:
        save_steps = load.save_steps_file(args.tk_steps_file)
        steps_per_epoch = len(train_loader)
        max_epochs = int(save_steps[-1] / steps_per_epoch)
        print(f"Overriding train epochs to last step in file ")
        print(f"    pre_epochs set to 0, post_epochs set to {max_epochs}")
        setattr(args, "pre_epochs", 0)
        setattr(args, "post_epochs", max_epochs)
    else:
        save_steps = None

    ## Pre-Train ##
    print("Pre-Train for {} epochs.".format(args.pre_epochs))
    pre_result = train_eval_loop(
        model,
        loss,
        optimizer,
        scheduler,
        train_loader,
        test_loader,
        device,
        args.pre_epochs,
        args.verbose,
        save_steps=save_steps,
        save_freq=args.save_freq,
        save_path=args.save_path,
    )

    ## Prune ##
    if args.prune_epochs > 0:
        print("Pruning with {} for {} epochs.".format(args.pruner,
                                                      args.prune_epochs))
        pruner = load.pruner(args.pruner)(generator.masked_parameters(
            model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
        sparsity = 10**(-float(args.compression))
        prune_loop(
            model,
            loss,
            pruner,
            prune_loader,
            device,
            sparsity,
            args.compression_schedule,
            args.mask_scope,
            args.prune_epochs,
            args.reinitialize,
        )

    ## Post-Train ##
    print("Post-Training for {} epochs.".format(args.post_epochs))
    post_result = train_eval_loop(
        model,
        loss,
        optimizer,
        scheduler,
        train_loader,
        test_loader,
        device,
        args.post_epochs,
        args.verbose,
        save_steps=save_steps,
        save_freq=args.save_freq,
        save_path=args.save_path,
    )

    ## Display Results ##
    frames = [
        pre_result.head(1),
        pre_result.tail(1),
        post_result.head(1),
        post_result.tail(1),
    ]
    train_result = pd.concat(
        frames, keys=["Init.", "Pre-Prune", "Post-Prune", "Final"])
    print("Train results:\n", train_result)
    if args.prune_epochs > 0:
        prune_result = metrics.summary(
            model,
            pruner.scores,
            metrics.flop(model, input_shape, device),
            lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                         prune_residual),
        )
        total_params = int(
            (prune_result["sparsity"] * prune_result["size"]).sum())
        possible_params = prune_result["size"].sum()
        total_flops = int(
            (prune_result["sparsity"] * prune_result["flops"]).sum())
        possible_flops = prune_result["flops"].sum()

        print("Prune results:\n", prune_result)
        print("Parameter Sparsity: {}/{} ({:.4f})".format(
            total_params, possible_params, total_params / possible_params))
        print("FLOP Sparsity: {}/{} ({:.4f})".format(
            total_flops, possible_flops, total_flops / possible_flops))

    ## Save Results and Model ##
    if args.save:
        print("Saving results.")
        pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))
        post_result.to_pickle("{}/post-train.pkl".format(args.result_dir))
        torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
        torch.save(optimizer.state_dict(),
                   "{}/optimizer.pt".format(args.result_dir))
        if args.prune_epochs > 0:
            prune_result.to_pickle("{}/compression.pkl".format(
                args.result_dir))
            torch.save(pruner.state_dict(),
                       "{}/pruner.pt".format(args.result_dir))
Ejemplo n.º 15
0
def run(args):
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset) 
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers, args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True, args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False, args.workers)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))
    model = load.model(args.model, args.model_class)(input_shape, 
                                                     num_classes, 
                                                     args.dense_classifier, 
                                                     args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)


    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                 test_loader, device, args.pre_epochs, args.verbose)

    ## Prune ##
    print('Pruning with {} for {} epochs.'.format(args.pruner, args.prune_epochs))
    pruner = load.pruner(args.pruner)(generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
    sparsity = 10**(-float(args.compression))
    prune_loop(model, loss, pruner, prune_loader, device, sparsity, 
               args.compression_schedule, args.mask_scope, args.prune_epochs, args.reinitialize, args.prune_train_mode)

    
    ## Post-Train ##
    print('Post-Training for {} epochs.'.format(args.post_epochs))
    post_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                  test_loader, device, args.post_epochs, args.verbose) 

    ## Display Results ##
    frames = [pre_result.head(1), pre_result.tail(1), post_result.head(1), post_result.tail(1)]
    train_result = pd.concat(frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
    prune_result = metrics.summary(model, 
                                   pruner.scores,
                                   metrics.flop(model, input_shape, device),
                                   lambda p: generator.prunable(p, args.prune_batchnorm, args.prune_residual))
    total_params = int((prune_result['sparsity'] * prune_result['size']).sum())
    possible_params = prune_result['size'].sum()
    total_flops = int((prune_result['sparsity'] * prune_result['flops']).sum())
    possible_flops = prune_result['flops'].sum()
    print("Train results:\n", train_result)
    print("Prune results:\n", prune_result)
    print("Parameter Sparsity: {}/{} ({:.4f})".format(total_params, possible_params, total_params / possible_params))
    print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops, total_flops / possible_flops))

    ## Save Results and Model ##
    if args.save:
        print('Saving results.')
        pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))
        post_result.to_pickle("{}/post-train.pkl".format(args.result_dir))
        prune_result.to_pickle("{}/compression.pkl".format(args.result_dir))
        torch.save(model.state_dict(),"{}/model.pt".format(args.result_dir))
        torch.save(optimizer.state_dict(),"{}/optimizer.pt".format(args.result_dir))
        torch.save(scheduler.state_dict(),"{}/scheduler.pt".format(args.result_dir))
Ejemplo n.º 16
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    loss = nn.CrossEntropyLoss()

    for compression in args.compression_list:
        for level in args.level_list:

            ## Models ##
            #models = []
            rows = []
            for i in range(args.models_num):
                print('Creating {} model ({}/{}).'.format(
                    args.model, str(i + 1), str(args.models_num)))
                model = load.model(args.model, args.model_class)(
                    input_shape, num_classes, args.dense_classifier,
                    args.pretrained).to(device)

                opt_class, opt_kwargs = load.optimizer(args.optimizer)
                optimizer = opt_class(generator.parameters(model),
                                      lr=args.lr,
                                      weight_decay=args.weight_decay,
                                      **opt_kwargs)
                scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer,
                    milestones=args.lr_drops,
                    gamma=args.lr_drop_rate)

                ## Save Original ##
                torch.save(model.state_dict(),
                           "{}/model.pt".format(args.result_dir))
                torch.save(optimizer.state_dict(),
                           "{}/optimizer.pt".format(args.result_dir))
                torch.save(scheduler.state_dict(),
                           "{}/scheduler.pt".format(args.result_dir))

                ## Train-Prune Loop ##
                print('{} compression ratio, {} train-prune levels'.format(
                    compression, level))

                for l in range(level):
                    # Pre Train Model
                    train_eval_loop(model, loss, optimizer, scheduler,
                                    train_loader, test_loader, device,
                                    args.pre_epochs, args.verbose)

                    # Prune Model
                    pruner = load.pruner(args.pruner)(
                        generator.masked_parameters(model, args.prune_bias,
                                                    args.prune_batchnorm,
                                                    args.prune_residual))
                    if args.pruner == 'synflown':
                        pruner.set_input(train_loader, num_classes, device)
                    sparsity = (10**(-float(compression)))**((l + 1) / level)
                    prune_loop(model, loss, pruner, prune_loader, device,
                               sparsity, args.compression_schedule,
                               args.mask_scope, args.prune_epochs, False,
                               args.prune_train_mode)

                    # Reset Model's Weights
                    original_dict = torch.load("{}/model.pt".format(
                        args.result_dir),
                                               map_location=device)
                    original_weights = dict(
                        filter(lambda v: (v[1].requires_grad == True),
                               original_dict.items()))
                    model_dict = model.state_dict()
                    model_dict.update(original_weights)
                    model.load_state_dict(model_dict)

                    # Reset Optimizer and Scheduler
                    optimizer.load_state_dict(
                        torch.load("{}/optimizer.pt".format(args.result_dir),
                                   map_location=device))
                    scheduler.load_state_dict(
                        torch.load("{}/scheduler.pt".format(args.result_dir),
                                   map_location=device))

                # Prune Result
                prune_result = metrics.summary(
                    model, pruner.scores,
                    metrics.flop(model, input_shape, device),
                    lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                                 prune_residual))
                # Train Model
                post_result = train_eval_loop(model, loss, optimizer,
                                              scheduler, train_loader,
                                              test_loader, device,
                                              args.post_epochs, args.verbose)

                # Save Data
                post_result.to_pickle("{}/post-train-{}-{}-{}-{}.pkl".format(
                    args.result_dir, args.pruner, str(compression), str(level),
                    str(i + 1)))
                prune_result.to_pickle("{}/compression-{}-{}-{}-{}.pkl".format(
                    args.result_dir, args.pruner, str(compression), str(level),
                    str(i + 1)))
                #torch.save(model.state_dict(),"{}/model{}.pt".format(args.result_dir, str(i+1)))
                #models.append(model)
                min_vals = post_result.min().values
                max_vals = post_result.max().values
                row = [min_vals[0], min_vals[1], max_vals[2], max_vals[3]]
                rows.append(row)

            # Evaluate the combined model
            columns = [
                'train_loss', 'test_loss', 'top1_accuracy', 'top5_accuracy'
            ]
            df = pd.DataFrame(rows, columns=columns)
            df.to_pickle("{}/stats-{}-{}-{}-{}.pkl".format(
                args.result_dir, args.pruner, str(compression), str(level),
                str(args.models_num)))

            print('Mean values:')
            print(df.mean())

            print('Standard deviations:')
            print(df.std())

    print('Done!')
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    input_shape, num_classes = load.dimension(args.dataset)
    data_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                  args.workers,
                                  args.prune_dataset_ratio * num_classes)

    ## Model, Loss, Optimizer ##
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))

    def score(parameters, model, loss, dataloader, device):
        @torch.no_grad()
        def linearize(model):
            signs = {}
            for name, param in model.state_dict().items():
                signs[name] = torch.sign(param)
                param.abs_()
            return signs

        @torch.no_grad()
        def nonlinearize(model, signs):
            for name, param in model.state_dict().items():
                param.mul_(signs[name])

        signs = linearize(model)
        (data, _) = next(iter(dataloader))
        input_dim = list(data[0, :].shape)
        input = torch.ones([1] + input_dim).to(device)
        output = model(input)
        maxflow = torch.sum(output)
        maxflow.backward()
        scores = {}
        for _, p in parameters:
            scores[id(p)] = torch.clone(p.grad * p).detach().abs_()
            p.grad.data.zero_()
        nonlinearize(model, signs)

        return scores, maxflow.item()

    def mask(parameters, scores, sparsity):
        global_scores = torch.cat([torch.flatten(v) for v in scores.values()])
        k = int((1.0 - sparsity) * global_scores.numel())
        cutsize = 0
        if not k < 1:
            cutsize = torch.sum(
                torch.topk(global_scores, k, largest=False).values).item()
            threshold, _ = torch.kthvalue(global_scores, k)
            for mask, param in parameters:
                score = scores[id(param)]
                zero = torch.tensor([0.]).to(mask.device)
                one = torch.tensor([1.]).to(mask.device)
                mask.copy_(torch.where(score <= threshold, zero, one))
        return cutsize

    @torch.no_grad()
    def apply_mask(parameters):
        for mask, param in parameters:
            param.mul_(mask)

    results = []
    for style in ['linear', 'exponential']:
        print(style)
        sparsity_ratios = []
        for i, exp in enumerate(args.compression_list):
            max_ratios = []
            for j, epochs in enumerate(args.prune_epoch_list):
                model.load_state_dict(
                    torch.load("{}/model.pt".format(args.result_dir),
                               map_location=device))
                parameters = list(
                    generator.masked_parameters(model, args.prune_bias,
                                                args.prune_batchnorm,
                                                args.prune_residual))
                model.eval()
                ratios = []
                for epoch in tqdm(range(epochs)):
                    apply_mask(parameters)
                    scores, maxflow = score(parameters, model, loss,
                                            data_loader, device)
                    sparsity = 10**(-float(exp))
                    if style == 'linear':
                        sparse = 1.0 - (1.0 - sparsity) * (
                            (epoch + 1) / epochs)
                    if style == 'exponential':
                        sparse = sparsity**((epoch + 1) / epochs)
                    cutsize = mask(parameters, scores, sparse)
                    ratios.append(cutsize / maxflow)
                max_ratios.append(max(ratios))
            sparsity_ratios.append(max_ratios)
        results.append(sparsity_ratios)
    np.save('{}/ratios'.format(args.result_dir), np.array(results))
Ejemplo n.º 18
0
def run(gpu_id, args):

    ## parameters for multi-processing
    print('using gpu', gpu_id)
    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=args.world_size,
                            rank=gpu_id)

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    # device = load.device(args.gpu)
    device = torch.device(gpu_id)

    args.gpu_id = gpu_id

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)

    ## need to change the workers for loading the data
    args.workers = int((args.workers + 4 - 1) / 4)
    print('Adjusted dataloader worker number is ', args.workers)

    # prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers,
    #                 args.prune_dataset_ratio * num_classes, world_size=args.world_size, rank=gpu_id)
    prune_loader, _ = load.dataloader(args.dataset, args.prune_batch_size,
                                      True, args.workers,
                                      args.prune_dataset_ratio * num_classes)

    ## need to divide the training batch size for each GPU
    args.train_batch_size = int(args.train_batch_size / args.gpu_count)
    train_loader, train_sampler = load.dataloader(args.dataset,
                                                  args.train_batch_size,
                                                  True,
                                                  args.workers,
                                                  args=args)
    # args.test_batch_size = int(args.test_batch_size/args.gpu_count)
    test_loader, _ = load.dataloader(args.dataset, args.test_batch_size, False,
                                     args.workers)

    print("data loader batch size (prune::train::test) is {}::{}::{}".format(
        prune_loader.batch_size, train_loader.batch_size,
        test_loader.batch_size))

    log_filename = '{}/{}'.format(args.result_dir, 'result.log')
    fout = open(log_filename, 'w')
    fout.write('start!\n')

    if args.compression_list == []:
        args.compression_list.append(args.compression)
    if args.pruner_list == []:
        args.pruner_list.append(args.pruner)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))

    # load the pre-defined model from the utils
    model = load.model(args.model, args.model_class)(input_shape, num_classes,
                                                     args.dense_classifier,
                                                     args.pretrained)

    ## wrap model with distributed dataparallel module
    torch.cuda.set_device(gpu_id)
    # model = model.to(device)
    model.cuda(gpu_id)
    model = ddp(model, device_ids=[gpu_id])

    ## don't need to move the loss to the GPU as it contains no parameters
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model,
                                 loss,
                                 optimizer,
                                 scheduler,
                                 train_loader,
                                 test_loader,
                                 device,
                                 args.pre_epochs,
                                 args.verbose,
                                 train_sampler=train_sampler)
    print('Pre-Train finished!')

    ## Save Original ##
    torch.save(model.state_dict(),
               "{}/pre_train_model_{}.pt".format(args.result_dir, gpu_id))
    torch.save(optimizer.state_dict(),
               "{}/pre_train_optimizer_{}.pt".format(args.result_dir, gpu_id))
    torch.save(scheduler.state_dict(),
               "{}/pre_train_scheduler_{}.pt".format(args.result_dir, gpu_id))

    if not args.unpruned:
        for compression in args.compression_list:
            for p in args.pruner_list:
                # Reset Model, Optimizer, and Scheduler
                print('compression ratio: {} ::: pruner: {}'.format(
                    compression, p))
                model.load_state_dict(
                    torch.load("{}/pre_train_model_{}.pt".format(
                        args.result_dir, gpu_id),
                               map_location=device))
                optimizer.load_state_dict(
                    torch.load("{}/pre_train_optimizer_{}.pt".format(
                        args.result_dir, gpu_id),
                               map_location=device))
                scheduler.load_state_dict(
                    torch.load("{}/pre_train_scheduler_{}.pt".format(
                        args.result_dir, gpu_id),
                               map_location=device))

                ## Prune ##
                print('Pruning with {} for {} epochs.'.format(
                    p, args.prune_epochs))
                pruner = load.pruner(p)(generator.masked_parameters(
                    model, args.prune_bias, args.prune_batchnorm,
                    args.prune_residual))
                sparsity = 10**(-float(compression))
                prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                           args.compression_schedule, args.mask_scope,
                           args.prune_epochs, args.reinitialize,
                           args.prune_train_mode, args.shuffle, args.invert)

                ## Post-Train ##
                print('Post-Training for {} epochs.'.format(args.post_epochs))
                post_train_start_time = timeit.default_timer()
                post_result = train_eval_loop(model,
                                              loss,
                                              optimizer,
                                              scheduler,
                                              train_loader,
                                              test_loader,
                                              device,
                                              args.post_epochs,
                                              args.verbose,
                                              train_sampler=train_sampler)
                post_train_end_time = timeit.default_timer()
                print("Post Training time: {:.4f}s".format(
                    post_train_end_time - post_train_start_time))

                ## Display Results ##
                frames = [
                    pre_result.head(1),
                    pre_result.tail(1),
                    post_result.head(1),
                    post_result.tail(1)
                ]
                train_result = pd.concat(
                    frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
                prune_result = metrics.summary(
                    model, pruner.scores,
                    metrics.flop(model, input_shape, device),
                    lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                                 prune_residual))
                total_params = int(
                    (prune_result['sparsity'] * prune_result['size']).sum())
                possible_params = prune_result['size'].sum()
                total_flops = int(
                    (prune_result['sparsity'] * prune_result['flops']).sum())
                possible_flops = prune_result['flops'].sum()
                print("Train results:\n", train_result)
                # print("Prune results:\n", prune_result)
                # print("Parameter Sparsity: {}/{} ({:.4f})".format(total_params, possible_params, total_params / possible_params))
                # print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops, total_flops / possible_flops))

                ## recording testing time for task 2 ##
                # evaluating the model, including some data gathering overhead
                # eval(model, loss, test_loader, device, args.verbose)
                model.eval()
                start_time = timeit.default_timer()
                with torch.no_grad():
                    for data, target in test_loader:
                        data, target = data.to(device), target.to(device)
                        temp_eval_out = model(data)
                end_time = timeit.default_timer()
                print("Testing time: {:.4f}s".format(end_time - start_time))

                fout.write('compression ratio: {} ::: pruner: {}'.format(
                    compression, p))
                fout.write('Train results:\n {}\n'.format(train_result))
                fout.write('Prune results:\n {}\n'.format(prune_result))
                fout.write('Parameter Sparsity: {}/{} ({:.4f})\n'.format(
                    total_params, possible_params,
                    total_params / possible_params))
                fout.write("FLOP Sparsity: {}/{} ({:.4f})\n".format(
                    total_flops, possible_flops, total_flops / possible_flops))
                fout.write("Testing time: {}s\n".format(end_time - start_time))
                fout.write("remaining weights: \n{}\n".format(
                    (prune_result['sparsity'] * prune_result['size'])))
                fout.write('flop each layer: {}\n'.format(
                    (prune_result['sparsity'] *
                     prune_result['flops']).values.tolist()))
                ## Save Results and Model ##
                if args.save:
                    print('Saving results.')
                    if not os.path.exists('{}/{}'.format(
                            args.result_dir, compression)):
                        os.makedirs('{}/{}'.format(args.result_dir,
                                                   compression))
                    # pre_result.to_pickle("{}/{}/pre-train.pkl".format(args.result_dir, compression))
                    # post_result.to_pickle("{}/{}/post-train.pkl".format(args.result_dir, compression))
                    # prune_result.to_pickle("{}/{}/compression.pkl".format(args.result_dir, compression))
                    # torch.save(model.state_dict(), "{}/{}/model.pt".format(args.result_dir, compression))
                    # torch.save(optimizer.state_dict(),
                    #         "{}/{}/optimizer.pt".format(args.result_dir, compression))
                    # torch.save(scheduler.state_dict(),
                    #         "{}/{}/scheduler.pt".format(args.result_dir, compression))

    else:
        print('Staring Unpruned NN training')
        print('Training for {} epochs.'.format(args.post_epochs))
        model.load_state_dict(
            torch.load("{}/pre_train_model.pt".format(args.result_dir),
                       map_location=device))
        optimizer.load_state_dict(
            torch.load("{}/pre_train_optimizer.pt".format(args.result_dir),
                       map_location=device))
        scheduler.load_state_dict(
            torch.load("{}/pre_train_scheduler.pt".format(args.result_dir),
                       map_location=device))

        train_start_time = timeit.default_timer()
        result = train_eval_loop(model,
                                 loss,
                                 optimizer,
                                 scheduler,
                                 train_loader,
                                 test_loader,
                                 device,
                                 args.post_epochs,
                                 args.verbose,
                                 train_sampler=train_sampler)
        train_end_time = timeit.default_timer()
        frames = [result.head(1), result.tail(1)]
        train_result = pd.concat(frames, keys=['Init.', 'Final'])
        print('Train results:\n', train_result)

    fout.close()

    dist.destroy_process_group()