Example #1
0
def test_camera():
    """Test camera."""
    viewer = ViewerModel()
    np.random.seed(0)
    data = np.random.random((10, 15, 20))
    viewer.add_image(data)
    assert len(viewer.layers) == 1
    assert np.all(viewer.layers[0].data == data)
    assert viewer.dims.ndim == 3

    assert viewer.dims.ndisplay == 2
    assert viewer.camera.center == (0, 7, 9.5)
    assert viewer.camera.angles == (0, 0, 90)

    viewer.dims.ndisplay = 3
    assert viewer.dims.ndisplay == 3
    assert viewer.camera.center == (4.5, 7, 9.5)
    assert viewer.camera.angles == (0, 0, 90)

    viewer.dims.ndisplay = 2
    assert viewer.dims.ndisplay == 2
    assert viewer.camera.center == (0, 7, 9.5)
    assert viewer.camera.angles == (0, 0, 90)
Example #2
0
def test_selection():
    """Test only last added is selected."""
    viewer = ViewerModel()
    viewer.add_image(np.random.random((10, 10)))
    assert viewer.layers[0].selected is True

    viewer.add_image(np.random.random((10, 10)))
    assert [lay.selected for lay in viewer.layers] == [False, True]

    viewer.add_image(np.random.random((10, 10)))
    assert [lay.selected for lay in viewer.layers] == [False] * 2 + [True]

    for lay in viewer.layers:
        lay.selected = True
    viewer.add_image(np.random.random((10, 10)))
    assert [lay.selected for lay in viewer.layers] == [False] * 3 + [True]
Example #3
0
def test_sliced_world_extent():
    """Test world extent after adding layers and slicing."""
    np.random.seed(0)
    viewer = ViewerModel()

    # Empty data is taken to be 512 x 512
    np.testing.assert_allclose(viewer._sliced_extent_world[0], (-0.5, -0.5))
    np.testing.assert_allclose(viewer._sliced_extent_world[1], (511.5, 511.5))

    # Add one layer
    viewer.add_image(np.random.random((6, 10, 15)),
                     scale=(3, 1, 1),
                     translate=(10, 20, 5))
    np.testing.assert_allclose(viewer.layers.extent.world[0], (8.5, 19.5, 4.5))
    np.testing.assert_allclose(viewer.layers.extent.world[1],
                               (26.5, 29.5, 19.5))
    np.testing.assert_allclose(viewer._sliced_extent_world[0], (19.5, 4.5))
    np.testing.assert_allclose(viewer._sliced_extent_world[1], (29.5, 19.5))

    # Change displayed dims order
    viewer.dims.order = (1, 2, 0)
    np.testing.assert_allclose(viewer._sliced_extent_world[0], (4.5, 8.5))
    np.testing.assert_allclose(viewer._sliced_extent_world[1], (19.5, 26.5))
Example #4
0
def test_swappable_dims():
    """Test swapping dims after adding layers."""
    viewer = ViewerModel()
    np.random.seed(0)
    image_data = np.random.random((7, 12, 10, 15))
    viewer.add_image(image_data)
    assert np.all(viewer.layers['Image']._data_view == image_data[0, 0, :, :])

    points_data = np.random.randint(6, size=(10, 4))
    viewer.add_points(points_data)

    vectors_data = np.random.randint(6, size=(10, 2, 4))
    viewer.add_vectors(vectors_data)

    labels_data = np.random.randint(20, size=(7, 12, 10, 15))
    viewer.add_labels(labels_data)
    assert np.all(viewer.layers['Labels']._data_raw == labels_data[0, 0, :, :])

    # Swap dims
    viewer.dims.order = [0, 2, 1, 3]
    assert viewer.dims.order == [0, 2, 1, 3]
    assert np.all(viewer.layers['Image']._data_view == image_data[0, :, 0, :])
    assert np.all(viewer.layers['Labels']._data_raw == labels_data[0, :, 0, :])
