示例#1
0
def predict(folder, n, show_umap, out_dir):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    LOG_DIRECTORY = folder
    SAVE_DIRECTORY = folder
    DATASET_DIRECTORY = folder
    click.echo('Predict')

    model = Model(in_channels=38, k=3)
    ds_train = SpatialTranscriptomicsDs(root_folder=root_folder, k=model.k, training=False,
                                        divideable_by=model.divideable_by)
    train_loader = data.DataLoader(ds_train, batch_size=1,
                                   num_workers=0)
    if n == 0:
        n = len(ds_train)
    print('N', n)

    trainer = Trainer(model)
    trainer.save_to_directory(SAVE_DIRECTORY)
    trainer.load()
    trainer.eval_mode()
    with torch.no_grad():
        all_res = []
        all_muh = []
        all_targets = []
        all_ids = []
        num_cells = []
        for i in range(n):

            torch.cuda.empty_cache()
            gc.collect()
            print(i)
            img, mask, neighbours, targets = ds_train[i]
            n_cells = neighbours.shape[0]
            img = trainer.to_device(torch.from_numpy(img))
            mask = trainer.to_device(torch.from_numpy(mask))
            neighbours = trainer.to_device(torch.from_numpy(neighbours))

            unetres, rec, muh, logvar, nn = trainer.apply_model(img[None, ...], mask[None, ...], neighbours[None, ...])

            # img.detach()
            mask = mask.cpu().numpy()
            nn.detach()
            logvar.detach()
            rec = rec.detach().cpu().numpy()
            image = img.detach().cpu().numpy()  # [0,:,:].swapaxes(0,1)
            unetres = unetres.detach().cpu().numpy()  # [0,:,:].swapaxes(0,1)

            if show_umap:
                app = pg.mkQApp()
                viewer = LayerViewerWidget()
                viewer.setWindowTitle('LayerViewer')
                viewer.show()
                image = numpy.moveaxis(image, 0, 2)
                image = numpy.swapaxes(image, 0, 1)

                unetres = numpy.moveaxis(unetres[0, ...], 0, 2)
                unetres = numpy.swapaxes(unetres, 0, 1)
                mask = numpy.swapaxes(mask[0, ...], 0, 1)

                layer = MultiChannelImageLayer(name='img', data=image)
                viewer.addLayer(layer=layer)

                layer = MultiChannelImageLayer(name='umap', data=unetres)
                viewer.addLayer(layer=layer)

                # labels = numpy.zeros(image.shape[0:2], dtype='uint8')
                label_layer = ObjectLayer(name='labels', data=mask)
                viewer.addLayer(layer=label_layer)
                QtGui.QApplication.instance().exec_()

            muh = muh.detach().cpu().numpy()  # [0,:,:].swapaxes(0,1)
            res = muh
            all_res.append(res)
            all_ids.append(numpy.ones(n_cells) * i)

            # num_cells.append(n_cells)
            img_name = ds_train.image_filenames[i]
            img_name = os.path.basename(img_name)
            print(i, img_name)

            fname = os.path.join(out_dir, f'{img_name}.h5')
            print(fname)
            with h5py.File(fname, 'w') as f5:
                try:
                    del f5['labels']
                except:
                    pass
                f5['labels'] = numpy.swapaxes(mask[0, ...], 0, 1)

                try:
                    del f5['vae_embedding']
                except:
                    pass
                f5['vae_embedding'] = res

                try:
                    del f5['targets']
                except:
                    pass
                f5['targets'] = targets

                try:
                    del f5['masks']
                except:
                    pass
                f5['masks'] = mask

                try:
                    del f5['rec']
                except:
                    pass
                f5['rec'] = rec

        # res = numpy.concatenate(all_res, axis=0)
        all_ids = numpy.concatenate(all_ids, axis=0)

        if False:

            # import seaborn
            # import pandas as pd
            # seaborn.pairplot(pd.DataFrame(res[:,0:5]))
            # pylab.show()

            embedding, embedding_list = list_umap(x_list=all_res, n_neighbors=30, min_dist=0.0)

            for i in range(n):
                img_name = ds_train.image_filenames[i]
                img_name = os.path.basename(img_name)
                print(i, img_name)

                fname = os.path.join(out_dir, f'{img_name}.h5')
                f5file = h5py.File(fname, 'r+')

                # f5file['labels'] = numpy.swapaxes(mask[0,...],0,1)
                f5file['vae_embedding'] = all_res[i]
                f5file['umap_embedding'] = embedding_list[i]
                f5file.close()

            # A random colormap for matplotlib
            cmap = matplotlib.colors.ListedColormap(numpy.random.rand(1000000, 3))

            n = len(all_ids)
            perm = numpy.random.permutation(numpy.arange(n))

            pylab.scatter(x=embedding[perm, 0], y=embedding[perm, 1], c=all_ids[perm], s=12, edgecolors='face',
                          cmap=cmap)
            pylab.show()
    import pushover_notification
    pushover_notification.send('predictions generated')
示例#2
0
# predict:
#trainer.load(best=True)
trainer.bind_loader('train', train_loader)
trainer.bind_loader('validate', validate_loader)
trainer.eval_mode()

if USE_CUDA:
    trainer.cuda()

# look at an example
for img,target in test_loader:
    if USE_CUDA:
        img = img.cuda()

    # softmax on each of the prediction
    preds = trainer.apply_model(img)
    preds = [nn.functional.softmax(pred,dim=1)        for pred in preds]
    preds = [unwrap(pred, as_numpy=True, to_cpu=True) for pred in preds]
    img    = unwrap(img,  as_numpy=True, to_cpu=True)
    target  = unwrap(target, as_numpy=True, to_cpu=True)

    n_plots = len(preds) + 2
    batch_size = preds[0].shape[0]

    for b in range(batch_size):

        fig = pylab.figure()

        ax1 = fig.add_subplot(2,4,1)
        ax1.set_title('image')
        ax1.imshow(img[b,0,...])