示例#1
0
def run(dataset, net_type):

    # Hyper Parameter settings
    layer_type = cfg.layer_type
    activation_type = cfg.activation_type

    train_ens = cfg.train_ens
    valid_ens = cfg.valid_ens
    n_epochs = cfg.n_epochs
    lr_start = cfg.lr_start
    num_workers = cfg.num_workers
    valid_size = cfg.valid_size
    batch_size = cfg.batch_size
    beta_type = cfg.beta_type

    trainset, testset, inputs, outputs = data.getDataset(dataset)
    train_loader, valid_loader, test_loader = data.getDataloader(
        trainset, testset, valid_size, batch_size, num_workers)
    net = getModel(net_type, inputs, outputs, layer_type,
                   activation_type).to(device)

    ckpt_dir = f'checkpoints/{dataset}/bayesian'
    ckpt_name = f'checkpoints/{dataset}/bayesian/model_{net_type}_{layer_type}.pt'

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir, exist_ok=True)

    criterion = metrics.ELBO(len(trainset)).to(device)
    optimizer = Adam(net.parameters(), lr=lr_start)
    lr_sched = lr_scheduler.ReduceLROnPlateau(optimizer,
                                              patience=6,
                                              verbose=True)
    valid_loss_max = np.Inf
    for epoch in range(n_epochs):  # loop over the dataset multiple times
        cfg.curr_epoch_no = epoch

        train_loss, train_acc, train_kl = train_model(net,
                                                      optimizer,
                                                      criterion,
                                                      train_loader,
                                                      num_ens=train_ens,
                                                      beta_type=beta_type)
        valid_loss, valid_acc = validate_model(net,
                                               criterion,
                                               valid_loader,
                                               num_ens=valid_ens)
        lr_sched.step(valid_loss)

        print(
            'Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'
            .format(epoch, train_loss, train_acc, valid_loss, valid_acc,
                    train_kl))

        # save model if validation accuracy has increased
        if valid_loss <= valid_loss_max:
            print(
                'Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'
                .format(valid_loss_max, valid_loss))
            torch.save(net.state_dict(), ckpt_name)
            valid_loss_max = valid_loss
示例#2
0
def train_splitted(num_tasks, bayesian=True, net_type='lenet'):
    assert 10 % num_tasks == 0

    # Hyper Parameter settings
    train_ens = cfg.train_ens
    valid_ens = cfg.valid_ens
    n_epochs = cfg.n_epochs
    lr_start = cfg.lr_start

    if bayesian:
        ckpt_dir = f"checkpoints/MNIST/bayesian/splitted/{num_tasks}-tasks/"
    else:
        ckpt_dir = f"checkpoints/MNIST/frequentist/splitted/{num_tasks}-tasks/"
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir, exist_ok=True)

    loaders, datasets = mix_utils.get_splitmnist_dataloaders(
        num_tasks, return_datasets=True)
    models = mix_utils.get_splitmnist_models(num_tasks,
                                             bayesian=bayesian,
                                             pretrained=False,
                                             net_type=net_type)

    for task in range(1, num_tasks + 1):
        print(f"Training task-{task}..")
        trainset, testset, _, _ = datasets[task - 1]
        train_loader, valid_loader, _ = loaders[task - 1]
        net = models[task - 1]
        net = net.to(device)
        ckpt_name = ckpt_dir + f"model_{net_type}_{num_tasks}.{task}.pt"

        criterion = (metrics.ELBO(len(trainset))
                     if bayesian else nn.CrossEntropyLoss()).to(device)
        optimizer = Adam(net.parameters(), lr=lr_start)
        valid_loss_max = np.Inf
        for epoch in range(n_epochs):  # loop over the dataset multiple times
            utils.adjust_learning_rate(
                optimizer, metrics.lr_linear(epoch, 0, n_epochs, lr_start))

            if bayesian:
                train_loss, train_acc, train_kl = train_bayesian(
                    net, optimizer, criterion, train_loader, num_ens=train_ens)
                valid_loss, valid_acc = validate_bayesian(net,
                                                          criterion,
                                                          valid_loader,
                                                          num_ens=valid_ens)
                print(
                    'Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'
                    .format(epoch, train_loss, train_acc, valid_loss,
                            valid_acc, train_kl))
            else:
                train_loss, train_acc = train_frequentist(
                    net, optimizer, criterion, train_loader)
                valid_loss, valid_acc = validate_frequentist(
                    net, criterion, valid_loader)
                print(
                    'Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f}'
                    .format(epoch, train_loss, train_acc, valid_loss,
                            valid_acc))

            # save model if validation accuracy has increased
            if valid_loss <= valid_loss_max:
                print(
                    'Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'
                    .format(valid_loss_max, valid_loss))
                torch.save(net.state_dict(), ckpt_name)
                valid_loss_max = valid_loss

        print(f"Done training task-{task}")
