Exemple #1
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, output_shape = load.dimension(args.dataset) 
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True, args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False, args.workers)

    ## Model ##
    print('Creating {} model.'.format(args.model))
    class Args:
        def __init__(self):
            self.n_resgroups = 10
            self.n_resblocks = 20
            self.n_feats = 32
            self.reduction = 16
            self.data_train = 'DIV2K'
            self.rgb_range = 255
            self.n_colors = 3
            self.res_scale = 1
    init_args = Args()
    model = load.model(args.model, args.model_class)(init_args).to(device)
    loss = nn.L1Loss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)

    ## Save Original ##
    torch.save(model.state_dict(),"{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),"{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),"{}/scheduler.pt".format(args.result_dir))

    ## Train-Prune Loop ##
    for compression in args.compression_list:
        for level in args.level_list:
            print('{} compression ratio, {} train-prune levels'.format(compression, level))
            
            # Reset Model, Optimizer, and Scheduler
            model.load_state_dict(torch.load("{}/model.pt".format(args.result_dir), map_location=device))
            optimizer.load_state_dict(torch.load("{}/optimizer.pt".format(args.result_dir), map_location=device))
            scheduler.load_state_dict(torch.load("{}/scheduler.pt".format(args.result_dir), map_location=device))
            
            for l in range(level):

                # Pre Train Model
                train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                test_loader, device, args.pre_epochs, args.verbose)

                # Prune Model
                pruner = load.pruner(args.pruner)(generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
                sparsity = (10**(-float(compression)))**((l + 1) / level)
                prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                           args.compression_schedule, args.mask_scope, args.prune_epochs, args.reinitialize, args.prune_train_mode, args.shuffle, args.invert)

                # Reset Model's Weights
                original_dict = torch.load("{}/model.pt".format(args.result_dir), map_location=device)
                original_weights = dict(filter(lambda v: (v[1].requires_grad == True), original_dict.items()))
                model_dict = model.state_dict()
                model_dict.update(original_weights)
                model.load_state_dict(model_dict)
                
                # Reset Optimizer and Scheduler
                optimizer.load_state_dict(torch.load("{}/optimizer.pt".format(args.result_dir), map_location=device))
                scheduler.load_state_dict(torch.load("{}/scheduler.pt".format(args.result_dir), map_location=device))

            # Prune Result
            prune_result = metrics.summary(model, 
                                           pruner.scores,
                                           metrics.flop(model, input_shape, device),
                                           lambda p: generator.prunable(p, args.prune_batchnorm, args.prune_residual))
            # Train Model
            post_result = train_eval_loop(model, loss, optimizer, scheduler, train_loader, 
                                          test_loader, device, args.post_epochs, args.verbose)
            
            # Save Data
            post_result.to_pickle("{}/post-train-{}-{}-{}.pkl".format(args.result_dir, args.pruner, str(compression),  str(level)))
            prune_result.to_pickle("{}/compression-{}-{}-{}.pkl".format(args.result_dir, args.pruner, str(compression), str(level)))
Exemple #2
0
def run(args):
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, output_shape = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))

    class Args:
        def __init__(self):
            self.n_resgroups = 10
            self.n_resblocks = 20
            self.n_feats = 32
            self.reduction = 16
            self.data_train = 'DIV2K'
            self.rgb_range = 255
            self.n_colors = 3
            self.res_scale = 1

    init_args = Args()
    model = load.model(args.model, args.model_class)(init_args).to(device)

    loss = nn.L1Loss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler,
                                 train_loader, test_loader, device,
                                 args.pre_epochs, args.verbose)

    ## Prune ##
    print('Pruning with {} for {} epochs.'.format(args.pruner,
                                                  args.prune_epochs))
    pruner = load.pruner(args.pruner)(generator.masked_parameters(
        model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
    sparsity = 10**(-float(args.compression))
    prune_loop(model, loss, pruner, prune_loader, device, sparsity,
               args.compression_schedule, args.mask_scope, args.prune_epochs,
               args.reinitialize, args.prune_train_mode, args.shuffle,
               args.invert)

    ## Post-Train ##
    print('Post-Training for {} epochs.'.format(args.post_epochs))
    post_result = train_eval_loop(model, loss, optimizer, scheduler,
                                  train_loader, test_loader, device,
                                  args.post_epochs, args.verbose)

    ## Display Results ##
    frames = [
        pre_result.head(1),
        pre_result.tail(1),
        post_result.head(1),
        post_result.tail(1)
    ]
    train_result = pd.concat(
        frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
    prune_result = metrics.summary(
        model, pruner.scores,
        metrics.flop(model, input_shape, device), lambda p: generator.prunable(
            p, args.prune_batchnorm, args.prune_residual))
    total_params = int((prune_result['sparsity'] * prune_result['size']).sum())
    possible_params = prune_result['size'].sum()
    total_flops = int((prune_result['sparsity'] * prune_result['flops']).sum())
    possible_flops = prune_result['flops'].sum()
    print("Train results:\n", train_result)
    print("Prune results:\n", prune_result)
    print("Parameter Sparsity: {}/{} ({:.4f})".format(
        total_params, possible_params, total_params / possible_params))
    print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops,
                                                 total_flops / possible_flops))

    ## Save Results and Model ##
    if args.save:
        print('Saving results.')
        pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))
        post_result.to_pickle("{}/post-train.pkl".format(args.result_dir))
        prune_result.to_pickle("{}/compression.pkl".format(args.result_dir))
        torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
        torch.save(optimizer.state_dict(),
                   "{}/optimizer.pt".format(args.result_dir))
        torch.save(scheduler.state_dict(),
                   "{}/scheduler.pt".format(args.result_dir))
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    input_shape, num_classes = load.dimension(args.dataset)
    data_loader = load.dataloader(
        args.dataset,
        args.prune_batch_size,
        True,
        args.workers,
        args.prune_dataset_ratio * num_classes,
    )

    ## Model, Loss, Optimizer ##
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))

    def score(parameters, model, loss, dataloader, device):
        @torch.no_grad()
        def linearize(model):
            signs = {}
            for name, param in model.state_dict().items():
                signs[name] = torch.sign(param)
                param.abs_()
            return signs

        @torch.no_grad()
        def nonlinearize(model, signs):
            for name, param in model.state_dict().items():
                param.mul_(signs[name])

        signs = linearize(model)
        (data, _) = next(iter(dataloader))
        input_dim = list(data[0, :].shape)
        input = torch.ones([1] + input_dim).to(device)
        output = model(input)
        maxflow = torch.sum(output)
        maxflow.backward()
        scores = {}
        for _, p in parameters:
            scores[id(p)] = torch.clone(p.grad * p).detach().abs_()
            p.grad.data.zero_()
        nonlinearize(model, signs)

        return scores, maxflow.item()

    def mask(parameters, scores, sparsity):
        global_scores = torch.cat([torch.flatten(v) for v in scores.values()])
        k = int((1.0 - sparsity) * global_scores.numel())
        cutsize = 0
        if not k < 1:
            cutsize = torch.sum(
                torch.topk(global_scores, k, largest=False).values).item()
            threshold, _ = torch.kthvalue(global_scores, k)
            for mask, param in parameters:
                score = scores[id(param)]
                zero = torch.tensor([0.0]).to(mask.device)
                one = torch.tensor([1.0]).to(mask.device)
                mask.copy_(torch.where(score <= threshold, zero, one))
        return cutsize

    @torch.no_grad()
    def apply_mask(parameters):
        for mask, param in parameters:
            param.mul_(mask)

    results = []
    for style in ["linear", "exponential"]:
        print(style)
        sparsity_ratios = []
        for i, exp in enumerate(args.compression_list):
            max_ratios = []
            for j, epochs in enumerate(args.prune_epoch_list):
                model.load_state_dict(
                    torch.load("{}/model.pt".format(args.result_dir),
                               map_location=device))
                parameters = list(
                    generator.masked_parameters(
                        model,
                        args.prune_bias,
                        args.prune_batchnorm,
                        args.prune_residual,
                    ))
                model.eval()
                ratios = []
                for epoch in tqdm(range(epochs)):
                    apply_mask(parameters)
                    scores, maxflow = score(parameters, model, loss,
                                            data_loader, device)
                    sparsity = 10**(-float(exp))
                    if style == "linear":
                        sparse = 1.0 - (1.0 - sparsity) * (
                            (epoch + 1) / epochs)
                    if style == "exponential":
                        sparse = sparsity**((epoch + 1) / epochs)
                    cutsize = mask(parameters, scores, sparse)
                    ratios.append(cutsize / maxflow)
                max_ratios.append(max(ratios))
            sparsity_ratios.append(max_ratios)
        results.append(sparsity_ratios)
    np.save("{}/ratios".format(args.result_dir), np.array(results))
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu, args.seed)

    ## Data ##
    input_shape, num_classes = load.dimension(args.dataset)
    train_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)

    ## Model, Loss, Optimizer ##
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)
    loss = nn.CrossEntropyLoss()

    ## Compute Layer Name and Inv Size ##
    def layer_names(model):
        names = []
        inv_size = []
        for name, module in model.named_modules():
            if isinstance(module, (layers.Linear, layers.Conv2d)):
                num_elements = np.prod(module.weight.shape)
                if module.bias is not None:
                    num_elements += np.prod(module.bias.shape)
                names.append(name)
                inv_size.append(1.0 / num_elements)
        return names, inv_size

    ## Compute Average Mag Score ##
    def average_mag_score(model):
        average_scores = []
        for module in model.modules():
            if isinstance(module, (layers.Linear, layers.Conv2d)):
                W = module.weight.detach().cpu().numpy()
                W_score = W**2
                score_sum = W_score.sum()
                num_elements = np.prod(W.shape)

                if module.bias is not None:
                    b = module.bias.detach().cpu().numpy()
                    b_score = b**2
                    score_sum += b_score.sum()
                    num_elements += np.prod(b.shape)

                average_scores.append(np.abs(score_sum / num_elements))
        return average_scores

    ## Train and Save Data ##
    _, inv_size = layer_names(init_model)
    Wscore = []
    for epoch in tqdm(range(args.post_epochs)):
        Wscore.append(average_score(model))
        train(model, loss, optimizer, train_loader, device, epoch,
              args.verbose)
        scheduler.step()

    np.save('{}/{}'.format(args.result_dir, 'inv-size'), inv_size)
    np.save('{}/score'.format(args.result_dir), np.array(Wscore))
