コード例 #1
0
ファイル: generate.py プロジェクト: Ma-Lab-Berkeley/MCR2
def gen_training_accuracy(args):
    # load data and model
    params = utils.load_params(args.model_dir)
    ckpt_dir = os.path.join(args.model_dir, 'checkpoints')
    ckpt_paths = [int(e[11:-3]) for e in os.listdir(ckpt_dir) if e[-3:] == ".pt"]
    ckpt_paths = np.sort(ckpt_paths)
    
    # csv
    headers = ["epoch", "acc_train", "acc_test"]
    csv_path = utils.create_csv(args.model_dir, 'accuracy.csv', headers)

    for epoch, ckpt_paths in enumerate(ckpt_paths):
        if epoch % 5 != 0:
            continue
        net, epoch = tf.load_checkpoint(args.model_dir, epoch=epoch, eval_=True)
        # load data
        train_transforms = tf.load_transforms('test')
        trainset = tf.load_trainset(params['data'], train_transforms, train=True)
        trainloader = DataLoader(trainset, batch_size=500, num_workers=4)
        train_features, train_labels = tf.get_features(net, trainloader, verbose=False)

        test_transforms = tf.load_transforms('test')
        testset = tf.load_trainset(params['data'], test_transforms, train=False)
        testloader = DataLoader(testset, batch_size=500, num_workers=4)
        test_features, test_labels = tf.get_features(net, testloader, verbose=False)

        acc_train, acc_test = svm(args, train_features, train_labels, test_features, test_labels)
        utils.save_state(args.model_dir, epoch, acc_train, acc_test, filename='accuracy.csv')
    print("Finished generating accuracy.")
コード例 #2
0
ファイル: generate.py プロジェクト: Ma-Lab-Berkeley/MCR2
def gen_testloss(args):
    # load data and model
    params = utils.load_params(args.model_dir)
    ckpt_dir = os.path.join(args.model_dir, 'checkpoints')
    ckpt_paths = [int(e[11:-3]) for e in os.listdir(ckpt_dir) if e[-3:] == ".pt"]
    ckpt_paths = np.sort(ckpt_paths)
    
    # csv
    headers = ["epoch", "step", "loss", "discrimn_loss_e", "compress_loss_e", 
        "discrimn_loss_t",  "compress_loss_t"]
    csv_path = utils.create_csv(args.model_dir, 'losses_test.csv', headers)
    print('writing to:', csv_path)

    # load data
    test_transforms = tf.load_transforms('test')
    testset = tf.load_trainset(params['data'], test_transforms, train=False)
    testloader = DataLoader(testset, batch_size=params['bs'], shuffle=False, num_workers=4)
    
    # save loss
    criterion = MaximalCodingRateReduction(gam1=params['gam1'], gam2=params['gam2'], eps=params['eps'])
    for epoch, ckpt_path in enumerate(ckpt_paths):
        net, epoch = tf.load_checkpoint(args.model_dir, epoch=epoch, eval_=True)
        for step, (batch_imgs, batch_lbls) in enumerate(testloader):
            features = net(batch_imgs.cuda())
            loss, loss_empi, loss_theo = criterion(features, batch_lbls, 
                                            num_classes=len(testset.num_classes))
            utils.save_state(args.model_dir, epoch, step, loss.item(), 
                *loss_empi, *loss_theo, filename='losses_test.csv')
    print("Finished generating test loss.")
コード例 #3
0
ファイル: evaluate_seq_ce.py プロジェクト: RobinWu218/mcr2
        '--cpb',
        type=int,
        default=10,
        help='number of classes in each learning batch (default: 10)')

    parser.add_argument('--save', action='store_true', help='save labels')
    parser.add_argument('--data_dir',
                        default='./data/',
                        help='path to dataset')
    args = parser.parse_args()

    print("evaluate using label_batch: {}".format(args.label_batch))

    params = utils.load_params(args.model_dir)
    # get train features and labels
    train_transforms = tf.load_transforms('test')
    trainset = tf.load_trainset(params['data'],
                                train_transforms,
                                train=True,
                                path=args.data_dir)
    if 'lcr' in params.keys():  # supervised corruption case
        trainset = tf.corrupt_labels(trainset, params['lcr'], params['lcs'])
    new_labels = trainset.targets
    assert (trainset.num_classes %
            args.cpb == 0), "Number of classes not divisible by cpb"
    ## load model
    net, epoch = tf.load_checkpoint_ce(args.model_dir,
                                       trainset.num_classes,
                                       args.epoch,
                                       eval_=True,
                                       label_batch_id=args.label_batch)
コード例 #4
0

## per model functions
def lr_schedule(epoch, optimizer):
    """decrease the learning rate"""
    lr = args.lr
    if epoch >= 400:
        lr = args.lr * 0.01
    elif epoch >= 200:
        lr = args.lr * 0.1
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


## Prepare for Training
transforms = tf.load_transforms(args.transform)
trainset = tf.load_trainset(args.data, transforms, path=args.data_dir)
#trainset = tf.corrupt_labels(trainset, args.lcr, args.lcs)
if args.pretrain_dir is not None:
    net, _ = tf.load_checkpoint(args.pretrain_dir, args.pretrain_epo)
    utils.update_params(model_dir, args.pretrain_dir)
