def main():
    checkpoint_dir = os.path.join(out_dir, 'chkpts')
    batch_size = config['training']['batch_size']

    if 'cifar' in config['data']['train_dir'].lower():
        name = 'cifar10'
    elif 'stacked_mnist' == config['data']['type']:
        name = 'stacked_mnist'
    else:
        name = 'image'

    if os.path.exists(os.path.join(out_dir, 'cluster_preds.npz')):
        # if we've already computed assignments, load them and move on
        with np.load(os.path.join(out_dir, 'cluster_preds.npz')) as f:
            y_reals = f['y_reals']
            y_preds = f['y_preds']
    else:
        train_dataset, _ = get_dataset(name=name,
                                       data_dir=config['data']['train_dir'],
                                       size=config['data']['img_size'])

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            num_workers=config['training']['nworkers'],
            shuffle=True,
            pin_memory=True,
            sampler=None,
            drop_last=True)

        checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)

        print('Loading clusterer:')
        most_recent = utils.get_most_recent(
            checkpoint_dir,
            'model') if args.model_it is None else args.model_it
        clusterer = checkpoint_io.load_clusterer(
            most_recent, load_samples=False, pretrained=config['pretrained'])

        if isinstance(clusterer.discriminator, nn.DataParallel):
            clusterer.discriminator = clusterer.discriminator.module

        y_preds = []
        y_reals = []

        for batch_num, (x_real, y_real) in enumerate(
                tqdm(train_loader, total=len(train_loader))):
            y_pred = clusterer.get_labels(x_real.cuda(), None)
            y_preds.append(y_pred.detach().cpu())
            y_reals.append(y_real)

        y_reals = torch.cat(y_reals).numpy()
        y_preds = torch.cat(y_preds).numpy()

        np.savez(os.path.join(out_dir, 'cluster_preds.npz'),
                 y_reals=y_reals,
                 y_preds=y_preds)

    if args.random:
        y_preds = np.random.randint(0, 100, size=y_reals.shape)

    nmi_score = nmi(y_preds, y_reals)
    purity = purity_score(y_preds, y_reals)
    print('nmi', nmi_score, 'purity', purity)
Пример #2
0
def main():
    checkpoint_dir = os.path.join(out_dir, 'chkpts')

    most_recent = utils.get_most_recent(
        checkpoint_dir, 'model') if args.model_it is None else args.model_it

    cluster_path = os.path.join(out_dir, 'clusters')
    print('Saving clusters/samples to', cluster_path)

    os.makedirs(cluster_path, exist_ok=True)

    shutil.copyfile('seeing/lightbox.html',
                    os.path.join(cluster_path, '+lightbox.html'))

    checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)

    most_recent = utils.get_most_recent(
        checkpoint_dir, 'model') if args.model_it is None else args.model_it
    clusterer = checkpoint_io.load_clusterer(most_recent,
                                             pretrained=config['pretrained'],
                                             load_samples=False)

    if isinstance(clusterer.discriminator, nn.DataParallel):
        clusterer.discriminator = clusterer.discriminator.module

    model_path = os.path.join(checkpoint_dir, 'model_%08d.pt' % most_recent)
    sampler = SeededSampler(args.config,
                            model_path=model_path,
                            clusterer_path=os.path.join(
                                checkpoint_dir, f'clusterer{most_recent}.pkl'),
                            pretrained=config['pretrained'])

    if args.show_clusters:
        clusters = [[] for _ in range(config['generator']['nlabels'])]
        train_dataset, _ = get_dataset(
            name='webp' if 'cifar' not in config['data']['train_dir'].lower()
            else 'cifar10',
            data_dir=config['data']['train_dir'],
            size=config['data']['img_size'])

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=config['training']['batch_size'],
            num_workers=config['training']['nworkers'],
            shuffle=True,
            pin_memory=True,
            sampler=None,
            drop_last=True)

        print('Generating clusters')
        for batch_num, (x_real, y_gt) in enumerate(train_loader):
            x_real = x_real.cuda()
            y_pred = clusterer.get_labels(x_real, y_gt)

            for i, yi in enumerate(y_pred):
                clusters[yi].append(x_real[i].cpu())

            # don't generate too many, we're only visualizing 20 per cluster
            if batch_num * config['training']['batch_size'] >= 10000:
                break
    else:
        clusters = [None] * config['generator']['nlabels']

    nimgs = 20
    nrows = 4

    for i in range(len(clusters)):
        if clusters[i] is None:
            pass
        elif len(clusters[i]) >= nimgs:
            cluster = torch.stack(clusters[i])[:nimgs]

            torchvision.utils.save_image(cluster * 0.5 + 0.5,
                                         os.path.join(cluster_path,
                                                      f'{i}_real.png'),
                                         nrow=nrows)
        generated = []
        for seed in range(nimgs):
            img = sampler.conditional_sample(i, seed=seed)
            generated.append(img.detach().cpu())
        generated = torch.cat(generated)

        torchvision.utils.save_image(generated * 0.5 + 0.5,
                                     os.path.join(cluster_path,
                                                  f'{i}_gen.png'),
                                     nrow=nrows)

    print('Clusters/samples can be visualized under', cluster_path)