コード例 #1
0
ファイル: hmm.py プロジェクト: gelles-brandeis/tapqir
    def model(self):
        """
        **Generative Model**
        """
        # global parameters
        gain = pyro.sample("gain", dist.HalfNormal(self.priors["gain_std"]))
        init = pyro.sample(
            "init",
            dist.Dirichlet(torch.ones(self.Q, self.S + 1) /
                           (self.S + 1)).to_event(1),
        )
        init = expand_offtarget(init)
        trans = pyro.sample(
            "trans",
            dist.Dirichlet(
                torch.ones(self.Q, self.S + 1, self.S + 1) /
                (self.S + 1)).to_event(2),
        )
        trans = expand_offtarget(trans)
        lamda = pyro.sample(
            "lamda",
            dist.Exponential(torch.full(
                (self.Q, ), self.priors["lamda_rate"])).to_event(1),
        )
        proximity = pyro.sample(
            "proximity", dist.Exponential(self.priors["proximity_rate"]))
        size = torch.stack(
            (
                torch.full_like(proximity, 2.0),
                (((self.data.P + 1) / (2 * proximity))**2 - 1),
            ),
            dim=-1,
        )

        # spots
        spots = pyro.plate("spots", self.K)
        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-3,
        )
        # time frames
        frames = (pyro.vectorized_markov(
            name="frames", size=self.data.F, dim=-2)
                  if self.vectorized else pyro.markov(range(self.data.F)))
        # color channels
        channels = pyro.plate(
            "channels",
            self.data.C,
            dim=-1,
        )

        with channels as cdx, aois as ndx:
            ndx = ndx[:, None, None]
            mask = Vindex(self.data.mask)[ndx].to(self.device)
            with handlers.mask(mask=mask):
                # background mean and std
                background_mean = pyro.sample(
                    "background_mean",
                    dist.HalfNormal(self.priors["background_mean_std"]),
                )
                background_std = pyro.sample(
                    "background_std",
                    dist.HalfNormal(self.priors["background_std_std"]))
                z_prev = None
                for fdx in frames:
                    if self.vectorized:
                        fsx, fdx = fdx
                        fdx = torch.as_tensor(fdx)
                        fdx = fdx.unsqueeze(-1)
                    else:
                        fsx = fdx
                    # fetch data
                    obs, target_locs, is_ontarget = self.data.fetch(
                        ndx, fdx, cdx)
                    # sample background intensity
                    background = pyro.sample(
                        f"background_f{fsx}",
                        dist.Gamma(
                            (background_mean / background_std)**2,
                            background_mean / background_std**2,
                        ),
                    )

                    # sample hidden model state (1+S,)
                    z_probs = (Vindex(init)[..., cdx, :,
                                            is_ontarget.long()]
                               if z_prev is None else
                               Vindex(trans)[..., cdx, z_prev, :,
                                             is_ontarget.long()])
                    z_curr = pyro.sample(f"z_f{fsx}",
                                         dist.Categorical(z_probs))

                    theta = pyro.sample(
                        f"theta_f{fsx}",
                        dist.Categorical(
                            Vindex(probs_theta(
                                self.K, self.device))[torch.clamp(z_curr,
                                                                  min=0,
                                                                  max=1)]),
                        infer={"enumerate": "parallel"},
                    )
                    onehot_theta = one_hot(theta, num_classes=1 + self.K)

                    ms, heights, widths, xs, ys = [], [], [], [], []
                    for kdx in spots:
                        specific = onehot_theta[..., 1 + kdx]
                        # spot presence
                        m_probs = Vindex(probs_m(lamda, self.K))[..., cdx,
                                                                 theta, kdx]
                        m = pyro.sample(
                            f"m_k{kdx}_f{fsx}",
                            dist.Categorical(
                                torch.stack((1 - m_probs, m_probs), -1)),
                        )
                        with handlers.mask(mask=m > 0):
                            # sample spot variables
                            height = pyro.sample(
                                f"height_k{kdx}_f{fsx}",
                                dist.HalfNormal(self.priors["height_std"]),
                            )
                            width = pyro.sample(
                                f"width_k{kdx}_f{fsx}",
                                AffineBeta(
                                    1.5,
                                    2,
                                    self.priors["width_min"],
                                    self.priors["width_max"],
                                ),
                            )
                            x = pyro.sample(
                                f"x_k{kdx}_f{fsx}",
                                AffineBeta(
                                    0,
                                    Vindex(size)[..., specific],
                                    -(self.data.P + 1) / 2,
                                    (self.data.P + 1) / 2,
                                ),
                            )
                            y = pyro.sample(
                                f"y_k{kdx}_f{fsx}",
                                AffineBeta(
                                    0,
                                    Vindex(size)[..., specific],
                                    -(self.data.P + 1) / 2,
                                    (self.data.P + 1) / 2,
                                ),
                            )

                        # append
                        ms.append(m)
                        heights.append(height)
                        widths.append(width)
                        xs.append(x)
                        ys.append(y)

                    # observed data
                    pyro.sample(
                        f"data_f{fsx}",
                        KSMOGN(
                            torch.stack(heights, -1),
                            torch.stack(widths, -1),
                            torch.stack(xs, -1),
                            torch.stack(ys, -1),
                            target_locs,
                            background,
                            gain,
                            self.data.offset.samples,
                            self.data.offset.logits.to(self.dtype),
                            self.data.P,
                            torch.stack(torch.broadcast_tensors(*ms), -1),
                            use_pykeops=self.use_pykeops,
                        ),
                        obs=obs,
                    )
                    z_prev = z_curr