else:
    net = tf.load_architectures_ce(args.arch, trainset.num_classes)
assert (trainset.num_classes %
        args.cpb == 0), "Number of classes not divisible by cpb"
classes = np.unique(trainset.targets)
class_batch_num = trainset.num_classes // args.cpb
class_batch_list = classes.reshape(class_batch_num, args.cpb)

#trainloader = DataLoader(trainset, batch_size=args.bs, drop_last=True, num_workers=4)
criterion = nn.CrossEntropyLoss()
コード例 #5
0
ファイル: plot.py プロジェクト: RobinWu218/mcr2
def plot_pca_epoch(args):
    """Plot PCA for different epochs in the same plot. """
    EPOCHS = [0, 10, 100, 500]

    params = utils.load_params(args.model_dir)
    transforms = tf.load_transforms('test')
    trainset = tf.load_trainset(params['data'], transforms)
    trainloader = DataLoader(trainset, batch_size=200, num_workers=4)

    sig_vals = []
    for epoch in EPOCHS:
        epoch_ = epoch - 1
        if epoch_ == -1:  # randomly initialized
            net = tf.load_architectures(params['arch'], params['fd'])
        else:
            net, epoch = tf.load_checkpoint(args.model_dir,
                                            epoch=epoch_,
                                            eval_=True)
        features, labels = tf.get_features(net, trainloader)
        if args.class_ is not None:
            features_sort, _ = utils.sort_dataset(
                features.numpy(),
                labels.numpy(),
                num_classes=trainset.num_classes,
                stack=False)
            features_ = features_sort[args.class_]
        else:
            features_ = features.numpy()
        n_comp = np.min([args.comp, features.shape[1]])
        pca = PCA(n_components=n_comp).fit(features_)
        sig_vals.append(pca.singular_values_)

    ## plot singular values
    plt.rc('text', usetex=True)
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['Times New Roman']
    fig, ax = plt.subplots(1, 1, figsize=(7, 5), dpi=400)
    x_min = np.min([len(sig_val) for sig_val in sig_vals])
    if args.class_ is not None:
        ax.set_xticks(np.arange(0, x_min, 10))
        ax.set_yticks(np.linspace(0, 40, 9))
        ax.set_ylim(0, 40)
    else:
        ax.set_xticks(np.arange(0, x_min, 10))
        ax.set_yticks(np.linspace(0, 80, 9))
        ax.set_ylim(0, 90)
    for epoch, sig_val in zip(EPOCHS, sig_vals):
        ax.plot(np.arange(x_min),
                sig_val[:x_min],
                marker='',
                markersize=5,
                label=f'epoch - {epoch}',
                alpha=0.6)
    ax.legend(loc='upper right',
              frameon=True,
              fancybox=True,
              prop={"size": 8},
              ncol=1,
              framealpha=0.5)
    ax.set_xlabel("components")
    ax.set_ylabel("sigular values")
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    [tick.label.set_fontsize(12) for tick in ax.xaxis.get_major_ticks()]
    [tick.label.set_fontsize(12) for tick in ax.yaxis.get_major_ticks()]
    ax.grid(True, color='white')
    ax.set_facecolor('whitesmoke')
    fig.tight_layout()

    ## save
    save_dir = os.path.join(args.model_dir, 'figures', 'pca')
    np.save(os.path.join(save_dir, "sig_vals_epoch.npy"), sig_vals)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    file_name = os.path.join(save_dir, f"pca_class{args.class_}.png")
    fig.savefig(file_name)
    print("Plot saved to: {}".format(file_name))
    file_name = os.path.join(save_dir, f"pca_class{args.class_}.pdf")
    fig.savefig(file_name)
    print("Plot saved to: {}".format(file_name))
    plt.close()
コード例 #6
0
ファイル: plot.py プロジェクト: RobinWu218/mcr2
    if args.traintest:
        path = os.path.join(args.model_dir, 'losses_test.csv')
        if not os.path.exists(path):
            gen_testloss(args)
        plot_traintest(args, path)
    if args.acc:
        path = os.path.join(args.model_dir, 'accuracy.csv')
        if not os.path.exists(path):
            gen_training_accuracy(args)
        plot_accuracy(args, path)

    if args.pca or args.hist or args.heat or args.nearcomp_sup or args.nearcomp_unsup or args.nearcomp_class:
        ## load data and model
        params = utils.load_params(args.model_dir)
        net, epoch = tf.load_checkpoint(args.model_dir, args.epoch, eval_=True)
        transforms = tf.load_transforms('test')
        trainset = tf.load_trainset(params['data'], transforms)
        if 'lcr' in params.keys():  # supervised corruption case
            trainset = tf.corrupt_labels(trainset, params['lcr'],
                                         params['lcs'])
        trainloader = DataLoader(trainset, batch_size=200, num_workers=4)
        features, labels = tf.get_features(net, trainloader)

    if args.pca:
        plot_pca(args, features, labels, epoch)
    if args.nearcomp_sup:
        plot_nearest_component_supervised(args, features, labels, epoch,
                                          trainset)
    if args.nearcomp_unsup:
        plot_nearest_component_unsupervised(args, features, labels, epoch,
                                            trainset)