Ejemplo n.º 1
0
def run_napari(dataset_name, usage=False):
    def none():
        return napari.Viewer()

    def num():
        images = [create_text_array(x) for x in range(20)]
        data = np.stack(images, axis=0)
        return napari.view_image(data, rgb=True, name='numbered slices')

    def num_tiled():
        size = (1000, 1030)
        images = [create_tiled_text_array(x, 16, 16, size) for x in range(5)]
        data = np.stack(images, axis=0)
        return napari.view_image(data, rgb=True, name='numbered slices')

    def num_tiled_1():
        size = (1024, 1024)
        images = [create_tiled_test_1(x, 16, 1, size) for x in range(5)]
        data = np.stack(images, axis=0)
        return napari.view_image(data, rgb=True, name='numbered slices')

    def num_tiny():
        images = [create_text_array(x, size=(16, 16)) for x in range(20)]
        data = np.stack(images, axis=0)
        return napari.view_image(data, rgb=True, name='numbered slices')

    def num_16():
        num_slices = 25
        data = create_grid_stack(4, 4, num_slices)
        names = [f"layer {n}" for n in range(num_slices)]
        return napari.view_image(data, name=names, channel_axis=0)

    def num_16_0():
        num_slices = 25
        data = create_grid_stack(4, 4, num_slices, 0)
        names = [f"layer {n}" for n in range(num_slices)]
        return napari.view_image(data, name=names, channel_axis=0)

    def num_4():
        num_slices = 25
        data = create_grid_stack(2, 2, num_slices)
        names = [f"layer {n}" for n in range(num_slices)]
        return napari.view_image(data, name=names, channel_axis=0)

    def num_1():
        num_slices = 25
        data = create_grid_stack(1, 1, num_slices)
        names = [f"layer {n}" for n in range(num_slices)]
        return napari.view_image(data, name=names, channel_axis=0)

    def num_delayed():
        data = create_stack(20, 1)
        return napari.view_image(data, name='delayed (1 second)')

    def num_delayed0():
        data = create_stack(20, 0)
        return napari.view_image(data, name='zero delay')

    def num_mixed():
        data = create_stack_mixed(20)
        return napari.view_image(data, name='zero delay')

    def num_2():
        data = add_delay(create_text_array("one"), 1)
        return napari.view_image(data, name='numbered slices', channel_axis=0)

    def async_3d():
        data = da.random.random((200, 512, 512, 512),
                                chunks=(1, 512, 512, 512))
        return napari.view_image(data, name='async_3d', channel_axis=0)

    def async_3d_small():
        data = da.random.random((5, 512, 512, 512), chunks=(1, 512, 512, 512))
        return napari.view_image(data, name='async_3d_small', channel_axis=0)

    def invisible():
        return napari.view_image(
            np.random.random((5, 1024, 1024)),
            name='five 1k images',
            visible=False,
        )

    def noise():
        return napari.view_image(np.random.random((5, 1024, 1024)),
                                 name='five 1k images')

    def big8():
        return napari.view_image(np.random.random((2, 8192, 8192)),
                                 name='two 8k 2d images')

    def big16():
        return napari.view_image(np.random.random((2, 16384, 16384)),
                                 name='two 16k 2d images')

    def big2d():
        return napari.view_image(np.random.random((21, 8192, 8192)),
                                 name='big 2D timeseries')

    def big3d():
        return napari.view_image(
            np.random.random((6, 256, 512, 512)),
            ndisplay=3,
            name='big 3D timeseries',
        )

    def small3d():
        return napari.view_image(
            np.random.random((3, 64, 64, 64)),
            ndisplay=3,
            name='small 3D timeseries',
        )

    def labels():
        return napari.view_labels(
            np.random.randint(10, size=(20, 2048, 2048)),
            name='big labels timeseries',
        )

    def multi_rand():
        shapes = [
            (167424, 79360),
            (83712, 39680),
            (41856, 19840),
            (20928, 9920),
            (10464, 4960),
            (5232, 2480),
            (2616, 1240),
            (1308, 620),
            (654, 310),
            (327, 155),
            (163, 77),
        ]
        pyramid = [da.random.random(s) for s in shapes]
        return napari.view_image(pyramid)

    def multi_zarr():
        path = 'https://s3.embassy.ebi.ac.uk/idr/zarr/v0.1/9822151.zarr'
        resolutions = [
            da.from_zarr(path, component=str(i))[0, 0, 0]
            for i in list(range(11))
        ]
        return napari.view_image(resolutions)

    def astronaut():
        from skimage import data

        return napari.view_image(data.astronaut(), rgb=True)

    REMOTE_SMALL_URL = (
        "https://s3.embassy.ebi.ac.uk/idr/zarr/v0.1/6001240.zarr")
    DATASETS = {
        "none": none,
        "num": num,
        "num_tiled": num_tiled,
        "num_tiny": num_tiny,
        "num_16": num_16,
        "num_16_0": num_16_0,
        "num_4": num_4,
        "num_2": num_2,
        "num_1": num_1,
        "num_delayed": num_delayed,
        "num_delayed0": num_delayed0,
        "num_mixed": num_mixed,
        "async_3d": async_3d,
        "async_3d_small": async_3d_small,
        "invisible": invisible,
        "noise": noise,
        "big8": big8,
        "big16": big16,
        "big2pd": big2d,
        "big3d": big3d,
        "small3d": small3d,
        "labels": labels,
        "remote": "https://s3.embassy.ebi.ac.uk/idr/zarr/v0.1/4495402.zarr",
        "remote-small": REMOTE_SMALL_URL,
        "big": "/data-ext/4495402.zarr",
        "small": "/Users/pbw/data/6001240.zarr",
        "multi_zarr": multi_zarr,
        "multi_rand": multi_rand,
        "astronaut": astronaut,
    }

    if usage:
        print('\n'.join(DATASETS.keys()))
        return 2

    if dataset_name is None:
        import napari
        from napari.__main__ import _run

        sys.argv = sys.argv[:1]

        with napari.gui_qt():
            _run()
    else:
        data_set = DATASETS[dataset_name]

        if isinstance(data_set, str):
            # Import late so it sees our env vars.
            from napari.__main__ import main as napari_main

            print(f"LOADING {dataset_name}: {data_set}")

            sys.argv[1] = data_set
            sys.exit(napari_main())

        else:
            # Import late so it sees our env vars.
            import napari

            # The DATASET is factory that creates a viewer.
            viewer_factory = data_set

            print(f"Starting napari with: {dataset_name}")

            # It's a callable function
            with napari.gui_qt():
                viewer = viewer_factory()
                print(viewer._title)
Ejemplo n.º 2
0
def inspect_loader(data_loader):
    with napari.gui_qt():
        inspector = Inspector(data_loader)
        inspector.create_viewer()
Ejemplo n.º 3
0
class Operation(Enum):
    """A set of valid arithmetic operations for image_arithmetic.

    To create nice dropdown menus with magicgui, it's best (but not required) to use
    Enums.  Here we make an Enum class for all of the image math operations we want to
    allow.
    """

    add = numpy.add
    subtract = numpy.subtract
    multiply = numpy.multiply
    divide = numpy.divide


with gui_qt():
    # create a viewer and add a couple image layers
    viewer = Viewer()
    viewer.add_image(numpy.random.rand(20, 20), name="Layer 1")
    viewer.add_image(numpy.random.rand(20, 20), name="Layer 2")

    # use the magic decorator!  This takes a function, and generates a widget instance
    # using the function signature. Note that we aren't returning a napari Image layer,
    # but instead a numpy array which we want napari to interperate as Image data.
    @magicgui(call_button="execute")
    def image_arithmetic(layerA: Image, operation: Operation,
                         layerB: Image) -> ImageData:
        """Add, subtracts, multiplies, or divides to image layers with equal shape."""
        return operation.value(layerA.data, layerB.data)

    # add our new magicgui widget to the viewer
Ejemplo n.º 4
0
from skimage import data
import napari
import time

t0 = time.time()
with napari.gui_qt(startup_logo=True):
    napari.view_image(data.astronaut(), rgb=True)
    t1 = time.time()
    print(f'splash took {t1 - t0} seconds')

t2 = time.time()
with napari.gui_qt(startup_logo=False):
    napari.view_image(data.astronaut(), rgb=True)
    t3 = time.time()
    print(f'no splash took {t3 - t2} seconds')