Exemple #5
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    ## Model ##
    print('Creating {} model.'.format(args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Save Original ##
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),
               "{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),
               "{}/scheduler.pt".format(args.result_dir))

    ## Train-Prune Loop ##
    for compression in args.compression_list:
        for level in args.level_list:
            print('{} compression ratio, {} train-prune levels'.format(
                compression, level))

            # Reset Model, Optimizer, and Scheduler
            original_dict = torch.load("{}/model.pt".format(args.result_dir),
                                       map_location=device)
            model.load_state_dict(original_dict)
            optimizer.load_state_dict(
                torch.load("{}/optimizer.pt".format(args.result_dir),
                           map_location=device))
            scheduler.load_state_dict(
                torch.load("{}/scheduler.pt".format(args.result_dir),
                           map_location=device))

            original_dict = torch.load("{}/model.pt".format(args.result_dir),
                                       map_location=device)
            original_weights = dict(
                filter(lambda v: (v[1].requires_grad == True),
                       original_dict.items()))

            for l in range(level):

                # Pre Train Model
                train_eval_loop(model, loss, optimizer, scheduler,
                                train_loader, test_loader, device,
                                args.pre_epochs, args.verbose)

                # Prune Model
                pruner = load.pruner(args.pruner)(generator.masked_parameters(
                    model, args.prune_bias, args.prune_batchnorm,
                    args.prune_residual))
                if args.pruner == 'synflown':
                    pruner.set_input(train_loader, num_classes, device)
                sparsity = 1 - (10**(-float(compression)))**(
                    (l + 1) / level)  # Note the (1 - old_sparsity) here
                prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                           args.compression_schedule, args.mask_scope,
                           args.prune_epochs, False, args.prune_train_mode)

                # Reset Model's Weights
                if args.reinitialize:
                    model_dict = model.state_dict()
                    model_dict.update(original_weights)
                    model.load_state_dict(model_dict)

                # Reinitialize masked parameters
                dummy_model = load.model(args.model, args.model_class)(
                    input_shape, num_classes, args.dense_classifier,
                    False).to(device)
                pruner.reinitialize_masked_parameters(
                    generator.masked_parameters(dummy_model, args.prune_bias,
                                                args.prune_batchnorm,
                                                args.prune_residual))
                pruner.alpha_mask(1.0)
                del dummy_model

                # Save new model's weights again
                if args.reinitialize:
                    torch.save(model.state_dict(),
                               "{}/temp.pt".format(args.result_dir))
                    original_dict = torch.load("{}/temp.pt".format(
                        args.result_dir),
                                               map_location=device)
                    original_weights = dict(
                        filter(lambda v: (v[1].requires_grad == True),
                               original_dict.items()))

                # Reset Optimizer and Scheduler
                optimizer.load_state_dict(
                    torch.load("{}/optimizer.pt".format(args.result_dir),
                               map_location=device))
                scheduler.load_state_dict(
                    torch.load("{}/scheduler.pt".format(args.result_dir),
                               map_location=device))

            # Prune Result
            prune_result = metrics.summary(
                model, pruner.scores, metrics.flop(model, input_shape, device),
                lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                             prune_residual))
            # Train Model
            post_result = train_eval_loop(model, loss, optimizer, scheduler,
                                          train_loader, test_loader, device,
                                          args.post_epochs, args.verbose)

            # Save Data
            post_result.to_pickle("{}/post-train-{}-{}-{}.pkl".format(
                args.result_dir, args.pruner, str(compression), str(level)))
            prune_result.to_pickle("{}/compression-{}-{}-{}.pkl".format(
                args.result_dir, args.pruner, str(compression), str(level)))
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    input_shape, num_classes = load.dimension(args.dataset)
    data_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                  args.workers,
                                  args.prune_dataset_ratio * num_classes)

    ## Model, Loss, Optimizer ##
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()

    ## Compute per Neuron Score ##
    def unit_score_sum(model, scores):
        in_scores = []
        out_scores = []
        for name, module in model.named_modules():
            # # Only plot hidden units between convolutions
            # if isinstance(module, layers.Linear):
            #     W = module.weight
            #     b = module.bias

            #     W_score = scores[id(W)].detach().cpu().numpy()
            #     b_score = scores[id(b)].detach().cpu().numpy()

            #     in_scores.append(W_score.sum(axis=1) + b_score)
            #     out_scores.append(W_score.sum(axis=0))
            if isinstance(module, layers.Conv2d):
                W = module.weight
                W_score = scores[id(W)].detach().cpu().numpy()
                in_score = W_score.sum(axis=(1, 2, 3))
                out_score = W_score.sum(axis=(0, 2, 3))

                if module.bias is not None:
                    b = module.bias
                    b_score = scores[id(b)].detach().cpu().numpy()
                    in_score += b_score

                in_scores.append(in_score)
                out_scores.append(out_score)

        in_scores = np.concatenate(in_scores[:-1])
        out_scores = np.concatenate(out_scores[1:])
        return in_scores, out_scores

    ## Loop through Pruners and Save Data ##
    unit_scores = []
    for i, p in enumerate(args.pruner_list):
        pruner = load.pruner(p)(generator.masked_parameters(
            model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
        sparsity = 10**(-float(args.compression))
        prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                   args.compression_schedule, args.mask_scope,
                   args.prune_epochs, args.reinitialize)
        unit_score = unit_score_sum(model, pruner.scores)
        unit_scores.append(unit_score)
        np.save('{}/{}'.format(args.result_dir, p), unit_score)
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    input_shape, num_classes = load.dimension(args.dataset)
    data_loader = load.dataloader(
        args.dataset,
        args.prune_batch_size,
        True,
        args.workers,
        args.prune_dataset_ratio * num_classes,
    )

    ## Model, Loss, Optimizer ##
    model = load.model(args.model, args.model_class)(
        input_shape, num_classes, args.dense_classifier, args.pretrained
    ).to(device)
    loss = nn.CrossEntropyLoss()

    ## Compute Layer Name and Inv Size ##
    def layer_names(model):
        names = []
        inv_size = []
        for name, module in model.named_modules():
            if isinstance(module, (layers.Linear, layers.Conv2d)):
                num_elements = np.prod(module.weight.shape)
                if module.bias is not None:
                    num_elements += np.prod(module.bias.shape)
                names.append(name)
                inv_size.append(1.0 / num_elements)
        return names, inv_size

    ## Compute Average Layer Score ##
    def average_layer_score(model, scores):
        average_scores = []
        for name, module in model.named_modules():
            if isinstance(module, (layers.Linear, layers.Conv2d)):
                W = module.weight
                W_score = scores[id(W)].detach().cpu().numpy()
                score_sum = W_score.sum()
                num_elements = np.prod(W.shape)

                if module.bias is not None:
                    b = module.bias
                    b_score = scores[id(b)].detach().cpu().numpy()
                    score_sum += b_score.sum()
                    num_elements += np.prod(b.shape)

                average_scores.append(np.abs(score_sum / num_elements))
        return average_scores

    ## Loop through Pruners and Save Data ##
    names, inv_size = layer_names(model)
    average_scores = []
    unit_scores = []
    for i, p in enumerate(args.pruner_list):
        pruner = load.pruner(p)(
            generator.masked_parameters(
                model, args.prune_bias, args.prune_batchnorm, args.prune_residual
            )
        )
        sparsity = 10 ** (-float(args.compression))
        prune_loop(
            model,
            loss,
            pruner,
            prune_loader,
            device,
            sparsity,
            args.compression_schedule,
            args.mask_scope,
            args.prune_epochs,
            args.reinitialize,
        )
        average_score = average_layer_score(model, pruner.scores)
        average_scores.append(average_score)
        np.save("{}/{}".format(args.result_dir, p), np.array(average_score))
    np.save("{}/{}".format(args.result_dir, "inv-size"), inv_size)
Exemple #8
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    args.sparse_lvl = 0.8**args.sparse_iter
    print(args.sparse_lvl)

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    torch.manual_seed(args.seed)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    if args.dataset == 'cifar10':
        # num_classes = 10
        # print('Loading {} dataset.'.format(args.dataset))
        # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
        #                                  std=[0.229, 0.224, 0.225])
        # train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
        #         transforms.RandomHorizontalFlip(),
        #         transforms.RandomCrop(32, 4),
        #         transforms.ToTensor(),
        #         normalize,
        #     ]), download=True)

        # train_loader = torch.utils.data.DataLoader(
        #     train_dataset,
        #     batch_size=args.batch_size, shuffle=True,
        #     num_workers=args.workers, pin_memory=True)

        # val_loader = torch.utils.data.DataLoader(
        #     datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
        #         transforms.ToTensor(),
        #         normalize,
        #     ])),
        #     batch_size=128, shuffle=False,
        #     num_workers=args.workers, pin_memory=True)
        print('Loading {} dataset.'.format(args.dataset))
        input_shape, num_classes = load.dimension(args.dataset)
        train_dataset, train_loader = load.dataloader(args.dataset,
                                                      args.batch_size, True,
                                                      args.workers)
        _, val_loader = load.dataloader(args.dataset, 128, False, args.workers)

    elif args.dataset == 'tiny-imagenet':
        args.batch_size = 256
        args.lr = 0.2
        args.epochs = 200
        print('Loading {} dataset.'.format(args.dataset))
        input_shape, num_classes = load.dimension(args.dataset)
        train_dataset, train_loader = load.dataloader(args.dataset,
                                                      args.batch_size, True,
                                                      args.workers)
        _, val_loader = load.dataloader(args.dataset, 128, False, args.workers)

    elif args.dataset == 'cifar100':
        args.batch_size = 128
        # args.lr = 0.01
        args.epochs = 160
        # args.weight_decay = 5e-4
        input_shape, num_classes = load.dimension(args.dataset)
        train_dataset, train_loader = load.dataloader(args.dataset,
                                                      args.batch_size, True,
                                                      args.workers)
        _, val_loader = load.dataloader(args.dataset, 128, False, args.workers)

    if args.arch == 'resnet20':
        print('Creating {} model.'.format(args.arch))
        model = torch.nn.DataParallel(resnet.__dict__[args.arch](
            ONI=args.ONI, T_iter=args.T_iter))
        model.cuda()
    elif args.arch == 'resnet18':
        print('Creating {} model.'.format(args.arch))
        # Using resnet18 from Synflow
        # model = load.model(args.arch, 'tinyimagenet')(input_shape,
        #                                              num_classes,
        #                                              dense_classifier = True).cuda()
        # Using resnet18 from torchvision
        model = models.resnet18()
        model.fc = nn.Linear(512, num_classes)
        model.cuda()
        utils.kaiming_initialize(model)
    elif args.arch == 'resnet110' or args.arch == 'resnet110full':
        # Using resnet110 from Apollo
        # model = apolo_resnet.ResNet(110, num_classes=num_classes)
        model = load.model(args.arch, 'lottery')(input_shape,
                                                 num_classes,
                                                 dense_classifier=True).cuda()
    elif args.arch in [
            'vgg16full', 'vgg16full-bn', 'vgg11full', 'vgg11full-bn'
    ]:
        if args.dataset == 'tiny-imagenet':
            modeltype = 'tinyimagenet'
        else:
            modeltype = 'lottery'
        # Using resnet110 from Apollo
        # model = apolo_resnet.ResNet(110, num_classes=num_classes)
        model = load.model(args.arch, modeltype)(input_shape,
                                                 num_classes,
                                                 dense_classifier=True).cuda()

    # for layer in model.modules():
    #     if isinstance(layer, nn.Linear):
    #         init.orthogonal_(layer.weight.data)
    #     elif isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d)):
    #         special_init.DeltaOrthogonal_init(layer.weight.data)

    print('Number of parameters of model: {}.'.format(count_parameters(model)))
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.compute_sv:
        print('[*] Will compute singular values throught training.')
        size_hook = utils.get_hook(model,
                                   (nn.Linear, nn.Conv2d, nn.ConvTranspose2d))
        utils.run_once(train_loader, model)
        utils.detach_hook([size_hook])
        training_sv = []
        # training_sv_avg = []
        # training_sv_std = []
        training_svmax = []
        training_sv20 = []  # 50% singular value
        training_sv50 = []  # 50% singular value
        training_sv80 = []  # 80% singular value
        # training_kclip12 = [] # singular values larger than 1e-12
        training_sv50p = []  # 50% non-zero singular value
        training_sv80p = []  # 80% non-zero singular value
        # training_kavg = [] # max condition number/average condition number
        sv, svmax, sv20, sv50, sv80, sv50p, sv80p = utils.get_sv(
            model, size_hook)
        training_sv.append(sv)
        # training_sv_avg.append(sv_avg)
        # training_sv_std.append(sv_std)
        training_svmax.append(svmax)
        training_sv20.append(sv20)
        training_sv50.append(sv50)
        training_sv80.append(sv80)
        # training_kclip12.append(kclip12)
        training_sv50p.append(sv50p)
        training_sv80p.append(sv80p)
        # training_kavg.append(kavg)

    if args.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=args.weight_decay)
    if args.dataset == 'tiny-imagenet':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1)
        # milestones=[30, 60, 80], last_epoch=args.start_epoch - 1)
    elif args.dataset == 'cifar100':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[60, 120],
            gamma=0.2,
            last_epoch=args.start_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[80, 120], last_epoch=args.start_epoch - 1)

    # if args.arch in ['resnet1202', 'resnet110']:
    #     # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
    #     # then switch back. In this setup it will correspond for first epoch.
    #     for param_group in optimizer.param_groups:
    #         param_group['lr'] = args.lr*0.1
    for epoch in range(args.pre_epochs):

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(train_loader, model, criterion, optimizer, epoch)
        lr_scheduler.step()

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename=os.path.join(args.save_dir, 'checkpoint.th'))

        # save_checkpoint({
        #     'state_dict': model.state_dict(),
        #     'best_prec1': best_prec1,
        # }, is_best, filename=os.path.join(args.save_dir, 'model.th'))

        if args.compute_sv and epoch % args.save_every == 0:
            sv, svmax, sv20, sv50, sv80, sv50p, sv80p = utils.get_sv(
                model, size_hook)
            training_sv.append(sv)
            # training_sv_avg.append(sv_avg)
            # training_sv_std.append(sv_std)
            training_svmax.append(svmax)
            training_sv20.append(sv20)
            training_sv50.append(sv50)
            training_sv80.append(sv80)
            # training_kclip12.append(kclip12)
            training_sv50p.append(sv50p)
            training_sv80p.append(sv80p)
            # training_kavg.append(kavg)
            np.save(os.path.join(args.save_dir, 'sv.npy'), training_sv)
            # np.save(os.path.join(args.save_dir, 'sv_avg.npy'), training_sv_avg)
            # np.save(os.path.join(args.save_dir, 'sv_std.npy'), training_sv_std)
            np.save(os.path.join(args.save_dir, 'sv_svmax.npy'),
                    training_svmax)
            np.save(os.path.join(args.save_dir, 'sv_sv20.npy'), training_sv20)
            np.save(os.path.join(args.save_dir, 'sv_sv50.npy'), training_sv50)
            np.save(os.path.join(args.save_dir, 'sv_sv80.npy'), training_sv80)
            # np.save(os.path.join(args.save_dir, 'sv_kclip12.npy'), training_kclip12)
            np.save(os.path.join(args.save_dir, 'sv_sv50p.npy'),
                    training_sv50p)
            np.save(os.path.join(args.save_dir, 'sv_sv80p.npy'),
                    training_sv80p)
            # np.save(os.path.join(args.save_dir, 'sv_kavg.npy'), training_kavg)

    print('[*] {} pre-training epochs done'.format(args.pre_epochs))

    # checkpoint = torch.load('preprune.th')
    # model.load_state_dict(checkpoint['state_dict'])

    if args.prune_method != 'NONE':
        nets = [model]
        # snip_loader = torch.utils.data.DataLoader(
        # train_dataset,
        # batch_size=num_classes, shuffle=False,
        # num_workers=args.workers, pin_memory=True, sampler=sampler.BalancedBatchSampler(train_dataset))

        if args.prune_method == 'SNIP':
            # for layer in model.modules():
            #     snip.add_mask_ones(layer)
            # # svfp.svip_reinit(model)
            # # if args.compute_sv:
            # #     sv, sv_avg, sv_std = utils.get_sv(model, size_hook)
            # #     training_sv.append(sv)
            # #     training_sv_avg.append(sv_avg)
            # #     training_sv_std.append(sv_std)
            # utils.save_sparsity(model, args.save_dir)
            # save_checkpoint({
            #     'state_dict': model.state_dict(),
            #     'best_prec1': 0,
            # }, 0, filename=os.path.join(args.save_dir, 'preprune.th'))
            if args.adv:
                snip.apply_advsnip(args,
                                   nets,
                                   train_loader,
                                   criterion,
                                   num_classes=num_classes)
            else:
                snip.apply_snip(args,
                                nets,
                                train_loader,
                                criterion,
                                num_classes=num_classes)
            # snip.apply_snip(args, nets, snip_loader, criterion)
        elif args.prune_method == 'SNIPRES':
            snipres.apply_snipres(args,
                                  nets,
                                  train_loader,
                                  criterion,
                                  input_shape,
                                  num_classes=num_classes)
        elif args.prune_method == 'TEST':
            # checkpoint = torch.load('preprune.th')
            # model.load_state_dict(checkpoint['state_dict'])
            # svip.apply_svip(args, nets)
            # svfp.apply_svip(args, nets)
            # given_sparsity = np.load('saved_sparsity.npy')
            # svfp.apply_svip_givensparsity(args, nets, given_sparsity)
            num_apply_layer = utils.get_apply_layer(model)
            given_sparsity = np.ones(num_apply_layer, )
            # given_sparsity[-4] = args.s_value
            # given_sparsity[-5] = args.s_value
            for layer in args.layer:
                given_sparsity[layer] = args.s_value
            print(given_sparsity)
            # given_sparsity = np.load(args.s_name+'.npy')
            # snip.apply_rand_prune_givensparsity(nets, given_sparsity)
            ftprune.apply_specprune(nets, given_sparsity)
        elif args.prune_method == 'RAND':
            # utils.save_sparsity(model, args.save_dir)
            # checkpoint = torch.load('preprune.th')
            # model.load_state_dict(checkpoint['state_dict'])
            # snip.apply_rand_prune(nets, args.sparse_lvl)
            # given_sparsity = np.load(args.save_dir+'/saved_sparsity.npy')
            num_apply_layer = utils.get_apply_layer(model)
            given_sparsity = np.ones(num_apply_layer, )
            # given_sparsity[-4] = args.s_value
            # given_sparsity[-5] = args.s_value
            for layer in args.layer:
                given_sparsity[layer] = args.s_value
            print(given_sparsity)
            # given_sparsity = np.load(args.s_name+'.npy')
            snip.apply_rand_prune_givensparsity(nets, given_sparsity)
            # svfp.apply_rand_prune_givensparsity_var(nets, given_sparsity, args.reduce_ratio, args.structured, args)
            # svfp.svip_reinit_givenlayer(nets, args.layer)
        elif args.prune_method == 'Zen':
            zenprune.apply_zenprune(args, nets, train_loader)
            # zenprune.apply_nsprune(args, nets, train_loader, num_classes=num_classes)
            # zenprune.apply_SAP(args, nets, train_loader, criterion, num_classes=num_classes)
        elif args.prune_method == 'OrthoSNIP':
            snip.apply_orthosnip(args,
                                 nets,
                                 train_loader,
                                 criterion,
                                 num_classes=num_classes)
        elif args.prune_method == 'Delta':
            snip.apply_prune_active(nets)
        elif args.prune_method == 'FTP':
            ftprune.apply_fpt(args,
                              nets,
                              train_loader,
                              num_classes=num_classes)
        elif args.prune_method == 'NTK':
            _, ntk_loader = load.dataloader(args.dataset, 10, True,
                                            args.workers)
            if args.adv:
                ntkprune.ntk_prune_adv(args,
                                       nets,
                                       ntk_loader,
                                       num_classes=num_classes)
            else:
                # ntkprune.ntk_prune(args, nets, ntk_loader, num_classes=num_classes)
                ntkprune.ntk_ep_prune(args,
                                      nets,
                                      ntk_loader,
                                      num_classes=num_classes)

        # utils.save_sparsity(model, args.save_dir)
        if args.compute_sv:
            if args.rescale:
                ###################################
                _, svmax, _, _, _, _, _ = utils.get_sv(model, size_hook)
                applied_layer = 0
                for net in nets:
                    for layer in net.modules():
                        if isinstance(
                                layer,
                            (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
                            if given_sparsity[applied_layer] != 1:
                                # add_mask_rand_basedonchannel(layer, sparsity[applied_layer], ratio, True, structured)
                                with torch.no_grad():
                                    layer.weight /= svmax[applied_layer]
                            applied_layer += 1
                ###################################

            sv, svmax, sv20, sv50, sv80, sv50p, sv80p = utils.get_sv(
                model, size_hook)
            training_sv.append(sv)
            # training_sv_avg.append(sv_avg)
            # training_sv_std.append(sv_std)
            training_svmax.append(svmax)
            training_sv20.append(sv20)
            training_sv50.append(sv50)
            training_sv80.append(sv80)
            # training_kclip12.append(kclip12)
            training_sv50p.append(sv50p)
            training_sv80p.append(sv80p)
            # training_kavg.append(kavg)
            np.save(os.path.join(args.save_dir, 'sv.npy'), training_sv)
            # np.save(os.path.join(args.save_dir, 'sv_avg.npy'), training_sv_avg)
            # np.save(os.path.join(args.save_dir, 'sv_std.npy'), training_sv_std)
            np.save(os.path.join(args.save_dir, 'sv_svmax.npy'),
                    training_svmax)
            np.save(os.path.join(args.save_dir, 'sv_sv20.npy'), training_sv20)
            np.save(os.path.join(args.save_dir, 'sv_sv50.npy'), training_sv50)
            np.save(os.path.join(args.save_dir, 'sv_sv80.npy'), training_sv80)
            # np.save(os.path.join(args.save_dir, 'sv_kclip12.npy'), training_kclip12)
            np.save(os.path.join(args.save_dir, 'sv_sv50p.npy'),
                    training_sv50p)
            np.save(os.path.join(args.save_dir, 'sv_sv80p.npy'),
                    training_sv80p)
            # np.save(os.path.join(args.save_dir, 'sv_kavg.npy'), training_kavg)

        print('[*] Sparsity after pruning: ', utils.check_sparsity(model))

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    # reinitialize
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=True,
                                weight_decay=args.weight_decay)
    if args.dataset == 'tiny-imagenet':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1)
        # milestones=[30, 60, 80], last_epoch=args.start_epoch - 1)
    elif args.dataset == 'cifar100':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[60, 120],
            gamma=0.2,
            last_epoch=args.start_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[80, 120], last_epoch=args.start_epoch - 1)
    for epoch in range(args.start_epoch, args.epochs):
        # for epoch in range(args.pre_epochs, args.epochs):

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              ortho=args.ortho)
        lr_scheduler.step()

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                },
                is_best,
                filename=os.path.join(args.save_dir, 'checkpoint.th'))
            # if args.prune_method !='NONE':
            #     print(utils.check_layer_sparsity(model))

        # save_checkpoint({
        #     'state_dict': model.state_dict(),
        #     'best_prec1': best_prec1,
        # }, is_best, filename=os.path.join(args.save_dir, 'model.th'))

        if args.compute_sv and epoch % args.save_every == 0:
            sv, svmax, sv20, sv50, sv80, sv50p, sv80p = utils.get_sv(
                model, size_hook)
            training_sv.append(sv)
            # training_sv_avg.append(sv_avg)
            # training_sv_std.append(sv_std)
            training_svmax.append(svmax)
            training_sv20.append(sv20)
            training_sv50.append(sv50)
            training_sv80.append(sv80)
            # training_kclip12.append(kclip12)
            training_sv50p.append(sv50p)
            training_sv80p.append(sv80p)
            # training_kavg.append(kavg)
            np.save(os.path.join(args.save_dir, 'sv.npy'), training_sv)
            # np.save(os.path.join(args.save_dir, 'sv_avg.npy'), training_sv_avg)
            # np.save(os.path.join(args.save_dir, 'sv_std.npy'), training_sv_std)
            np.save(os.path.join(args.save_dir, 'sv_svmax.npy'),
                    training_svmax)
            np.save(os.path.join(args.save_dir, 'sv_sv20.npy'), training_sv20)
            np.save(os.path.join(args.save_dir, 'sv_sv50.npy'), training_sv50)
            np.save(os.path.join(args.save_dir, 'sv_sv80.npy'), training_sv80)
            # np.save(os.path.join(args.save_dir, 'sv_kclip12.npy'), training_kclip12)
            np.save(os.path.join(args.save_dir, 'sv_sv50p.npy'),
                    training_sv50p)
            np.save(os.path.join(args.save_dir, 'sv_sv80p.npy'),
                    training_sv80p)