コード例 #2
0
    def model(self):
        r"""
        Generative Model
        """
        # global parameters
        gain = pyro.sample("gain", dist.HalfNormal(self.priors["gain_std"]))
        alpha = pyro.sample(
            "alpha",
            dist.Dirichlet(
                torch.ones((self.Q, self.data.C)) +
                torch.eye(self.Q) * 9).to_event(1),
        )
        pi = pyro.sample(
            "pi",
            dist.Dirichlet(torch.ones(
                (self.Q, self.S + 1)) / (self.S + 1)).to_event(1),
        )
        pi = expand_offtarget(pi)
        lamda = pyro.sample(
            "lamda",
            dist.Exponential(torch.full(
                (self.Q, ), self.priors["lamda_rate"])).to_event(1),
        )
        proximity = pyro.sample(
            "proximity", dist.Exponential(self.priors["proximity_rate"]))
        size = torch.stack(
            (
                torch.full_like(proximity, 2.0),
                (((self.data.P + 1) / (2 * proximity))**2 - 1),
            ),
            dim=-1,
        )

        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-2,
        )
        # time frames
        frames = pyro.plate(
            "frames",
            self.data.F,
            subsample=self.f,
            subsample_size=self.fbatch_size,
            dim=-1,
        )

        with aois as ndx:
            ndx = ndx[:, None]
            mask = Vindex(self.data.mask)[ndx].to(self.device)
            with handlers.mask(mask=mask):
                # background mean and std
                background_mean = pyro.sample(
                    "background_mean",
                    dist.HalfNormal(self.priors["background_mean_std"]).expand(
                        (self.data.C, )).to_event(1),
                )
                background_std = pyro.sample(
                    "background_std",
                    dist.HalfNormal(self.priors["background_std_std"]).expand(
                        (self.data.C, )).to_event(1),
                )
                with frames as fdx:
                    # fetch data
                    obs, target_locs, is_ontarget = self.data.fetch(
                        ndx.unsqueeze(-1), fdx.unsqueeze(-1),
                        torch.arange(self.data.C))
                    # sample background intensity
                    background = pyro.sample(
                        "background",
                        dist.Gamma(
                            (background_mean / background_std)**2,
                            background_mean / background_std**2,
                        ).to_event(1),
                    )

                    ms, heights, widths, xs, ys = [], [], [], [], []
                    is_ontarget = is_ontarget.squeeze(-1)
                    for qdx in range(self.Q):
                        # sample hidden model state (1+S,)
                        z_probs = Vindex(pi)[..., qdx, :, is_ontarget.long()]
                        z = pyro.sample(
                            f"z_q{qdx}",
                            dist.Categorical(z_probs),
                            infer={"enumerate": "parallel"},
                        )
                        theta = pyro.sample(
                            f"theta_q{qdx}",
                            dist.Categorical(
                                Vindex(probs_theta(
                                    self.K, self.device))[torch.clamp(z,
                                                                      min=0,
                                                                      max=1)]),
                            infer={"enumerate": "parallel"},
                        )
                        onehot_theta = one_hot(theta, num_classes=1 + self.K)

                        for kdx in range(self.K):
                            specific = onehot_theta[..., 1 + kdx]
                            # spot presence
                            m = pyro.sample(
                                f"m_k{kdx}_q{qdx}",
                                dist.Bernoulli(
                                    Vindex(probs_m(lamda,
                                                   self.K))[..., qdx, theta,
                                                            kdx]),
                            )
                            with handlers.mask(mask=m > 0):
                                # sample spot variables
                                height = pyro.sample(
                                    f"height_k{kdx}_q{qdx}",
                                    dist.HalfNormal(self.priors["height_std"]),
                                )
                                width = pyro.sample(
                                    f"width_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        1.5,
                                        2,
                                        self.priors["width_min"],
                                        self.priors["width_max"],
                                    ),
                                )
                                x = pyro.sample(
                                    f"x_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        0,
                                        Vindex(size)[..., specific],
                                        -(self.data.P + 1) / 2,
                                        (self.data.P + 1) / 2,
                                    ),
                                )
                                y = pyro.sample(
                                    f"y_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        0,
                                        Vindex(size)[..., specific],
                                        -(self.data.P + 1) / 2,
                                        (self.data.P + 1) / 2,
                                    ),
                                )

                            # append
                            ms.append(m)
                            heights.append(height)
                            widths.append(width)
                            xs.append(x)
                            ys.append(y)

                    heights = torch.stack(
                        [
                            torch.stack(heights[q * self.K:(1 + q) * self.K],
                                        -1) for q in range(self.Q)
                        ],
                        -2,
                    )
                    widths = torch.stack(
                        [
                            torch.stack(widths[q * self.K:(1 + q) * self.K],
                                        -1) for q in range(self.Q)
                        ],
                        -2,
                    )
                    xs = torch.stack(
                        [
                            torch.stack(xs[q * self.K:(1 + q) * self.K], -1)
                            for q in range(self.Q)
                        ],
                        -2,
                    )
                    ys = torch.stack(
                        [
                            torch.stack(ys[q * self.Q:(1 + q) * self.K], -1)
                            for q in range(self.Q)
                        ],
                        -2,
                    )
                    ms = torch.broadcast_tensors(*ms)
                    ms = torch.stack(
                        [
                            torch.stack(ms[q * self.Q:(1 + q) * self.K], -1)
                            for q in range(self.Q)
                        ],
                        -2,
                    )
                    # observed data
                    pyro.sample(
                        "data",
                        KSMOGN(
                            heights,
                            widths,
                            xs,
                            ys,
                            target_locs,
                            background,
                            gain,
                            self.data.offset.samples,
                            self.data.offset.logits.to(self.dtype),
                            self.data.P,
                            ms,
                            alpha,
                            use_pykeops=self.use_pykeops,
                        ),
                        obs=obs,
                    )
