Пример #1
0
def activation_pruning_experiments(args: ExperimentArgs, protocol: PruneProtocol,
                                   train_loader: DataLoader, test_loader: DataLoader,
                                   device=0):
    protocol.prune_by = 'online'
    model = load_model(args.arch, args.final_model_path, device=device)
    model.eval()

    print('Testing final model accuracy before pruning')
    #correct, total = test(model, test_loader, device=device)
    #acc_no_prune = correct.sum() / total.sum() * 100.
    #print('Model accuracy before pruning: %.2f' % acc_no_prune)

    pruner = ModulePruner(protocol,
                          device=device,
                          dataloader=test_loader,
                          network=model)

    print('Testing final model accuracy with real-time activation pruning')
    with pruner.prune(clear_on_exit=True):
        retrain_bn(model, train_loader, device=device)
        model.eval()
        correct, total = test(model, test_loader, device=device)
    prune_acc = correct.sum() / total.sum() * 100.
    print('Model accuracy with pruning: %.2f' % prune_acc)
Пример #2
0
def subnetwork_experiments(args: ExperimentArgs,
                           init_protocol: PruneProtocol, final_protocol: PruneProtocol,
                           train_loader: DataLoader, test_loader: DataLoader,
                           device=0):
    # load final model
    final_model = load_model(args.arch, args.final_model_path, device=device)
    final_model.eval()
    print('Loaded final model from %s' % args.final_model_path)

    # load initial model
    init_model = load_model(args.arch, args.init_model_path, device=device)
    print('Loaded initialized model from %s' % args.init_model_path)

    # test final model accuracy before pruning
    """
    print('Testing final model accuracy before pruning')
    correct, total = test(final_model, test_loader, device=device)
    acc_no_prune = correct.sum() / total.sum() * 100.
    print('Model accuracy before pruning: %.2f' % acc_no_prune)
    """
    acc_no_prune = 66.76

    # get pruners
    final_pruner, init_pruner = get_pruners(final_protocol, init_protocol,
                                            device=device,
                                            networks=(final_model, init_model))

    # compute reference prune masks
    final_pruner = ModulePruner(final_protocol,
                                device=device,
                                network=final_model)

    print('Computing prune masks for final model...')
    final_masks = final_pruner.compute_prune_masks(reset=False)

    # test final model performance when final prune mask is used
    print('Testing final model performance after pruning from final model...')
    with final_pruner.prune(clear_on_exit=True):
        retrain_bn(final_model, train_loader, device=device)
        final_model.eval()
        correct, total = test(final_model, test_loader, device=device)
    final_acc = correct.sum() / total.sum() * 100.
    print('Model accuracy using pruning on final model: %.2f' % final_acc)

    # compute initial prune masks
    print('Computing prune masks for initialized model...')
    init_masks = init_pruner.compute_prune_masks(reset=not args.retrain)

    # test final model performance when initial prune mask is used
    print('Testing final model performance after pruning from model at initialization...')
    final_pruner.set_prune_masks(**init_masks)
    with final_pruner.prune(clear_on_exit=False):
        retrain_bn(final_model, train_loader, device=device)
        final_model.eval()
        correct, total = test(final_model, test_loader, device=device)
    init_acc = correct.sum() / total.sum() * 100.
    print('Model accuracy using pruning at model initialization: %.2f' % init_acc)

    # compute overlap between prune masks
    print('Computing overlap between prune masks of initialized model and final model...')
    mask_accuracy_dict = compute_mask_accuracy(init_masks, final_masks)
    print('Mean mask accuracy: %.2f' % mask_accuracy_dict['mean_accuracy'])
    print('Mean mask retained recall: %.2f' % mask_accuracy_dict['mean_retained_recall'])
    print('Mean mask pruned recall: %.2f' % mask_accuracy_dict['mean_pruned_recall'])

    # test final model performance with random prune mask
    print('Testing final model performance after random pruning...')

    def make_random_mask(mask):
        neg_mask = torch.zeros_like(mask).type(torch.bool)
        flat_mask = neg_mask.flatten()
        length = flat_mask.shape[0]
        flat_mask[np.random.choice(length, int(length * init_protocol.prune_ratio), replace=False)] = True
        return neg_mask

    # test final model with random pruning
    """
    random_masks = {name: make_random_mask(mask) for name, mask in init_masks.items()}
    final_pruner.set_prune_masks(**random_masks)
    with final_pruner.prune(clear_on_exit=True):
        correct, total = test(final_model, test_loader, device=device)
    random_acc = correct.sum() / total.sum() * 100.
    print('Model accuracy using random pruning: %.2f' % random_acc)
    """
    random_acc = None

    if args.save_results:
        print('Saving experiment results to %s' % args.results_filepath)
        np.savez(args.results_filepath,
                 final_accuracy=acc_no_prune,
                 init_subnet_accuracy=init_acc,
                 final_subnet_accuracy=final_acc,
                 random_subnet_accuracy=random_acc,
                 **mask_accuracy_dict)
