def main(n_images, n_tissues, n_patches, patch_size, model_file_id): logger.info('Initializing cluster_classify script') dataset = Dataset(n_tissues=n_tissues, n_images=n_images) data = dataset.sample_data(patch_size, n_patches) patches, GTEx_IDs = data image_objs = [Image(x) for x in GTEx_IDs] dataset_name = ''.join([s for s in str(dataset) if s.isalnum()]) features_ID = dataset_name + f'_{n_patches}_{patch_size}_{n_images}' \ + model_file_id features = generate_features(features_ID, patches, model_file_id) a_features, a_image_objs = aggregate_features(dataset_name, features, image_objs, 'GTEx_IDs', np.mean) a_features, a_image_objs = aggregated_features['GTEx_factor_IDs'][ 'np.mean'] lung_features, lung_image_objs = subselect_tissue(dataset_name, 'Lung', features, image_objs) train_classifiers(dataset_name, features_ID, lung_features, lung_image_objs, 'GTEx_IDs', retrain=True)
def main(n_tissues, n_images, n_patches, patch_size, model_type, param_string): np.random.seed(42) os.makedirs('data/images', exist_ok=True) dataset = Dataset(n_tissues=n_tissues, n_images=n_images) logger.debug('Initializing download script') params = extract_params(param_string) params['patch_size'] = patch_size N = dataset.n_tissues * dataset.n_images * params['batch_size'] data = dataset.sample_data(patch_size, int(n_patches)) patches_data, imageIDs_data = data if model_type == 'concrete_vae': from dependencies.vae_concrete.vae_concrete import VAE m = VAE(latent_cont_dim=256) m.fit(patches_data, num_epochs=20) else: Model = eval(model_type) m = Model(inner_dim=params['inner_dim']) N = patches_data.shape[0] assert N == imageIDs_data.shape[0] p = np.random.permutation(N) patches_data, imageIDs_data = patches_data[p], imageIDs_data[p] m.train_on_data(patches_data, params) m.save()
def main(): logger.info('Initializing debug script') dataset = Dataset(n_tissues=6, n_images=10) data = dataset.sample_data(128, 50) patches_data, imageIDs_data = data for i in tqdm(range(len(imageIDs_data))): GTEx_ID = imageIDs_data[i] idx = i % 50 scipy.misc.imsave( f'data/cellprofiler/patches/{i:04d}_{GTEx_ID}_{idx}.png', 255 - patches_data[i])
def main(n_tissues, n_images, n_patches, patch_size, model_file): logger.info('Initializing inspect script') dataset = Dataset(n_tissues=n_tissues, n_images=n_images) data = dataset.sample_data(patch_size, 15) patches_data, imageIDs_data = data K = 5 N = patches_data.shape[0] idx = np.random.choice(range(N), K) patches = patches_data[idx] if model_file: # fig, ax = plt.subplots( # 2, K, figsize=(8, 3) # ) fig = plt.figure() figsize = 128 figure = np.zeros((figsize * 2, figsize * K, 3)) model = load_model(MODEL_PATH + f'{model_file}.pkl') decoded_patches = model.predict(patches) fig.suptitle(model_file, fontsize=10) for i in range(K): figure[0 * figsize:(0 + 1) * figsize, i * figsize:(i + 1) * figsize, :] = deprocess(patches[i]) figure[1 * figsize:(1 + 1) * figsize, i * figsize:(i + 1) * figsize, :] = deprocess(decoded_patches[i]) # ax[0][i].imshow(deprocess(patches[i])) # ax[0][i].axis('off') # ax[1][i].imshow(deprocess(decoded_patches[i])) # ax[1][i].axis('off') plt.imshow(figure) fig.savefig(f'figures/{model_file}.png', bbox_inches='tight') else: model_files = sorted(os.listdir(MODEL_PATH)) n = len(model_files) fig, ax = plt.subplots(2 * n, K, figsize=(8, 4 * n)) for (k, model_file) in enumerate(model_files): model_name = model_file.replace('.pkl', '') model = load_model(MODEL_PATH + f'{model_name}.pkl') logger.debug(f'Generating decodings for {model_file}') decoded_patches = model.predict(patches) for i in range(K): ax[2 * k][i].imshow(deprocess(patches[i])) ax[2 * k][i].axis('off') if i == int(K / 2): ax[2 * k][i].set_title(model_file) ax[2 * k + 1][i].imshow(deprocess(decoded_patches[i])) ax[2 * k + 1][i].axis('off') plt.savefig(f'figures/all_models.png')