Exemple #9
0
def run(args):
    print(args)
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset,
                                   args.prune_batch_size,
                                   True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes,
                                   prune_loader=True)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    # print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler,
                                 train_loader, test_loader, device, 0,
                                 args.verbose)

    ## Prune ##
    print('Pruning with {} for {} epochs.'.format(args.pruner,
                                                  args.prune_epochs))
    pruner = load.pruner(args.pruner)(generator.masked_parameters(
        model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
    sparsity = 10**(-float(args.compression))
    print("Sparsity: {}".format(sparsity))
    save_pruned_path = args.save_pruned_path + "/%s/%s/%s" % (
        args.model_class,
        args.model,
        args.pruner,
    )
    if (args.save_pruned):
        print("Saving pruned models to: %s" % (save_pruned_path, ))
        if not os.path.exists(save_pruned_path):
            os.makedirs(save_pruned_path)
    prune_loop(model, loss, pruner, prune_loader, device, sparsity,
               args.compression_schedule, args.mask_scope, args.prune_epochs,
               args.reinitialize, args.save_pruned, save_pruned_path)

    save_batch_output_path = args.save_pruned_path + "/%s/%s/%s/output_%s" % (
        args.model_class, args.model, args.pruner,
        (args.dataset + "_" + str(args.seed) + "_" + str(args.compression)))
    save_init_path_kernel_output_path = args.save_pruned_path + "/init-path-kernel-values.csv"
    row_name = f"{args.model}_{args.dataset}_{args.pruner}_{str(args.seed)}_{str(args.compression)}"

    print(save_init_path_kernel_output_path)
    print(row_name)

    if not os.path.exists(save_batch_output_path):
        os.makedirs(save_batch_output_path)
    ## Post-Train ##
    # print('Post-Training for {} epochs.'.format(args.post_epochs))
    post_result = train_eval_loop(
        model, loss, optimizer, scheduler, train_loader, test_loader, device,
        args.post_epochs, args.verbose, args.compute_path_kernel,
        args.track_weight_movement, save_batch_output_path, True,
        save_init_path_kernel_output_path, row_name, args.compute_init_outputs,
        args.compute_init_grads)

    if (args.save_result):
        save_result_path = args.save_pruned_path + "/%s/%s/%s" % (
            args.model_class,
            args.model,
            args.pruner,
        )
        if not os.path.exists(save_pruned_path):
            os.makedirs(save_result_path)

        print(f"Saving results to {save_result_path}")
        post_result.to_csv(save_result_path + "/%s" %
                           (args.dataset + "_" + str(args.seed) + "_" +
                            str(args.compression) + ".csv"))

    ## Display Results ##
    frames = [pre_result.head(1), post_result.head(1), post_result.tail(1)]
    train_result = pd.concat(frames, keys=['Init.', 'Post-Prune', "Final"])
    prune_result = metrics.summary(
        model, pruner.scores,
        metrics.flop(model, input_shape, device), lambda p: generator.prunable(
            p, args.prune_batchnorm, args.prune_residual))
    total_params = int((prune_result['sparsity'] * prune_result['size']).sum())
    possible_params = prune_result['size'].sum()
    total_flops = int((prune_result['sparsity'] * prune_result['flops']).sum())
    possible_flops = prune_result['flops'].sum()
    print("Train results:\n", train_result)
    print("Prune results:\n", prune_result)
    print("Parameter Sparsity: {}/{} ({:.4f})".format(
        total_params, possible_params, total_params / possible_params))
    print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops,
                                                 total_flops / possible_flops))

    ## Save Results and Model ##
    if args.save:
        print('Saving results.')
        pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))
        post_result.to_pickle("{}/post-train.pkl".format(args.result_dir))
        prune_result.to_pickle("{}/compression.pkl".format(args.result_dir))
        torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
        torch.save(optimizer.state_dict(),
                   "{}/optimizer.pt".format(args.result_dir))
        torch.save(pruner.state_dict(), "{}/pruner.pt".format(args.result_dir))
