def main(spec, num_samples, pool): checkpoint_dir = os.path.join(CHECKPOINT_ROOT, spec) model_type, model_args, dataset_names = spec_util.parse_setup_spec(spec) if model_type == 'VAE': model = vae.VAE(model_args) trainer = vae.Trainer(model, beta=4.) trainer.cuda() models.load_checkpoint(trainer, checkpoint_dir) model.eval() sample_latent = model.sample_latent(num_samples) sample_imgs = model.dec(sample_latent) elif model_type in ['GAN', 'GANmc']: model = gan.GAN(model_args) trainer = gan.Trainer(model) trainer.cuda() models.load_checkpoint(trainer, checkpoint_dir) model.eval() sample_imgs = model(num_samples) else: raise ValueError(f"Invalid model type: {model_type}") print(f"Loaded model {checkpoint_dir}. Measuring samples...") sample_imgs_np = sample_imgs.detach().cpu().squeeze().numpy() sample_metrics = measure.measure_batch(sample_imgs_np, pool=pool) os.makedirs(METRICS_ROOT, exist_ok=True) metrics_path = os.path.join(METRICS_ROOT, f"{spec}_metrics.csv") sample_metrics.to_csv(metrics_path, index_label='index') print(f"Morphometrics saved to {metrics_path}")
def measure_dir(data_dir, pool): for name in ['t10k', 'train']: in_path = os.path.join(data_dir, name + "-images-idx3-ubyte.gz") out_path = os.path.join(data_dir, name + "-morpho.csv") print(f"Processing MNIST data file {in_path}...") data = io.load_idx(in_path) df = measure.measure_batch(data, pool=pool, chunksize=100) df.to_csv(out_path, index_label='index') print(f"Morphometrics saved to {out_path}")
def main(checkpoint_dir, pcorr_dir=None): spec = os.path.split(checkpoint_dir)[-1] _, latent_dims, dataset_names = spec_util.parse_setup_spec(spec) device = torch.device('cuda') gan = infogan.InfoGAN(*latent_dims) trainer = infogan.Trainer(gan).to(device) load_checkpoint(trainer, checkpoint_dir) gan.eval() dataset_name = SPEC_TO_DATASET['+'.join(dataset_names)] data_dirs = [os.path.join(DATA_ROOT, dataset_name)] test_metrics, test_images, test_labels, test_which = load_test_data( data_dirs) print(test_metrics.head()) idx = np.random.permutation(10000) #[:1000] X = torch.from_numpy( test_images[idx]).float().unsqueeze(1).to(device) / 255. cols = ['length', 'thickness', 'slant', 'width', 'height'] test_cols = cols[:] test_hrule = None if 'swel+frac' in spec: add_swel_frac(data_dirs[0], test_metrics) test_cols += ['swel', 'frac'] test_hrule = len(cols) if pcorr_dir is None: pcorr_dir = checkpoint_dir os.makedirs(pcorr_dir, exist_ok=True) process(gan, X, test_metrics.loc[idx], test_cols, pcorr_dir, spec, 'test', test_hrule) X_ = gan(1000).detach() with multiprocessing.Pool() as pool: sample_metrics = measure.measure_batch(X_.cpu().squeeze().numpy(), pool=pool) sample_hrule = None process(gan, X_, sample_metrics, cols, pcorr_dir, spec, 'sample', sample_hrule)