Пример #1
0
def my_embed(folder, n):
    click.echo('computing embeddings')

    vae = Vae()
    ds_train = HemoDataset(root_folder=root_folder, training=True)
    ds_test = HemoDataset(root_folder=root_folder, training=False)
    batch_size = 512
    train_loader = data.DataLoader(ds_train, batch_size=batch_size, num_workers=8)
    test_loader = data.DataLoader(ds_test, batch_size=batch_size, num_workers=8)

    trainer = Trainer(vae)
    trainer.save_to_directory(folder)
    trainer.load()
    trainer.cuda()
    trainer.eval_mode()
    i = 0

    def compute_embedding_for_loader(loader, training_set: bool):
        with torch.no_grad():
            all_mu = []
            all_filenames = []
            with tqdm(total=len(loader)) as bar:
                for k, x in enumerate(loader):
                    nonlocal i
                    i += 1
                    if n != 0 and i > n:
                        break
                    if training_set:
                        image, targets = x
                    else:
                        image = x
                    image = trainer.to_device(image)
                    _, x_reconstructed_batch, mu_batch, _ = trainer.apply_model(image)
                    image = image.detach().cpu().numpy()
                    x_reconstructed_batch = x_reconstructed_batch.detach().cpu().numpy()
                    mu_batch = mu_batch.detach().cpu().numpy()
                    all_mu.append(mu_batch)
                    start = k * loader.batch_size
                    end = start + loader.batch_size
                    if training_set:
                        paths = loader.dataset.training_file_paths
                    else:
                        paths = loader.dataset.testing_file_paths
                    filenames = [os.path.basename(paths[j]) for j in range(start, end)]
                    all_filenames.append(filenames)
                    bar.update(1)

            all_mu = np.concatenate(all_mu, axis=0)
            all_filenames = list(itertools.chain.from_iterable(all_filenames))
            torch.cuda.empty_cache()
            gc.collect()
            return all_mu, all_filenames

    h5_path = os.path.join(folder, 'embeddings.h5')
    with h5py.File(h5_path, 'w') as f5:
        embeddings, filenames = compute_embedding_for_loader(train_loader, training_set=True)
        # breakpoint()
        f5['training_data/embeddings'] = embeddings
        pickle_path = os.path.join(folder, 'embeddings_training_data_row_names.pickle')
        pickle.dump(filenames, open(pickle_path, 'wb'))
        pushover_notification.send('training embeddings generated')

        embeddings, filenames = compute_embedding_for_loader(test_loader, training_set=False)
        f5['testing_data/embeddings'] = embeddings
        pickle_path = os.path.join(folder, 'embeddings_testing_data_row_names.pickle')
        pickle.dump(filenames, open(pickle_path, 'wb'))
        pushover_notification.send('testing embeddings generated')
