## Pipelines Setup
model_dir = os.path.join(
    args.save_dir,
    'selfsup_{}+{}_{}_epo{}_bs{}_aug{}+{}_lr{}_mom{}_wd{}_gam1{}_gam2{}_eps{}{}'
    .format(args.arch, args.fd, args.data, args.epo, args.bs, args.aug,
            args.transform, args.lr, args.mom, args.wd, args.gam1, args.gam2,
            args.eps, args.tail))
utils.init_pipeline(model_dir)

## Prepare for Training
if args.pretrain_dir is not None:
    net, _ = tf.load_checkpoint(args.pretrain_dir, args.pretrain_epo)
    utils.update_params(model_dir, args.pretrain_dir)
else:
    net = tf.load_architectures(args.arch, args.fd)
transforms = tf.load_transforms(args.transform)
trainset = tf.load_trainset(args.data, path=args.data_dir)
trainloader = AugmentLoader(trainset,
                            transforms=transforms,
                            sampler=args.sampler,
                            batch_size=args.bs,
                            num_aug=args.aug)

criterion = MaximalCodingRateReduction(gam1=args.gam1,
                                       gam2=args.gam2,
                                       eps=args.eps)
optimizer = optim.SGD(net.parameters(),
                      lr=args.lr,
                      momentum=args.mom,
                      weight_decay=args.wd)
Exemple #2
0
def plot_pca_epoch(args):
    """Plot PCA for different epochs in the same plot. """
    EPOCHS = [0, 10, 100, 500]

    params = utils.load_params(args.model_dir)
    transforms = tf.load_transforms('test')
    trainset = tf.load_trainset(params['data'], transforms)
    trainloader = DataLoader(trainset, batch_size=200, num_workers=4)

    sig_vals = []
    for epoch in EPOCHS:
        epoch_ = epoch - 1
        if epoch_ == -1:  # randomly initialized
            net = tf.load_architectures(params['arch'], params['fd'])
        else:
            net, epoch = tf.load_checkpoint(args.model_dir,
                                            epoch=epoch_,
                                            eval_=True)
        features, labels = tf.get_features(net, trainloader)
        if args.class_ is not None:
            features_sort, _ = utils.sort_dataset(
                features.numpy(),
                labels.numpy(),
                num_classes=trainset.num_classes,
                stack=False)
            features_ = features_sort[args.class_]
        else:
            features_ = features.numpy()
        n_comp = np.min([args.comp, features.shape[1]])
        pca = PCA(n_components=n_comp).fit(features_)
        sig_vals.append(pca.singular_values_)

    ## plot singular values
    plt.rc('text', usetex=True)
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.serif'] = ['Times New Roman']
    fig, ax = plt.subplots(1, 1, figsize=(7, 5), dpi=400)
    x_min = np.min([len(sig_val) for sig_val in sig_vals])
    if args.class_ is not None:
        ax.set_xticks(np.arange(0, x_min, 10))
        ax.set_yticks(np.linspace(0, 40, 9))
        ax.set_ylim(0, 40)
    else:
        ax.set_xticks(np.arange(0, x_min, 10))
        ax.set_yticks(np.linspace(0, 80, 9))
        ax.set_ylim(0, 90)
    for epoch, sig_val in zip(EPOCHS, sig_vals):
        ax.plot(np.arange(x_min),
                sig_val[:x_min],
                marker='',
                markersize=5,
                label=f'epoch - {epoch}',
                alpha=0.6)
    ax.legend(loc='upper right',
              frameon=True,
              fancybox=True,
              prop={"size": 8},
              ncol=1,
              framealpha=0.5)
    ax.set_xlabel("components")
    ax.set_ylabel("sigular values")
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    [tick.label.set_fontsize(12) for tick in ax.xaxis.get_major_ticks()]
    [tick.label.set_fontsize(12) for tick in ax.yaxis.get_major_ticks()]
    ax.grid(True, color='white')
    ax.set_facecolor('whitesmoke')
    fig.tight_layout()

    ## save
    save_dir = os.path.join(args.model_dir, 'figures', 'pca')
    np.save(os.path.join(save_dir, "sig_vals_epoch.npy"), sig_vals)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    file_name = os.path.join(save_dir, f"pca_class{args.class_}.png")
    fig.savefig(file_name)
    print("Plot saved to: {}".format(file_name))
    file_name = os.path.join(save_dir, f"pca_class{args.class_}.pdf")
    fig.savefig(file_name)
    print("Plot saved to: {}".format(file_name))
    plt.close()