def run(dataset, net_type, train=True):

    # Hyper Parameter settings
    train_ens = cfg.train_ens
    valid_ens = cfg.valid_ens
    test_ens = cfg.test_ens
    n_epochs = cfg.n_epochs
    lr_start = cfg.lr_start
    num_workers = cfg.num_workers
    valid_size = cfg.valid_size
    batch_size = cfg.batch_size

    trainset, testset, inputs, outputs = data.getDataset(dataset)
    train_loader, valid_loader, test_loader = data.getDataloader(
        trainset, testset, valid_size, batch_size, num_workers)
    net = getModel(net_type, inputs, outputs).to(device)

    ckpt_dir = f'checkpoints/{dataset}/bayesian'
    ckpt_name = f'checkpoints/{dataset}/bayesian/model_{net_type}.pt'

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir, exist_ok=True)

    criterion = metrics.ELBO(len(trainset)).to(device)

    if train:
        optimizer = Adam(net.parameters(), lr=lr_start)
        valid_loss_max = np.Inf
        for epoch in range(n_epochs):  # loop over the dataset multiple times
            cfg.curr_epoch_no = epoch
            utils.adjust_learning_rate(
                optimizer, metrics.lr_linear(epoch, 0, n_epochs, lr_start))

            train_loss, train_acc, train_kl = train_model(net,
                                                          optimizer,
                                                          criterion,
                                                          train_loader,
                                                          num_ens=train_ens)
            valid_loss, valid_acc = validate_model(net,
                                                   criterion,
                                                   valid_loader,
                                                   num_ens=valid_ens)

            print(
                'Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'
                .format(epoch, train_loss, train_acc, valid_loss, valid_acc,
                        train_kl))
            print(
                'Training Loss: {:.4f} \tTraining Likelihood Loss: {:.4f} \tTraining Kl Loss: {:.4f}'
                .format(train_loss, train_loss - train_kl, train_kl))

            # save model if validation accuracy has increased
            if valid_loss <= valid_loss_max:
                print(
                    'Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'
                    .format(valid_loss_max, valid_loss))
                torch.save(net.state_dict(), ckpt_name)
                valid_loss_max = valid_loss

    # test saved model
    best_model = getModel(net_type, inputs, outputs).to(device)
    best_model.load_state_dict(torch.load(ckpt_name))
    test_loss, test_acc = test_model(best_model,
                                     criterion,
                                     test_loader,
                                     num_ens=test_ens)
    print('Test Loss: {:.4f} \tTest Accuracy: {:.4f} '.format(
        test_loss, test_acc))
    print('Test uncertainities:')
    test_uncertainities(best_model, test_loader, num_ens=10)