Ejemplo n.º 5
0
def view(raw_data, surf_stack, surf_projection):
    with napari.gui_qt():
        v = napari.Viewer(title="Surfcut")
        v.add_image(raw_data, name="Raw data")
        v.add_image(surf_stack, name="Surf stack")
        v.add_image(surf_projection, name="Surf projection")
    def correct_segment(self, seg_id, qitem):
        if qitem is None:
            return

        print("Processing cilia:", seg_id)
        raw, cil_seg, cil_mask, cell_seg = qitem

        with napari.gui_qt():
            viewer = napari.Viewer()
            viewer.add_image(raw, name='raw')
            viewer.add_labels(cil_seg, name='cilia-segmentation')
            viewer.add_labels(cil_mask, name='cilia-mask')
            if cell_seg is not None:
                viewer.add_labels(cell_seg, name='cell-segmentation')

            @viewer.bind_key('c')
            def confirm(viewer):
                print("Confirming the current id", seg_id, "as correct")
                self.processed_id_map[int(seg_id)] = 'correct'

            @viewer.bind_key('b')
            def background(viewer):
                print("Confirming the current id", seg_id, "into background")
                self.processed_id_map[int(seg_id)] = 'background'

            @viewer.bind_key('m')
            def merge(viewer):
                print("Merging the current id", seg_id, "with other cilia")
                valid_input = False
                while not valid_input:
                    merge_id = input("Please enter the merge id:")
                    try:
                        merge_id = int(merge_id)
                        valid_input = True
                    except ValueError:
                        valid_input = False
                        print("You have entered an invalid input", merge_id,
                              "please try again")
                self.processed_id_map[int(seg_id)] = merge_id

            @viewer.bind_key('r')
            def revisit(viewer):
                print("Marking the current id", seg_id,
                      "to be revisited because something is off")
                self.processed_id_map[int(seg_id)] = 'revisit'

            @viewer.bind_key('h')
            def print_help(viewer):
                print("[c] - confirm cilia as correct")
                print("[b] - mark cilia as background")
                print("[m] - merge cilia with other cilia id")
                print("[d] - revisit this cilia")
                print("[q] - quit")

            # save progress and sys.exit
            @viewer.bind_key('q')
            def quit(viewer):
                print("Quit correction tool")
                self.save_result(seg_id)
                sys.exit(0)

        # save the results for this segment
        self.save_result(seg_id)
Ejemplo n.º 7
0
def run(
    image,
    registration_directory,
    preview=False,
    volumes=False,
    debug=False,
    num_colors=10,
    brush_size=30,
    alpha=0.8,
    shading="flat",
):
    paths = Paths(registration_directory, image)
    registration_directory = Path(registration_directory)

    if not paths.tmp__inverse_transformed_image.exists():
        transform_image_to_standard_space(
            registration_directory,
            image_to_transform_fname=image,
            output_fname=paths.tmp__inverse_transformed_image,
            log_file_path=paths.tmp__inverse_transform_log_path,
            error_file_path=paths.tmp__inverse_transform_error_path,
        )
    else:
        print("Registered image exists, skipping")

    registered_image = prepare_load_nii(paths.tmp__inverse_transformed_image)

    print("\nLoading manual segmentation GUI.\n ")
    print("Please 'colour in' the regions you would like to segment. \n "
          "When you are done, press Ctrl+S to save and exit. \n If you have "
          "used the '--preview' flag, \n the region will be shown in 3D in "
          "brainrender\n for you to inspect.")

    with napari.gui_qt():
        viewer = napari.Viewer(title="Manual segmentation")
        display_channel(
            viewer,
            registration_directory,
            paths.tmp__inverse_transformed_image,
        )

        global label_layers
        label_layers = []

        label_files = glob(str(paths.regions_directory) + "/*.nii")
        if paths.regions_directory.exists() and label_files != []:
            label_layers = []
            for label_file in label_files:
                label_layers.append(
                    add_existing_label_layers(viewer, label_file))
        else:
            label_layers.append(
                add_new_label_layer(
                    viewer,
                    registered_image,
                    brush_size=brush_size,
                    num_colors=num_colors,
                ))

        @viewer.bind_key("Control-N")
        def add_region(viewer):
            print("\nAdding new region")
            label_layers.append(
                add_new_label_layer(
                    viewer,
                    registered_image,
                    name="new_region",
                    brush_size=brush_size,
                    num_colors=num_colors,
                ))

        @viewer.bind_key("Control-X")
        def close_viewer(viewer):
            print("\nClosing viewer")
            QApplication.closeAllWindows()

        @viewer.bind_key("Control-S")
        def save_analyse_regions(viewer):
            ensure_directory_exists(paths.regions_directory)
            delete_directory_contents(str(paths.regions_directory))

            if volumes:
                annotations = load_any(paths.annotations)
                hemispheres = load_any(paths.hemispheres)
                structures_reference_df = load_structures_as_df(
                    get_structures_path())

                print(
                    f"\nSaving summary volumes to: {paths.regions_directory}")
                for label_layer in label_layers:
                    analyse_region_brain_areas(
                        label_layer,
                        paths.regions_directory,
                        annotations,
                        hemispheres,
                        structures_reference_df,
                    )

            print(f"\nSaving regions to: {paths.regions_directory}")
            for label_layer in label_layers:
                save_regions_to_file(
                    label_layer,
                    paths.regions_directory,
                    paths.downsampled_image,
                )
            close_viewer(viewer)

    if not debug:
        print("Deleting temporary files")
        delete_temp(paths.registration_output_folder, paths)

    obj_files = glob(str(paths.regions_directory) + "/*.obj")
    if obj_files:
        if preview:
            print("\nPreviewing in brainrender")
            load_regions_into_brainrender(obj_files,
                                          alpha=alpha,
                                          shading=shading)
    else:
        print("\n'--preview' selected, but no regions to display")
Ejemplo n.º 8
0
import numpy as np
import napari
import pickle
import calendar
from src.preprocess.utils import select_bounding_box
from src.preprocess.utils import Region

data_path = Path('/Users/tommylees/Downloads/static_embeddings.nc')
ds = xr.open_dataset(data_path)
cluster_ds = ds
da = ds.cluster_5

array = da.values[:, ::-1, :]  # time, lat, lon

%gui qt
napari.gui_qt()
# viewer = napari.Viewer()

for i in range(array.shape[0]):
    viewer.add_image(
        array[i, :, :],
        contrast_limits=[0, 4],
        colormap='viridis',
        name=calendar.month_abbr[i+1]
    )


# viewer.layers
[l.name for l in viewer.layers]
kitui = [l.data for l in viewer.layers if l.name == 'kitui']
victoria = [l.data for l in viewer.layers if l.name == 'victoria']
Ejemplo n.º 9
0
    def __call__(self):
        with napari.gui_qt():
            viewer = napari.Viewer()
            self.init_layers(viewer)

            @viewer.bind_key('Control-Left')
            def move_left(_viewer):
                """move field of view left"""
                self.move(0, -self.stride)
                self.update(_viewer)

            @viewer.bind_key('Control-Right')
            def move_right(_viewer):
                """move field of view right"""
                self.move(0, self.stride)
                self.update(_viewer)

            @viewer.bind_key('Control-Up')
            def move_up(_viewer):
                """move field of view up"""
                self.move(-self.stride, 0)
                self.update(_viewer)

            @viewer.bind_key('Control-Down')
            def move_down(_viewer):
                """move field of view down"""
                self.move(self.stride, 0)
                self.update(_viewer)

            @viewer.bind_key('S')
            def save_current(_viewer):
                """save edits on h5 and create a training ready stack"""
                self.relabel_seg()
                self.save_h5()
                self.crop_update(_viewer)

            @viewer.bind_key('J')
            def _update_boundaries(_viewer):
                """Update boundaries"""
                self.update_boundary()
                self.crop_update(_viewer)

            @viewer.bind_key('K')
            def _update_segmentation(_viewer):
                """Update Segmentation under cursor"""
                z, x, y = _viewer.layers[segmentation_key].coordinates
                self.update_segmentation(z, x, y)
                self.update_boundary()
                self.crop_update(_viewer)

            @viewer.bind_key('M')
            def _seeds_merge(_viewer):
                """Merge label from seeds"""
                points = _viewer.layers[seeds_merge_key].data
                self.merge_from_seeds(points)
                self.update_boundary()
                self.crop_update(_viewer)
                _viewer.layers[seeds_merge_key].data = np.empty((0, 3))

            @viewer.bind_key('N')
            def _seeds_split(_viewer):
                """Split label from seeds"""
                seeds = _viewer.layers[seeds_split_key].data
                self.split_from_seeds(seeds)
                self.update_boundary()
                self.crop_update(_viewer)

            @viewer.bind_key('Control-B')
            def _undo_seeds_split(_viewer):
                """Undo-Split label from seeds or Undo-Merge label from seeds"""
                self.load_old()
                self.update_boundary()
                self.crop_update(_viewer)

            @viewer.bind_key('C')
            def _clean_split_seeds(_viewer):
                """Clean split seeds layer"""
                self.clean_seeds()
                self.crop_update(_viewer)

            @viewer.bind_key('O')
            def _seg_correct(_viewer):
                z, x, y = _viewer.layers[segmentation_key].coordinates
                self.mark_label_ok(z, x, y)
                self.crop_update(_viewer)

            @viewer.bind_key('Alt-Up')
            def zoom_in(_viewer):
                """zoom in"""
                self.xy_size = int(self.xy_size * zoom_factor)
                self.crop_update(_viewer)

            @viewer.bind_key('Alt-Down')
            def zoom_out(_viewer):
                """zoom out"""
                self.xy_size = int(self.xy_size / zoom_factor)
                self.crop_update(_viewer)
