예제 #1
0
def show_plane_segmentation_2d(plane_seg: PlaneSegmentation,
                               color_wheel=color_wheel,
                               color_by=None,
                               threshold=.01,
                               fig=None):
    """

    Parameters
    ----------
    plane_seg: PlaneSegmentation
    color_wheel: list
    color_by: str, optional
    threshold: float
    fig: plotly.graph_objects.Figure, optional

    Returns
    -------

    """
    layout_kwargs = dict()
    if color_by:
        if color_by not in plane_seg:
            raise ValueError(
                'specified color_by parameter, {}, not in plane_seg object'.
                format(color_by))
        cats = np.unique(plane_seg[color_by][:])
    else:
        layout_kwargs.update(title=color_by)

    data = plane_seg['image_mask'].data
    nUnits = data.shape[0]
    if fig is None:
        fig = go.FigureWidget()

    aux_leg = []
    all_hover = df_to_hover_text(plane_seg.to_dataframe())

    for i in range(nUnits):
        if plane_seg[color_by][i] not in aux_leg:
            show_leg = True
            aux_leg.append(plane_seg[color_by][i])
        else:
            show_leg = False
        kwargs = dict()

        if color_by:
            c = color_wheel[np.where(cats == plane_seg[color_by][i])[0][0]]
            kwargs.update(
                line_color=c,
                name=str(plane_seg[color_by][i]),
                legendgroup=str(plane_seg[color_by][i]),
                showlegend=show_leg,
            )

        # form cell borders
        x, y = compute_outline(plane_seg['image_mask'][i], threshold)

        fig.add_trace(
            go.Scatter(x=x,
                       y=y,
                       fill='toself',
                       mode='lines',
                       text=all_hover[i],
                       hovertext='text',
                       line=dict(width=.5),
                       **kwargs))

    width = 600
    fig.update_layout(width=width,
                      yaxis=dict(mirror=True,
                                 scaleanchor="x",
                                 scaleratio=1,
                                 range=[0, plane_seg['image_mask'].shape[2]],
                                 constrain='domain'),
                      xaxis=dict(mirror=True,
                                 range=[0, plane_seg['image_mask'].shape[1]],
                                 constrain='domain'),
                      margin=dict(t=10, b=10),
                      **layout_kwargs)

    return fig
예제 #2
0
def show_plane_segmentation_2d(plane_seg: PlaneSegmentation,
                               color_wheel: list = color_wheel,
                               color_by: str = None,
                               threshold: float = .01,
                               fig: go.Figure = None,
                               width: int = 600,
                               ref_image=None):
    """

    Parameters
    ----------
    plane_seg: PlaneSegmentation
    color_wheel: list, optional
    color_by: str, optional
    threshold: float, optional
    fig: plotly.graph_objects.Figure, optional
    width: int, optional
        width of image in pixels. Height is automatically determined
        to be proportional
    ref_image: image, optional


    Returns
    -------

    """
    layout_kwargs = dict()
    if color_by:
        if color_by not in plane_seg:
            raise ValueError(
                'specified color_by parameter, {}, not in plane_seg object'.
                format(color_by))
        cats = np.unique(plane_seg[color_by][:])
        layout_kwargs.update(title=color_by)

    data = plane_seg['image_mask'].data
    nUnits = data.shape[0]
    if fig is None:
        fig = go.FigureWidget()

    if ref_image is not None:
        fig.add_trace(
            go.Heatmap(z=ref_image,
                       hoverinfo='skip',
                       showscale=False,
                       colorscale='gray'))

    aux_leg = []
    all_hover = df_to_hover_text(plane_seg.to_dataframe())

    for i in range(nUnits):
        kwargs = dict(showlegend=False)
        if color_by is not None:
            if plane_seg[color_by][i] not in aux_leg:
                kwargs.update(showlegend=True)
                aux_leg.append(plane_seg[color_by][i])
            c = color_wheel[np.where(cats == plane_seg[color_by][i])[0][0]]
            kwargs.update(
                line_color=c,
                name=str(plane_seg[color_by][i]),
                legendgroup=str(plane_seg[color_by][i]),
            )

        # form cell borders
        x, y = compute_outline(plane_seg['image_mask'][i], threshold)

        fig.add_trace(
            go.Scatter(x=x,
                       y=y,
                       fill='toself',
                       mode='lines',
                       text=all_hover[i],
                       hovertext='text',
                       line=dict(width=.5),
                       **kwargs))

    fig.update_layout(width=width,
                      yaxis=dict(mirror=True,
                                 scaleanchor="x",
                                 scaleratio=1,
                                 range=[0, plane_seg['image_mask'].shape[2]],
                                 constrain='domain'),
                      xaxis=dict(mirror=True,
                                 range=[0, plane_seg['image_mask'].shape[1]],
                                 constrain='domain'),
                      margin=dict(t=30, b=10),
                      **layout_kwargs)
    return fig