コード例 #3
0
    def model(self):
        r"""
        **Generative Model**

        Model parameters:

        +-----------------+-----------+-------------------------------------+
        | Parameter       | Shape     | Description                         |
        +=================+===========+=====================================+
        | |g| - :math:`g` | (1,)      | camera gain                         |
        +-----------------+-----------+-------------------------------------+
        | |sigma| - |prox|| (1,)      | proximity                           |
        +-----------------+-----------+-------------------------------------+
        | ``lamda`` - |ld|| (1,)      | average rate of target-nonspecific  |
        |                 |           | binding                             |
        +-----------------+-----------+-------------------------------------+
        | ``pi`` - |pi|   | (1,)      | average binding probability of      |
        |                 |           | target-specific binding             |
        +-----------------+-----------+-------------------------------------+
        | |bg| - |b|      | (N, F)    | background intensity                |
        +-----------------+-----------+-------------------------------------+
        | |z| - :math:`z` | (N, F)    | target-specific spot presence       |
        +-----------------+-----------+-------------------------------------+
        | |t| - |theta|   | (N, F)    | target-specific spot index          |
        +-----------------+-----------+-------------------------------------+
        | |m| - :math:`m` | (K, N, F) | spot presence indicator             |
        +-----------------+-----------+-------------------------------------+
        | |h| - :math:`h` | (K, N, F) | spot intensity                      |
        +-----------------+-----------+-------------------------------------+
        | |w| - :math:`w` | (K, N, F) | spot width                          |
        +-----------------+-----------+-------------------------------------+
        | |x| - :math:`x` | (K, N, F) | spot position on x-axis             |
        +-----------------+-----------+-------------------------------------+
        | |y| - :math:`y` | (K, N, F) | spot position on y-axis             |
        +-----------------+-----------+-------------------------------------+
        | |D| - :math:`D` | |shape|   | observed images                     |
        +-----------------+-----------+-------------------------------------+

        .. |ps| replace:: :math:`p(\mathsf{specific})`
        .. |theta| replace:: :math:`\theta`
        .. |prox| replace:: :math:`\sigma^{xy}`
        .. |ld| replace:: :math:`\lambda`
        .. |b| replace:: :math:`b`
        .. |shape| replace:: (N, F, P, P)
        .. |sigma| replace:: ``proximity``
        .. |bg| replace:: ``background``
        .. |h| replace:: ``height``
        .. |w| replace:: ``width``
        .. |D| replace:: ``data``
        .. |m| replace:: ``m``
        .. |z| replace:: ``z``
        .. |t| replace:: ``theta``
        .. |x| replace:: ``x``
        .. |y| replace:: ``y``
        .. |pi| replace:: :math:`\pi`
        .. |g| replace:: ``gain``

        Full joint distribution:

        .. math::

            \begin{aligned}
                p(D, \phi) =~&p(g) p(\sigma^{xy}) p(\pi) p(\lambda)
                \prod_{\mathsf{AOI}} \left[ p(\mu^b) p(\sigma^b) \prod_{\mathsf{frame}}
                \left[ \vphantom{\prod_{F}} p(b | \mu^b, \sigma^b) p(z | \pi) p(\theta | z)
                \vphantom{\prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}}} \cdot \right. \right. \\
                &\prod_{\mathsf{spot}} \left[ \vphantom{\prod_{F}} p(m | \theta, \lambda)
                p(h) p(w) p(x | \sigma^{xy}, \theta) p(y | \sigma^{xy}, \theta) \right] \left. \left.
                \prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}} \sum_{\delta} p(\delta)
                p(D | \mu^I, g, \delta) \right] \right]
            \end{aligned}

        :math:`z` and :math:`\theta` marginalized joint distribution:

        .. math::

            \begin{aligned}
                \sum_{z, \theta} p(D, \phi) =~&p(g) p(\sigma^{xy}) p(\pi) p(\lambda)
                \prod_{\mathsf{AOI}} \left[ p(\mu^b) p(\sigma^b) \prod_{\mathsf{frame}}
                \left[ \vphantom{\prod_{F}} p(b | \mu^b, \sigma^b) \sum_{z} p(z | \pi) \sum_{\theta} p(\theta | z)
                \vphantom{\prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}}} \cdot \right. \right. \\
                &\prod_{\mathsf{spot}} \left[ \vphantom{\prod_{F}} p(m | \theta, \lambda)
                p(h) p(w) p(x | \sigma^{xy}, \theta) p(y | \sigma^{xy}, \theta) \right] \left. \left.
                \prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}} \sum_{\delta} p(\delta)
                p(D | \mu^I, g, \delta) \right] \right]
            \end{aligned}
        """
        # global parameters
        gain = pyro.sample("gain", dist.HalfNormal(self.gain_std))
        pi = pyro.sample("pi",
                         dist.Dirichlet(torch.ones(self.S + 1) / (self.S + 1)))
        pi = expand_offtarget(pi)
        lamda = pyro.sample("lamda", dist.Exponential(self.lamda_rate))
        proximity = pyro.sample("proximity",
                                dist.Exponential(self.proximity_rate))
        size = torch.stack(
            (
                torch.full_like(proximity, 2.0),
                (((self.data.P + 1) / (2 * proximity))**2 - 1),
            ),
            dim=-1,
        )

        # spots
        spots = pyro.plate("spots", self.K)
        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-2,
        )
        # time frames
        frames = pyro.plate(
            "frames",
            self.data.F,
            subsample=self.f,
            subsample_size=self.fbatch_size,
            dim=-1,
        )

        with aois as ndx:
            ndx = ndx[:, None]
            # background mean and std
            background_mean = pyro.sample(
                "background_mean", dist.HalfNormal(self.background_mean_std))
            background_std = pyro.sample(
                "background_std", dist.HalfNormal(self.background_std_std))
            with frames as fdx:
                # fetch data
                obs, target_locs, is_ontarget = self.data.fetch(
                    ndx, fdx, self.cdx)
                # sample background intensity
                background = pyro.sample(
                    "background",
                    dist.Gamma(
                        (background_mean / background_std)**2,
                        background_mean / background_std**2,
                    ),
                )

                # sample hidden model state (1+S,)
                z = pyro.sample(
                    "z",
                    dist.Categorical(Vindex(pi)[..., :,
                                                is_ontarget.long()]),
                    infer={"enumerate": "parallel"},
                )
                theta = pyro.sample(
                    "theta",
                    dist.Categorical(
                        Vindex(probs_theta(self.K,
                                           self.device))[torch.clamp(z,
                                                                     min=0,
                                                                     max=1)]),
                    infer={"enumerate": "parallel"},
                )
                onehot_theta = one_hot(theta, num_classes=1 + self.K)

                ms, heights, widths, xs, ys = [], [], [], [], []
                for kdx in spots:
                    specific = onehot_theta[..., 1 + kdx]
                    # spot presence
                    m = pyro.sample(
                        f"m_{kdx}",
                        dist.Bernoulli(
                            Vindex(probs_m(lamda, self.K))[..., theta, kdx]),
                    )
                    with handlers.mask(mask=m > 0):
                        # sample spot variables
                        height = pyro.sample(
                            f"height_{kdx}",
                            dist.HalfNormal(self.height_std),
                        )
                        width = pyro.sample(
                            f"width_{kdx}",
                            AffineBeta(
                                1.5,
                                2,
                                self.width_min,
                                self.width_max,
                            ),
                        )
                        x = pyro.sample(
                            f"x_{kdx}",
                            AffineBeta(
                                0,
                                Vindex(size)[..., specific],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )
                        y = pyro.sample(
                            f"y_{kdx}",
                            AffineBeta(
                                0,
                                Vindex(size)[..., specific],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )

                    # append
                    ms.append(m)
                    heights.append(height)
                    widths.append(width)
                    xs.append(x)
                    ys.append(y)

                # observed data
                pyro.sample(
                    "data",
                    KSMOGN(
                        torch.stack(heights, -1),
                        torch.stack(widths, -1),
                        torch.stack(xs, -1),
                        torch.stack(ys, -1),
                        target_locs,
                        background,
                        gain,
                        self.data.offset.samples,
                        self.data.offset.logits.to(self.dtype),
                        self.data.P,
                        torch.stack(torch.broadcast_tensors(*ms), -1),
                        self.use_pykeops,
                    ),
                    obs=obs,
                )