Example #5
0
def test_dask_local_unoptimized_slicing(delayed_dask_stack, monkeypatch):
    """Prove that the dask_configure function works with a counterexample."""
    # make sure we are not caching for this test, which also tests that we
    # can turn off caching
    resize_dask_cache(0)
    assert _dask_utils._DASK_CACHE.cache.available_bytes == 0

    monkeypatch.setattr(layers.base.base, 'configure_dask',
                        lambda *_: nullcontext)

    # add dask stack to viewer.
    v = ViewerModel()
    dask_stack = delayed_dask_stack['stack']
    v.add_image(dask_stack, cache=False)
    # the first and the middle stack will be loaded
    assert delayed_dask_stack['calls'] == 2

    # without optimized dask slicing, we get a new call to the get_array func
    # (which "re-reads" the full z stack) EVERY time we change the Z plane
    # even though we've already read this full timepoint.
    for i in range(3):
        v.dims.set_point(1, i)
        assert delayed_dask_stack['calls'] == 2 + 1 + i  # 😞

    # of course we still incur calls when moving to a new timepoint...
    v.dims.set_point(0, 1)
    v.dims.set_point(0, 2)
    assert delayed_dask_stack['calls'] == 7

    # without the cache we ALSO incur calls when returning to previously loaded
    # timepoints 😭
    v.dims.set_point(0, 1)
    v.dims.set_point(0, 0)
    v.dims.set_point(0, 3)
    # all told, we have ~2x as many calls as the optimized version above.
    # (should be exactly 8 calls, but for some reason, sometimes less on CI)
    assert delayed_dask_stack['calls'] >= 10
Example #6
0
def test_cursor_ndim_matches_layer():
    """Test cursor position ndim matches viewer ndim after update."""
    viewer = ViewerModel()
    np.random.seed(0)
    im = viewer.add_image(np.random.random((10, 10)))
    assert viewer.dims.ndim == 2
    assert len(viewer.cursor.position) == 2

    im.data = np.random.random((10, 10, 10))
    assert viewer.dims.ndim == 3
    assert len(viewer.cursor.position) == 3

    im.data = np.random.random((10, 10))
    assert viewer.dims.ndim == 2
    assert len(viewer.cursor.position) == 2
Example #7
0
def test_qt_viewer_data_integrity(qtbot, dtype):
    """Test that the viewer doesn't change the underlying array."""

    image = np.random.rand(10, 32, 32)
    image *= 200 if dtype.endswith('8') else 2**14
    image = image.astype(dtype)
    imean = image.mean()

    viewer = ViewerModel()
    view = QtViewer(viewer)
    qtbot.addWidget(view)

    viewer.add_image(image.copy())
    datamean = viewer.layers[0].data.mean()
    assert datamean == imean
    # toggle dimensions
    viewer.dims.ndisplay = 3
    datamean = viewer.layers[0].data.mean()
    assert datamean == imean
    # back to 2D
    viewer.dims.ndisplay = 2
    datamean = viewer.layers[0].data.mean()
    assert datamean == imean
    view.shutdown()
Example #8
0
def test_add_remove_layer_dims_change():
    """Test dims change appropriately when adding and removing layers."""
    np.random.seed(0)
    viewer = ViewerModel()

    # Check ndim starts at 2
    assert viewer.dims.ndim == 2

    # Check ndim increase to 3 when 3D data added
    data = np.random.random((10, 15, 20))
    layer = viewer.add_image(data)
    assert len(viewer.layers) == 1
    assert np.all(viewer.layers[0].data == data)
    assert viewer.dims.ndim == 3

    # Remove layer and check ndim returns to 2
    viewer.layers.remove(layer)
    assert len(viewer.layers) == 0
    assert viewer.dims.ndim == 2