Ejemplo n.º 10
0
def showImageNapari(Raw,Image, rgb = False):
    with napari.gui_qt():
        
        viewer = napari.view_image(Image, rgb = rgb)
        viewer.add_image(Raw)
Ejemplo n.º 11
0
def view_movie(mov):
    import napari    
    
    with napari.gui_qt():
        napari.view_image(mov)
Ejemplo n.º 12
0
 def create(self, image):
     with napari.gui_qt():
         self.viewer = napari.Viewer()
         self.viewer.add_image(image)
Ejemplo n.º 13
0
def napari_process(
    data_channel: mp.Queue,
    initial_data: Dict[str, Dict[str, Any]],
    t_initial: float = None,
    viewer_args: Dict[str, Any] = None,
):
    """:mod:`multiprocessing.Process` running `napari <https://napari.org>`__

    Args:
        data_channel (:class:`multiprocessing.Queue`):
            queue instance to receive data to view
        initial_data (dict):
            Initial data to be shown by napari. The layers are named according to
            the keys in the dictionary. The associated value needs to be a tuple,
            where the first item is a string indicating the type of the layer and
            the second carries the associated data
        t_initial (float):
            Initial time
        viewer_args (dict):
            Additional arguments passed to the napari viewer
    """
    logger = logging.getLogger(__name__ + ".napari_process")

    try:
        import napari
        from napari.qt import thread_worker
    except ModuleNotFoundError:
        logger.error(
            "The `napari` python module could not be found. This module needs to be "
            "installed to use the interactive tracker.")
        return

    logger.info("Start napari process")

    # ignore keyboard interrupts in this process
    signal.signal(signal.SIGINT, signal.SIG_IGN)

    if viewer_args is None:
        viewer_args = {}

    # start napari Qt GUI
    with napari.gui_qt():

        # create and initialize the viewer
        viewer = napari.Viewer(**viewer_args)
        napari_add_layers(viewer, initial_data)

        # add time if given
        if t_initial is not None:
            from qtpy.QtWidgets import QLabel  # type: ignore

            label = QLabel()
            label.setText(f"Time: {t_initial}")
            viewer.window.add_dock_widget(label)
        else:
            label = None

        def check_signal(msg: Optional[str]):
            """helper function that processes messages by the listener thread"""
            if msg is None:
                return  # do nothing
            elif msg == "close":
                viewer.close()
            else:
                raise RuntimeError(f"Unknown message from listener: {msg}")

        @thread_worker(connect={"yielded": check_signal})
        def update_listener():
            """helper thread that listens to the data_channel"""
            logger.info("Start napari thread to receive data")

            # infinite loop waiting for events in the queue
            while True:
                # get all items from the queue and display the last update
                update_data = None  # nothing to update yet
                while True:
                    time.sleep(0.02)  # read queue with 50 fps
                    try:
                        action, data = data_channel.get(block=False)
                    except queue.Empty:
                        break

                    if action == "close":
                        logger.info("Forced closing of napari...")
                        yield "close"  # signal to napari process to shut down
                        break
                    elif action == "update":
                        update_data = data
                        # continue running until the queue is empty
                    else:
                        logger.warning(f"Unexpected action: {action}")

                # update napari view when there is data
                if update_data is not None:
                    logger.debug(f"Update napari layer...")
                    layer_data, t = update_data
                    if label is not None:
                        label.setText(f"Time: {t}")
                    for name, layer_data in layer_data.items():
                        viewer.layers[name].data = layer_data["data"]

                yield

        # start worker thread that listens to the data_channel
        update_listener()

    logger.info("Shutting down napari process")
Ejemplo n.º 14
0
def TrackMateLiveTracks(Raw, Seg, Mask, savedir, scale, locationID, RegionID,
                        VolumeID, ID, StartID, Tcalibration):

    if Mask is not None and len(Mask.shape) < len(Seg.shape):
        # T Z Y X
        UpdateMask = np.zeros_like(Seg)
        for i in range(0, UpdateMask.shape[0]):
            for j in range(0, UpdateMask.shape[1]):

                UpdateMask[i, j, :, :] = Mask[i, :, :]
    else:
        UpdateMask = Mask

    Boundary = GetBorderMask(UpdateMask.copy())

    with napari.gui_qt():
        if Raw is not None:

            viewer = napari.view_image(Raw, scale=scale, name='Image')
            Labels = viewer.add_labels(Seg, scale=scale, name='SegImage')
        else:
            viewer = napari.view_image(Seg, scale=scale, name='SegImage')

        if Mask is not None:

            LabelsMask = viewer.add_labels(Boundary, scale=scale, name='Mask')

        trackbox = QComboBox()
        trackbox.addItem(Boxname)

        tracksavebutton = QPushButton('Save Track')
        saveplot = tracksavebutton.clicked.connect(on_click)

        for i in range(0, len(ID)):
            trackbox.addItem(str(ID[i]))
        try:
            figure = plt.figure(figsize=(5, 5))
            multiplot_widget = FigureCanvas(figure)
            ax = multiplot_widget.figure.subplots(2, 3)
        except:
            pass
        viewer.window.add_dock_widget(multiplot_widget,
                                      name="TrackStats",
                                      area='right')
        multiplot_widget.figure.tight_layout()
        trackbox.currentIndexChanged.connect(
            lambda trackid=trackbox: TrackViewer(viewer,
                                                 Raw,
                                                 Seg,
                                                 Mask,
                                                 locationID,
                                                 RegionID,
                                                 VolumeID,
                                                 scale,
                                                 trackbox.currentText(),
                                                 StartID,
                                                 multiplot_widget,
                                                 ax,
                                                 figure,
                                                 savedir,
                                                 saveplot=False,
                                                 Tcalibration=Tcalibration))

        if saveplot:
            tracksavebutton.clicked.connect(
                lambda trackid=tracksavebutton: TrackViewer(
                    viewer, Raw, Seg, Mask, locationID, RegionID, VolumeID,
                    scale, trackbox.currentText(), StartID, multiplot_widget,
                    ax, figure, savedir, True, Tcalibration))

        viewer.window.add_dock_widget(trackbox, name="TrackID", area='left')
        viewer.window.add_dock_widget(tracksavebutton,
                                      name="Save TrackID",
                                      area='left')
def launch_covid_if_annotation_tool(data_path=None,
                                    annotation_path=None,
                                    saturation_factor=1,
                                    edge_width=1):
    """ Launch the Covid IF anootation tool.

    Based on https://github.com/transformify-plugins/segmentify/blob/master/examples/launch.py
    """

    with_data = data_path is not None
    with_annotations = annotation_path is not None

    if with_data and not with_annotations:
        raise ValueError(
            "If annotations are passed you also need to pass data!")

    with napari.gui_qt():
        viewer = napari.Viewer()

        # the event object will have the following useful things:
        # event.source -> the full viewer.layers object itself
        # event.item -> the specific layer that cause the change
        # event.type -> a string like 'added', 'removed'
        def on_layer_change(event):
            try:
                needs_update = False
                layers = event.source
                layer = event.item

                def replace_image_layer(name, im_layer):
                    im_layers = [ll for ll in layers if name in ll.name]
                    if name in im_layer.name and len(im_layers) > 1:
                        replace_layer(im_layer, layers, name)

                # replace the raw data image layers
                if isinstance(layer, Image) and event.type == 'added':
                    replace_image_layer('raw', layer)
                    replace_image_layer('virus-marker', layer)
                    replace_image_layer('cell-outlines', layer)

                # if we add new labels or new points, we need to replace instead
                # of adding them
                if isinstance(layer, Labels) and event.type == 'added':
                    if len([ll
                            for ll in layers if isinstance(ll, Labels)]) > 1:
                        replace_layer(layer, layers, 'cell-segmentation')

                if isinstance(layer, Points) and event.type == 'added':
                    if len([ll
                            for ll in layers if isinstance(ll, Points)]) > 1:
                        replace_layer(layer, layers, 'infected-vs-control')

                        # select the new points layer
                        layer = viewer.layers['infected-vs-control']
                        viewer.layers.unselect_all()
                        layer.selected = True
                        needs_update = True

                    # modifty the new points layer
                    # set the corect color maps
                    face_color_cycle_map = {
                        0: (1, 1, 1, 1),
                        1: (1, 0, 0, 1),
                        2: (0, 1, 1, 1),
                        3: (1, 1, 0, 1)
                    }

                    viewer.layers[
                        'infected-vs-control'].face_color_cycle_map = face_color_cycle_map
                    viewer.layers['infected-vs-control'].refresh_colors()

                # always modify the points layer to deactivate the buttons we don't need
                if isinstance(event.item, Points):
                    modify_points_layer(viewer)

                if needs_update:
                    update_layers(viewer)

            except AttributeError:
                pass

        viewer.layers.events.changed.connect(on_layer_change)
        if with_data:
            initialize_from_file(viewer, data_path, saturation_factor,
                                 edge_width)

        if with_annotations:
            initialize_annotations(viewer, annotation_path)

        # connect the gui elements and modify layer functionality
        connect_to_viewer(viewer)

        if with_annotations:
            update_layers(viewer)