Exemple #10
0
def run(gpu_id, args):

    ## parameters for multi-processing
    print('using gpu', gpu_id)
    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=args.world_size,
                            rank=gpu_id)

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    # device = load.device(args.gpu)
    device = torch.device(gpu_id)

    args.gpu_id = gpu_id

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)

    ## need to change the workers for loading the data
    args.workers = int((args.workers + 4 - 1) / 4)
    print('Adjusted dataloader worker number is ', args.workers)

    # prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True, args.workers,
    #                 args.prune_dataset_ratio * num_classes, world_size=args.world_size, rank=gpu_id)
    prune_loader, _ = load.dataloader(args.dataset, args.prune_batch_size,
                                      True, args.workers,
                                      args.prune_dataset_ratio * num_classes)

    ## need to divide the training batch size for each GPU
    args.train_batch_size = int(args.train_batch_size / args.gpu_count)
    train_loader, train_sampler = load.dataloader(args.dataset,
                                                  args.train_batch_size,
                                                  True,
                                                  args.workers,
                                                  args=args)
    # args.test_batch_size = int(args.test_batch_size/args.gpu_count)
    test_loader, _ = load.dataloader(args.dataset, args.test_batch_size, False,
                                     args.workers)

    print("data loader batch size (prune::train::test) is {}::{}::{}".format(
        prune_loader.batch_size, train_loader.batch_size,
        test_loader.batch_size))

    log_filename = '{}/{}'.format(args.result_dir, 'result.log')
    fout = open(log_filename, 'w')
    fout.write('start!\n')

    if args.compression_list == []:
        args.compression_list.append(args.compression)
    if args.pruner_list == []:
        args.pruner_list.append(args.pruner)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))

    # load the pre-defined model from the utils
    model = load.model(args.model, args.model_class)(input_shape, num_classes,
                                                     args.dense_classifier,
                                                     args.pretrained)

    ## wrap model with distributed dataparallel module
    torch.cuda.set_device(gpu_id)
    # model = model.to(device)
    model.cuda(gpu_id)
    model = ddp(model, device_ids=[gpu_id])

    ## don't need to move the loss to the GPU as it contains no parameters
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model,
                                 loss,
                                 optimizer,
                                 scheduler,
                                 train_loader,
                                 test_loader,
                                 device,
                                 args.pre_epochs,
                                 args.verbose,
                                 train_sampler=train_sampler)
    print('Pre-Train finished!')

    ## Save Original ##
    torch.save(model.state_dict(),
               "{}/pre_train_model_{}.pt".format(args.result_dir, gpu_id))
    torch.save(optimizer.state_dict(),
               "{}/pre_train_optimizer_{}.pt".format(args.result_dir, gpu_id))
    torch.save(scheduler.state_dict(),
               "{}/pre_train_scheduler_{}.pt".format(args.result_dir, gpu_id))

    if not args.unpruned:
        for compression in args.compression_list:
            for p in args.pruner_list:
                # Reset Model, Optimizer, and Scheduler
                print('compression ratio: {} ::: pruner: {}'.format(
                    compression, p))
                model.load_state_dict(
                    torch.load("{}/pre_train_model_{}.pt".format(
                        args.result_dir, gpu_id),
                               map_location=device))
                optimizer.load_state_dict(
                    torch.load("{}/pre_train_optimizer_{}.pt".format(
                        args.result_dir, gpu_id),
                               map_location=device))
                scheduler.load_state_dict(
                    torch.load("{}/pre_train_scheduler_{}.pt".format(
                        args.result_dir, gpu_id),
                               map_location=device))

                ## Prune ##
                print('Pruning with {} for {} epochs.'.format(
                    p, args.prune_epochs))
                pruner = load.pruner(p)(generator.masked_parameters(
                    model, args.prune_bias, args.prune_batchnorm,
                    args.prune_residual))
                sparsity = 10**(-float(compression))
                prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                           args.compression_schedule, args.mask_scope,
                           args.prune_epochs, args.reinitialize,
                           args.prune_train_mode, args.shuffle, args.invert)

                ## Post-Train ##
                print('Post-Training for {} epochs.'.format(args.post_epochs))
                post_train_start_time = timeit.default_timer()
                post_result = train_eval_loop(model,
                                              loss,
                                              optimizer,
                                              scheduler,
                                              train_loader,
                                              test_loader,
                                              device,
                                              args.post_epochs,
                                              args.verbose,
                                              train_sampler=train_sampler)
                post_train_end_time = timeit.default_timer()
                print("Post Training time: {:.4f}s".format(
                    post_train_end_time - post_train_start_time))

                ## Display Results ##
                frames = [
                    pre_result.head(1),
                    pre_result.tail(1),
                    post_result.head(1),
                    post_result.tail(1)
                ]
                train_result = pd.concat(
                    frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
                prune_result = metrics.summary(
                    model, pruner.scores,
                    metrics.flop(model, input_shape, device),
                    lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                                 prune_residual))
                total_params = int(
                    (prune_result['sparsity'] * prune_result['size']).sum())
                possible_params = prune_result['size'].sum()
                total_flops = int(
                    (prune_result['sparsity'] * prune_result['flops']).sum())
                possible_flops = prune_result['flops'].sum()
                print("Train results:\n", train_result)
                # print("Prune results:\n", prune_result)
                # print("Parameter Sparsity: {}/{} ({:.4f})".format(total_params, possible_params, total_params / possible_params))
                # print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops, total_flops / possible_flops))

                ## recording testing time for task 2 ##
                # evaluating the model, including some data gathering overhead
                # eval(model, loss, test_loader, device, args.verbose)
                model.eval()
                start_time = timeit.default_timer()
                with torch.no_grad():
                    for data, target in test_loader:
                        data, target = data.to(device), target.to(device)
                        temp_eval_out = model(data)
                end_time = timeit.default_timer()
                print("Testing time: {:.4f}s".format(end_time - start_time))

                fout.write('compression ratio: {} ::: pruner: {}'.format(
                    compression, p))
                fout.write('Train results:\n {}\n'.format(train_result))
                fout.write('Prune results:\n {}\n'.format(prune_result))
                fout.write('Parameter Sparsity: {}/{} ({:.4f})\n'.format(
                    total_params, possible_params,
                    total_params / possible_params))
                fout.write("FLOP Sparsity: {}/{} ({:.4f})\n".format(
                    total_flops, possible_flops, total_flops / possible_flops))
                fout.write("Testing time: {}s\n".format(end_time - start_time))
                fout.write("remaining weights: \n{}\n".format(
                    (prune_result['sparsity'] * prune_result['size'])))
                fout.write('flop each layer: {}\n'.format(
                    (prune_result['sparsity'] *
                     prune_result['flops']).values.tolist()))
                ## Save Results and Model ##
                if args.save:
                    print('Saving results.')
                    if not os.path.exists('{}/{}'.format(
                            args.result_dir, compression)):
                        os.makedirs('{}/{}'.format(args.result_dir,
                                                   compression))
                    # pre_result.to_pickle("{}/{}/pre-train.pkl".format(args.result_dir, compression))
                    # post_result.to_pickle("{}/{}/post-train.pkl".format(args.result_dir, compression))
                    # prune_result.to_pickle("{}/{}/compression.pkl".format(args.result_dir, compression))
                    # torch.save(model.state_dict(), "{}/{}/model.pt".format(args.result_dir, compression))
                    # torch.save(optimizer.state_dict(),
                    #         "{}/{}/optimizer.pt".format(args.result_dir, compression))
                    # torch.save(scheduler.state_dict(),
                    #         "{}/{}/scheduler.pt".format(args.result_dir, compression))

    else:
        print('Staring Unpruned NN training')
        print('Training for {} epochs.'.format(args.post_epochs))
        model.load_state_dict(
            torch.load("{}/pre_train_model.pt".format(args.result_dir),
                       map_location=device))
        optimizer.load_state_dict(
            torch.load("{}/pre_train_optimizer.pt".format(args.result_dir),
                       map_location=device))
        scheduler.load_state_dict(
            torch.load("{}/pre_train_scheduler.pt".format(args.result_dir),
                       map_location=device))

        train_start_time = timeit.default_timer()
        result = train_eval_loop(model,
                                 loss,
                                 optimizer,
                                 scheduler,
                                 train_loader,
                                 test_loader,
                                 device,
                                 args.post_epochs,
                                 args.verbose,
                                 train_sampler=train_sampler)
        train_end_time = timeit.default_timer()
        frames = [result.head(1), result.tail(1)]
        train_result = pd.concat(frames, keys=['Init.', 'Final'])
        print('Train results:\n', train_result)

    fout.close()

    dist.destroy_process_group()
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    loss = nn.CrossEntropyLoss()

    for compression in args.compression_list:
        for level in args.level_list:

            ## Models ##
            models = []
            for i in range(args.models_num):
                print('Creating {} model ({}/{}).'.format(
                    args.model, str(i + 1), str(args.models_num)))
                model = load.model(args.model, args.model_class)(
                    input_shape, num_classes, args.dense_classifier,
                    args.pretrained).to(device)

                opt_class, opt_kwargs = load.optimizer(args.optimizer)
                optimizer = opt_class(generator.parameters(model),
                                      lr=args.lr,
                                      weight_decay=args.weight_decay,
                                      **opt_kwargs)
                scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    optimizer,
                    milestones=args.lr_drops,
                    gamma=args.lr_drop_rate)

                ## Save Original ##
                torch.save(model.state_dict(),
                           "{}/model.pt".format(args.result_dir))
                torch.save(optimizer.state_dict(),
                           "{}/optimizer.pt".format(args.result_dir))
                torch.save(scheduler.state_dict(),
                           "{}/scheduler.pt".format(args.result_dir))

                ## Train-Prune Loop ##
                print('{} compression ratio, {} train-prune levels'.format(
                    compression, level))

                for l in range(level):
                    # Pre Train Model
                    train_eval_loop(model, loss, optimizer, scheduler,
                                    train_loader, test_loader, device,
                                    args.pre_epochs, args.verbose)

                    # Prune Model
                    pruner = load.pruner(args.pruner)(
                        generator.masked_parameters(model, args.prune_bias,
                                                    args.prune_batchnorm,
                                                    args.prune_residual))
                    if args.pruner == 'synflown':
                        pruner.set_input(train_loader, num_classes, device)
                    sparsity = (10**(-float(compression)))**((l + 1) / level)
                    prune_loop(model, loss, pruner, prune_loader, device,
                               sparsity, args.compression_schedule,
                               args.mask_scope, args.prune_epochs, False,
                               args.prune_train_mode)

                    # Reset Model's Weights
                    original_dict = torch.load("{}/model.pt".format(
                        args.result_dir),
                                               map_location=device)
                    original_weights = dict(
                        filter(lambda v: (v[1].requires_grad == True),
                               original_dict.items()))
                    model_dict = model.state_dict()
                    model_dict.update(original_weights)
                    model.load_state_dict(model_dict)

                    # Reset Optimizer and Scheduler
                    optimizer.load_state_dict(
                        torch.load("{}/optimizer.pt".format(args.result_dir),
                                   map_location=device))
                    scheduler.load_state_dict(
                        torch.load("{}/scheduler.pt".format(args.result_dir),
                                   map_location=device))

                # Prune Result
                prune_result = metrics.summary(
                    model, pruner.scores,
                    metrics.flop(model, input_shape, device),
                    lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                                 prune_residual))
                # Train Model
                post_result = train_eval_loop(model, loss, optimizer,
                                              scheduler, train_loader,
                                              test_loader, device,
                                              args.post_epochs, args.verbose)

                # Save Data
                post_result.to_pickle("{}/post-train-{}-{}-{}-{}.pkl".format(
                    args.result_dir, args.pruner, str(compression), str(level),
                    str(i + 1)))
                prune_result.to_pickle("{}/compression-{}-{}-{}-{}.pkl".format(
                    args.result_dir, args.pruner, str(compression), str(level),
                    str(i + 1)))
                torch.save(model.state_dict(),
                           "{}/model{}.pt".format(args.result_dir, str(i + 1)))
                models.append(model)

            # Evaluate the combined model
            train_average_loss, train_accuracy1, train_accuracy5 = eval_multi_models(
                models, loss, train_loader, device, args.verbose)
            print(
                "Training loss: {}, Training 1-accuracy: {}, Training 5-accuracy: {}"
                .format(train_average_loss, train_accuracy1, train_accuracy5))
            test_average_loss, test_accuracy1, test_accuracy5 = eval_multi_models(
                models, loss, test_loader, device, args.verbose)
            print("Test loss: {}, Test 1-accuracy: {}, Test 5-accuracy: {}".
                  format(test_average_loss, test_accuracy1, test_accuracy5))

            results = [
                train_average_loss, train_accuracy1, train_accuracy5,
                test_average_loss, test_accuracy1, test_accuracy5
            ]
            pickle.dump(
                results,
                open(
                    "{}/results-{}-{}-{}-{}.pkl".format(
                        args.result_dir, args.pruner, str(compression),
                        str(level), str(i + 1)), "wb"))

    print('Done!')
