def predict(folder, n, show_umap, out_dir): if not os.path.exists(out_dir): os.makedirs(out_dir) LOG_DIRECTORY = folder SAVE_DIRECTORY = folder DATASET_DIRECTORY = folder click.echo('Predict') model = Model(in_channels=38, k=3) ds_train = SpatialTranscriptomicsDs(root_folder=root_folder, k=model.k, training=False, divideable_by=model.divideable_by) train_loader = data.DataLoader(ds_train, batch_size=1, num_workers=0) if n == 0: n = len(ds_train) print('N', n) trainer = Trainer(model) trainer.save_to_directory(SAVE_DIRECTORY) trainer.load() trainer.eval_mode() with torch.no_grad(): all_res = [] all_muh = [] all_targets = [] all_ids = [] num_cells = [] for i in range(n): torch.cuda.empty_cache() gc.collect() print(i) img, mask, neighbours, targets = ds_train[i] n_cells = neighbours.shape[0] img = trainer.to_device(torch.from_numpy(img)) mask = trainer.to_device(torch.from_numpy(mask)) neighbours = trainer.to_device(torch.from_numpy(neighbours)) unetres, rec, muh, logvar, nn = trainer.apply_model(img[None, ...], mask[None, ...], neighbours[None, ...]) # img.detach() mask = mask.cpu().numpy() nn.detach() logvar.detach() rec = rec.detach().cpu().numpy() image = img.detach().cpu().numpy() # [0,:,:].swapaxes(0,1) unetres = unetres.detach().cpu().numpy() # [0,:,:].swapaxes(0,1) if show_umap: app = pg.mkQApp() viewer = LayerViewerWidget() viewer.setWindowTitle('LayerViewer') viewer.show() image = numpy.moveaxis(image, 0, 2) image = numpy.swapaxes(image, 0, 1) unetres = numpy.moveaxis(unetres[0, ...], 0, 2) unetres = numpy.swapaxes(unetres, 0, 1) mask = numpy.swapaxes(mask[0, ...], 0, 1) layer = MultiChannelImageLayer(name='img', data=image) viewer.addLayer(layer=layer) layer = MultiChannelImageLayer(name='umap', data=unetres) viewer.addLayer(layer=layer) # labels = numpy.zeros(image.shape[0:2], dtype='uint8') label_layer = ObjectLayer(name='labels', data=mask) viewer.addLayer(layer=label_layer) QtGui.QApplication.instance().exec_() muh = muh.detach().cpu().numpy() # [0,:,:].swapaxes(0,1) res = muh all_res.append(res) all_ids.append(numpy.ones(n_cells) * i) # num_cells.append(n_cells) img_name = ds_train.image_filenames[i] img_name = os.path.basename(img_name) print(i, img_name) fname = os.path.join(out_dir, f'{img_name}.h5') print(fname) with h5py.File(fname, 'w') as f5: try: del f5['labels'] except: pass f5['labels'] = numpy.swapaxes(mask[0, ...], 0, 1) try: del f5['vae_embedding'] except: pass f5['vae_embedding'] = res try: del f5['targets'] except: pass f5['targets'] = targets try: del f5['masks'] except: pass f5['masks'] = mask try: del f5['rec'] except: pass f5['rec'] = rec # res = numpy.concatenate(all_res, axis=0) all_ids = numpy.concatenate(all_ids, axis=0) if False: # import seaborn # import pandas as pd # seaborn.pairplot(pd.DataFrame(res[:,0:5])) # pylab.show() embedding, embedding_list = list_umap(x_list=all_res, n_neighbors=30, min_dist=0.0) for i in range(n): img_name = ds_train.image_filenames[i] img_name = os.path.basename(img_name) print(i, img_name) fname = os.path.join(out_dir, f'{img_name}.h5') f5file = h5py.File(fname, 'r+') # f5file['labels'] = numpy.swapaxes(mask[0,...],0,1) f5file['vae_embedding'] = all_res[i] f5file['umap_embedding'] = embedding_list[i] f5file.close() # A random colormap for matplotlib cmap = matplotlib.colors.ListedColormap(numpy.random.rand(1000000, 3)) n = len(all_ids) perm = numpy.random.permutation(numpy.arange(n)) pylab.scatter(x=embedding[perm, 0], y=embedding[perm, 1], c=all_ids[perm], s=12, edgecolors='face', cmap=cmap) pylab.show() import pushover_notification pushover_notification.send('predictions generated')
# predict: #trainer.load(best=True) trainer.bind_loader('train', train_loader) trainer.bind_loader('validate', validate_loader) trainer.eval_mode() if USE_CUDA: trainer.cuda() # look at an example for img,target in test_loader: if USE_CUDA: img = img.cuda() # softmax on each of the prediction preds = trainer.apply_model(img) preds = [nn.functional.softmax(pred,dim=1) for pred in preds] preds = [unwrap(pred, as_numpy=True, to_cpu=True) for pred in preds] img = unwrap(img, as_numpy=True, to_cpu=True) target = unwrap(target, as_numpy=True, to_cpu=True) n_plots = len(preds) + 2 batch_size = preds[0].shape[0] for b in range(batch_size): fig = pylab.figure() ax1 = fig.add_subplot(2,4,1) ax1.set_title('image') ax1.imshow(img[b,0,...])