def main(): checkpoint_dir = os.path.join(out_dir, 'chkpts') batch_size = config['training']['batch_size'] if 'cifar' in config['data']['train_dir'].lower(): name = 'cifar10' elif 'stacked_mnist' == config['data']['type']: name = 'stacked_mnist' else: name = 'image' if os.path.exists(os.path.join(out_dir, 'cluster_preds.npz')): # if we've already computed assignments, load them and move on with np.load(os.path.join(out_dir, 'cluster_preds.npz')) as f: y_reals = f['y_reals'] y_preds = f['y_preds'] else: train_dataset, _ = get_dataset(name=name, data_dir=config['data']['train_dir'], size=config['data']['img_size']) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, num_workers=config['training']['nworkers'], shuffle=True, pin_memory=True, sampler=None, drop_last=True) checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir) print('Loading clusterer:') most_recent = utils.get_most_recent( checkpoint_dir, 'model') if args.model_it is None else args.model_it clusterer = checkpoint_io.load_clusterer( most_recent, load_samples=False, pretrained=config['pretrained']) if isinstance(clusterer.discriminator, nn.DataParallel): clusterer.discriminator = clusterer.discriminator.module y_preds = [] y_reals = [] for batch_num, (x_real, y_real) in enumerate( tqdm(train_loader, total=len(train_loader))): y_pred = clusterer.get_labels(x_real.cuda(), None) y_preds.append(y_pred.detach().cpu()) y_reals.append(y_real) y_reals = torch.cat(y_reals).numpy() y_preds = torch.cat(y_preds).numpy() np.savez(os.path.join(out_dir, 'cluster_preds.npz'), y_reals=y_reals, y_preds=y_preds) if args.random: y_preds = np.random.randint(0, 100, size=y_reals.shape) nmi_score = nmi(y_preds, y_reals) purity = purity_score(y_preds, y_reals) print('nmi', nmi_score, 'purity', purity)
def main(): checkpoint_dir = os.path.join(out_dir, 'chkpts') most_recent = utils.get_most_recent( checkpoint_dir, 'model') if args.model_it is None else args.model_it cluster_path = os.path.join(out_dir, 'clusters') print('Saving clusters/samples to', cluster_path) os.makedirs(cluster_path, exist_ok=True) shutil.copyfile('seeing/lightbox.html', os.path.join(cluster_path, '+lightbox.html')) checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir) most_recent = utils.get_most_recent( checkpoint_dir, 'model') if args.model_it is None else args.model_it clusterer = checkpoint_io.load_clusterer(most_recent, pretrained=config['pretrained'], load_samples=False) if isinstance(clusterer.discriminator, nn.DataParallel): clusterer.discriminator = clusterer.discriminator.module model_path = os.path.join(checkpoint_dir, 'model_%08d.pt' % most_recent) sampler = SeededSampler(args.config, model_path=model_path, clusterer_path=os.path.join( checkpoint_dir, f'clusterer{most_recent}.pkl'), pretrained=config['pretrained']) if args.show_clusters: clusters = [[] for _ in range(config['generator']['nlabels'])] train_dataset, _ = get_dataset( name='webp' if 'cifar' not in config['data']['train_dir'].lower() else 'cifar10', data_dir=config['data']['train_dir'], size=config['data']['img_size']) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config['training']['batch_size'], num_workers=config['training']['nworkers'], shuffle=True, pin_memory=True, sampler=None, drop_last=True) print('Generating clusters') for batch_num, (x_real, y_gt) in enumerate(train_loader): x_real = x_real.cuda() y_pred = clusterer.get_labels(x_real, y_gt) for i, yi in enumerate(y_pred): clusters[yi].append(x_real[i].cpu()) # don't generate too many, we're only visualizing 20 per cluster if batch_num * config['training']['batch_size'] >= 10000: break else: clusters = [None] * config['generator']['nlabels'] nimgs = 20 nrows = 4 for i in range(len(clusters)): if clusters[i] is None: pass elif len(clusters[i]) >= nimgs: cluster = torch.stack(clusters[i])[:nimgs] torchvision.utils.save_image(cluster * 0.5 + 0.5, os.path.join(cluster_path, f'{i}_real.png'), nrow=nrows) generated = [] for seed in range(nimgs): img = sampler.conditional_sample(i, seed=seed) generated.append(img.detach().cpu()) generated = torch.cat(generated) torchvision.utils.save_image(generated * 0.5 + 0.5, os.path.join(cluster_path, f'{i}_gen.png'), nrow=nrows) print('Clusters/samples can be visualized under', cluster_path)