Exemple #12
0
def run(args):
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    if args.compression_list == []:
        args.compression_list.append(args.compression)

    for compression in args.compression_list:
        ## Model, Loss, Optimizer ##
        print('Creating {}-{} model.'.format(args.model_class, args.model))
        if args.model != 'fc':
            model = load.model(args.model,
                               args.model_class)(input_shape, num_classes,
                                                 args.dense_classifier,
                                                 args.pretrained).to(device)
        else:
            model = load.model(args.model,
                               args.model_class)(input_shape,
                                                 num_classes,
                                                 args.dense_classifier,
                                                 args.pretrained,
                                                 N=args.hidden).to(device)
        loss = nn.CrossEntropyLoss()
        opt_class, opt_kwargs = load.optimizer(args.optimizer)
        optimizer = opt_class(generator.parameters(model),
                              lr=args.lr,
                              weight_decay=args.weight_decay,
                              **opt_kwargs)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)

        ## Pre-Train ##
        print('Pre-Train for {} epochs.'.format(args.pre_epochs))
        pre_result = train_eval_loop(model, loss, optimizer, scheduler,
                                     train_loader, test_loader, device,
                                     args.pre_epochs, args.verbose)

        ## Prune ##
        print('Pruning with {} for {} epochs.'.format(args.pruner,
                                                      args.prune_epochs))
        pruner = load.pruner(args.pruner)(generator.masked_parameters(
            model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
        sparsity = 10**(-float(compression))
        prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                   args.compression_schedule, args.mask_scope,
                   args.prune_epochs, args.reinitialize, args.prune_train_mode)

        ## Post-Train ##
        print('Post-Training for {} epochs.'.format(args.post_epochs))
        post_result = train_eval_loop(model, loss, optimizer, scheduler,
                                      train_loader, test_loader, device,
                                      args.post_epochs, args.verbose)

        ## Display Results ##
        frames = [
            pre_result.head(1),
            pre_result.tail(1),
            post_result.head(1),
            post_result.tail(1)
        ]
        train_result = pd.concat(
            frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
        prune_result = metrics.summary(
            model, pruner.scores, metrics.flop(model, input_shape, device),
            lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                         prune_residual))
        total_params = int(
            (prune_result['sparsity'] * prune_result['size']).sum())
        possible_params = prune_result['size'].sum()
        total_flops = int(
            (prune_result['sparsity'] * prune_result['flops']).sum())
        possible_flops = prune_result['flops'].sum()
        print("Train results:\n", train_result)
        print("Prune results:\n", prune_result)
        print("Parameter Sparsity: {}/{} ({:.4f})".format(
            total_params, possible_params, total_params / possible_params))
        print("FLOP Sparsity: {}/{} ({:.4f})".format(
            total_flops, possible_flops, total_flops / possible_flops))

        ## Save Results and Model ##
        if args.save:
            print('Saving results.')
            if not os.path.exists('{}/{}'.format(args.result_dir,
                                                 compression)):
                os.makedirs('{}/{}'.format(args.result_dir, compression))
            pre_result.to_pickle("{}/{}/pre-train.pkl".format(
                args.result_dir, compression))
            post_result.to_pickle("{}/{}/post-train.pkl".format(
                args.result_dir, compression))
            prune_result.to_pickle("{}/{}/compression.pkl".format(
                args.result_dir, compression))
            torch.save(model.state_dict(),
                       "{}/{}/model.pt".format(args.result_dir, compression))
            torch.save(
                optimizer.state_dict(),
                "{}/{}/optimizer.pt".format(args.result_dir, compression))
            torch.save(
                scheduler.state_dict(),
                "{}/{}/scheduler.pt".format(args.result_dir, compression))
Exemple #13
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    ## Model ##
    print('Creating {} model.'.format(args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler,
                                 train_loader, test_loader, device,
                                 args.pre_epochs, args.verbose)

    pre_result.to_pickle("{}/pre-train.pkl".format(args.result_dir))
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),
               "{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),
               "{}/scheduler.pt".format(args.result_dir))

    ## Prune and Fine-Tune##
    pruners = args.pruner_list
    comp_exponents = args.compression_list
    for j, exp in enumerate(comp_exponents[::-1]):
        for i, p in enumerate(pruners):
            print(p, exp)

            model.load_state_dict(
                torch.load("{}/model.pt".format(args.result_dir),
                           map_location=device))
            optimizer.load_state_dict(
                torch.load("{}/optimizer.pt".format(args.result_dir),
                           map_location=device))
            scheduler.load_state_dict(
                torch.load("{}/scheduler.pt".format(args.result_dir),
                           map_location=device))

            pruner = load.pruner(p)(generator.masked_parameters(
                model, args.prune_bias, args.prune_batchnorm,
                args.prune_residual))
            sparsity = 10**(-float(exp))

            if p == 'sf':
                prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                           args.normalize_score, args.mask_scope,
                           args.prune_epochs, args.reinitialize)
            else:
                prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                           args.normalize_score, args.mask_scope, 1,
                           args.reinitialize)

            prune_result = metrics.summary(
                model, pruner.scores, metrics.flop(model, input_shape, device),
                lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                             prune_residual))

            post_result = train_eval_loop(model, loss, optimizer, scheduler,
                                          train_loader, test_loader, device,
                                          args.post_epochs, args.verbose)

            post_result.to_pickle("{}/post-train-{}-{}-{}.pkl".format(
                args.result_dir, p, str(exp), args.prune_epochs))
            prune_result.to_pickle("{}/compression-{}-{}-{}.pkl".format(
                args.result_dir, p, str(exp), args.prune_epochs))
