Beispiel #1
0
def snr_and_chi2(
    data: torch.Tensor,
    height: torch.Tensor,
    width: torch.Tensor,
    x: torch.Tensor,
    y: torch.Tensor,
    target_locs: torch.Tensor,
    background: torch.Tensor,
    gain: float,
    offset_mean: float,
    offset_var: float,
    P: int,
    theta_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""
    Calculate the signal-to-noise ratio.

    Total signal:

    .. math::
        \mu_{knf} =  \sum_{ij} I_{nfij}
        \mathcal{N}(i, j \mid x_{knf}, y_{knf}, w_{knf})`

    Noise:

    .. math::
        \sigma^2_{knf} = \sigma^2_{\text{offset}}
        + \mu_{knf} \text{gain}`

    Signal-to-noise ratio:

    .. math::
        \text{SNR}_{knf} =
        \dfrac{\mu_{knf} - b_{nf} - \mu_{\text{offset}}}{\sigma_{knf}}
        \text{ for } \theta_{nf} = k`
    """
    gaussians = gaussian_spots(
        height,
        width,
        x,
        y,
        target_locs,
        P,
    )

    # snr
    weights = gaussians / height[..., None, None]
    signal = ((data - background[..., None, None] - offset_mean) *
              weights).sum(dim=(-2, -1))
    noise = (offset_var + background * gain).sqrt()
    snr_result = signal / noise

    # chi2 test
    img_ideal = background[..., None, None] + gaussians.sum(-5)
    chi2_result = (data - img_ideal - offset_mean)**2 / img_ideal

    return snr_result, chi2_result.mean(dim=(-1, -2))
Beispiel #2
0
def snr(
    data: torch.Tensor,
    width: torch.Tensor,
    x: torch.Tensor,
    y: torch.Tensor,
    target_locs: torch.Tensor,
    background: torch.Tensor,
    gain: float,
    offset_mean: float,
    offset_var: float,
    P: int,
    theta_probs: torch.Tensor,
) -> torch.Tensor:
    r"""
    Calculate the signal-to-noise ratio.

    Total signal:

    .. math::
        \mu_{knf} =  \sum_{ij} I_{nfij}
        \mathcal{N}(i, j \mid x_{knf}, y_{knf}, w_{knf})`

    Noise:

    .. math::
        \sigma^2_{knf} = \sigma^2_{\text{offset}}
        + \mu_{knf} \text{gain}`

    Signal-to-noise ratio:

    .. math::
        \text{SNR}_{knf} =
        \dfrac{\mu_{knf} - b_{nf} - \mu_{\text{offset}}}{\sigma_{knf}}
        \text{ for } \theta_{nf} = k`
    """
    weights = gaussian_spots(
        torch.ones(1, device=torch.device("cpu")),
        width,
        x,
        y,
        target_locs,
        P,
    )
    signal = ((data - background[..., None, None] - offset_mean) * weights).sum(
        dim=(-2, -1)
    )
    noise = (offset_var + background * gain).sqrt()
    result = signal / noise
    mask = theta_probs > 0.5
    return result[mask]
Beispiel #3
0
def updateRange(f1, n, model, fig, item, ax, zoom, targets, fov):
    n = get_value(n)
    f1 = get_value(f1)
    f2 = f1 + 15

    frames = torch.arange(f1, f2)
    img_ideal = (
        model.data.offset.mean
        + model.params["background"]["Mean"][n, frames, :, None, None]
    )
    gaussian = gaussian_spots(
        model.params["height"]["Mean"][:, n, frames],
        model.params["width"]["Mean"][:, n, frames],
        model.params["x"]["Mean"][:, n, frames],
        model.params["y"]["Mean"][:, n, frames],
        model.data.xy[n, frames],
        model.data.P,
    )
    img_ideal = img_ideal + gaussian.sum(-5)
    for c in range(model.data.C):
        for i, f in enumerate(range(f1, f2)):
            ax[f"image_f{i}_c{c}"].set_title(rf"${f}$", fontsize=9)
            item[f"image_f{i}_c{c}"].set_data(model.data.images[n, f, c].numpy())
            item[f"ideal_f{i}_c{c}"].set_data(img_ideal[i, c].numpy())
            if targets.value:
                item[f"target_f{i}_c{c}"].remove()
                item[f"target_f{i}_c{c}"] = ax[f"image_f{i}_c{c}"].scatter(
                    model.data.x[n, f, c].item(),
                    model.data.y[n, f, c].item(),
                    c="C0",
                    s=40,
                    marker="+",
                )

    if not zoom.value:
        for key, a in ax.items():
            if (
                key.startswith("image")
                or key.startswith("ideal")
                or key.startswith("glimpse")
            ):
                continue
            a.set_xlim(f1 - 0.5, f2 - 0.5)
    else:
        for p in [
            "z_map",
            "p_specific",
            "height",
            "width",
            "x",
            "y",
            "background",
            "chi2",
        ]:
            item[f"{p}_vspan"].remove()
            item[f"{p}_vspan"] = ax[p].axvspan(f1, f2, facecolor="C0", alpha=0.3)
    if fov is not None:
        for c in range(model.data.C):
            fov.plot(
                fov.dtypes,
                model.data.P,
                n=n,
                f=f1,
                save=False,
                ax=ax[f"glimpse_c{c}"],
                item=item,
            )
    fig.canvas.draw()
Beispiel #4
0
def updateParams(
    n,
    f1,
    model,
    fig,
    item,
    ax,
    targets,
    nonspecific,
    labels,
    fov_controls,
    exclude_aoi,
    show_fov,
):
    n_old = n.old
    n = get_value(n)
    f1 = get_value(f1)
    f2 = f1 + 15
    color = (
        [f"C{2+q}" for q in range(model.Q)] if model.data.mask[n] else ["C7"] * model.Q
    )

    exclude_aoi.value = not model.data.mask[n]

    frames = torch.arange(f1, f2)
    img_ideal = (
        model.data.offset.mean
        + model.params["background"]["Mean"][n, frames, :, None, None]
    )
    gaussian = gaussian_spots(
        model.params["height"]["Mean"][:, n, frames],
        model.params["width"]["Mean"][:, n, frames],
        model.params["x"]["Mean"][:, n, frames],
        model.params["y"]["Mean"][:, n, frames],
        model.data.xy[n, frames],
        model.data.P,
    )
    img_ideal = img_ideal + gaussian.sum(-5)
    for c in range(model.data.C):
        for i, f in enumerate(range(f1, f2)):
            ax[f"image_f{i}_c{c}"].set_title(rf"${f}$", fontsize=9)
            item[f"image_f{i}_c{c}"].set_data(model.data.images[n, f, c].numpy())
            item[f"ideal_f{i}_c{c}"].set_data(img_ideal[i, c].numpy())
            if targets.value:
                item[f"target_f{i}_c{c}"].remove()
                item[f"target_f{i}_c{c}"] = ax[f"image_f{i}_c{c}"].scatter(
                    model.data.x[n, f, c].item(),
                    model.data.y[n, f, c].item(),
                    c="C0",
                    s=40,
                    marker="+",
                )

    params = [
        "p_specific",
        "z_map",
        "height",
        "width",
        "x",
        "y",
        "background",
        "chi2",
    ]
    if labels.value:
        params += ["labels"]
    for q in range(model.Q):
        theta_mask = model.params["theta_probs"][:, n, :, q] > 0.5
        j_mask = (model.params["m_probs"][:, n, :, q] > 0.5) & ~theta_mask
        for p in params:
            if p == "p_specific":
                item[f"{p}_q{q}"].set_ydata(model.params[p][n, :, q])
                item[f"{p}_q{q}"].set_color(color[q])
            elif p in {"z_map"}.intersection(model.params.keys()):
                item[f"{p}_q{q}"].set_ydata(model.params[p][n, :, q])
            elif p == "labels":
                item[p].set_ydata(model.data.labels["z"][n, :, q])
            elif p in ["height", "width", "x", "y"]:
                # target-nonspecific spots
                if nonspecific.value:
                    for k in range(model.K):
                        f_mask = j_mask[k]
                        mean = model.params[p]["Mean"][k, n, :, q] * f_mask
                        ll = model.params[p]["LL"][k, n, :, q] * f_mask
                        ul = model.params[p]["UL"][k, n, :, q] * f_mask
                        item[f"{p}_nonspecific{k}_mean_q{q}"].remove()
                        (item[f"{p}_nonspecific{k}_mean_q{q}"],) = ax[p].plot(
                            torch.arange(0, model.data.F)[f_mask],
                            mean[f_mask],
                            "o",
                            ms=2,
                            lw=1,
                            color=f"C{k}",
                        )
                        item[f"{p}_nonspecific{k}_fill_q{q}"].remove()
                        item[f"{p}_nonspecific{k}_fill_q{q}"] = ax[p].fill_between(
                            torch.arange(0, model.data.F),
                            ll,
                            ul,
                            where=f_mask,
                            alpha=0.3,
                            color=f"C{k}",
                        )
                # target-specific spots
                f_mask = theta_mask.sum(0).bool()
                mean = (model.params[p]["Mean"][:, n, :, q] * theta_mask).sum(0)
                ll = (model.params[p]["LL"][:, n, :, q] * theta_mask).sum(0)
                ul = (model.params[p]["UL"][:, n, :, q] * theta_mask).sum(0)
                item[f"{p}_specific_fill_q{q}"].remove()
                item[f"{p}_specific_fill_q{q}"] = ax[p].fill_between(
                    torch.arange(0, model.data.F),
                    ll,
                    ul,
                    where=f_mask,
                    alpha=0.3,
                    color=color[q],
                )
                item[f"{p}_specific_mean_q{q}"].remove()
                (item[f"{p}_specific_mean_q{q}"],) = ax[p].plot(
                    torch.arange(0, model.data.F)
                    if p == "height"
                    else torch.arange(0, model.data.F)[f_mask],
                    mean if p == "height" else mean[f_mask],
                    "-" if p == "height" else "o",
                    ms=2,
                    color=color[q],
                )

    for c in range(model.data.C):
        for p in params:
            if p == "chi2":
                item[f"{p}_c{c}"].set_ydata(model.params[p]["values"][n, :, c])
            elif p == "background":
                item[f"{p}_fill_c{c}"].remove()
                item[f"{p}_fill_c{c}"] = ax[p].fill_between(
                    torch.arange(0, model.data.F),
                    model.params[p]["LL"][n, :, c],
                    model.params[p]["UL"][n, :, c],
                    alpha=0.3,
                    color=color[c],
                )
                item[f"{p}_mean_c{c}"].set_ydata(model.params[p]["Mean"][n, :, c])

    if get_value(show_fov):
        ax["glimpse_c0"].set_title(rf"AOI ${n}$, Frame ${f1}$", fontsize=9)
        n_old_dtype = "ontarget" if n_old < model.data.N else "offtarget"
        n_old_visible = fov_controls[n_old_dtype].value
        colors = {"ontarget": "#AA3377", "offtarget": "#CCBB44"}
        for c in range(model.data.C):
            item[f"aoi_n{n}_c{c}"].set_edgecolor(color[c])
            item[f"aoi_n{n}_c{c}"].set(zorder=2, visible=True)
            item[f"aoi_n{n_old}_c{c}"].set_edgecolor(colors[n_old_dtype])
            item[f"aoi_n{n_old}_c{c}"].set(zorder=1, visible=n_old_visible)
    fig.canvas.draw()
Beispiel #5
0
def show(
    model: avail_models = typer.Option("cosmos",
                                       help="Tapqir model",
                                       prompt="Tapqir model"),
    n: int = typer.Option(0, help="n", prompt="n"),
    f1: Optional[int] = None,
    f2: Optional[int] = None,
    show_fov: bool = True,
    gui=None,
):
    from tapqir.imscroll import GlimpseDataset
    from tapqir.models import models

    logger = logging.getLogger("tapqir")

    global DEFAULTS
    cd = DEFAULTS["cd"]

    model = models[model](device="cpu", dtype="float")
    try:
        model.load(cd, data_only=False)
    except TapqirFileNotFoundError as err:
        logger.exception(f"Failed to load {err.name} file")
        return 1
    if f1 is None:
        f1 = 0
    if f2 is None:
        f2 = f1 + 15

    width, dpi = 6.25, 100
    s = 2 * model.data.C
    height = 4.45 + (1.875 + 0.9) * s if show_fov else 4.45 + 0.9 * s
    fig = plt.figure(figsize=(width, height), dpi=dpi)
    gs = fig.add_gridspec(
        nrows=8 + s,
        ncols=15,
        top=0.96,
        bottom=0.39 if show_fov else 0.02,
        left=0.1,
        right=0.98,
        hspace=0.1,
        height_ratios=[0.9] * s + [1, 1, 1, 1, 1, 1, 1, 1],
    )
    if show_fov:
        gs2 = fig.add_gridspec(
            nrows=model.data.C,
            ncols=1,
            top=0.32,
            bottom=0.02,
            left=0.1,
            right=0.98,
        )
    ax = {}
    item = {}

    frames = torch.arange(f1, f2)
    img_ideal = (model.data.offset.mean +
                 model.params["background"]["Mean"][n, frames, :, None, None])
    gaussian = gaussian_spots(
        model.params["height"]["Mean"][:, n, frames],
        model.params["width"]["Mean"][:, n, frames],
        model.params["x"]["Mean"][:, n, frames],
        model.params["y"]["Mean"][:, n, frames],
        model.data.xy[n, frames],
        model.data.P,
    )
    img_ideal = img_ideal + gaussian.sum(-5)
    for c in range(model.data.C):
        for f in range(15):
            ax[f"image_f{f}_c{c}"] = fig.add_subplot(gs[0 + c * 2, f])
            item[f"image_f{f}_c{c}"] = ax[f"image_f{f}_c{c}"].imshow(
                model.data.images[n, f, c].numpy(),
                vmin=model.data.vmin[c] - 20,
                vmax=model.data.vmax[c] + 30,
                cmap="gray",
            )
            if c == 0:
                ax[f"image_f{f}_c{c}"].set_title(rf"${f}$", fontsize=9)
            ax[f"image_f{f}_c{c}"].axis("off")
            ax[f"ideal_f{f}_c{c}"] = fig.add_subplot(gs[1 + c * 2, f])
            item[f"ideal_f{f}_c{c}"] = ax[f"ideal_f{f}_c{c}"].imshow(
                img_ideal[f, c].numpy(),
                vmin=model.data.vmin[c] - 20,
                vmax=model.data.vmax[c] + 30,
                cmap="gray",
            )
            ax[f"ideal_f{f}_c{c}"].axis("off")

    ax["z_map"] = fig.add_subplot(gs[s, :])
    config_axis(ax["z_map"], r"$z$", f1, f2, -0.1, model.S + 0.1)
    ax["p_specific"] = fig.add_subplot(gs[s + 1, :])
    config_axis(ax["p_specific"], r"$p(\mathsf{specific})$", f1, f2, -0.1, 1.1)

    ax["height"] = fig.add_subplot(gs[s + 2, :])
    config_axis(
        ax["height"],
        r"$h$",
        f1,
        f2,
        model.params["height"]["vmin"],
        model.params["height"]["vmax"],
    )

    ax["width"] = fig.add_subplot(gs[s + 3, :])
    config_axis(
        ax["width"],
        r"$w$",
        f1,
        f2,
        model.params["width"]["vmin"],
        model.params["width"]["vmax"],
    )

    ax["x"] = fig.add_subplot(gs[s + 4, :])
    config_axis(ax["x"], r"$x$", f1, f2, model.params["x"]["vmin"],
                model.params["x"]["vmax"])

    ax["y"] = fig.add_subplot(gs[s + 5, :])
    config_axis(ax["y"], r"$y$", f1, f2, model.params["y"]["vmin"],
                model.params["y"]["vmax"])

    ax["background"] = fig.add_subplot(gs[s + 6, :])
    config_axis(
        ax["background"],
        r"$b$",
        f1,
        f2,
        model.params["background"]["vmin"],
        model.params["background"]["vmax"],
    )

    ax["chi2"] = fig.add_subplot(gs[s + 7, :])
    config_axis(
        ax["chi2"],
        r"$\chi^2 \mathsf{test}$",
        f1,
        f2,
        model.params["chi2"]["vmin"],
        model.params["chi2"]["vmax"],
        True,
    )
    ax["chi2"].set_xlabel("Time (frame)")

    color = ([f"C{2+q}"
              for q in range(model.Q)] if model.data.mask[n] else ["C7"] *
             model.Q)
    for q in range(model.Q):
        (item[f"z_map_q{q}"], ) = ax["z_map"].plot(
            torch.arange(0, model.data.F),
            model.params["z_map"][n, :, q],
            "-",
            lw=1,
            color=color[q],
        )

        (item[f"p_specific_q{q}"], ) = ax["p_specific"].plot(
            torch.arange(0, model.data.F),
            model.params["p_specific"][n, :, q],
            "o-",
            ms=3,
            lw=1,
            color=color[q],
        )
        theta_mask = model.params["theta_probs"][:, n, :, q] > 0.5
        j_mask = (model.params["m_probs"][:, n, :, q] > 0.5) & ~theta_mask
        for p in ["height", "width", "x", "y"]:
            # target-nonspecific spots
            for k in range(model.K):
                f_mask = j_mask[k]
                mean = model.params[p]["Mean"][k, n, :, q] * f_mask
                ll = model.params[p]["LL"][k, n, :, q] * f_mask
                ul = model.params[p]["UL"][k, n, :, q] * f_mask
                (item[f"{p}_nonspecific{k}_mean_q{q}"], ) = ax[p].plot(
                    torch.arange(0, model.data.F)[f_mask],
                    mean[f_mask],
                    "o",
                    ms=2,
                    lw=1,
                    color=f"C{k}",
                )
                item[f"{p}_nonspecific{k}_fill_q{q}"] = ax[p].fill_between(
                    torch.arange(0, model.data.F),
                    ll,
                    ul,
                    where=f_mask,
                    alpha=0.3,
                    color=f"C{k}",
                )
            # target-specific spots
            f_mask = theta_mask.sum(0).bool()
            mean = (model.params[p]["Mean"][:, n, :, q] * theta_mask).sum(0)
            ll = (model.params[p]["LL"][:, n, :, q] * theta_mask).sum(0)
            ul = (model.params[p]["UL"][:, n, :, q] * theta_mask).sum(0)
            (item[f"{p}_specific_mean_q{q}"], ) = ax[p].plot(
                torch.arange(0, model.data.F)
                if p == "height" else torch.arange(0, model.data.F)[f_mask],
                mean if p == "height" else mean[f_mask],
                "-" if p == "height" else "o",
                ms=2,
                color=color[q],
            )
            item[f"{p}_specific_fill_q{q}"] = ax[p].fill_between(
                torch.arange(0, model.data.F),
                ll,
                ul,
                where=f_mask,
                alpha=0.3,
                color=color[q],
            )
    for c in range(model.data.C):
        (item[f"background_mean_c{c}"], ) = ax["background"].plot(
            torch.arange(0, model.data.F),
            model.params["background"]["Mean"][n, :, c],
            "o-",
            ms=3,
            lw=1,
            color=color[c],
        )
        item[f"background_fill_c{c}"] = ax["background"].fill_between(
            torch.arange(0, model.data.F),
            model.params["background"]["LL"][n, :, c],
            model.params["background"]["UL"][n, :, c],
            alpha=0.3,
            color=color[c],
        )

        (item[f"chi2_c{c}"], ) = ax["chi2"].plot(
            torch.arange(0, model.data.F),
            model.params["chi2"]["values"][n, :, c],
            "-",
            lw=1,
            color=color[c],
        )

    if show_fov:
        P = DEFAULTS.pop("P")
        channels = DEFAULTS.pop("channels")
        for c in range(model.data.C):
            ax[f"glimpse_c{c}"] = fig.add_subplot(gs2[c])
            fov = GlimpseDataset(**DEFAULTS, **channels[c], c=c)
            fov.plot(
                fov.dtypes,
                P,
                n=0,
                f=0,
                save=False,
                ax=ax[f"glimpse_c{c}"],
                item=item,
            )
    else:
        fov = None

    if not gui:
        plt.show()
    return model, fig, item, ax, fov
Beispiel #6
0
def show(
    model: Model = typer.Option("cosmos", help="Tapqir model", prompt="Tapqir model"),
    channels: List[int] = typer.Option(
        [0],
        help="Color-channel numbers to analyze",
        prompt="Channel numbers (space separated if multiple)",
    ),
    n: int = typer.Option(0, help="n", prompt="n"),
    f1: Optional[int] = None,
    f2: Optional[int] = None,
):
    from tapqir.models import models

    global DEFAULTS
    cd = DEFAULTS["cd"]

    model = models[model](1, 2, channels, "cpu", "float")
    model.load(cd, data_only=False)
    if f1 is None:
        f1 = 0
    if f2 is None:
        f2 = model.data.F
    f2 = 15
    c = model.cdx

    width, height, dpi = 6.25, 5, 100
    fig = plt.figure(figsize=(width, height), dpi=dpi)
    gs = fig.add_gridspec(
        nrows=8,
        ncols=15,
        top=0.95,
        bottom=0.05,
        left=0.1,
        right=0.98,
        hspace=0.1,
        height_ratios=[0.9, 0.9, 1, 1, 1, 1, 1, 1],
    )
    ax = {}
    item = {}

    frames = torch.arange(f1, f2)
    img_ideal = (
        model.data.offset.mean
        + model.params["background"]["Mean"][n, frames, None, None]
    )
    gaussian = gaussian_spots(
        model.params["height"]["Mean"][:, n, frames],
        model.params["width"]["Mean"][:, n, frames],
        model.params["x"]["Mean"][:, n, frames],
        model.params["y"]["Mean"][:, n, frames],
        model.data.xy[n, frames, c],
        model.data.P,
    )
    img_ideal = img_ideal + gaussian.sum(-4)
    for f in range(15):
        ax[f"image_{f}"] = fig.add_subplot(gs[0, f])
        item[f"image_{f}"] = ax[f"image_{f}"].imshow(
            model.data.images[n, f, c].numpy(),
            vmin=model.data.vmin - 50,
            vmax=model.data.vmax + 50,
            cmap="gray",
        )
        ax[f"image_{f}"].set_title(f)
        ax[f"image_{f}"].axis("off")
        ax[f"ideal_{f}"] = fig.add_subplot(gs[1, f])
        item[f"ideal_{f}"] = ax[f"ideal_{f}"].imshow(
            img_ideal[f].numpy(),
            vmin=model.data.vmin - 50,
            vmax=model.data.vmax + 50,
            cmap="gray",
        )
        ax[f"ideal_{f}"].axis("off")

    ax["pspecific"] = fig.add_subplot(gs[2, :])
    config_axis(ax["pspecific"], r"$p(\mathsf{specific})$", f1, f2, -0.1, 1.1)
    (item["pspecific"],) = ax["pspecific"].plot(
        torch.arange(0, model.data.F),
        model.params["z_probs"][n],
        "o-",
        ms=3,
        lw=1,
        color="C2",
    )

    ax["height"] = fig.add_subplot(gs[3, :])
    config_axis(ax["height"], r"$h$", f1, f2, -100, 8000)

    ax["width"] = fig.add_subplot(gs[4, :])
    config_axis(ax["width"], r"$w$", f1, f2, 0.5, 2.5)

    ax["x"] = fig.add_subplot(gs[5, :])
    config_axis(ax["x"], r"$x$", f1, f2, -9, 9)

    ax["y"] = fig.add_subplot(gs[6, :])
    config_axis(ax["y"], r"$y$", f1, f2, -9, 9)

    ax["background"] = fig.add_subplot(gs[7, :])
    config_axis(ax["background"], r"$b$", f1, f2, 0, 500, True)
    ax["background"].set_xlabel("Time (frame)")

    for p in ["height", "width", "x", "y"]:
        for k in range(model.K):
            (item[f"{p}_{k}_mean"],) = ax[p].plot(
                torch.arange(0, model.data.F),
                model.params[p]["Mean"][k, n],
                "o-",
                ms=3,
                lw=1,
                color=f"C{k}",
            )
            item[f"{p}_{k}_fill"] = ax[p].fill_between(
                torch.arange(0, model.data.F),
                model.params[p]["LL"][k, n],
                model.params[p]["UL"][k, n],
                alpha=0.3,
                color=f"C{k}",
            )
    (item["background_mean"],) = ax["background"].plot(
        torch.arange(0, model.data.F),
        model.params["background"]["Mean"][n],
        "o-",
        ms=3,
        lw=1,
        color="k",
    )
    item["background_fill"] = ax["background"].fill_between(
        torch.arange(0, model.data.F),
        model.params["background"]["LL"][n],
        model.params["background"]["UL"][n],
        alpha=0.3,
        color="k",
    )
    plt.show()
    return model, fig, item, ax