示例#4
0
def run(dataset,
        net_type,
        checkpoint='None',
        prune_criterion='EmptyCrit',
        pruning_limit=0.0,
        lower_limit=0.5,
        local_pruning=False):
    # Hyper Parameter settings
    layer_type = cfg.layer_type
    activation_type = cfg.activation_type
    priors = cfg.priors

    train_ens = cfg.train_ens
    valid_ens = cfg.valid_ens
    n_epochs = cfg.n_epochs
    lr_start = cfg.lr_start
    num_workers = cfg.num_workers
    valid_size = cfg.valid_size
    batch_size = cfg.batch_size
    beta_type = cfg.beta_type

    # LOAD STRUCTURED PRUNED MODEL
    if net_type == 'customconv6':
        import pickle
        with open('/nfs/homedirs/ayle/model_conv6_0.5.pickle', 'rb') as f:
            pre_pruned_model = pickle.load(f)
    else:
        pre_pruned_model = None

    trainset, testset, inputs, outputs = data.getDataset(dataset)
    train_loader, valid_loader, test_loader = data.getDataloader(
        trainset, testset, valid_size, batch_size, num_workers)
    net = getModel(net_type, inputs, outputs, priors, layer_type,
                   activation_type, pre_pruned_model).to(device)

    # LOAD PRUNED UNSTRUCTURED MASK
    # import pickle
    # with open('/nfs/homedirs/ayle/mask.pickle', 'rb') as f:
    #     mask = pickle.load(f)
    #
    # mask_keys = list(mask.keys())
    #
    # count = 0
    # for name, module in net.named_modules():
    #     if name.startswith('conv') or name.startswith('fc'):
    #         module.mask = mask[mask_keys[count]]
    #         count += 1
    #         print(module.mask.sum().float() / torch.numel(module.mask))

    ckpt_dir = f'checkpoints/{dataset}/bayesian'
    ckpt_name = f'checkpoints/{dataset}/bayesian/model_{net_type}_{layer_type}_{activation_type}_{prune_criterion}_{pruning_limit}_after.pt'

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir, exist_ok=True)

    if checkpoint != 'None':
        net.load_state_dict(torch.load(checkpoint))

    if layer_type == 'mgp':
        criterion = metrics.ELBO2(len(trainset)).to(device)
    else:
        criterion = metrics.ELBO(len(trainset)).to(device)
    optimizer = Adam(net.parameters(), lr=lr_start)
    lr_sched = lr_scheduler.ReduceLROnPlateau(optimizer,
                                              patience=6,
                                              verbose=True)
    valid_loss_max = np.Inf

    if prune_criterion == 'SNIPit':
        pruning_criterion = SNIPit(limit=pruning_limit,
                                   model=net,
                                   lower_limit=lower_limit)
        pruning_criterion.prune(pruning_limit,
                                train_loader=train_loader,
                                local=local_pruning)
    elif prune_criterion == 'SNR':
        pruning_criterion = SNR(limit=pruning_limit,
                                model=net,
                                lower_limit=lower_limit)
        pruning_criterion.prune(pruning_limit,
                                train_loader=train_loader,
                                local=local_pruning)
    elif prune_criterion == 'StructuredSNR':
        pruning_criterion = StructuredSNR(limit=pruning_limit,
                                          model=net,
                                          lower_limit=lower_limit)
        # pruning_criterion.prune(pruning_limit, train_loader=train_loader, local=local_pruning)

    init_num_params = sum([
        np.prod(x.shape) for name, x in net.named_parameters()
        if "W_mu" in name
    ])
    new_num_params = init_num_params

    for epoch in range(n_epochs):  # loop over the dataset multiple times

        train_loss, train_acc, train_kl = train_model(net,
                                                      optimizer,
                                                      criterion,
                                                      train_loader,
                                                      num_ens=train_ens,
                                                      beta_type=beta_type,
                                                      epoch=epoch,
                                                      num_epochs=n_epochs,
                                                      layer_type=layer_type)
        valid_loss, valid_acc, _ = validate_model(net,
                                                  criterion,
                                                  valid_loader,
                                                  num_ens=valid_ens,
                                                  beta_type=beta_type,
                                                  epoch=epoch,
                                                  num_epochs=n_epochs,
                                                  layer_type=layer_type)
        lr_sched.step(valid_loss)

        print(
            'Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'
            .format(epoch, train_loss, train_acc, valid_loss, valid_acc,
                    train_kl))

        # save model if validation accuracy has increased
        if valid_loss <= valid_loss_max:
            print(
                'Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'
                .format(valid_loss_max, valid_loss))
            torch.save(net.state_dict(), ckpt_name)
            valid_loss_max = valid_loss

        # if epoch == 0 or epoch == 1:
        # if (epoch % 40 == 0) and (epoch > 1) and (epoch < 200) and (1 - new_num_params / init_num_params) < pruning_limit:
        #     net.zero_grad()
        #     optimizer.zero_grad()
        #
        #     with torch.no_grad():
        #         pruning_criterion.prune(0.1, train_loader=train_loader, local=local_pruning)
        #
        #     import pickle
        #     with open('testt', 'wb') as f:
        #         pickle.dump(net, f)
        #
        #     with open('testt', 'rb') as f:
        #         net = pickle.load(f).to(device)
        #
        #     net.post_init_implementation()
        #     criterion = metrics.ELBO(len(trainset)).to(device)
        #     optimizer = Adam(net.parameters(), lr=lr_start)
        #     lr_sched = lr_scheduler.ReduceLROnPlateau(optimizer, patience=6, verbose=True)
        #     valid_loss_max = np.Inf
        #     pruning_criterion = StructuredSNR(limit=pruning_limit, model=net, lower_limit=lower_limit)
        #
        #     new_num_params = sum([np.prod(x.shape) for name, x in net.named_parameters() if "W_mu" in name])
        #     print('Overall sparsity', 1 - new_num_params / init_num_params)

    # import pickle
    # with open(ckpt_name, 'wb') as f:
    #     pickle.dump(net, f)
    torch.save(net.state_dict(), ckpt_name)