Ejemplo n.º 1
0
 def t(pbar: widgets.ProgressBar):
     assert pbar.get_value() == 32
     pbar.decrement()
     assert pbar.get_value() == 30
     pbar.step = 5
     assert pbar.get_value() == 35
     pbar.decrement(10)
     assert pbar.get_value() == 25
Ejemplo n.º 2
0
def make_widget(
    pbar: widgets.ProgressBar,
    image: ImageData,
    min_sigma: Annotated[float, {
        "min": 0.5,
        "max": 15,
        "step": 0.5
    }] = 5,
    max_sigma: Annotated[float, {
        "min": 1,
        "max": 200,
        "step": 0.5
    }] = 30,
    num_sigma: Annotated[int, {
        "min": 1,
        "max": 20
    }] = 10,
    threshold: Annotated[float, {
        "min": 0,
        "max": 1000,
        "step": 0.1
    }] = 6,
) -> FunctionWorker[LayerDataTuple]:

    # @thread_worker creates a worker that runs a function in another thread
    # we connect the "returned" signal to the ProgressBar.hide method
    @thread_worker(connect={'returned': pbar.hide})
    def detect_blobs() -> LayerDataTuple:
        # this is the potentially long-running function
        blobs = blob_log(image, min_sigma, max_sigma, num_sigma, threshold)
        points = blobs[:, :image.ndim]
        meta = dict(
            size=blobs[:, -1],
            edge_color="red",
            edge_width=2,
            face_color="transparent",
        )
        # return a "LayerDataTuple"
        return (points, meta, 'points')

    # show progress bar and return worker
    pbar.show()
    return detect_blobs()
Ejemplo n.º 3
0
    def _get_progressbar(self, **kwargs) -> ProgressBar:
        """Create ProgressBar or get from the parent gui `_tqdm_pbars` deque.

        The deque allows us to create nested iterables inside of a magigui, while
        resetting and reusing progress bars across ``FunctionGui`` calls. The nesting
        depth (into the deque) is reset by :meth:`FunctionGui.__call__`, right before
        the function is called.  Then, as the function encounters `tqdm` instances,
        this method gets or creates a progress bar and increment the
        :attr:`FunctionGui._tqdm_depth` counter on the ``FunctionGui``.
        """
        if self._mgui is None:
            return ProgressBar(**kwargs)

        if len(self._mgui._tqdm_pbars) > self._mgui._tqdm_depth:
            pbar = self._mgui._tqdm_pbars[self._mgui._tqdm_depth]
        else:
            pbar = ProgressBar(**kwargs)
            self._mgui._tqdm_pbars.append(pbar)
            self._mgui.append(pbar)
        self._mgui._tqdm_depth += 1
        return pbar
Ejemplo n.º 4
0
def manual(pbar: ProgressBar, increment: bool = True):
    """Example of manual progress bar control."""
    if increment:
        pbar.increment()
    else:
        pbar.decrement()