Ejemplo n.º 16
0
def demo(image_clipped):
    image_clipped = normalise(image_clipped.astype(numpy.float32))
    blurred_image, psf_kernel = add_microscope_blur_3d(image_clipped)
    noisy_blurred_image = add_poisson_gaussian_noise(blurred_image,
                                                     alpha=0.001,
                                                     sigma=0.1,
                                                     sap=0.01,
                                                     quant_bits=10)

    lr = ImageTranslatorLRDeconv(psf_kernel=psf_kernel, backend="cupy")
    lr.train(noisy_blurred_image)
    # lr.max_num_iterations=2
    # lr_deconvolved_image_2 = lr.translate(noisy_blurred_image)
    lr.max_num_iterations = 5
    lr_deconvolved_image_5 = lr.translate(noisy_blurred_image)
    # lr.max_num_iterations=10
    # lr_deconvolved_image_10 = lr.translate(noisy_blurred_image)
    # lr.max_num_iterations=20
    # lr_deconvolved_image_20 = lr.translate(noisy_blurred_image)

    it_deconv = SSIDeconvolution(
        max_epochs=3000,
        patience=300,
        batch_size=8,
        learning_rate=0.01,
        normaliser_type="identity",
        psf_kernel=psf_kernel,
        model_class=UNet,
        masking=True,
        masking_density=0.01,
        loss="l2",
    )

    start = time.time()
    it_deconv.train(noisy_blurred_image)
    stop = time.time()
    print(f"Training: elapsed time:  {stop - start} ")

    start = time.time()
    deconvolved_image = it_deconv.translate(noisy_blurred_image)
    stop = time.time()
    print(f"inference: elapsed time:  {stop - start} ")

    image_clipped = numpy.clip(image_clipped, 0, 1)
    # lr_deconvolved_image_2_clipped = numpy.clip(lr_deconvolved_image_2, 0, 1)
    lr_deconvolved_image_5_clipped = numpy.clip(lr_deconvolved_image_5, 0, 1)
    # lr_deconvolved_image_10_clipped = numpy.clip(lr_deconvolved_image_10, 0, 1)
    # lr_deconvolved_image_20_clipped = numpy.clip(lr_deconvolved_image_20, 0, 1)
    deconvolved_image_clipped = numpy.clip(deconvolved_image, 0, 1)

    columns = ["PSNR", "norm spectral mutual info", "norm mutual info", "SSIM"]
    print_header(columns)
    print_score(
        "blurry image",
        psnr(image_clipped, blurred_image),
        spectral_mutual_information(image_clipped, blurred_image),
        mutual_information(image_clipped, blurred_image),
        ssim(image_clipped, blurred_image),
    )

    print_score(
        "noisy and blurry image",
        psnr(image_clipped, noisy_blurred_image),
        spectral_mutual_information(image_clipped, noisy_blurred_image),
        mutual_information(image_clipped, noisy_blurred_image),
        ssim(image_clipped, noisy_blurred_image),
    )

    # print_score(
    #     "lr deconv (n=2)",
    #     psnr(image_clipped, lr_deconvolved_image_2_clipped),
    #     spectral_mutual_information(image_clipped, lr_deconvolved_image_2_clipped),
    #     mutual_information(image_clipped, lr_deconvolved_image_2_clipped),
    #     ssim(image_clipped, lr_deconvolved_image_2_clipped),
    # )

    print_score(
        "lr deconv (n=5)",
        psnr(image_clipped, lr_deconvolved_image_5_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_5_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_5_clipped),
        ssim(image_clipped, lr_deconvolved_image_5_clipped),
    )

    # print_score(
    #     "lr deconv (n=10)",
    #     psnr(image_clipped, lr_deconvolved_image_10_clipped),
    #     spectral_mutual_information(image_clipped, lr_deconvolved_image_10_clipped),
    #     mutual_information(image_clipped, lr_deconvolved_image_10_clipped),
    #     ssim(image_clipped, lr_deconvolved_image_10_clipped),
    # )
    #
    # print_score(
    #     "lr deconv (n=20)",
    #     psnr(image_clipped, lr_deconvolved_image_20_clipped),
    #     spectral_mutual_information(image_clipped, lr_deconvolved_image_20_clipped),
    #     mutual_information(image_clipped, lr_deconvolved_image_20_clipped),
    #     ssim(image_clipped, lr_deconvolved_image_20_clipped),
    # )

    print_score(
        "ssi deconv",
        psnr(image_clipped, deconvolved_image_clipped),
        spectral_mutual_information(image_clipped, deconvolved_image_clipped),
        mutual_information(image_clipped, deconvolved_image_clipped),
        ssim(image_clipped, deconvolved_image_clipped),
    )

    print(
        "NOTE: if you get a bad results for ssi, blame stochastic optimisation and retry..."
    )
    print(
        "      The training is done on the same exact image that we infer on, very few pixels..."
    )
    print("      Training should be more stable given more data...")

    if use_napari:
        with napari.gui_qt():
            viewer = napari.Viewer()
            viewer.add_image(image_clipped, name="image")
            viewer.add_image(blurred_image, name="blurred")
            viewer.add_image(noisy_blurred_image, name="noisy_blurred_image")
            # viewer.add_image(lr_deconvolved_image_2_clipped, name='lr_deconvolved_image_2')
            viewer.add_image(lr_deconvolved_image_5_clipped,
                             name="lr_deconvolved_image_5")
            # viewer.add_image(lr_deconvolved_image_10_clipped, name='lr_deconvolved_image_10')
            # viewer.add_image(lr_deconvolved_image_20_clipped, name='lr_deconvolved_image_20')
            viewer.add_image(deconvolved_image_clipped,
                             name="ssi_deconvolved_image")