Пример #2
0
def predict(folder, n, show_umap, out_dir):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    LOG_DIRECTORY = folder
    SAVE_DIRECTORY = folder
    DATASET_DIRECTORY = folder
    click.echo('Predict')

    model = Model(in_channels=38, k=3)
    ds_train = SpatialTranscriptomicsDs(root_folder=root_folder, k=model.k, training=False,
                                        divideable_by=model.divideable_by)
    train_loader = data.DataLoader(ds_train, batch_size=1,
                                   num_workers=0)
    if n == 0:
        n = len(ds_train)
    print('N', n)

    trainer = Trainer(model)
    trainer.save_to_directory(SAVE_DIRECTORY)
    trainer.load()
    trainer.eval_mode()
    with torch.no_grad():
        all_res = []
        all_muh = []
        all_targets = []
        all_ids = []
        num_cells = []
        for i in range(n):

            torch.cuda.empty_cache()
            gc.collect()
            print(i)
            img, mask, neighbours, targets = ds_train[i]
            n_cells = neighbours.shape[0]
            img = trainer.to_device(torch.from_numpy(img))
            mask = trainer.to_device(torch.from_numpy(mask))
            neighbours = trainer.to_device(torch.from_numpy(neighbours))

            unetres, rec, muh, logvar, nn = trainer.apply_model(img[None, ...], mask[None, ...], neighbours[None, ...])

            # img.detach()
            mask = mask.cpu().numpy()
            nn.detach()
            logvar.detach()
            rec = rec.detach().cpu().numpy()
            image = img.detach().cpu().numpy()  # [0,:,:].swapaxes(0,1)
            unetres = unetres.detach().cpu().numpy()  # [0,:,:].swapaxes(0,1)

            if show_umap:
                app = pg.mkQApp()
                viewer = LayerViewerWidget()
                viewer.setWindowTitle('LayerViewer')
                viewer.show()
                image = numpy.moveaxis(image, 0, 2)
                image = numpy.swapaxes(image, 0, 1)

                unetres = numpy.moveaxis(unetres[0, ...], 0, 2)
                unetres = numpy.swapaxes(unetres, 0, 1)
                mask = numpy.swapaxes(mask[0, ...], 0, 1)

                layer = MultiChannelImageLayer(name='img', data=image)
                viewer.addLayer(layer=layer)

                layer = MultiChannelImageLayer(name='umap', data=unetres)
                viewer.addLayer(layer=layer)

                # labels = numpy.zeros(image.shape[0:2], dtype='uint8')
                label_layer = ObjectLayer(name='labels', data=mask)
                viewer.addLayer(layer=label_layer)
                QtGui.QApplication.instance().exec_()

            muh = muh.detach().cpu().numpy()  # [0,:,:].swapaxes(0,1)
            res = muh
            all_res.append(res)
            all_ids.append(numpy.ones(n_cells) * i)

            # num_cells.append(n_cells)
            img_name = ds_train.image_filenames[i]
            img_name = os.path.basename(img_name)
            print(i, img_name)

            fname = os.path.join(out_dir, f'{img_name}.h5')
            print(fname)
            with h5py.File(fname, 'w') as f5:
                try:
                    del f5['labels']
                except:
                    pass
                f5['labels'] = numpy.swapaxes(mask[0, ...], 0, 1)

                try:
                    del f5['vae_embedding']
                except:
                    pass
                f5['vae_embedding'] = res

                try:
                    del f5['targets']
                except:
                    pass
                f5['targets'] = targets

                try:
                    del f5['masks']
                except:
                    pass
                f5['masks'] = mask

                try:
                    del f5['rec']
                except:
                    pass
                f5['rec'] = rec

        # res = numpy.concatenate(all_res, axis=0)
        all_ids = numpy.concatenate(all_ids, axis=0)

        if False:

            # import seaborn
            # import pandas as pd
            # seaborn.pairplot(pd.DataFrame(res[:,0:5]))
            # pylab.show()

            embedding, embedding_list = list_umap(x_list=all_res, n_neighbors=30, min_dist=0.0)

            for i in range(n):
                img_name = ds_train.image_filenames[i]
                img_name = os.path.basename(img_name)
                print(i, img_name)

                fname = os.path.join(out_dir, f'{img_name}.h5')
                f5file = h5py.File(fname, 'r+')

                # f5file['labels'] = numpy.swapaxes(mask[0,...],0,1)
                f5file['vae_embedding'] = all_res[i]
                f5file['umap_embedding'] = embedding_list[i]
                f5file.close()

            # A random colormap for matplotlib
            cmap = matplotlib.colors.ListedColormap(numpy.random.rand(1000000, 3))

            n = len(all_ids)
            perm = numpy.random.permutation(numpy.arange(n))

            pylab.scatter(x=embedding[perm, 0], y=embedding[perm, 1], c=all_ids[perm], s=12, edgecolors='face',
                          cmap=cmap)
            pylab.show()
    import pushover_notification
    pushover_notification.send('predictions generated')
Пример #3
0
if USE_CUDA:
    trainer.cuda()

# Go!
trainer.fit()


##############################################################################
# Predict with the trained network
# and visualize the results

# predict:
#trainer.load(best=True)
trainer.bind_loader('train', train_loader)
trainer.bind_loader('validate', validate_loader)
trainer.eval_mode()

if USE_CUDA:
    trainer.cuda()

# look at an example
for img,target in test_loader:
    if USE_CUDA:
        img = img.cuda()

    # softmax on each of the prediction
    preds = trainer.apply_model(img)
    preds = [nn.functional.softmax(pred,dim=1)        for pred in preds]
    preds = [unwrap(pred, as_numpy=True, to_cpu=True) for pred in preds]
    img    = unwrap(img,  as_numpy=True, to_cpu=True)
    target  = unwrap(target, as_numpy=True, to_cpu=True)