Example #9
0
def test_dask_global_optimized_slicing(delayed_dask_stack, monkeypatch):
    """Test that dask_configure reduces compute with dask stacks."""

    # add dask stack to the viewer, making sure to pass multiscale and clims
    v = ViewerModel()
    dask_stack = delayed_dask_stack['stack']
    layer = v.add_image(dask_stack)
    # the first and the middle stack will be loaded
    assert delayed_dask_stack['calls'] == 2

    with layer.dask_optimized_slicing() as (_, cache):
        assert cache.cache.available_bytes > 0
        assert cache.active
        # make sure the cache actually has been populated
        assert len(cache.cache.heap.heap) > 0

    assert not cache.active  # only active inside of the context

    # changing the Z plane should never incur calls
    # since the stack has already been loaded (& it is chunked as a 3D array)
    current_z = v.dims.point[1]
    for i in range(3):
        v.dims.set_point(1, current_z + i)
        assert delayed_dask_stack['calls'] == 2  # still just the first call

    # changing the timepoint will, of course, incur some compute calls
    initial_t = v.dims.point[0]
    v.dims.set_point(0, initial_t + 1)
    assert delayed_dask_stack['calls'] == 3
    v.dims.set_point(0, initial_t + 2)
    assert delayed_dask_stack['calls'] == 4

    # but going back to previous timepoints should not, since they are cached
    v.dims.set_point(0, initial_t + 1)
    v.dims.set_point(0, initial_t + 0)
    assert delayed_dask_stack['calls'] == 4
    # again, visiting a new point will increment the counter
    v.dims.set_point(0, initial_t + 3)
    assert delayed_dask_stack['calls'] == 5
Example #10
0
def test_dask_unoptimized_slicing(delayed_dask_stack, monkeypatch):
    """Prove that the dask_configure function works with a counterexample."""
    # we start with a cache...but then intentionally turn it off per-layer.
    resize_dask_cache(10000)
    assert _dask_utils._DASK_CACHE.cache.available_bytes == 10000

    # add dask stack to viewer.
    v = ViewerModel()
    dask_stack = delayed_dask_stack['stack']
    layer = v.add_image(dask_stack, cache=False)
    # the first and the middle stack will be loaded
    assert delayed_dask_stack['calls'] == 2

    with layer.dask_optimized_slicing() as (_, cache):
        assert cache is None

    # without optimized dask slicing, we get a new call to the get_array func
    # (which "re-reads" the full z stack) EVERY time we change the Z plane
    # even though we've already read this full timepoint.
    current_z = v.dims.point[1]
    for i in range(3):
        v.dims.set_point(1, current_z + i)
        assert delayed_dask_stack['calls'] == 2 + i  # 😞

    # of course we still incur calls when moving to a new timepoint...
    initial_t = v.dims.point[0]
    v.dims.set_point(0, initial_t + 1)
    v.dims.set_point(0, initial_t + 2)
    assert delayed_dask_stack['calls'] == 6

    # without the cache we ALSO incur calls when returning to previously loaded
    # timepoints 😭
    v.dims.set_point(0, initial_t + 1)
    v.dims.set_point(0, initial_t + 0)
    v.dims.set_point(0, initial_t + 3)
    # all told, we have ~2x as many calls as the optimized version above.
    # (should be exactly 9 calls, but for some reason, sometimes more on CI)
    assert delayed_dask_stack['calls'] >= 9
Example #11
0
def test_dask_cache_resizing(delayed_dask_stack):
    """Test that we can spin up, resize, and spin down the cache."""

    # make sure we have a cache
    # big enough for 10+ (10, 10, 10) "timepoints"
    resize_dask_cache(100000)

    # add dask stack to the viewer, making sure to pass multiscale and clims

    v = ViewerModel()
    dask_stack = delayed_dask_stack['stack']

    v.add_image(dask_stack)
    assert _dask_utils._DASK_CACHE.cache.available_bytes > 0
    # make sure the cache actually has been populated
    assert len(_dask_utils._DASK_CACHE.cache.heap.heap) > 0

    # we can resize that cache back to 0 bytes
    resize_dask_cache(0)
    assert _dask_utils._DASK_CACHE.cache.available_bytes == 0

    # adding a 2nd stack should not adjust the cache size once created
    v.add_image(dask_stack)
    assert _dask_utils._DASK_CACHE.cache.available_bytes == 0
    # and the cache will remain empty regardless of what we do
    for i in range(3):
        v.dims.set_point(1, i)
    assert len(_dask_utils._DASK_CACHE.cache.heap.heap) == 0

    # but we can always spin it up again
    resize_dask_cache(1e4)
    assert _dask_utils._DASK_CACHE.cache.available_bytes == 1e4
    # and adding a new image doesn't change the size
    v.add_image(dask_stack)
    assert _dask_utils._DASK_CACHE.cache.available_bytes == 1e4
    # but the cache heap is getting populated again
    for i in range(3):
        v.dims.set_point(0, i)
    assert len(_dask_utils._DASK_CACHE.cache.heap.heap) > 0