Ejemplo n.º 5
0
    def plugin(
        viewer: napari.Viewer,
        label_head,
        image: napari.layers.Image,
        axes,
        label_nn,
        model_type,
        model2d,
        model3d,
        model_folder,
        model_axes,
        norm_image,
        perc_low,
        perc_high,
        input_scale,
        label_nms,
        prob_thresh,
        nms_thresh,
        output_type,
        label_adv,
        n_tiles,
        norm_axes,
        timelapse_opts,
        cnn_output,
        set_thresholds,
        defaults_button,
        progress_bar: mw.ProgressBar,
    ) -> List[napari.types.LayerDataTuple]:

        model = get_model(*model_selected)
        if model._is_multiclass():
            warn(
                "multi-class mode not supported yet, ignoring classification output"
            )

        lkwargs = {}
        x = get_data(image)
        axes = axes_check_and_normalize(axes, length=x.ndim)

        if not (input_scale is None
                or isinstance(input_scale, numbers.Number)):
            input_scale = tuple(s for a, s in zip(axes, input_scale)
                                if a not in ("T", ))
            # print(f'scaling by {input_scale}')

        if not axes.replace("T", "").startswith(
                model._axes_out.replace("C", "")):
            warn(
                f"output images have different axes ({model._axes_out.replace('C','')}) than input image ({axes})"
            )
            # TODO: adjust image.scale according to shuffled axes

        if norm_image:
            axes_norm = axes_check_and_normalize(norm_axes)
            axes_norm = "".join(set(axes_norm).intersection(
                set(axes)))  # relevant axes present in input image
            assert len(axes_norm) > 0
            # always jointly normalize channels for RGB images
            if ("C" in axes and image.rgb == True) and ("C" not in axes_norm):
                axes_norm = axes_norm + "C"
                warn("jointly normalizing channels of RGB input image")
            ax = axes_dict(axes)
            _axis = tuple(sorted(ax[a] for a in axes_norm))
            # # TODO: address joint vs. channel/time-separate normalization properly (let user choose)
            # #       also needs to be documented somewhere
            # if 'T' in axes:
            #     if 'C' not in axes or image.rgb == True:
            #          # normalize channels jointly, frames independently
            #          _axis = tuple(i for i in range(x.ndim) if i not in (ax['T'],))
            #     else:
            #         # normalize channels independently, frames independently
            #         _axis = tuple(i for i in range(x.ndim) if i not in (ax['T'],ax['C']))
            # else:
            #     if 'C' not in axes or image.rgb == True:
            #          # normalize channels jointly
            #         _axis = None
            #     else:
            #         # normalize channels independently
            #         _axis = tuple(i for i in range(x.ndim) if i not in (ax['C'],))
            x = normalize(x, perc_low, perc_high, axis=_axis)

        # TODO: progress bar (labels) often don't show up. events not processed?
        if "T" in axes:
            app = use_app()
            t = axes_dict(axes)["T"]
            n_frames = x.shape[t]
            if n_tiles is not None:
                # remove tiling value for time axis
                n_tiles = tuple(v for i, v in enumerate(n_tiles) if i != t)

            def progress(it, **kwargs):
                progress_bar.label = "StarDist Prediction (frames)"
                progress_bar.range = (0, n_frames)
                progress_bar.value = 0
                progress_bar.show()
                app.process_events()
                for item in it:
                    yield item
                    progress_bar.increment()
                    app.process_events()
                app.process_events()

        elif n_tiles is not None and np.prod(n_tiles) > 1:
            n_tiles = tuple(n_tiles)
            app = use_app()

            def progress(it, **kwargs):
                progress_bar.label = "CNN Prediction (tiles)"
                progress_bar.range = (0, kwargs.get("total", 0))
                progress_bar.value = 0
                progress_bar.show()
                app.process_events()
                for item in it:
                    yield item
                    progress_bar.increment()
                    app.process_events()
                #
                progress_bar.label = "NMS Postprocessing"
                progress_bar.range = (0, 0)
                app.process_events()

        else:
            progress = False
            progress_bar.label = "StarDist Prediction"
            progress_bar.range = (0, 0)
            progress_bar.show()
            use_app().process_events()

        # semantic output axes of predictions
        assert model._axes_out[-1] == "C"
        axes_out = list(model._axes_out[:-1])

        if "T" in axes:
            x_reorder = np.moveaxis(x, t, 0)
            axes_reorder = axes.replace("T", "")
            axes_out.insert(t, "T")
            res = tuple(
                zip(*tuple(
                    model.predict_instances(
                        _x,
                        axes=axes_reorder,
                        prob_thresh=prob_thresh,
                        nms_thresh=nms_thresh,
                        n_tiles=n_tiles,
                        scale=input_scale,
                        sparse=(not cnn_output),
                        return_predict=cnn_output,
                    ) for _x in progress(x_reorder))))

            if cnn_output:
                labels, polys = tuple(zip(*res[0]))
                cnn_output = tuple(np.stack(c, t) for c in tuple(zip(*res[1])))
            else:
                labels, polys = res

            labels = np.asarray(labels)

            if len(polys) > 1:
                if timelapse_opts == TimelapseLabels.Match.value:
                    # match labels in consecutive frames (-> simple IoU tracking)
                    labels = group_matching_labels(labels)
                elif timelapse_opts == TimelapseLabels.Unique.value:
                    # make label ids unique (shift by offset)
                    offsets = np.cumsum([len(p["points"]) for p in polys])
                    for y, off in zip(labels[1:], offsets):
                        y[y > 0] += off
                elif timelapse_opts == TimelapseLabels.Separate.value:
                    # each frame processed separately (nothing to do)
                    pass
                else:
                    raise NotImplementedError(
                        f"unknown option '{timelapse_opts}' for time-lapse labels"
                    )

            labels = np.moveaxis(labels, 0, t)

            if isinstance(model, StarDist3D):
                # TODO poly output support for 3D timelapse
                polys = None
            else:
                polys = dict(
                    coord=np.concatenate(
                        tuple(
                            np.insert(p["coord"], t, _t, axis=-2)
                            for _t, p in enumerate(polys)),
                        axis=0,
                    ),
                    points=np.concatenate(
                        tuple(
                            np.insert(p["points"], t, _t, axis=-1)
                            for _t, p in enumerate(polys)),
                        axis=0,
                    ),
                )

            if cnn_output:
                pred = (labels, polys), cnn_output
            else:
                pred = labels, polys

        else:
            # TODO: possible to run this in a way that it can be canceled?
            pred = model.predict_instances(
                x,
                axes=axes,
                prob_thresh=prob_thresh,
                nms_thresh=nms_thresh,
                n_tiles=n_tiles,
                show_tile_progress=progress,
                scale=input_scale,
                sparse=(not cnn_output),
                return_predict=cnn_output,
            )
        progress_bar.hide()

        # determine scale for output axes
        scale_in_dict = dict(zip(axes, image.scale))
        scale_out = [scale_in_dict.get(a, 1.0) for a in axes_out]

        layers = []
        if cnn_output:
            (labels, polys), cnn_out = pred
            prob, dist = cnn_out[:2]
            dist = np.moveaxis(dist, -1, 0)

            assert len(model.config.grid) == len(model.config.axes) - 1
            grid_dict = dict(
                zip(model.config.axes.replace("C", ""), model.config.grid))
            # scale output axes to match input axes
            _scale = [
                s * grid_dict.get(a, 1) for a, s in zip(axes_out, scale_out)
            ]
            # small translation correction if grid > 1 (since napari centers objects)
            _translate = [0.5 * (grid_dict.get(a, 1) - 1) for a in axes_out]

            layers.append((
                dist,
                dict(
                    name="StarDist distances",
                    scale=[1] + _scale,
                    translate=[0] + _translate,
                    **lkwargs,
                ),
                "image",
            ))
            layers.append((
                prob,
                dict(
                    name="StarDist probability",
                    scale=_scale,
                    translate=_translate,
                    **lkwargs,
                ),
                "image",
            ))
        else:
            labels, polys = pred

        if output_type in (Output.Labels.value, Output.Both.value):
            layers.append((
                labels,
                dict(name="StarDist labels",
                     scale=scale_out,
                     opacity=0.5,
                     **lkwargs),
                "labels",
            ))
        if output_type in (Output.Polys.value, Output.Both.value):
            n_objects = len(polys["points"])
            if isinstance(model, StarDist3D):
                surface = surface_from_polys(polys)
                layers.append((
                    surface,
                    dict(
                        name="StarDist polyhedra",
                        contrast_limits=(0, surface[-1].max()),
                        scale=scale_out,
                        colormap=label_colormap(n_objects),
                        **lkwargs,
                    ),
                    "surface",
                ))
            else:
                # TODO: sometimes hangs for long time (indefinitely?) when returning many polygons (?)
                #       seems to be a known issue: https://github.com/napari/napari/issues/2015
                # TODO: coordinates correct or need offset (0.5 or so)?
                shapes = np.moveaxis(polys["coord"], -1, -2)
                layers.append((
                    shapes,
                    dict(
                        name="StarDist polygons",
                        shape_type="polygon",
                        scale=scale_out,
                        edge_width=0.75,
                        edge_color="yellow",
                        face_color=[0, 0, 0, 0],
                        **lkwargs,
                    ),
                    "shapes",
                ))
        return layers