def organelle_assignment_module(results_folder,
                                organelle_assignments_filepath,
                                export_name,
                                mem_pred_filepath,
                                raw,
                                mem,
                                result_dtype,
                                export_binary=False,
                                conncomp_on_paintera_export=False,
                                merge_small_segments_on_paintera_export=False,
                                verbose=False):
    print('\nNapari and editor started in sub-processes.')
    print('\nFill the text file with assignments')

    with open_file(os.path.join(results_folder, 'exported_seg.h5'),
                   mode='r') as f:
        exp_seg = f['data'][:]
    if conncomp_on_paintera_export:
        # Computing connected components
        exp_seg = labelMultiArray(exp_seg.astype('float32'))
    if merge_small_segments_on_paintera_export:
        # Merge small segments
        print('Removing small segments ...')
        # exp_seg = sizeFilterSegInplace(exp_seg.astype('uint32') + 1, int(np.max(exp_seg)), 48, checkAtBorder=True)
        exp_seg = small_objects_to_zero(exp_seg + 1, 48)
        print('Filling the holes ...')
        print(f'exp_seg.shape = {exp_seg.shape}')
        print(f'exp_seg.dtype = {exp_seg.dtype}')
        exp_seg = watershedsNew(exp_seg.astype('float32'),
                                seeds=exp_seg.astype('uint32'),
                                neighborhood=26)[0] - 1
        print('... done!')
    exp_seg = relabel_consecutive(exp_seg, sort_by_size=True)

    if not os.path.exists(organelle_assignments_filepath):
        with open(organelle_assignments_filepath, mode='w') as f:
            json.dump(
                dict(
                    CYTO=dict(labels=[0], type='single'),
                    MITO=dict(labels=[], type='multi'),
                    # DMV=dict(labels=[], type='multi'),
                    # ER=dict(labels=[], type='single'),
                    # ENDO=dict(labels=[], type='multi'),
                    # LIPID=dict(labels=[], type='multi'),
                    # NUC=dict(labels=[], type='single'),
                    # EXT=dict(labels=[], type='single')
                ),
                f,
                indent=2)

    all_ids = list(np.unique(exp_seg))

    def _generate_organelle_maps():

        try:
            # This should be wrapped in a try/except in case of invalid json syntax
            # and then be caught to tell user to correct it
            # get the current organelle assignments from the text file
            with open(organelle_assignments_filepath, mode='r') as f:
                assignments = json.load(f)

            maps = {}
            assigned = []
            for organelle, assignment in assignments.items():
                print('found organelle: {}'.format(organelle))
                maps[organelle] = np.zeros(exp_seg.shape, dtype=exp_seg.dtype)
                val = 1
                for idx in assignment['labels']:
                    maps[organelle][exp_seg == idx] = val
                    if assignment['type'] == 'multi':
                        val += 1
                    assigned.append(idx)

            unassigned = np.setdiff1d(all_ids, assigned)
            maps['MISC'] = np.zeros(exp_seg.shape, dtype=exp_seg.dtype)
            val = 1
            for idx in unassigned:
                maps['MISC'][exp_seg == idx] = val
                val += 1

            map_names = sorted(maps.keys())
            maps['SEMANTICS'] = np.zeros(exp_seg.shape, dtype=exp_seg.dtype)
            for map_idx, map_name in enumerate(map_names):
                maps['SEMANTICS'][maps[map_name] > 0] = map_idx

            return maps
        except:
            print(
                'Invalid json syntax!!! Fix the json file, save and update Napari again!'
            )
            return {}

    def _print_help():
        # I don't think we need explicit quit command any more
        # print('            exit / q      -> finish assignments and export')
        print('            update / u    -> updates Napari display')
        print('            editor / e    -> re-opens editor')

    # start the editor in a sub-process
    editor_p = mp.Process(target=_open_editor,
                          args=(organelle_assignments_filepath, ))
    editor_p.start()

    with napari.gui_qt():
        viewer = napari.Viewer()
        # add the initiail (static) layers
        viewer.add_image(raw, name='raw')
        if mem_pred_filepath is not None:
            viewer.add_image(mem, name='mem', visible=False)

        # add the initial organelle maps
        organelle_maps = _generate_organelle_maps()
        for name, data in organelle_maps.items():
            is_visible = name == 'MISC'
            viewer.add_labels(data, name=name, visible=is_visible)

        viewer.add_labels(exp_seg,
                          name='from Paintera',
                          visible=True,
                          opacity=0)

        _print_help()

        # I don't think this is necessary any more
        # @viewer.bind_key('q')
        # def quit(viewer):
        #     pass

        @viewer.bind_key('h')
        def help(viewer):
            _print_help()

        @viewer.bind_key('u')
        def update(viewer):
            print("Updating napari layers from organelle assignments ...")
            new_organelle_maps = _generate_organelle_maps()

            # iterate over the organelle maps, if we have it in the layers already, update the layer,
            # otherwise add a new layer
            # TODO this does not catch the case where a category is removed yet (the layer will persist)
            # this should also be caught and the layer be removed
            layers = viewer.layers
            for name, data in new_organelle_maps.items():
                is_visible = name == 'MISC'
                # if name in layers:
                try:
                    # This raises a key error if the layer does not exist
                    # FIXME is there a solution like 'if name in layers: ...' that does not error out?
                    name in layers
                    layers[name].data = data
                except KeyError:
                    viewer.add_labels(data, name=name, visible=is_visible)
            print("... done")

        @viewer.bind_key('e')
        def editor(viewer):
            nonlocal editor_p
            editor_p.terminate()
            editor_p.join()
            editor_p = mp.Process(target=_open_editor,
                                  args=(organelle_assignments_filepath, ))
            editor_p.start()

    # 10. Export organelle maps
    print('Exporting organelle maps ...')
    organelle_maps = _generate_organelle_maps()
    for map_name, map in organelle_maps.items():
        if not os.path.exists(
                os.path.join(os.path.join(results_folder, 'results'))):
            os.mkdir(os.path.join(os.path.join(results_folder, 'results')))
        # Export labeled result
        organelle_filepath = os.path.join(
            results_folder, 'results', export_name + '_{}.h5'.format(map_name))
        _write_data(organelle_filepath,
                    map.astype(result_dtype),
                    verbose=verbose)

        if export_binary:
            # Export binary result
            organelle_filepath = os.path.join(
                results_folder, 'results',
                export_name + '_{}_bin.h5'.format(map_name))
            map = (
                1 -
                (1 - map.astype('float32') / map.max()).astype('uint8')) * 255
            _write_data(organelle_filepath, map, verbose=verbose)
Ejemplo n.º 18
0
def main():
    args = parser().parse_args()
    args = define_pixel_sizes(args)

    if args.output is None:
        output = Path(args.cells_xml)
        output_directory = output.parent
        print(f"No output directory given, so setting output "
              f"directory to: {output_directory}")
    else:
        output_directory = args.output

    ensure_directory_exists(output_directory)
    output_filename = output_directory / OUTPUT_NAME

    img_paths = get_sorted_file_paths(args.signal_image_paths,
                                      file_extension=".tif")
    cells, labels = get_cell_labels_arrays(args.cells_xml)

    properties = {"cell": labels}

    with napari.gui_qt():
        viewer = napari.Viewer(title="Cellfinder cell curation")
        images = magic_imread(img_paths, use_dask=True, stack=True)
        viewer.add_image(images)
        face_color_cycle = ["lightskyblue", "lightgoldenrodyellow"]
        points_layer = viewer.add_points(
            cells,
            properties=properties,
            symbol=args.symbol,
            n_dimensional=True,
            size=args.marker_size,
            face_color="cell",
            face_color_cycle=face_color_cycle,
            name="Cell candidates",
        )

        @viewer.bind_key("t")
        def toggle_point_property(viewer):
            """Toggle point type"""
            selected_points = viewer.layers[1].selected_data
            if selected_points:
                selected_properties = viewer.layers[1].properties["cell"][
                    selected_points]
                toggled_properties = np.logical_not(selected_properties)
                viewer.layers[1].properties["cell"][
                    selected_points] = toggled_properties

                # Add curated cells to list
                CURATED_POINTS.extend(selected_points)
                print(f"{len(selected_points)} points "
                      f"toggled and added to the list ")

                # refresh the properties colour
                viewer.layers[1].refresh_colors(update_color_mapping=False)

        @viewer.bind_key("c")
        def confirm_point_property(viewer):
            """Confirm point type"""
            selected_points = viewer.layers[1].selected_data
            if selected_points:
                # Add curated cells to list
                CURATED_POINTS.extend(selected_points)
                print(f"{len(selected_points)} points "
                      f"confirmed and added to the list ")

        @viewer.bind_key("Control-S")
        def save_curation(viewer):
            """Save file"""
            if not CURATED_POINTS:
                print("No cells have been confirmed or toggled, not saving")
            else:
                unique_cells = unique_elements_lists(CURATED_POINTS)
                points = viewer.layers[1].data[unique_cells]
                labels = viewer.layers[1].properties["cell"][unique_cells]
                labels = labels.astype("int")
                labels = labels + 1

                cells_to_save = []
                for idx, point in enumerate(points):
                    cell = Cell([point[2], point[1], point[0]], labels[idx])
                    cells_to_save.append(cell)

                print(f"Saving results to: {output_filename}")
                save_cells(cells_to_save, output_filename)

        @viewer.bind_key("Alt-E")
        def start_cube_extraction(viewer):
            """Extract cubes for training"""

            if not output_filename.exists():
                print("No curation results have been saved. "
                      "Please save before extracting cubes")
            else:
                print(f"Saving cubes to: {output_directory}")
                run_extraction(
                    output_filename,
                    output_directory,
                    args.signal_image_paths,
                    args.background_image_paths,
                    args.cube_depth,
                    args.cube_width,
                    args.cube_height,
                    args.x_pixel_um,
                    args.y_pixel_um,
                    args.z_pixel_um,
                    args.x_pixel_um_network,
                    args.y_pixel_um_network,
                    args.z_pixel_um_network,
                    args.max_ram,
                    args.n_free_cpus,
                    args.save_empty_cubes,
                )

                print("Saving yaml file to use for training")
                save_yaml_file(output_directory)

                print("Closing window")
                QApplication.closeAllWindows()
                print("Finished! You may now annotate more "
                      "datasets, or go straight to training")
Ejemplo n.º 19
0
def load_arrays(images, names):
    with napari.gui_qt():
        v = napari.Viewer(title="track viewer")
        for image, name in zip(images, names):
            image = np.swapaxes(image, 2, 0)
            v.add_image(image, name=name)