Exemple #14
0
def run(args):
    if not args.save:
        print("This experiment requires an expid.")
        quit()

    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset, args.prune_batch_size, True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True,
                                   args.workers)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False,
                                  args.workers)

    ## Model ##
    print('Creating {} model.'.format(args.model))
    model = load.model(args.model,
                       args.model_class)(input_shape, num_classes,
                                         args.dense_classifier,
                                         args.pretrained).to(device)
    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Save Original ##
    torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),
               "{}/optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),
               "{}/scheduler.pt".format(args.result_dir))

    ## Train-Prune Loop ##
    comp_exponents = np.arange(0, 6, 1)  # log 10 of inverse sparsity
    train_iterations = [100]  # number of epochs between prune periods
    prune_iterations = np.arange(1, 6, 1)  # number of prune periods
    for i, comp_exp in enumerate(comp_exponents[::-1]):
        for j, train_iters in enumerate(train_iterations):
            for k, prune_iters in enumerate(prune_iterations):
                print(
                    '{} compression ratio, {} train epochs, {} prune iterations'
                    .format(comp_exp, train_iters, prune_iters))

                # Reset Model, Optimizer, and Scheduler
                model.load_state_dict(
                    torch.load("{}/model.pt".format(args.result_dir),
                               map_location=device))
                optimizer.load_state_dict(
                    torch.load("{}/optimizer.pt".format(args.result_dir),
                               map_location=device))
                scheduler.load_state_dict(
                    torch.load("{}/scheduler.pt".format(args.result_dir),
                               map_location=device))

                for l in range(prune_iters):

                    # Pre Train Model
                    train_eval_loop(model, loss, optimizer, scheduler,
                                    train_loader, test_loader, device,
                                    train_iters, args.verbose)

                    # Prune Model
                    pruner = load.pruner('mag')(generator.masked_parameters(
                        model, args.prune_bias, args.prune_batchnorm,
                        args.prune_residual))
                    sparsity = (10**(-float(comp_exp)))**((l + 1) /
                                                          prune_iters)
                    prune_loop(model, loss, pruner, prune_loader, device,
                               sparsity, args.normalize_score, args.mask_scope,
                               1, args.reinitialize)

                    # Reset Model's Weights
                    original_dict = torch.load("{}/model.pt".format(
                        args.result_dir),
                                               map_location=device)
                    original_weights = dict(
                        filter(lambda v: (v[1].requires_grad == True),
                               model.state_dict().items()))
                    model_dict = model.state_dict()
                    model_dict.update(original_weights)
                    model.load_state_dict(model_dict)

                    # Reset Optimizer and Scheduler
                    optimizer.load_state_dict(
                        torch.load("{}/optimizer.pt".format(args.result_dir),
                                   map_location=device))
                    scheduler.load_state_dict(
                        torch.load("{}/scheduler.pt".format(args.result_dir),
                                   map_location=device))

                # Prune Result
                prune_result = metrics.summary(
                    model, pruner.scores,
                    metrics.flop(model, input_shape, device),
                    lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                                 prune_residual))
                # Train Model
                post_result = train_eval_loop(model, loss, optimizer,
                                              scheduler, train_loader,
                                              test_loader, device,
                                              args.post_epochs, args.verbose)

                # Save Data
                prune_result.to_pickle("{}/compression-{}-{}-{}.pkl".format(
                    args.result_dir, str(comp_exp), str(train_iters),
                    str(prune_iters)))
                post_result.to_pickle("{}/performance-{}-{}-{}.pkl".format(
                    args.result_dir, str(comp_exp), str(train_iters),
                    str(prune_iters)))