Пример #3
0
def sparse_grad_training(init_model_args: ModelInitArgs, train_args: SparseTrainingArgs,
                         train_loader: DataLoader, test_loader: DataLoader,
                         device=0):
    network = initialize_model(init_model_args, device=device)
    network.train()

    def get_optim(lr):
        return torch.optim.SGD(network.parameters(),
                               lr=lr,
                               nesterov=train_args.nesterov,
                               momentum=train_args.momentum,
                               weight_decay=train_args.weight_decay)

    lr = train_args.lr
    optim = get_optim(lr)
    loss_fn = torch.nn.CrossEntropyLoss()
    total, correct = [], []
    torch.manual_seed(train_args.seed)  # seed dataloader shuffling

    for e in range(train_args.epochs):
        # check for lr decay
        if e in train_args.decay_epochs:
            lr /= train_args.lr_decay
            optim = get_optim(lr)

        print('Beginning epoch %d/%d' % (e + 1, train_args.epochs))
        losses = []
        perc_grad_pruned = []

        for i, x, y in tqdm(train_loader):
            x, y = x.to(device), y.to(device)
            out = network(x)
            loss = loss_fn(out, y)
            loss.backward()

            # threshold weight gradients
            total_grad_pruned = 0
            total_grad = 0
            for module in get_named_modules_from_network(network).values():
                with torch.no_grad():
                    abs_grad = module.weight.grad.abs()
                    mean_grad = abs_grad.mean()
                    max_grad = abs_grad.max()
                    threshold = (mean_grad * train_args.mean_max_coef + max_grad) / (train_args.mean_max_coef + 1)
                    grad_mask = abs_grad < threshold
                    module.weight.grad[grad_mask] = 0.0

                    total_grad += len(grad_mask.flatten())
                    total_grad_pruned += len(np.where(grad_mask.cpu())[0])

            perc_grad_pruned += [total_grad_pruned / total_grad]
            #print(perc_grad_pruned[-1])

            optim.step()
            optim.zero_grad()
            losses += [loss.item()]

        print('Mean loss for epoch %d: %.4f' % (e, sum(losses) / len(losses)))
        print('Average percent of gradients pruned: %.2f' % (sum(perc_grad_pruned) / len(perc_grad_pruned) * 100.))
        print('Test accuracy for epoch %d:' % e, end=' ')

        network.eval()
        correct_, total_ = test(network, test_loader, device=device)
        network.train()
        total += [total_]
        correct += [correct_]
        if train_args.save_acc:
            np.savez(train_args.acc_save_path, correct=np.stack(correct, axis=0), total=np.stack(total, axis=0))
        save_model(network, train_args.model_save_path, device=device)
Пример #4
0
        cifar100_test_loader = get_test_dataloader(CIFAR100_TRAIN_MEAN,
                                                   CIFAR100_TRAIN_STD,
                                                   num_workers=1,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   root="../data/test/",
                                                   classes=task)

        print("Evaluating on task {}...".format(task_number))

        for task_trained_on in range(task_number, args.n_tasks):
            for model_name in ["naive", "foolish", "EWC"]:
                ewc_lambda = 0.4 if model_name == "EWC" else None

                model = Net()
                model.load_state_dict(
                    torch.load("models/{}_{}_task{}.pt".format(
                        args.exp_name, model_name, task_trained_on)))
                accuracy = test(model, device, cifar100_test_loader)

                results_df.loc["{}_train{}_eval{}".format(
                    model_name, task_trained_on, task_number), :] = [
                        args.exp_name, task_number, task_trained_on,
                        model_name, ewc_lambda, accuracy
                    ]

        print("Runtime until task {}: {}".format(task_trained_on,
                                                 time.time() - start_time))

    results_df.to_csv("results_{}.csv".format(args.exp_name))