Ejemplo n.º 20
0
def segmentation_correction(raw_path,
                            raw_root,
                            raw_scale,
                            ws_path,
                            ws_root,
                            ws_scale,
                            node_label_path,
                            node_label_key,
                            save_path,
                            save_key,
                            n_scales,
                            seg_scale,
                            seg_scale_factor,
                            graph=None,
                            weights=None):

    ds_raw = _load_multiscale_ds(raw_path, raw_root, raw_scale, n_scales)

    ds_ws = _load_multiscale_ds(ws_path, ws_root, ws_scale, n_scales)
    # assert ds_ws[0].shape == ds_raw[0].shape

    node_labels = _load_node_labes(node_label_path, node_label_key, save_path,
                                   save_key)
    next_id = int(node_labels.max()) + 1
    mask_id = None

    node_label_history = []
    ws_base = ds_ws[seg_scale][:]

    with napari.gui_qt():

        def _seg_from_labels(node_labels):
            seg = merge_seg_from_node_labels(ws_base, node_labels)
            return seg

        viewer = napari.Viewer()
        viewer.add_image(ds_raw, name='raw')
        viewer.add_labels(ds_ws, name='fragments', visible=False)
        seg = _seg_from_labels(node_labels)
        viewer.add_labels(seg, name='segments', scale=seg_scale_factor)
        viewer.add_labels(np.zeros_like(ws_base),
                          scale=seg_scale_factor,
                          name='mask')

        if graph is not None:
            assert weights is not None
            assert len(weights) == graph.numberOfEdges
            uv_ids = graph.uvIds()
            viewer.add_labels(np.zeros_like(ws_base),
                              scale=seg_scale_factor,
                              name='seeds')
            ws_id = None
            sub_graph, sub_weights = None, None
            sub_nodes, sub_edges = None, None
            mapping = None

        # split of fragment from segment
        @viewer.bind_key('Shift-D')
        def split(viewer):
            nonlocal next_id
            nonlocal node_labels
            nonlocal node_label_history

            position = _get_cursor_position(viewer, 'fragments')
            if position is None:
                print("No layer was selected, aborting split")
                return

            # get the segmentation value under the cursor
            frag_id = viewer.layers['fragments'].data[0][position]

            if frag_id == 0:
                print("Cannot split background label, aborting split")
                return

            seg_id = node_labels[frag_id]
            print("Splitting fragment", frag_id, "from segment", seg_id,
                  "and assigning segment id", next_id)
            node_label_history.append(node_labels.copy())

            node_labels[frag_id] = next_id
            next_id += 1
            seg = _seg_from_labels(node_labels)
            viewer.layers['segments'].data = seg

            print("split done")

        # merge two segments
        @viewer.bind_key('Shift-A')
        def merge(viewer):
            nonlocal node_labels
            nonlocal node_label_history

            position = _get_cursor_position(viewer, 'segments')
            if position is None:
                print("No layer was selected, aborting detach")
                return

            # get the segmentation value under the cursor
            seg_id1 = viewer.layers['segments'].data[position]
            if seg_id1 == 0:
                print("Cannot merge background label, aborting merge")
                return

            # get the selected id in the merged seg layer
            seg_id2 = viewer.layers['segments'].selected_label
            if seg_id2 == 0:
                print("Cannot merge into background value")
                return

            node_label_history.append(node_labels.copy())
            print("Merging id", seg_id1, "into id", seg_id2)
            node_labels[node_labels == seg_id1] = seg_id2
            seg = _seg_from_labels(node_labels)
            viewer.layers['segments'].data = seg

            print("Merge done")

        # # toggle hidden mode for the selected segment
        # @viewer.bind_key()
        # def toggle_hidden(viewer):
        #     pass

        # # toggle visibility for hidden segments
        # @viewer.bind_key()
        # def toggle_view_hidden(viewer):
        #     pass

        @viewer.bind_key('q')
        def update_mask(viewer):
            nonlocal mask_id
            seg_id = viewer.layers['segments'].selected_label
            if seg_id == mask_id:
                return
            mask_id = seg_id

            print("Updating mask for", mask_id)
            mask = (viewer.layers['segments'].data == mask_id).astype(
                viewer.layers['mask'].data.dtype)
            viewer.layers['mask'].data = mask
            if 'seeds' in viewer.layers:
                viewer.layers['seeds'].data = np.zeros_like(ws_base)

        @viewer.bind_key('w')
        def watershed(viewer):
            nonlocal ws_id, next_id
            nonlocal node_labels
            nonlocal node_label_history
            nonlocal sub_nodes, sub_edges
            nonlocal sub_graph, sub_weights
            nonlocal mapping

            if mask_id is None:
                print("Need to select segment to run watershed")
                return

            if ws_id != mask_id or sub_graph is None:
                print("Computing sub-graph for", mask_id, " ...")
                sub_nodes = np.where(
                    node_labels == mask_id)[0].astype('uint64')
                sub_edges, _ = graph.extractSubgraphFromNodes(
                    sub_nodes, allowInvalidNodes=True)
                sub_weights = weights[sub_edges]

                nodes_relabeled, max_id, mapping = vigra.analysis.relabelConsecutive(
                    sub_nodes, start_label=0, keep_zeros=False)
                sub_uvs = uv_ids[sub_edges]
                sub_uvs = nt.takeDict(mapping, sub_uvs)

                n_nodes = max_id + 1
                sub_graph = nifty.graph.undirectedGraph(n_nodes)
                sub_graph.insertEdges(sub_uvs)
                ws_id = mask_id

            mask = viewer.layers['mask'].data
            seeds = viewer.layers['seeds'].data
            seeds[np.logical_not(mask)] = 0
            seed_ids = np.unique(seeds)[1:]

            seed_nodes = np.zeros(sub_graph.numberOfNodes, dtype='uint64')
            for seed_id in seed_ids:
                seeded = np.unique(ws_base[seeds == seed_id])
                if seeded[0] == 0:
                    seeded = seeded[1:]
                seeded = nt.takeDict(mapping, seeded)
                seed_nodes[seeded] = seed_id

            print("Computing graph watershed")
            sub_labels = nifty.graph.edgeWeightedWatershedsSegmentation(
                sub_graph, seed_nodes, sub_weights)
            node_label_history.append(node_labels.copy())

            node_labels[sub_nodes] = sub_labels + (next_id - 1)

            mask_node_labels = np.zeros_like(node_labels)
            mask_node_labels[sub_nodes] = sub_labels
            mask = _seg_from_labels(mask_node_labels)
            viewer.layers['mask'].data = mask

            # TODO should also update the seg, but for now skip this to speed this up
            # seg = _seg_from_labels(node_labels)
            # viewer.layers['segments'].data = seg

            next_id = int(node_labels.max()) + 1

        # # undo the last split / merge action
        @viewer.bind_key('u')
        def undo(viewer):
            nonlocal node_labels
            nonlocal node_label_history
            if len(node_label_history) == 0:
                return
            print("Undo last action")
            node_labels = node_label_history.pop()
            seg = _seg_from_labels(node_labels)
            viewer.layers['segments'].data = seg

        # save the current node labeling to disc
        @viewer.bind_key('s')
        def save_labels(viewer):
            print("saving node labels")
            with open_file(save_path, 'a') as f:
                ds = f.require_dataset(save_key,
                                       shape=node_labels.shape,
                                       chunks=node_labels.shape,
                                       compression='gzip',
                                       dtype=node_labels.dtype)
                ds[:] = node_labels

        @viewer.bind_key('x')
        def update_seg(viewer):
            seg = _seg_from_labels(node_labels)
            viewer.layers['segments'].data = seg
def check_exported(paintera_path, old_assignment_key, assignment_key,
                   table_path, table_key, scale_factor, raw_path, raw_key,
                   ws_path, ws_key, check_ids):
    print("Start to check exported node labels")
    import napari
    import nifty.tools as nt

    with open_file(paintera_path, 'r') as f:
        ds = f[old_assignment_key]
        ds.n_threads = 8
        old_assignments = ds[:].T

        ds = f[assignment_key]
        ds.n_threads = 8
        assignments = ds[:].T

    fragment_ids, segment_ids = assignments[:, 0], assignments[:, 1]
    old_fragment_ids, old_segment_ids = old_assignments[:,
                                                        0], old_assignments[:,
                                                                            1]
    assert np.array_equal(fragment_ids, old_fragment_ids)

    print("Loading bounding boxes ...")
    bounding_boxes = get_bounding_boxes(table_path, table_key, scale_factor)
    print("... done")
    with open_file(raw_path, 'r') as fraw, open_file(ws_path, 'r') as fws:

        ds_raw = fraw[raw_key]
        ds_raw.n_thread = 8

        ds_ws = fws[ws_key]
        ds_ws.n_thread = 8
        ds_ws = LabelMultisetWrapper(ds_ws)

        for seg_id in check_ids:
            print("Check object", seg_id)
            bb = bounding_boxes[seg_id]
            print("Within bounding box", bb)

            raw = ds_raw[bb]
            ws = ds_ws[bb]

            id_mask = old_segment_ids == seg_id
            ws_ids = fragment_ids[id_mask]
            seg_mask = np.isin(ws, ws_ids)
            ws[~seg_mask] = 0

            ids_old = old_segment_ids[id_mask]
            dict_old = {wid: oid for wid, oid in zip(ws_ids, ids_old)}
            dict_old[0] = 0
            seg_old = nt.takeDict(dict_old, ws)

            ids_new = segment_ids[id_mask]
            dict_new = {wid: oid for wid, oid in zip(ws_ids, ids_new)}
            dict_new[0] = 0
            seg_new = nt.takeDict(dict_new, ws)

            with napari.gui_qt():
                viewer = napari.Viewer()
                viewer.add_image(raw, name='raw')
                viewer.add_labels(seg_mask, name='seg-mask')
                viewer.add_labels(seg_old, name='old-seg')
                viewer.add_labels(seg_new, name='new-seg')
Ejemplo n.º 22
0
def view_volume2(imgvol, name=""):
    with napari.gui_qt():
        viewer = napari.Viewer()
        viewer.add_image(imgvol, name=name)