Example #12
0
def test_add_image_multichannel_share_memory():
    viewer = ViewerModel()
    image = np.random.random((10, 5, 64, 64))
    layers = viewer.add_image(image, channel_axis=1)
    for layer in layers:
        assert np.may_share_memory(image, layer.data)
Example #13
0
class ImageView(QWidget):
    position_changed = Signal([int, int, int], [int, int])
    component_clicked = Signal(int)
    text_info_change = Signal(str)
    hide_signal = Signal(bool)
    view_changed = Signal()
    image_added = Signal()

    def __init__(
        self,
        settings: BaseSettings,
        channel_property: ChannelProperty,
        name: str,
        parent: Optional[QWidget] = None,
        ndisplay=2,
    ):
        super().__init__(parent=parent)

        self.settings = settings
        self.channel_property = channel_property
        self.name = name
        self.image_info: Dict[str, ImageInfo] = {}
        self.current_image = ""
        self._current_order = "xy"
        self.components = None
        self.worker_list = []

        self.viewer = Viewer(ndisplay=ndisplay)
        self.viewer.theme = self.settings.theme_name
        self.viewer_widget = NapariQtViewer(self.viewer)
        self.image_state = ImageShowState(settings, name)
        self.channel_control = ColorComboBoxGroup(settings,
                                                  name,
                                                  channel_property,
                                                  height=30)
        self.ndim_btn = QtNDisplayButton(self.viewer)
        self.reset_view_button = QtViewerPushButton(self.viewer, "home",
                                                    "Reset view",
                                                    self._reset_view)
        self.roll_dim_button = QtViewerPushButton(self.viewer, "roll",
                                                  "Roll dimension",
                                                  self._rotate_dim)
        self.roll_dim_button.setContextMenuPolicy(Qt.CustomContextMenu)
        self.roll_dim_button.customContextMenuRequested.connect(
            self._dim_order_menu)
        self.mask_chk = QCheckBox()
        self.mask_label = QLabel("Mask:")

        self.btn_layout = QHBoxLayout()
        self.btn_layout.addWidget(self.reset_view_button)
        self.btn_layout.addWidget(self.ndim_btn)
        self.btn_layout.addWidget(self.roll_dim_button)
        self.btn_layout.addWidget(self.channel_control, 1)
        self.btn_layout.addWidget(self.mask_label)
        self.btn_layout.addWidget(self.mask_chk)
        self.btn_layout2 = QHBoxLayout()
        layout = QVBoxLayout()
        layout.addLayout(self.btn_layout)
        layout.addLayout(self.btn_layout2)
        layout.addWidget(self.viewer_widget)

        self.setLayout(layout)

        self.channel_control.change_channel.connect(self.change_visibility)
        self.viewer.events.status.connect(self.print_info)

        settings.mask_changed.connect(self.set_mask)
        settings.mask_representation_changed.connect(
            self.update_mask_parameters)
        settings.roi_changed.connect(self.set_roi)
        settings.roi_clean.connect(self.set_roi)
        settings.image_changed.connect(self.set_image)
        settings.image_spacing_changed.connect(self.update_spacing_info)
        # settings.labels_changed.connect(self.paint_layer)
        self.old_scene: BaseCamera = self.viewer_widget.view.scene

        self.image_state.coloring_changed.connect(self.update_roi_coloring)
        self.image_state.roi_presented_changed.connect(
            self.update_roi_representation)
        self.image_state.borders_changed.connect(
            self.update_roi_representation)
        self.mask_chk.stateChanged.connect(self.change_mask_visibility)
        self.viewer_widget.view.scene.transform.changed.connect(
            self._view_changed, position="last")
        try:
            self.viewer.dims.events.current_step.connect(self._view_changed,
                                                         position="last")
        except AttributeError:
            self.viewer.dims.events.axis.connect(self._view_changed,
                                                 position="last")
        self.viewer.dims.events.ndisplay.connect(self._view_changed,
                                                 position="last")
        if hasattr(self.viewer.dims.events, "ndisplay"):
            self.viewer.dims.events.ndisplay.connect(self._view_changed,
                                                     position="last")
            self.viewer.dims.events.ndisplay.connect(self.camera_change,
                                                     position="last")
        else:
            self.viewer.dims.events.camera.connect(self._view_changed,
                                                   position="last")
            self.viewer.dims.events.camera.connect(self.camera_change,
                                                   position="last")
        self.viewer.events.reset_view.connect(self._view_changed,
                                              position="last")

    def _dim_order_menu(self, point: QPoint):
        menu = QMenu()
        for key in ORDER_DICT:
            action = menu.addAction(key)
            action.triggered.connect(partial(self._set_new_order, key))
            if key == self._current_order:
                font = action.font()
                font.setBold(True)
                action.setFont(font)

        menu.exec_(self.roll_dim_button.mapToGlobal(point))

    def _set_new_order(self, text: str):
        self._current_order = text
        self.viewer.dims.order = ORDER_DICT[text]
        self.update_roi_representation()

    def _reset_view(self):
        self._set_new_order("xy")
        self.viewer.dims.order = ORDER_DICT[self._current_order]
        self.viewer.reset_view()

    def _rotate_dim(self):
        self._set_new_order(NEXT_ORDER[self._current_order])

    def camera_change(self, _args):
        self.old_scene.transform.changed.disconnect(self._view_changed)
        self.old_scene: BaseCamera = self.viewer_widget.view.camera
        self.old_scene.transform.changed.connect(self._view_changed,
                                                 position="last")

    def _view_changed(self, _args):
        self.view_changed.emit()

    def get_state(self):
        return {
            "ndisplay": self.viewer.dims.ndisplay,
            "point": self.viewer.dims.point,
            "camera": self.viewer_widget.view.camera.get_state(),
        }

    def set_state(self, dkt):
        if "ndisplay" in dkt and self.viewer.dims.ndisplay != dkt["ndisplay"]:
            self.viewer.dims.ndisplay = dkt["ndisplay"]
            return
        if "point" in dkt:
            for i, val in enumerate(dkt["point"]):
                self.viewer.dims.set_point(i, val)
        if "camera" in dkt:
            try:
                self.viewer_widget.view.camera.set_state(dkt["camera"])
            except KeyError:
                pass

    def change_mask_visibility(self):
        for image_info in self.image_info.values():
            if image_info.mask is not None:
                image_info.mask.visible = self.mask_chk.isChecked()

    def update_spacing_info(self, image: Optional[Image] = None) -> None:
        """
        Update spacing of image if not provide, then use image pointed by settings.

        :param Optional[Image] image: image which spacing should be updated.
        :return: None
        """
        if image is None:
            image = self.settings.image

        if image.file_path not in self.image_info:
            raise ValueError("Image not registered")

        image_info = self.image_info[image.file_path]

        for layer in image_info.layers:
            layer.scale = image.normalized_scaling()

        if image_info.roi is not None:
            image_info.roi.scale = image.normalized_scaling()

        if image_info.mask is not None:
            image_info.mask.scale = image.normalized_scaling()

    def print_info(self, value):
        if not self.viewer.active_layer:
            return
        cords = np.array(
            [int(x) for x in self.viewer.active_layer.coordinates])
        bright_array = []
        components = []
        for image_info in self.image_info.values():
            if not image_info.coords_in(cords):
                continue
            moved_coords = image_info.translated_coords(cords)
            for layer in image_info.layers:
                if layer.visible:
                    bright_array.append(layer.data[tuple(moved_coords)])
            if image_info.roi_info.roi is not None and image_info.roi is not None:
                val = image_info.roi_info.roi[tuple(moved_coords)]
                if val:
                    components.append(val)

        if not bright_array and not components:
            self.text_info_change.emit("")
            return
        text = f"{cords}: "
        if bright_array:
            if len(bright_array) == 1:
                text += str(bright_array[0])
            else:
                text += str(bright_array)
        self.components = components
        if components:
            if len(components) == 1:
                text += f" component: {components[0]}"
            else:
                text += f" components: {components}"
        self.text_info_change.emit(text)

    def get_control_view(self) -> ImageShowState:
        return self.image_state

    @staticmethod
    def convert_to_vispy_colormap(colormap: ColorMap):
        return Colormap(ColorArray(create_color_map(colormap) / 255))

    def mask_opacity(self) -> float:
        """Get mask opacity"""
        return self.settings.get_from_profile("mask_presentation_opacity", 1)

    def mask_color(self) -> Colormap:
        """Get mask marking color"""
        color = Color(
            np.divide(
                self.settings.get_from_profile("mask_presentation_color",
                                               [255, 255, 255]), 255))
        return Colormap(ColorArray(["black", color.rgba]))

    def get_image(self, image: Optional[Image]) -> Image:
        if image is not None:
            return image
        if self.current_image not in self.image_info:
            return self.settings.image
        return self.image_info[self.current_image].image

    def set_roi(self,
                roi_info: Optional[ROIInfo] = None,
                image: Optional[Image] = None) -> None:
        image = self.get_image(image)
        if roi_info is None:
            roi_info = self.settings.roi_info
        image_info = self.image_info[image.file_path]
        if image_info.roi is not None:
            self.viewer.layers.unselect_all()
            image_info.roi.selected = True
            self.viewer.layers.remove_selected()
            image_info.roi = None

        if roi_info.roi is None:
            return

        image_info.roi_info = roi_info
        image_info.roi_count = max(
            roi_info.bound_info) if roi_info.bound_info else 0
        self.add_roi_layer(image_info)
        image_info.roi.colormap = self.get_roi_view_parameters(image_info)
        image_info.roi.opacity = self.image_state.opacity

    def get_roi_view_parameters(self, image_info: ImageInfo) -> Colormap:
        colors = self.settings.label_colors / 255
        if self.image_state.show_label == LabelEnum.Not_show or image_info.roi_count == 0 or colors.size == 0:
            colors = np.array([[0, 0, 0, 0], [0, 0, 0, 0]])
        else:
            repeat = int(np.ceil(image_info.roi_count / colors.shape[0]))
            colors = np.concatenate([colors] * repeat)
            colors = np.concatenate(
                [colors,
                 np.ones(colors.shape[0]).reshape(colors.shape[0], 1)],
                axis=1)
            colors = np.concatenate([[[0, 0, 0, 0]],
                                     colors[:image_info.roi_count]])
            if self.image_state.show_label == LabelEnum.Show_selected:
                try:
                    colors *= self.settings.components_mask().reshape(
                        (colors.shape[0], 1))
                except ValueError:
                    pass
        control_points = [0] + list(
            np.linspace(1 / (2 * colors.shape[0]),
                        1,
                        endpoint=True,
                        num=colors.shape[0]))
        return Colormap(colors, controls=control_points, interpolation="zero")

    def update_roi_coloring(self):
        for image_info in self.image_info.values():
            if image_info.roi is None:
                continue
            image_info.roi.colormap = self.get_roi_view_parameters(image_info)
            image_info.roi.opacity = self.image_state.opacity

    def remove_all_roi(self):
        self.viewer.layers.unselect_all()
        for image_info in self.image_info.values():
            if image_info.roi is None:
                continue
            image_info.roi.selected = True
            image_info.roi = None

        self.viewer.layers.remove_selected()

    def add_roi_layer(self, image_info: ImageInfo):
        if image_info.roi_info.roi is None:
            return
        try:
            max_num = max(1, image_info.roi_count)
        except ValueError:
            max_num = 1
        roi = image_info.roi_info.alternative.get(
            self.image_state.roi_presented, image_info.roi_info.roi)
        if self.image_state.only_borders:

            data = calculate_borders(
                roi.transpose(ORDER_DICT[self._current_order]),
                self.image_state.borders_thick // 2,
                self.viewer.dims.ndisplay == 2,
            ).transpose(np.argsort(ORDER_DICT[self._current_order]))
            image_info.roi = self.viewer.add_image(
                data,
                scale=image_info.image.normalized_scaling(),
                contrast_limits=[0, max_num],
            )
        else:
            image_info.roi = self.viewer.add_image(
                roi,
                scale=image_info.image.normalized_scaling(),
                contrast_limits=[0, max_num],
                name="ROI",
                blending="translucent",
            )
        image_info.roi._interpolation[3] = Interpolation3D.NEAREST

    def update_roi_representation(self):
        self.remove_all_roi()

        for image_info in self.image_info.values():
            self.add_roi_layer(image_info)

        self.update_roi_coloring()

    def set_mask(self,
                 mask: Optional[np.ndarray] = None,
                 image: Optional[Image] = None) -> None:
        image = self.get_image(image)
        if image.file_path not in self.image_info:
            raise ValueError("Image not added to viewer")
        if mask is None:
            mask = image.mask

        image_info = self.image_info[image.file_path]
        if image_info.mask is not None:
            self.viewer.layers.unselect_all()
            image_info.mask.selected = True
            self.viewer.layers.remove_selected()
            image_info.mask = None

        if mask is None:
            return

        mask_marker = mask == 0

        layer = self.viewer.add_image(mask_marker,
                                      scale=image.normalized_scaling(),
                                      blending="additive")
        layer.colormap = self.mask_color()
        layer.opacity = self.mask_opacity()
        layer.visible = self.mask_chk.isChecked()
        image_info.mask = layer

    def update_mask_parameters(self):
        opacity = self.mask_opacity()
        colormap = self.mask_color()
        for image_info in self.image_info.values():
            if image_info.mask is not None:
                image_info.mask.opacity = opacity
                image_info.mask.colormap = colormap

    def set_image(self, image: Optional[Image] = None):
        self.image_info = {}
        self.add_image(image, True)

    def has_image(self, image: Image):
        return image.file_path in self.image_info

    @staticmethod
    def calculate_filter(
            array: np.ndarray,
            parameters: Tuple[NoiseFilterType, float]) -> Optional[np.ndarray]:
        if parameters[0] == NoiseFilterType.No or parameters[1] == 0:
            return array
        if parameters[0] == NoiseFilterType.Gauss:
            return gaussian(array, parameters[1])
        return median(array, int(parameters[1]))

    def _remove_worker(self, sender):
        for worker in self.worker_list:
            signals = "_signals" if hasattr(worker, "_signals") else "signals"
            if sender is getattr(worker, signals):
                self.worker_list.remove(worker)
                break
        else:
            print("[_remove_worker]", sender)

    def _add_layer_util(self, index, layer, filters):
        self.viewer.add_layer(layer)

        def set_data(val):
            self._remove_worker(self.sender())
            data_, layer_ = val
            if data_ is None:
                return
            if layer_ not in self.viewer.layers:
                return
            layer_.data = data_

        @thread_worker(connect={"returned": set_data})
        def calc_filter(j, layer_):
            if filters[j][0] == NoiseFilterType.No or filters[j][1] == 0:
                return None, layer_
            return self.calculate_filter(layer_.data,
                                         parameters=filters[j]), layer_

        worker = calc_filter(index, layer)
        self.worker_list.append(worker)

    def _add_image(self, image_data: Tuple[ImageInfo, bool]):
        self._remove_worker(self.sender())

        image_info, replace = image_data
        image = image_info.image
        if replace:
            self.viewer.layers.select_all()
            self.viewer.layers.remove_selected()

        filters = self.channel_control.get_filter()
        for i, layer in enumerate(image_info.layers):
            self._add_layer_util(i, layer, filters)

        self.image_info[image.file_path].filter_info = filters
        self.image_info[image.file_path].layers = image_info.layers
        self.current_image = image.file_path
        self.viewer.reset_view()
        if self.viewer.layers:
            self.viewer.layers[-1].selected = True

        for i, axis in enumerate(image.axis_order):
            if axis == "C":
                continue
            self.viewer.dims.set_point(
                i, image.shape[i] * image.normalized_scaling()[i] // 2)
        if self.image_info[image.file_path].roi is not None:
            self.set_roi()
        if image_info.image.mask is not None:
            self.set_mask()
        self.image_added.emit()

    def add_image(self, image: Optional[Image], replace=False):
        if image is None:
            image = self.settings.image

        if not image.channels:
            raise ValueError("Need non empty image")

        if image.file_path in self.image_info:
            raise ValueError("Image already added")

        self.image_info[image.file_path] = ImageInfo(image, [])

        channels = image.channels
        if self.image_info and not replace:
            channels = max(
                channels,
                *[x.image.channels for x in self.image_info.values()])

        self.channel_control.set_channels(channels)
        visibility = self.channel_control.channel_visibility
        limits = self.channel_control.get_limits()
        ranges = image.get_ranges()
        limits = [
            ranges[i] if x is None else x
            for i, x in zip(range(image.channels), limits)
        ]
        gamma = self.channel_control.get_gamma()
        colormaps = [
            self.convert_to_vispy_colormap(
                self.channel_control.selected_colormaps[i])
            for i in range(image.channels)
        ]
        parameters = ImageParameters(limits, visibility, gamma, colormaps,
                                     image.normalized_scaling(),
                                     len(self.viewer.layers))

        self._prepare_layers(image, parameters, replace)

        return image

    def _prepare_layers(self, image, parameters, replace):
        worker = prepare_layers(image, parameters, replace)
        worker.returned.connect(self._add_image)
        self.worker_list.append(worker)
        worker.start()

    def images_bounds(self) -> Tuple[List[int], List[int]]:
        ranges = []
        for image_info in self.image_info.values():
            if not image_info.layers:
                continue
            ranges = [(min(a, b), max(c, d), min(e, f)) for (a, c, e), (
                b, d,
                f) in itertools.zip_longest(image_info.layers[0].dims.range,
                                            ranges,
                                            fillvalue=(np.inf, -np.inf,
                                                       np.inf))]

        visible = [ranges[i] for i in self.viewer.dims.displayed]
        min_shape, max_shape, _ = zip(*visible)
        size = np.subtract(max_shape, min_shape)
        return size, min_shape

    @staticmethod
    def _shift_layer(layer: Layer, translate_2d):
        translate = [0] * layer.ndim
        translate[-2:] = translate_2d
        layer.translate_grid = translate

    def grid_view(self):
        """Present multiple images in grid view"""
        n_row = np.ceil(np.sqrt(len(self.image_info))).astype(int)
        n_row = max(1, n_row)
        scene_size, _ = self.images_bounds()
        for image_info, pos in zip(self.image_info.values(),
                                   itertools.product(range(n_row), repeat=2)):
            translate_2d = np.multiply(scene_size[-2:], pos)
            for layer in image_info.layers:
                self._shift_layer(layer, translate_2d)

            if image_info.mask is not None:
                self._shift_layer(image_info.mask, translate_2d)

            if image_info.roi is not None:
                self._shift_layer(image_info.roi, translate_2d)
        self.viewer.reset_view()

    def change_visibility(self, name: str, index: int):
        for image_info in self.image_info.values():
            if len(image_info.layers) > index:
                image_info.layers[
                    index].visible = self.channel_control.channel_visibility[
                        index]
                if self.channel_control.channel_visibility[index]:
                    image_info.layers[
                        index].colormap = self.convert_to_vispy_colormap(
                            self.channel_control.selected_colormaps[index])
                    limits = self.channel_control.get_limits()[index]
                    limits = image_info.image.get_ranges(
                    )[index] if limits is None else limits
                    image_info.layers[index].contrast_limits = limits
                    image_info.layers[
                        index].gamma = self.channel_control.get_gamma()[index]
                    filter_type = self.channel_control.get_filter()[index]
                    if filter_type != image_info.filter_info[index]:
                        image_info.layers[index].data = self.calculate_filter(
                            image_info.image.get_channel(index), filter_type)
                        image_info.filter_info[index] = filter_type

    def reset_image_size(self):
        self.viewer.reset_view()

    def set_theme(self, theme: str):
        self.viewer.theme = theme

    def closeEvent(self, event):
        for worker in self.worker_list:
            worker.quit()
        super().closeEvent(event)

    def get_tool_tip_text(self) -> str:
        image = self.settings.image
        image_info = self.image_info[image.file_path]
        text_list = []
        for el in self.components:
            text_list.append(
                _print_dict(image_info.roi_info.annotations.get(el, {})))
        return " ".join(text_list)

    def event(self, event: QEvent):
        if event.type() == QEvent.ToolTip and self.components:
            text = self.get_tool_tip_text()
            if text:
                QToolTip.showText(event.globalPos(), text)
        return super().event(event)
Example #14
0
def test_add_empty_shapes_layer():
    viewer = ViewerModel()
    image = np.random.random((8, 64, 64))
    image_layer = viewer.add_image(image)
    shp = viewer.add_shapes()
    assert shp.dims.ndim == 3