Пример #5
0
    help="Test one of the models from bipart_fer13, bipart_ck, dcnn_fer13")

args = parser.parse_args()

if args.save_data:
    option = args.save_data
    if option == 'fer2013':
        save_data.save_fer2013()
    elif option == 'ck_plus':
        save_data.save_ck_plus()
    elif option == 'hog_bipart':
        save_data.save_hog_bipart()
    elif option == 'fer2013_bipart':
        save_data.save_fer2013_bipart()
elif args.train_model:
    option = args.train_model
    if option == 'bipart_cnn':
        train_models.train_bipart_cnn()
    elif option == 'dcnn_fer13':
        train_models.train_dcnn_fer()
elif args.test_model:
    option = args.test_model
    if option == 'bipart_fer13':
        train_models.test('model_data/cnn_by_parts_fer13.pt')
    elif option == 'dcnn_fer13':
        train_models.test('model_data/cnn_fer2013.pt')
    elif option == 'bipart_ck+':
        train_models.test('model_data/cnn_by_parts_CK+.pt')
else:
    print('No argument found!')
Пример #6
0
def grad_alignment_study(init_model_args: ModelInitArgs,
                         train_args: TrainingArgs,
                         train_loader: DataLoader,
                         test_loader: DataLoader,
                         device=0):
    network = initialize_model(init_model_args, device=device)
    network.train()

    def get_optim(lr):
        return torch.optim.SGD(network.parameters(),
                               lr=lr,
                               nesterov=train_args.nesterov,
                               momentum=train_args.momentum,
                               weight_decay=train_args.weight_decay)

    lr = train_args.lr
    optim = get_optim(lr)
    loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
    total, correct = [], []
    torch.manual_seed(train_args.seed)  # seed dataloader shuffling

    for e in range(train_args.epochs):
        # check for lr decay
        if e in train_args.decay_epochs:
            lr /= train_args.lr_decay
            optim = get_optim(lr)

        print('Beginning epoch %d/%d' % (e + 1, train_args.epochs))
        losses = []

        for idx, (i, x, y) in enumerate(tqdm(train_loader)):
            x, y = x.to(device), y.to(device)
            out = network(x)
            loss = loss_fn(out, y)

            if e in train_args.test_epochs and idx == 0:
                # explore sample gradients
                print(
                    '\nComputing gradient alignment across network modules...')
                modules = find_network_modules_by_name(network,
                                                       train_args.test_layers)
                grads = {n: [] for n in train_args.test_layers}
                for i in range(100):
                    loss[i].backward(retain_graph=True)

                    for n, m in zip(train_args.test_layers, modules):
                        grads[n] += [m.weight.grad.cpu()]

                    optim.zero_grad()

                mean = {}
                magnitude = {}
                variance = {}
                for n, grad in grads.items():
                    mean[n] = torch.stack(grad).mean(dim=0).numpy()
                    magnitude[n] = torch.stack(grad).abs().mean(dim=0).numpy()
                    variance[n] = ((torch.stack(grad).numpy() -
                                    mean[n][None])**2).sum(axis=0)

                np.savez(
                    'gradient-study/alignment/metrics-epoch_%d.npz' % e, **{
                        n: [
                            np.array(('mean', mean[n]), dtype=np.object),
                            np.array(('magnitude', magnitude[n]),
                                     dtype=np.object),
                            np.array(('variance', variance[n]),
                                     dtype=np.object)
                        ]
                        for n in train_args.test_layers
                    })

            loss.mean().backward()

            optim.step()
            optim.zero_grad()
            losses += [loss.mean().item()]

        print('Mean loss for epoch %d: %.4f' % (e, sum(losses) / len(losses)))
        print('Test accuracy for epoch %d:' % e, end=' ')

        network.eval()
        correct_, total_ = test(network, test_loader, device=device)
        network.train()
        total += [total_]
        correct += [correct_]
        if train_args.save_acc:
            np.savez(train_args.acc_save_path,
                     correct=np.stack(correct, axis=0),
                     total=np.stack(total, axis=0))