Ejemplo n.º 23
0
def run(**kwargs):
    """ run an instance of napari with the plugin """
    with napari.gui_qt():
        viewer = napari.Viewer()
        build_plugin_v2(viewer, **kwargs)
Ejemplo n.º 24
0
def validate_upload(ff, input_root):
    file_name = os.path.split(ff)[1].replace('_annotations', '')
    in_file = os.path.join(input_root, file_name)
    assert os.path.exists(in_file), in_file

    exp_labels = np.array([0, 1, 2, 3])

    with h5py.File(in_file, 'r') as f:
        serum = read_image(f, 'serum_IgG')
        marker = read_image(f, 'marker')

    with h5py.File(ff, 'r') as f:
        seg = read_image(f, 'cell_segmentation')
        if not has_table(f, 'infected_cell_labels'):
            warnings.warn(f"{file_name} does not have labels!")
            return
        (label_ids, centroids,
         _, infected_cell_labels) = get_segmentation_data(f, seg, edge_width=1)

    n_labels = len(infected_cell_labels) - 1
    unique_labels = np.unique(infected_cell_labels)
    if not np.array_equal(unique_labels, exp_labels):
        print("Found unexpected labels for", file_name)
        print(unique_labels)

    print("Check annotations for file:", file_name)

    # print percentage of infected / control / uncertain / unlabeled
    n_unlabeled = (infected_cell_labels == 0).sum() - 1
    frac_unlabeled = float(n_unlabeled) / n_labels
    print("Unlabeled:", n_unlabeled, "/", n_labels, "=", frac_unlabeled, "%")

    n_infected = (infected_cell_labels == 1).sum()
    frac_infected = float(n_infected) / n_labels
    print("infected:", n_infected, "/", n_labels, "=", frac_infected, "%")

    n_control = (infected_cell_labels == 2).sum()
    frac_control = float(n_control) / n_labels
    print("control:", n_control, "/", n_labels, "=", frac_control, "%")

    n_uncertain = (infected_cell_labels == 3).sum()
    frac_uncertain = float(n_uncertain) / n_labels
    print("uncertain:", n_uncertain, "/", n_labels, "=", frac_uncertain, "%")

    # make label mask
    label_mask = np.zeros_like(seg)
    label_mask[np.isin(seg, label_ids[infected_cell_labels == 1])] = 1
    label_mask[np.isin(seg, label_ids[infected_cell_labels == 2])] = 2
    label_mask[np.isin(seg, label_ids[infected_cell_labels == 3])] = 3
    label_mask[np.isin(seg, label_ids[infected_cell_labels == 0])] = 4

    label_mask[seg == 0] = 0

    centroids = get_centroids(seg)
    ckwargs = get_centroid_kwargs(centroids, infected_cell_labels)

    with napari.gui_qt():
        viewer = napari.Viewer(title=file_name)
        viewer.add_image(serum)
        viewer.add_image(marker)
        viewer.add_labels(seg, visible=False)
        viewer.add_labels(label_mask, visible=True)
        # FIXME something with the points is weird ...
        viewer.add_points(centroids, visible=False, **ckwargs)
Ejemplo n.º 25
0
    def view_napari(self, rgb: bool = False, **kwargs):
        """
        If installed, load the image in a napari viewer.

        Parameters
        ----------
        rgb: bool
            Is the image RGB / RGBA
            Default: False (is not RGB)
        **kwargs
            Extra arguments passed down to the viewer
        """
        try:
            import napari

            # Construct getitem operations tuple to select down the data in the filled dimensions
            ops = []
            selected_dims = []
            for dim in self.dims:
                if self.size(dim)[0] == 1:
                    ops.append(0)
                else:
                    ops.append(slice(None, None, None))
                    selected_dims.append(dim)

            # Actually select the down
            data = self.dask_data[tuple(ops)]

            # Convert selected_dims to string
            dims = "".join(selected_dims)

            # Create name for window
            if isinstance(self.reader, ArrayLikeReader):
                title = f"napari: {self.dask_data.shape}"
            else:
                title = f"napari: {self.reader._file.name}"

            # Handle RGB entirely differently
            if rgb:
                # Swap channel to last dimension
                new_dims = f"{dims.replace(Dimensions.Channel, '')}{Dimensions.Channel}"
                data = transforms.transpose_to_dims(data=data,
                                                    given_dims=dims,
                                                    return_dims=new_dims)

                # Run napari
                with napari.gui_qt():
                    napari.view_image(
                        data,
                        is_pyramid=False,
                        ndisplay=3 if Dimensions.SpatialZ in dims else 2,
                        title=title,
                        axis_labels=dims.replace(Dimensions.Channel, ""),
                        rgb=rgb,
                        **kwargs)

            # Handle all other images besides RGB not requested
            else:
                # Channel axis
                c_axis = dims.index(
                    Dimensions.Channel) if Dimensions.Channel in dims else None

                # Set visible based on number of channels
                if c_axis is not None:
                    if data.shape[c_axis] > 3:
                        visible = False
                    else:
                        visible = True
                else:
                    visible = True

                # Drop channel from dims string
                dims = dims.replace(Dimensions.Channel,
                                    "") if Dimensions.Channel in dims else dims

                # Run napari
                with napari.gui_qt():
                    napari.view_image(
                        data,
                        is_pyramid=False,
                        ndisplay=3 if Dimensions.SpatialZ in dims else 2,
                        channel_axis=c_axis,
                        axis_labels=dims,
                        title=title,
                        visible=visible,
                        **kwargs)

        except ModuleNotFoundError:
            raise ModuleNotFoundError(
                f"'napari' has not been installed. To use this function install napari with either: "
                f"'pip install napari' or 'pip install aicsimageio[interactive]'"
            )
Ejemplo n.º 26
0
def view(array):
    with napari.gui_qt():
        v = napari.Viewer(title="Surfcut")
        v.add_image(array)
Ejemplo n.º 27
0
def demo(image_clipped):
    image_clipped = normalise(image_clipped.astype(numpy.float32))
    blurred_image, psf_kernel = add_microscope_blur_2d(image_clipped)
    # noisy_blurred_image = add_noise(blurred_image, intensity=None, variance=0.01, sap=0.01, clip=True)
    noisy_blurred_image = add_poisson_gaussian_noise(blurred_image,
                                                     alpha=0.001,
                                                     sigma=0.1,
                                                     sap=0.01,
                                                     quant_bits=10)

    lr = ImageTranslatorLRDeconv(psf_kernel=psf_kernel, backend="cupy")
    lr.train(noisy_blurred_image)
    lr.max_num_iterations = 2
    lr_deconvolved_image_2 = lr.translate(noisy_blurred_image)
    lr.max_num_iterations = 5
    lr_deconvolved_image_5 = lr.translate(noisy_blurred_image)
    lr.max_num_iterations = 10
    lr_deconvolved_image_10 = lr.translate(noisy_blurred_image)
    lr.max_num_iterations = 20
    lr_deconvolved_image_20 = lr.translate(noisy_blurred_image)

    it_deconv = SSIDeconvolution(
        max_epochs=3000,
        patience=300,
        batch_size=8,
        learning_rate=0.01,
        normaliser_type='identity',
        psf_kernel=psf_kernel,
        model_class=UNet,
        masking=True,
        masking_density=0.01,
        loss='l2',
    )

    start = time.time()
    it_deconv.train(noisy_blurred_image)
    stop = time.time()
    print(f"Training: elapsed time:  {stop - start} ")

    start = time.time()
    deconvolved_image = it_deconv.translate(noisy_blurred_image)
    stop = time.time()
    print(f"inference: elapsed time:  {stop - start} ")

    image_clipped = numpy.clip(image_clipped, 0, 1)
    lr_deconvolved_image_2_clipped = numpy.clip(lr_deconvolved_image_2, 0, 1)
    lr_deconvolved_image_5_clipped = numpy.clip(lr_deconvolved_image_5, 0, 1)
    lr_deconvolved_image_10_clipped = numpy.clip(lr_deconvolved_image_10, 0, 1)
    lr_deconvolved_image_20_clipped = numpy.clip(lr_deconvolved_image_20, 0, 1)
    deconvolved_image_clipped = numpy.clip(deconvolved_image, 0, 1)

    print(
        "Below in order: PSNR, norm spectral mutual info, norm mutual info, SSIM: "
    )
    printscore(
        "blurry image          :   ",
        psnr(image_clipped, blurred_image),
        spectral_mutual_information(image_clipped, blurred_image),
        mutual_information(image_clipped, blurred_image),
        ssim(image_clipped, blurred_image),
    )

    printscore(
        "noisy and blurry image:   ",
        psnr(image_clipped, noisy_blurred_image),
        spectral_mutual_information(image_clipped, noisy_blurred_image),
        mutual_information(image_clipped, noisy_blurred_image),
        ssim(image_clipped, noisy_blurred_image),
    )

    printscore(
        "lr deconv (n=2)       :    ",
        psnr(image_clipped, lr_deconvolved_image_2_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_2_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_2_clipped),
        ssim(image_clipped, lr_deconvolved_image_2_clipped),
    )

    printscore(
        "lr deconv (n=5)       :    ",
        psnr(image_clipped, lr_deconvolved_image_5_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_5_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_5_clipped),
        ssim(image_clipped, lr_deconvolved_image_5_clipped),
    )

    printscore(
        "lr deconv (n=10)      :    ",
        psnr(image_clipped, lr_deconvolved_image_10_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_10_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_10_clipped),
        ssim(image_clipped, lr_deconvolved_image_10_clipped),
    )

    printscore(
        "lr deconv (n=20)      :    ",
        psnr(image_clipped, lr_deconvolved_image_20_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_20_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_20_clipped),
        ssim(image_clipped, lr_deconvolved_image_20_clipped),
    )

    printscore(
        "ssi deconv            : ",
        psnr(image_clipped, deconvolved_image_clipped),
        spectral_mutual_information(image_clipped, deconvolved_image_clipped),
        mutual_information(image_clipped, deconvolved_image_clipped),
        ssim(image_clipped, deconvolved_image_clipped),
    )

    print(
        "NOTE: if you get a bad results for ssi, blame stochastic optimisation and retry..."
    )
    print(
        "      The training is done on the same exact image that we infer on, very few pixels..."
    )
    print("      Training should be more stable given more data...")

    with napari.gui_qt():
        viewer = napari.Viewer()
        viewer.add_image(image, name='image')
        viewer.add_image(blurred_image, name='blurred')
        viewer.add_image(noisy_blurred_image, name='noisy_blurred_image')
        viewer.add_image(lr_deconvolved_image_2_clipped,
                         name='lr_deconvolved_image_2')
        viewer.add_image(lr_deconvolved_image_5_clipped,
                         name='lr_deconvolved_image_5')
        viewer.add_image(lr_deconvolved_image_10_clipped,
                         name='lr_deconvolved_image_10')
        viewer.add_image(lr_deconvolved_image_20_clipped,
                         name='lr_deconvolved_image_20')
        viewer.add_image(deconvolved_image_clipped,
                         name='ssi_deconvolved_image')