Exemple #15
0
def run(args):
    ## Random Seed and Device ##
    torch.manual_seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    print('Loading {} dataset.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader, _ = load.dataloader(args.dataset, args.prune_batch_size,
                                      True, args.workers,
                                      args.prune_dataset_ratio * num_classes)
    train_loader, _ = load.dataloader(args.dataset, args.train_batch_size,
                                      True, args.workers)
    test_loader, _ = load.dataloader(args.dataset, args.test_batch_size, False,
                                     args.workers)

    log_filename = '{}/{}'.format(args.result_dir, 'result.log')
    fout = open(log_filename, 'w')
    fout.write('start!\n')

    if args.compression_list == []:
        args.compression_list.append(args.compression)
    if args.pruner_list == []:
        args.pruner_list.append(args.pruner)

    ## Model, Loss, Optimizer ##
    print('Creating {}-{} model.'.format(args.model_class, args.model))

    # load the pre-defined model from the utils
    model = load.model(args.model, args.model_class)(input_shape, num_classes,
                                                     args.dense_classifier,
                                                     args.pretrained)

    #########################################
    # enable distributed data parallelism
    #########################################
    if (args.data_parallel):
        if torch.cuda.device_count() > 1:
            print("let's use", torch.cuda.device_count(), "GPUs!")
            model = nn.DataParallel(model)

    model = model.to(device)

    loss = nn.CrossEntropyLoss()
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model),
                          lr=args.lr,
                          weight_decay=args.weight_decay,
                          **opt_kwargs)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=args.lr_drops,
                                                     gamma=args.lr_drop_rate)

    ## Pre-Train ##
    if args.unpruned:
        print(
            'Pruning is disabled. Total training epoch = Pre-Train epochs + Post-Train epochs'
        )
    print('Pre-Train for {} epochs.'.format(args.pre_epochs))
    pre_result = train_eval_loop(model, loss, optimizer, scheduler,
                                 train_loader, test_loader, device,
                                 args.pre_epochs, args.verbose)
    print('Pre-Train finished!')

    ## Save Original ##
    torch.save(model.state_dict(),
               "{}/pre_train_model.pt".format(args.result_dir))
    torch.save(optimizer.state_dict(),
               "{}/pre_train_optimizer.pt".format(args.result_dir))
    torch.save(scheduler.state_dict(),
               "{}/pre_train_scheduler.pt".format(args.result_dir))

    if not args.unpruned:
        for compression in args.compression_list:
            for p in args.pruner_list:
                # Reset Model, Optimizer, and Scheduler
                print('compression ratio: {} ::: pruner: {}'.format(
                    compression, p))
                model.load_state_dict(
                    torch.load("{}/pre_train_model.pt".format(args.result_dir),
                               map_location=device))
                optimizer.load_state_dict(
                    torch.load("{}/pre_train_optimizer.pt".format(
                        args.result_dir),
                               map_location=device))
                scheduler.load_state_dict(
                    torch.load("{}/pre_train_scheduler.pt".format(
                        args.result_dir),
                               map_location=device))

                ## Prune ##
                print('Pruning with {} for {} epochs.'.format(
                    p, args.prune_epochs))
                pruner = load.pruner(p)(generator.masked_parameters(
                    model, args.prune_bias, args.prune_batchnorm,
                    args.prune_residual))
                sparsity = 10**(-float(compression))
                prune_loop(model, loss, pruner, prune_loader, device, sparsity,
                           args.compression_schedule, args.mask_scope,
                           args.prune_epochs, args.reinitialize,
                           args.prune_train_mode, args.shuffle, args.invert)

                ## Post-Train ##
                print('Post-Training for {} epochs.'.format(args.post_epochs))
                post_train_start_time = timeit.default_timer()
                post_result = train_eval_loop(model, loss, optimizer,
                                              scheduler, train_loader,
                                              test_loader, device,
                                              args.post_epochs, args.verbose)
                post_train_end_time = timeit.default_timer()
                print("Post Training time: {:.4f}s".format(
                    post_train_end_time - post_train_start_time))

                ## Display Results ##
                frames = [
                    pre_result.head(1),
                    pre_result.tail(1),
                    post_result.head(1),
                    post_result.tail(1)
                ]
                train_result = pd.concat(
                    frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
                prune_result = metrics.summary(
                    model, pruner.scores,
                    metrics.flop(model, input_shape, device),
                    lambda p: generator.prunable(p, args.prune_batchnorm, args.
                                                 prune_residual))
                total_params = int(
                    (prune_result['sparsity'] * prune_result['size']).sum())
                possible_params = prune_result['size'].sum()
                total_flops = int(
                    (prune_result['sparsity'] * prune_result['flops']).sum())
                possible_flops = prune_result['flops'].sum()
                print("Train results:\n", train_result)
                # print("Prune results:\n", prune_result)
                # print("Parameter Sparsity: {}/{} ({:.4f})".format(total_params, possible_params, total_params / possible_params))
                # print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops, possible_flops, total_flops / possible_flops))

                ## recording testing time for task 2 ##
                # evaluating the model, including some data gathering overhead
                # eval(model, loss, test_loader, device, args.verbose)
                model.eval()
                start_time = timeit.default_timer()
                with torch.no_grad():
                    for data, target in test_loader:
                        data, target = data.to(device), target.to(device)
                        temp_eval_out = model(data)
                        # print("Test eval data size: input: {}; output: {}".format(data.size(), temp_eval_out.size()))
                end_time = timeit.default_timer()
                print("Testing time: {:.4f}s".format(end_time - start_time))

                fout.write('compression ratio: {} ::: pruner: {}'.format(
                    compression, p))
                fout.write('Train results:\n {}\n'.format(train_result))
                fout.write('Prune results:\n {}\n'.format(prune_result))
                fout.write('Parameter Sparsity: {}/{} ({:.4f})\n'.format(
                    total_params, possible_params,
                    total_params / possible_params))
                fout.write("FLOP Sparsity: {}/{} ({:.4f})\n".format(
                    total_flops, possible_flops, total_flops / possible_flops))
                fout.write("Testing time: {}s\n".format(end_time - start_time))
                fout.write("remaining weights: \n{}\n".format(
                    (prune_result['sparsity'] * prune_result['size'])))
                fout.write('flop each layer: {}\n'.format(
                    (prune_result['sparsity'] *
                     prune_result['flops']).values.tolist()))
                ## Save Results and Model ##
                if args.save:
                    print('Saving results.')
                    if not os.path.exists('{}/{}'.format(
                            args.result_dir, compression)):
                        os.makedirs('{}/{}'.format(args.result_dir,
                                                   compression))
                    # pre_result.to_pickle("{}/{}/pre-train.pkl".format(args.result_dir, compression))
                    # post_result.to_pickle("{}/{}/post-train.pkl".format(args.result_dir, compression))
                    # prune_result.to_pickle("{}/{}/compression.pkl".format(args.result_dir, compression))
                    # torch.save(model.state_dict(), "{}/{}/model.pt".format(args.result_dir, compression))
                    # torch.save(optimizer.state_dict(),
                    #         "{}/{}/optimizer.pt".format(args.result_dir, compression))
                    # torch.save(scheduler.state_dict(),
                    #         "{}/{}/scheduler.pt".format(args.result_dir, compression))

    else:
        print('Staring Unpruned NN training')
        print('Training for {} epochs.'.format(args.post_epochs))
        model.load_state_dict(
            torch.load("{}/pre_train_model.pt".format(args.result_dir),
                       map_location=device))
        optimizer.load_state_dict(
            torch.load("{}/pre_train_optimizer.pt".format(args.result_dir),
                       map_location=device))
        scheduler.load_state_dict(
            torch.load("{}/pre_train_scheduler.pt".format(args.result_dir),
                       map_location=device))

        train_start_time = timeit.default_timer()
        result = train_eval_loop(model, loss, optimizer, scheduler,
                                 train_loader, test_loader, device,
                                 args.post_epochs, args.verbose)
        train_end_time = timeit.default_timer()
        frames = [result.head(1), result.tail(1)]
        train_result = pd.concat(frames, keys=['Init.', 'Final'])
        print('Train results:\n', train_result)
    fout.close()
Exemple #16
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    args.sparse_lvl = 0.8 ** args.sparse_iter
    print(args.sparse_lvl)

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    
    torch.manual_seed(args.seed)

    cudnn.benchmark = True

    if args.dataset =='cifar10':
        print('Loading {} dataset.'.format(args.dataset))
        input_shape, num_classes = load.dimension(args.dataset) 
        train_dataset, train_loader = load.dataloader(args.dataset, args.batch_size, True, args.workers)
        _, val_loader = load.dataloader(args.dataset, 128, False, args.workers)

    elif args.dataset == 'tiny-imagenet':
        args.batch_size = 256
        args.lr = 0.2
        args.epochs = 200
        print('Loading {} dataset.'.format(args.dataset))
        input_shape, num_classes = load.dimension(args.dataset) 
        train_dataset, train_loader = load.dataloader(args.dataset, args.batch_size, True, args.workers)
        _, val_loader = load.dataloader(args.dataset, 128, False, args.workers)

    elif args.dataset == 'cifar100':
        args.batch_size = 128
        # args.lr = 0.01
        args.epochs = 160
        # args.weight_decay = 5e-4
        input_shape, num_classes = load.dimension(args.dataset) 
        train_dataset, train_loader = load.dataloader(args.dataset, args.batch_size, True, args.workers)
        _, val_loader = load.dataloader(args.dataset, 128, False, args.workers)

    if args.arch == 'resnet20':
        print('Creating {} model.'.format(args.arch))
        # model = torch.nn.DataParallel(resnet.__dict__[args.arch](ONI=args.ONI, T_iter=args.T_iter))
        model = resnet.__dict__[args.arch](ONI=args.ONI, T_iter=args.T_iter)
        model.cuda()
    elif args.arch == 'resnet18':
        print('Creating {} model.'.format(args.arch))
        # Using resnet18 from Synflow
        # model = load.model(args.arch, 'tinyimagenet')(input_shape, 
        #                                              num_classes,
        #                                              dense_classifier = True).cuda()
        # Using resnet18 from torchvision
        model = models.resnet18()
        model.fc = nn.Linear(512, num_classes)
        model.cuda()
        utils.kaiming_initialize(model)
    elif args.arch == 'resnet110' or args.arch == 'resnet110full':
        # Using resnet110 from Apollo
        # model = apolo_resnet.ResNet(110, num_classes=num_classes)
        model = load.model(args.arch, 'lottery')(input_shape, 
                                             num_classes,
                                             dense_classifier = True).cuda()
    elif args.arch in ['vgg16full', 'vgg16full-bn', 'vgg11full', 'vgg11full-bn'] :
        if args.dataset == 'tiny-imagenet':
            modeltype = 'tinyimagenet'
        else:
            modeltype = 'lottery'
        # Using resnet110 from Apollo
        # model = apolo_resnet.ResNet(110, num_classes=num_classes)
        model = load.model(args.arch, modeltype)(input_shape, 
                                             num_classes,
                                             dense_classifier = True).cuda()
    
    # for layer in model.modules():
    #     if isinstance(layer, nn.Linear):
    #         init.orthogonal_(layer.weight.data)
    #     elif isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d)):
    #         special_init.DeltaOrthogonal_init(layer.weight.data)

    print('Number of parameters of model: {}.'.format(count_parameters(model)))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.compute_sv:
        print('[*] Will compute singular values throught training.')
        size_hook = utils.get_hook(model, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d))
        utils.run_once(train_loader, model)
        utils.detach_hook([size_hook])
        training_sv = []
        training_svmax = []
        training_sv20 = [] # 50% singular value
        training_sv50 = [] # 50% singular value
        training_sv80 = [] # 80% singular value
        training_kclip = [] # singular values larger than 1e-12
        sv, svmax, sv20, sv50, sv80, kclip = utils.get_sv(model, size_hook)
        training_sv.append(sv)
        training_svmax.append(svmax)
        training_sv20.append(sv20)
        training_sv50.append(sv50)
        training_sv80.append(sv80)
        training_kclip.append(kclip)
    
    if args.compute_ntk:
        training_ntk_eig = []
        if num_classes>=32:
            _, ntk_loader = load.dataloader(args.dataset, 32, True, args.workers)
            grasp_fetch = False
        else:
            ntk_loader = train_loader
            grasp_fetch = True
        training_ntk_eig.append(utils.get_ntk_eig(ntk_loader, [model], train_mode = True, num_batch=1, num_classes=num_classes, samples_per_class=1, grasp_fetch=grasp_fetch))
    
    if args.compute_lrs:
        # training_lrs = []
        # lrc_model = utils.Linear_Region_Collector(train_loader, input_size=(args.batch_size,*input_shape), sample_batch=300)
        # lrc_model.reinit(models=[model])
        # lrs = lrc_model.forward_batch_sample()[0]
        # training_lrs.append(lrs)
        # lrc_model.clear_hooks()
        # print('[*] Current number of linear regions:{}'.format(lrs))
        GAP_zen, output_zen = utils.get_zenscore(model, train_loader, args.arch, num_classes)
        print('[*] Before pruning: GAP_zen:{:e}, output_zen:{:e}'.format(GAP_zen,output_zen))
    
    if args.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                nesterov = True,
                                weight_decay=args.weight_decay)

    if args.dataset ==  'tiny-imagenet':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[100, 150], last_epoch=args.start_epoch - 1)
                                                        # milestones=[30, 60, 80], last_epoch=args.start_epoch - 1)
    elif args.dataset ==  'cifar100':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[60, 120], gamma = 0.2, last_epoch=args.start_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[80, 120], last_epoch=args.start_epoch - 1)

    # This part is for training full NN model to obtain Lottery ticket

    # # First save original network:
    init_path = os.path.join(args.save_dir, 'init_checkpoint.th')
    save_checkpoint({
                'state_dict': model.state_dict()
            }, False, filename=init_path)

    if args.prune_method == 'NONE':
        pre_epochs = args.epochs
    else:
        pre_epochs = 0

    training_loss = []
    for epoch in range(pre_epochs):

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(train_loader, model, criterion, optimizer, epoch, track = training_loss)
        lr_scheduler.step()

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best, filename=os.path.join(args.save_dir, 'densenet_checkpoint.th'))

        if args.compute_sv and epoch % args.save_every == 0:
            sv,  svmax, sv20, sv50, sv80, kclip= utils.get_sv(model, size_hook)
            training_sv.append(sv)
            training_svmax.append(svmax)
            training_sv20.append(sv20)
            training_sv50.append(sv50)
            training_sv80.append(sv80)
            training_kclip.append(kclip)
            np.save(os.path.join(args.save_dir, 'sv.npy'), training_sv)
            np.save(os.path.join(args.save_dir, 'sv_svmax.npy'), training_svmax)
            np.save(os.path.join(args.save_dir, 'sv_sv20.npy'), training_sv20)
            np.save(os.path.join(args.save_dir, 'sv_sv50.npy'), training_sv50)
            np.save(os.path.join(args.save_dir, 'sv_sv80.npy'), training_sv80)
            np.save(os.path.join(args.save_dir, 'sv_kclip.npy'), training_kclip)

        if args.compute_ntk and epoch % args.save_every == 0:
            training_ntk_eig.append(utils.get_ntk_eig(ntk_loader, [model], train_mode = True, num_batch=1, num_classes=num_classes, samples_per_class=1, grasp_fetch=grasp_fetch))
            np.save(os.path.join(args.save_dir, 'ntk_eig.npy'), training_ntk_eig)
        


    print('[*] {} epochs of dense network pre-training done'.format(pre_epochs))
    np.save(os.path.join(args.save_dir, 'trainloss.npy'), training_loss)

    # densenet_checkpoint = torch.load(os.path.join(args.save_dir, 'densenet_checkpoint.th'))
    # model.load_state_dict(densenet_checkpoint['state_dict'])
    # print('Model loaded!')
    # Obtain lottery ticket by magnitude pruning
    if args.prune_method == 'NONE':
        snip.apply_mag_prune(args, model)
        # reinitialize
        init_checkpoint = torch.load(init_path)
        model.load_state_dict(init_checkpoint['state_dict'])
        print('Model reinitialized!')
    elif args.prune_method == 'SNIP':
        init_checkpoint = torch.load(init_path)
        model.load_state_dict(init_checkpoint['state_dict'])
        print('Model reinitialized!')
        snip.apply_snip(args, [model], train_loader, criterion, num_classes=num_classes)
        # attack.shuffle_mask(model)
    elif args.prune_method == 'RAND':
        init_checkpoint = torch.load(init_path)
        model.load_state_dict(init_checkpoint['state_dict'])
        print('Model reinitialized!')
        snip.apply_rand_prune([model], args.sparse_lvl)
    elif args.prune_method == 'GRASP':
        init_checkpoint = torch.load(init_path)
        model.load_state_dict(init_checkpoint['state_dict'])
        print('Model reinitialized!')
        snip.apply_grasp(args, [model], train_loader, criterion, num_classes=num_classes)
    elif args.prune_method == 'Zen':
        zenprune.apply_zenprune(args, [model], train_loader)
        # zenprune.apply_cont_zenprune(args, [model], train_loader)
        # zenprune.apply_zentransfer(args, [model], train_loader)
        # init_checkpoint = torch.load(init_path)
        # model.load_state_dict(init_checkpoint['state_dict'])
        # print('Model reinitialized!')
    elif args.prune_method == 'Mag':
        snip.apply_mag_prune(args, model)
        init_checkpoint = torch.load(init_path)
        model.load_state_dict(init_checkpoint['state_dict'])
        print('Model reinitialized!')
    elif args.prune_method == 'Synflow':
        synflow.apply_synflow(args, model)

    print('{} done, sparsity of the current model: {}.'.format(args.prune_method, utils.check_sparsity(model)))
    
    if args.compute_lrs:
        # training_lrs = []
        # lrc_model = utils.Linear_Region_Collector(train_loader, input_size=(args.batch_size,*input_shape), sample_batch=300)
        # lrc_model.reinit(models=[model])
        # lrs = lrc_model.forward_batch_sample()[0]
        # training_lrs.append(lrs)
        # lrc_model.clear_hooks()
        # print('[*] Current number of linear regions:{}'.format(lrs))
        GAP_zen, output_zen = utils.get_zenscore(model, train_loader, args.arch, num_classes)
        print('[*] After pruning: GAP_zen:{:e}, output_zen:{:e}'.format(GAP_zen,output_zen))

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    # Recreate optimizer and learning scheduler
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                nesterov = True,
                                weight_decay=args.weight_decay)

    if args.dataset ==  'tiny-imagenet':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[100, 150], last_epoch=args.start_epoch - 1)
                                                        # milestones=[30, 60, 80], last_epoch=args.start_epoch - 1)
    elif args.dataset ==  'cifar100':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[60, 120], gamma = 0.2, last_epoch=args.start_epoch - 1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[80, 120], last_epoch=args.start_epoch - 1)

    for epoch in range(args.epochs):
    # for epoch in range(args.pre_epochs, args.epochs):

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(train_loader, model, criterion, optimizer, epoch, track=training_loss, ortho=args.ortho)
        lr_scheduler.step()

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        if epoch > 0 and epoch % args.save_every == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best, filename=os.path.join(args.save_dir, 'sparsenet_checkpoint.th'))
            # if args.prune_method !='NONE':
            #     print(utils.check_layer_sparsity(model))   

        # save_checkpoint({
        #     'state_dict': model.state_dict(),
        #     'best_prec1': best_prec1,
        # }, is_best, filename=os.path.join(args.save_dir, 'model.th'))

        if args.compute_sv and epoch % args.save_every == 0:
            sv,  svmax, sv20, sv50, sv80,  kclip = utils.get_sv(model, size_hook)
            training_sv.append(sv)
            training_svmax.append(svmax)
            training_sv20.append(sv20)
            training_sv50.append(sv50)
            training_sv80.append(sv80)
            training_kclip.append(kclip)
            np.save(os.path.join(args.save_dir, 'sv.npy'), training_sv)
            np.save(os.path.join(args.save_dir, 'sv_svmax.npy'), training_svmax)
            np.save(os.path.join(args.save_dir, 'sv_sv20.npy'), training_sv20)
            np.save(os.path.join(args.save_dir, 'sv_sv50.npy'), training_sv50)
            np.save(os.path.join(args.save_dir, 'sv_sv80.npy'), training_sv80)
            np.save(os.path.join(args.save_dir, 'sv_kclip.npy'), training_kclip)

        if args.compute_ntk and epoch % args.save_every == 0:
            training_ntk_eig.append(utils.get_ntk_eig(ntk_loader, [model], train_mode = True, num_batch=1, num_classes=num_classes, samples_per_class=1, grasp_fetch=grasp_fetch))
            np.save(os.path.join(args.save_dir, 'ntk_eig.npy'), training_ntk_eig)

        # if args.compute_lrs and epoch % args.save_every == 0:
        #     lrc_model.reinit(models=[model])
        #     lrs = lrc_model.forward_batch_sample()[0]
        #     training_lrs.append(lrs)
        #     lrc_model.clear_hooks()
        #     print('[*] Current number of linear regions:{}'.format(lrs))
    np.save(os.path.join(args.save_dir, 'trainloss.npy'), training_loss)