Ejemplo n.º 28
0
def demo(image):
    image = image[0:512]

    _, psf_kernel = add_microscope_blur_2d(image[0])

    def degrade(image):
        blurred_image = add_microscope_blur_2d(image)[0]
        noisy_blurred_image = add_poisson_gaussian_noise(blurred_image,
                                                         alpha=0.001,
                                                         sigma=0.1,
                                                         sap=0.0001,
                                                         quant_bits=10,
                                                         fix_seed=False)
        return noisy_blurred_image

    degraded_stack_filepath = "/media/royer/data1/aydin_datasets/__benchmark_datasets/_DIV2K_train_HR/div2k_degraded.tiff"
    if not exists(degraded_stack_filepath):
        noisy_blurred_image = numpy.stack([degrade(plane) for plane in image])
        imwrite(degraded_stack_filepath, noisy_blurred_image)
    else:
        noisy_blurred_image = imread(degraded_stack_filepath)

    import napari

    with napari.gui_qt():
        viewer = napari.Viewer()
        viewer.add_image(image, name='image')
        viewer.add_image(noisy_blurred_image, name='noisy_blurred_image')
        viewer.add_image(psf_kernel, name='psf_kernel')

    it_deconv = PTCNNDeconvolution(max_epochs=2000,
                                   patience=64,
                                   batch_size=1,
                                   learning_rate=0.01,
                                   normaliser_type='identity',
                                   psf_kernel=psf_kernel,
                                   model_class=UNet,
                                   masking=True,
                                   masking_density=0.05,
                                   loss='l1')

    batch_dim = (True, False, False)

    start = time.time()
    it_deconv.train(noisy_blurred_image, batch_dims=batch_dim)
    stop = time.time()
    print(f"Training: elapsed time:  {stop - start} ")

    start = time.time()
    deconvolved_image = it_deconv.translate(noisy_blurred_image,
                                            batch_dims=batch_dim)
    stop = time.time()
    print(f"inference: elapsed time:  {stop - start} ")

    image = numpy.clip(image, 0, 1)
    deconvolved_image = numpy.clip(deconvolved_image, 0, 1)

    torch.save(
        it_deconv.model.state_dict(),
        "/media/royer/data1/aydin_datasets/__benchmark_datasets/_DIV2K_train_HR/div2k.unet.ptm"
    )

    import napari

    with napari.gui_qt():
        viewer = napari.Viewer()
        viewer.add_image(image, name='image')
        viewer.add_image(noisy_blurred_image, name='noisy_blurred_image')
        viewer.add_image(deconvolved_image, name='ssi_deconvolved_image')
Ejemplo n.º 29
0
"""
Displays an 100GB zarr file of lattice light sheet data
"""

import numpy as np
import napari
import dask.array as da
from skimage import filters

file_name = 'data/LLSM/AOLLSM_m4_560nm.zarr'
data = da.from_zarr(file_name)

with napari.gui_qt():
    viewer = napari.Viewer(axis_labels='tzyx')
    viewer.add_image(data,
                     name='AOLLSM_m4_560nm',
                     multiscale=False,
                     scale=[1, 3, 1, 1],
                     contrast_limits=[0, 150_000],
                     colormap='magma')
    viewer.add_image(data.map_blocks(filters.sobel),
                     name='sobel',
                     scale=[1, 3, 1, 1],
                     contrast_limits=[0, 10_000],
                     visible=False)
Ejemplo n.º 30
0
def demo(image_clipped):
    image_clipped = normalise(image_clipped.astype(numpy.float32))
    blurred_image, psf_kernel = add_microscope_blur_2d(image_clipped)
    # noisy_blurred_image = add_noise(blurred_image, intensity=None, variance=0.01, sap=0.01, clip=True)
    noisy_blurred_image = add_poisson_gaussian_noise(blurred_image,
                                                     alpha=0.001,
                                                     sigma=0.1,
                                                     sap=0.01,
                                                     quant_bits=10)

    lr = ImageTranslatorLRDeconv(psf_kernel=psf_kernel,
                                 max_num_iterations=30,
                                 backend="cupy")
    lr.train(noisy_blurred_image)
    lr_deconvolved_image = lr.translate(noisy_blurred_image)

    it_deconv = PTCNNDeconvolution(max_epochs=3000,
                                   patience=100,
                                   batch_size=8,
                                   learning_rate=0.01,
                                   normaliser_type='identity',
                                   psf_kernel=psf_kernel,
                                   model_class=UNet,
                                   masking=True,
                                   masking_density=0.05,
                                   loss='l2',
                                   bounds_loss=0.1,
                                   sharpening=0,
                                   entropy=0,
                                   broaden_psf=1)

    start = time.time()
    it_deconv.train(noisy_blurred_image)
    stop = time.time()
    print(f"Training: elapsed time:  {stop - start} ")

    start = time.time()
    deconvolved_image = it_deconv.translate(noisy_blurred_image)
    stop = time.time()
    print(f"inference: elapsed time:  {stop - start} ")

    image_clipped = numpy.clip(image_clipped, 0, 1)
    lr_deconvolved_image_clipped = numpy.clip(lr_deconvolved_image, 0, 1)
    deconvolved_image_clipped = numpy.clip(deconvolved_image, 0, 1)

    print(
        "Below in order: PSNR, norm spectral mutual info, norm mutual info, SSIM: "
    )
    printscore(
        "blurry image          :   ",
        psnr(image_clipped, blurred_image),
        spectral_mutual_information(image_clipped, blurred_image),
        mutual_information(image_clipped, blurred_image),
        ssim(image_clipped, blurred_image),
    )

    printscore(
        "noisy and blurry image:   ",
        psnr(image_clipped, noisy_blurred_image),
        spectral_mutual_information(image_clipped, noisy_blurred_image),
        mutual_information(image_clipped, noisy_blurred_image),
        ssim(image_clipped, noisy_blurred_image),
    )

    printscore(
        "lr deconv             :    ",
        psnr(image_clipped, lr_deconvolved_image_clipped),
        spectral_mutual_information(image_clipped,
                                    lr_deconvolved_image_clipped),
        mutual_information(image_clipped, lr_deconvolved_image_clipped),
        ssim(image_clipped, lr_deconvolved_image_clipped),
    )

    printscore(
        "ssi deconv            : ",
        psnr(image_clipped, deconvolved_image_clipped),
        spectral_mutual_information(image_clipped, deconvolved_image_clipped),
        mutual_information(image_clipped, deconvolved_image_clipped),
        ssim(image_clipped, deconvolved_image_clipped),
    )

    with napari.gui_qt():
        viewer = napari.Viewer()
        viewer.add_image(image, name='image')
        viewer.add_image(blurred_image, name='blurred')
        viewer.add_image(noisy_blurred_image, name='noisy_blurred_image')
        viewer.add_image(lr_deconvolved_image, name='lr_deconvolved_image')
        viewer.add_image(deconvolved_image, name='ssi_deconvolved_image')