def model_8(weeks_data, days_data, history, vectorized):
    x_dim, y_dim, w_dim, z_dim = 3, 2, 2, 3
    x_init = pyro.param("x_init",
                        lambda: torch.rand(x_dim),
                        constraint=constraints.simplex)
    x_trans = pyro.param("x_trans",
                         lambda: torch.rand((x_dim, x_dim)),
                         constraint=constraints.simplex)
    y_probs = pyro.param("y_probs",
                         lambda: torch.rand(x_dim, y_dim),
                         constraint=constraints.simplex)
    w_init = pyro.param("w_init",
                        lambda: torch.rand(w_dim),
                        constraint=constraints.simplex)
    w_trans = pyro.param("w_trans",
                         lambda: torch.rand((w_dim, w_dim)),
                         constraint=constraints.simplex)
    z_probs = pyro.param("z_probs",
                         lambda: torch.rand(w_dim, z_dim),
                         constraint=constraints.simplex)

    x_prev = None
    weeks_loop = (pyro.vectorized_markov(
        name="weeks", size=len(weeks_data), dim=-1, history=history)
                  if vectorized else pyro.markov(range(len(weeks_data)),
                                                 history=history))
    for i in weeks_loop:
        if isinstance(i, int) and i == 0:
            x_probs = x_init
        else:
            x_probs = Vindex(x_trans)[x_prev]

        x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs))
        pyro.sample(
            "y_{}".format(i),
            dist.Categorical(Vindex(y_probs)[x_curr]),
            obs=weeks_data[i],
        )
        x_prev = x_curr

    w_prev = None
    days_loop = (pyro.vectorized_markov(
        name="days", size=len(days_data), dim=-1, history=history)
                 if vectorized else pyro.markov(range(len(days_data)),
                                                history=history))
    for j in days_loop:
        if isinstance(j, int) and j == 0:
            w_probs = w_init
        else:
            w_probs = Vindex(w_trans)[w_prev]

        w_curr = pyro.sample("w_{}".format(j), dist.Categorical(w_probs))
        pyro.sample(
            "z_{}".format(j),
            dist.Categorical(Vindex(z_probs)[w_curr]),
            obs=days_data[j],
        )
        w_prev = w_curr
Exemple #2
0
def model_2(data, history, vectorized):
    x_dim, y_dim = 3, 2
    x_init = pyro.param("x_init",
                        lambda: torch.rand(x_dim),
                        constraint=constraints.simplex)
    x_trans = pyro.param("x_trans",
                         lambda: torch.rand((x_dim, x_dim)),
                         constraint=constraints.simplex)
    y_init = pyro.param("y_init",
                        lambda: torch.rand(x_dim, y_dim),
                        constraint=constraints.simplex)
    y_trans = pyro.param("y_trans",
                         lambda: torch.rand((x_dim, y_dim, y_dim)),
                         constraint=constraints.simplex)

    x_prev = y_prev = None
    markov_loop = \
        pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \
        else pyro.markov(range(len(data)), history=history)
    for i in markov_loop:
        x_curr = pyro.sample(
            "x_{}".format(i),
            dist.Categorical(
                x_init if isinstance(i, int) and i < 1 else x_trans[x_prev]))
        with pyro.plate("tones", data.shape[-1], dim=-1):
            y_curr = pyro.sample(
                "y_{}".format(i),
                dist.Categorical(y_init[x_curr] if isinstance(i, int) and i < 1
                                 else Vindex(y_trans)[x_curr, y_prev]),
                obs=data[i])
        x_prev, y_prev = x_curr, y_curr
Exemple #3
0
def model_0(data, history, vectorized):
    x_dim = 3
    init = pyro.param("init",
                      lambda: torch.rand(x_dim),
                      constraint=constraints.simplex)
    trans = pyro.param("trans",
                       lambda: torch.rand((x_dim, x_dim)),
                       constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(x_dim))

    with pyro.plate("sequences", data.shape[0], dim=-3) as sequences:
        sequences = sequences[:, None]
        x_prev = None
        markov_loop = \
            pyro.vectorized_markov(name="time", size=data.shape[1], dim=-2, history=history) if vectorized \
            else pyro.markov(range(data.shape[1]), history=history)
        for i in markov_loop:
            x_curr = pyro.sample(
                "x_{}".format(i),
                dist.Categorical(
                    init if isinstance(i, int) and i < 1 else trans[x_prev]))
            with pyro.plate("tones", data.shape[2], dim=-1):
                pyro.sample("y_{}".format(i),
                            dist.Normal(Vindex(locs)[..., x_curr], 1.),
                            obs=Vindex(data)[sequences, i])
            x_prev = x_curr
Exemple #4
0
def model_6(data, history, vectorized):
    x_dim = 3
    x_init = pyro.param("x_init",
                        lambda: torch.rand(x_dim),
                        constraint=constraints.simplex)
    x_trans = pyro.param("x_trans",
                         lambda: torch.rand((len(data) - 1, x_dim, x_dim)),
                         constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(x_dim))

    x_prev = None
    markov_loop = \
        pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \
        else pyro.markov(range(len(data)), history=history)
    for i in markov_loop:
        if isinstance(i, int) and i < 1:
            x_probs = x_init
        elif isinstance(i, int):
            x_probs = x_trans[i - 1, x_prev]
        else:
            x_probs = Vindex(x_trans)[(i - 1)[:, None], x_prev]

        x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs))
        with pyro.plate("tones", data.shape[-1], dim=-1):
            pyro.sample("y_{}".format(i),
                        dist.Normal(Vindex(locs)[..., x_curr], 1.),
                        obs=data[i])
        x_prev = x_curr
def model_1(data, history, vectorized):
    x_dim = 3
    init = pyro.param("init",
                      lambda: torch.rand(x_dim),
                      constraint=constraints.simplex)
    trans = pyro.param("trans",
                       lambda: torch.rand((x_dim, x_dim)),
                       constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(x_dim))

    x_prev = None
    markov_loop = (pyro.vectorized_markov(
        name="time", size=len(data), dim=-2, history=history) if vectorized
                   else pyro.markov(range(len(data)), history=history))
    for i in markov_loop:
        x_curr = pyro.sample(
            "x_{}".format(i),
            dist.Categorical(
                init if isinstance(i, int) and i < 1 else trans[x_prev]),
        )
        with pyro.plate("tones", data.shape[-1], dim=-1):
            pyro.sample(
                "y_{}".format(i),
                dist.Normal(Vindex(locs)[..., x_curr], 1.0),
                obs=data[i],
            )
        x_prev = x_curr
def model_5(data, history, vectorized):
    x_dim, y_dim = 3, 2
    x_init = pyro.param("x_init",
                        lambda: torch.rand(x_dim),
                        constraint=constraints.simplex)
    x_init_2 = pyro.param("x_init_2",
                          lambda: torch.rand(x_dim, x_dim),
                          constraint=constraints.simplex)
    x_trans = pyro.param(
        "x_trans",
        lambda: torch.rand((x_dim, x_dim, x_dim)),
        constraint=constraints.simplex,
    )
    y_probs = pyro.param("y_probs",
                         lambda: torch.rand(x_dim, y_dim),
                         constraint=constraints.simplex)

    x_prev = x_prev_2 = None
    markov_loop = (pyro.vectorized_markov(
        name="time", size=len(data), dim=-2, history=history) if vectorized
                   else pyro.markov(range(len(data)), history=history))
    for i in markov_loop:
        if isinstance(i, int) and i == 0:
            x_probs = x_init
        elif isinstance(i, int) and i == 1:
            x_probs = Vindex(x_init_2)[x_prev]
        else:
            x_probs = Vindex(x_trans)[x_prev_2, x_prev]

        x_curr = pyro.sample("x_{}".format(i), dist.Categorical(x_probs))
        with pyro.plate("tones", data.shape[-1], dim=-1):
            pyro.sample("y_{}".format(i),
                        dist.Categorical(Vindex(y_probs)[x_curr]),
                        obs=data[i])
        x_prev_2, x_prev = x_prev, x_curr
Exemple #7
0
def model_10(data, history, vectorized):
    init_probs = torch.tensor([0.5, 0.5])
    transition_probs = pyro.param("transition_probs",
                                  torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
                                  constraint=constraints.simplex)
    emission_probs = pyro.param("emission_probs",
                                torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
                                constraint=constraints.simplex)
    x = None
    markov_loop = \
        pyro.vectorized_markov(name="time", size=len(data), history=history) if vectorized \
        else pyro.markov(range(len(data)), history=history)
    for i in markov_loop:
        probs = init_probs if x is None else transition_probs[x]
        x = pyro.sample("x_{}".format(i), dist.Categorical(probs))
        pyro.sample("y_{}".format(i),
                    dist.Categorical(emission_probs[x]),
                    obs=data[i])
def model_4(data, history, vectorized):
    w_dim, x_dim, y_dim = 2, 3, 2
    w_init = pyro.param("w_init",
                        lambda: torch.rand(w_dim),
                        constraint=constraints.simplex)
    w_trans = pyro.param("w_trans",
                         lambda: torch.rand((w_dim, w_dim)),
                         constraint=constraints.simplex)
    x_init = pyro.param("x_init",
                        lambda: torch.rand(w_dim, x_dim),
                        constraint=constraints.simplex)
    x_trans = pyro.param(
        "x_trans",
        lambda: torch.rand((w_dim, x_dim, x_dim)),
        constraint=constraints.simplex,
    )
    y_probs = pyro.param(
        "y_probs",
        lambda: torch.rand(w_dim, x_dim, y_dim),
        constraint=constraints.simplex,
    )

    w_prev = x_prev = None
    markov_loop = (pyro.vectorized_markov(
        name="time", size=len(data), dim=-2, history=history) if vectorized
                   else pyro.markov(range(len(data)), history=history))
    for i in markov_loop:
        w_curr = pyro.sample(
            "w_{}".format(i),
            dist.Categorical(
                w_init if isinstance(i, int) and i < 1 else w_trans[w_prev]),
        )
        x_curr = pyro.sample(
            "x_{}".format(i),
            dist.Categorical(x_init[w_curr] if isinstance(i, int) and i < 1
                             else x_trans[w_curr, x_prev]),
        )
        with pyro.plate("tones", data.shape[-1], dim=-1):
            pyro.sample(
                "y_{}".format(i),
                dist.Categorical(Vindex(y_probs)[w_curr, x_curr]),
                obs=data[i],
            )
        x_prev, w_prev = x_curr, w_curr
Exemple #9
0
def model_7(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    with handlers.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1),
        )
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),
        )
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    # Note that since we're using dim=-2 for the time dimension, we need
    # to batch sequences over a different dimension, here dim=-3.
    with pyro.plate("sequences", num_sequences, batch_size, dim=-3) as batch:
        lengths = lengths[batch]
        batch = batch[:, None]
        x_prev = 0
        # To vectorize time dimension we use pyro.vectorized_markov(name=...).
        # With the help of Vindex and additional unsqueezes we can ensure that
        # dimensions line up properly.
        for t in pyro.vectorized_markov(
            name="time", size=int(max_length if args.jit else lengths.max()), dim=-2
        ):
            with handlers.mask(mask=(t < lengths.unsqueeze(-1)).unsqueeze(-1)):
                x_curr = pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x[x_prev]),
                    infer={"enumerate": "parallel"},
                )
                with tones_plate:
                    pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(probs_y[x_curr.squeeze(-1)]),
                        obs=Vindex(sequences)[batch, t],
                    )
Exemple #10
0
    def model(self):
        """
        **Generative Model**
        """
        # global parameters
        gain = pyro.sample("gain", dist.HalfNormal(self.priors["gain_std"]))
        init = pyro.sample(
            "init",
            dist.Dirichlet(torch.ones(self.Q, self.S + 1) /
                           (self.S + 1)).to_event(1),
        )
        init = expand_offtarget(init)
        trans = pyro.sample(
            "trans",
            dist.Dirichlet(
                torch.ones(self.Q, self.S + 1, self.S + 1) /
                (self.S + 1)).to_event(2),
        )
        trans = expand_offtarget(trans)
        lamda = pyro.sample(
            "lamda",
            dist.Exponential(torch.full(
                (self.Q, ), self.priors["lamda_rate"])).to_event(1),
        )
        proximity = pyro.sample(
            "proximity", dist.Exponential(self.priors["proximity_rate"]))
        size = torch.stack(
            (
                torch.full_like(proximity, 2.0),
                (((self.data.P + 1) / (2 * proximity))**2 - 1),
            ),
            dim=-1,
        )

        # spots
        spots = pyro.plate("spots", self.K)
        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-3,
        )
        # time frames
        frames = (pyro.vectorized_markov(
            name="frames", size=self.data.F, dim=-2)
                  if self.vectorized else pyro.markov(range(self.data.F)))
        # color channels
        channels = pyro.plate(
            "channels",
            self.data.C,
            dim=-1,
        )

        with channels as cdx, aois as ndx:
            ndx = ndx[:, None, None]
            mask = Vindex(self.data.mask)[ndx].to(self.device)
            with handlers.mask(mask=mask):
                # background mean and std
                background_mean = pyro.sample(
                    "background_mean",
                    dist.HalfNormal(self.priors["background_mean_std"]),
                )
                background_std = pyro.sample(
                    "background_std",
                    dist.HalfNormal(self.priors["background_std_std"]))
                z_prev = None
                for fdx in frames:
                    if self.vectorized:
                        fsx, fdx = fdx
                        fdx = torch.as_tensor(fdx)
                        fdx = fdx.unsqueeze(-1)
                    else:
                        fsx = fdx
                    # fetch data
                    obs, target_locs, is_ontarget = self.data.fetch(
                        ndx, fdx, cdx)
                    # sample background intensity
                    background = pyro.sample(
                        f"background_f{fsx}",
                        dist.Gamma(
                            (background_mean / background_std)**2,
                            background_mean / background_std**2,
                        ),
                    )

                    # sample hidden model state (1+S,)
                    z_probs = (Vindex(init)[..., cdx, :,
                                            is_ontarget.long()]
                               if z_prev is None else
                               Vindex(trans)[..., cdx, z_prev, :,
                                             is_ontarget.long()])
                    z_curr = pyro.sample(f"z_f{fsx}",
                                         dist.Categorical(z_probs))

                    theta = pyro.sample(
                        f"theta_f{fsx}",
                        dist.Categorical(
                            Vindex(probs_theta(
                                self.K, self.device))[torch.clamp(z_curr,
                                                                  min=0,
                                                                  max=1)]),
                        infer={"enumerate": "parallel"},
                    )
                    onehot_theta = one_hot(theta, num_classes=1 + self.K)

                    ms, heights, widths, xs, ys = [], [], [], [], []
                    for kdx in spots:
                        specific = onehot_theta[..., 1 + kdx]
                        # spot presence
                        m_probs = Vindex(probs_m(lamda, self.K))[..., cdx,
                                                                 theta, kdx]
                        m = pyro.sample(
                            f"m_k{kdx}_f{fsx}",
                            dist.Categorical(
                                torch.stack((1 - m_probs, m_probs), -1)),
                        )
                        with handlers.mask(mask=m > 0):
                            # sample spot variables
                            height = pyro.sample(
                                f"height_k{kdx}_f{fsx}",
                                dist.HalfNormal(self.priors["height_std"]),
                            )
                            width = pyro.sample(
                                f"width_k{kdx}_f{fsx}",
                                AffineBeta(
                                    1.5,
                                    2,
                                    self.priors["width_min"],
                                    self.priors["width_max"],
                                ),
                            )
                            x = pyro.sample(
                                f"x_k{kdx}_f{fsx}",
                                AffineBeta(
                                    0,
                                    Vindex(size)[..., specific],
                                    -(self.data.P + 1) / 2,
                                    (self.data.P + 1) / 2,
                                ),
                            )
                            y = pyro.sample(
                                f"y_k{kdx}_f{fsx}",
                                AffineBeta(
                                    0,
                                    Vindex(size)[..., specific],
                                    -(self.data.P + 1) / 2,
                                    (self.data.P + 1) / 2,
                                ),
                            )

                        # append
                        ms.append(m)
                        heights.append(height)
                        widths.append(width)
                        xs.append(x)
                        ys.append(y)

                    # observed data
                    pyro.sample(
                        f"data_f{fsx}",
                        KSMOGN(
                            torch.stack(heights, -1),
                            torch.stack(widths, -1),
                            torch.stack(xs, -1),
                            torch.stack(ys, -1),
                            target_locs,
                            background,
                            gain,
                            self.data.offset.samples,
                            self.data.offset.logits.to(self.dtype),
                            self.data.P,
                            torch.stack(torch.broadcast_tensors(*ms), -1),
                            use_pykeops=self.use_pykeops,
                        ),
                        obs=obs,
                    )
                    z_prev = z_curr
Exemple #11
0
    def guide(self):
        """
        **Variational Distribution**
        """
        # global parameters
        pyro.sample(
            "gain",
            dist.Gamma(
                pyro.param("gain_loc") * pyro.param("gain_beta"),
                pyro.param("gain_beta"),
            ),
        )
        pyro.sample(
            "init",
            dist.Dirichlet(pyro.param("init_mean") *
                           pyro.param("init_size")).to_event(1),
        )
        pyro.sample(
            "trans",
            dist.Dirichlet(
                pyro.param("trans_mean") *
                pyro.param("trans_size")).to_event(2),
        )
        pyro.sample(
            "lamda",
            dist.Gamma(
                pyro.param("lamda_loc") * pyro.param("lamda_beta"),
                pyro.param("lamda_beta"),
            ).to_event(1),
        )
        pyro.sample(
            "proximity",
            AffineBeta(
                pyro.param("proximity_loc"),
                pyro.param("proximity_size"),
                0,
                (self.data.P + 1) / math.sqrt(12),
            ),
        )

        # spots
        spots = pyro.plate("spots", self.K)
        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-3,
        )
        # time frames
        frames = (pyro.vectorized_markov(
            name="frames", size=self.data.F, dim=-2)
                  if self.vectorized else pyro.markov(range(self.data.F)))
        # color channels
        channels = pyro.plate(
            "channels",
            self.data.C,
            dim=-1,
        )

        with channels as cdx, aois as ndx:
            ndx = ndx[:, None, None]
            mask = Vindex(self.data.mask)[ndx].to(self.device)
            with handlers.mask(mask=mask):
                pyro.sample(
                    "background_mean",
                    dist.Delta(
                        Vindex(pyro.param("background_mean_loc"))[ndx, 0,
                                                                  cdx]),
                )
                pyro.sample(
                    "background_std",
                    dist.Delta(
                        Vindex(pyro.param("background_std_loc"))[ndx, 0, cdx]),
                )
                z_prev = None
                for fdx in frames:
                    if self.vectorized:
                        fsx, fdx = fdx
                        fdx = torch.as_tensor(fdx)
                        fdx = fdx.unsqueeze(-1)
                    else:
                        fsx = fdx
                    # sample background intensity
                    pyro.sample(
                        f"background_f{fsx}",
                        dist.Gamma(
                            Vindex(pyro.param("b_loc"))[ndx, fdx, cdx] *
                            Vindex(pyro.param("b_beta"))[ndx, fdx, cdx],
                            Vindex(pyro.param("b_beta"))[ndx, fdx, cdx],
                        ),
                    )

                    # sample hidden model state
                    z_probs = (Vindex(pyro.param("z_trans"))[ndx, fdx, cdx, 0]
                               if z_prev is None else Vindex(
                                   pyro.param("z_trans"))[ndx, fdx, cdx,
                                                          z_prev])
                    z_curr = pyro.sample(
                        f"z_f{fsx}",
                        dist.Categorical(z_probs),
                        infer={"enumerate": "parallel"},
                    )

                    for kdx in spots:
                        # spot presence
                        m_probs = Vindex(pyro.param("m_probs"))[z_curr, kdx,
                                                                ndx, fdx, cdx]
                        m = pyro.sample(
                            f"m_k{kdx}_f{fsx}",
                            dist.Categorical(
                                torch.stack((1 - m_probs, m_probs), -1)),
                            infer={"enumerate": "parallel"},
                        )
                        with handlers.mask(mask=m > 0):
                            # sample spot variables
                            pyro.sample(
                                f"height_k{kdx}_f{fsx}",
                                dist.Gamma(
                                    Vindex(pyro.param("h_loc"))[kdx, ndx, fdx,
                                                                cdx] *
                                    Vindex(pyro.param("h_beta"))[kdx, ndx, fdx,
                                                                 cdx],
                                    Vindex(pyro.param("h_beta"))[kdx, ndx, fdx,
                                                                 cdx],
                                ),
                            )
                            pyro.sample(
                                f"width_k{kdx}_f{fsx}",
                                AffineBeta(
                                    Vindex(pyro.param("w_mean"))[kdx, ndx, fdx,
                                                                 cdx],
                                    Vindex(pyro.param("w_size"))[kdx, ndx, fdx,
                                                                 cdx],
                                    self.priors["width_min"],
                                    self.priors["width_max"],
                                ),
                            )
                            pyro.sample(
                                f"x_k{kdx}_f{fsx}",
                                AffineBeta(
                                    Vindex(pyro.param("x_mean"))[kdx, ndx, fdx,
                                                                 cdx],
                                    Vindex(pyro.param("size"))[kdx, ndx, fdx,
                                                               cdx],
                                    -(self.data.P + 1) / 2,
                                    (self.data.P + 1) / 2,
                                ),
                            )
                            pyro.sample(
                                f"y_k{kdx}_f{fsx}",
                                AffineBeta(
                                    Vindex(pyro.param("y_mean"))[kdx, ndx, fdx,
                                                                 cdx],
                                    Vindex(pyro.param("size"))[kdx, ndx, fdx,
                                                               cdx],
                                    -(self.data.P + 1) / 2,
                                    (self.data.P + 1) / 2,
                                ),
                            )

                    z_prev = z_curr
Exemple #12
0
    def guide(self):
        """
        **Variational Distribution**
        """
        # global parameters
        pyro.sample(
            "gain",
            dist.Gamma(
                pyro.param("gain_loc") * pyro.param("gain_beta"),
                pyro.param("gain_beta"),
            ),
        )
        pyro.sample(
            "init",
            dist.Dirichlet(pyro.param("init_mean") * pyro.param("init_size")))
        pyro.sample(
            "trans",
            dist.Dirichlet(
                pyro.param("trans_mean") *
                pyro.param("trans_size")).to_event(1),
        )
        pyro.sample(
            "lamda",
            dist.Gamma(
                pyro.param("lamda_loc") * pyro.param("lamda_beta"),
                pyro.param("lamda_beta"),
            ),
        )
        pyro.sample(
            "proximity",
            AffineBeta(
                pyro.param("proximity_loc"),
                pyro.param("proximity_size"),
                0,
                (self.data.P + 1) / math.sqrt(12),
            ),
        )

        # spots
        spots = pyro.plate("spots", self.K)
        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-2,
        )
        # time frames
        frames = (pyro.vectorized_markov(
            name="frames", size=self.data.F, dim=-1)
                  if self.vectorized else pyro.markov(range(self.data.F)))

        with aois as ndx:
            ndx = ndx[:, None]
            pyro.sample(
                "background_mean",
                dist.Delta(Vindex(pyro.param("background_mean_loc"))[ndx, 0]),
            )
            pyro.sample(
                "background_std",
                dist.Delta(Vindex(pyro.param("background_std_loc"))[ndx, 0]),
            )
            z_prev = None
            for fdx in frames:
                if self.vectorized:
                    fsx, fdx = fdx
                else:
                    fsx = fdx
                # sample background intensity
                pyro.sample(
                    f"background_{fsx}",
                    dist.Gamma(
                        Vindex(pyro.param("b_loc"))[ndx, fdx] *
                        Vindex(pyro.param("b_beta"))[ndx, fdx],
                        Vindex(pyro.param("b_beta"))[ndx, fdx],
                    ),
                )

                # sample hidden model state
                z_probs = (Vindex(pyro.param("z_trans"))[ndx, fdx, 0]
                           if isinstance(fdx, int) and fdx < 1 else Vindex(
                               pyro.param("z_trans"))[ndx, fdx, z_prev])
                z_curr = pyro.sample(
                    f"z_{fsx}",
                    dist.Categorical(z_probs),
                    infer={"enumerate": "parallel"},
                )

                for kdx in spots:
                    # spot presence
                    m_probs = Vindex(pyro.param("m_probs"))[z_curr, kdx, ndx,
                                                            fdx]
                    m = pyro.sample(
                        f"m_{kdx}_{fsx}",
                        dist.Categorical(
                            torch.stack((1 - m_probs, m_probs), -1)),
                        infer={"enumerate": "parallel"},
                    )
                    with handlers.mask(mask=m > 0):
                        # sample spot variables
                        pyro.sample(
                            f"height_{kdx}_{fsx}",
                            dist.Gamma(
                                Vindex(pyro.param("h_loc"))[kdx, ndx, fdx] *
                                Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                                Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                            ),
                        )
                        pyro.sample(
                            f"width_{kdx}_{fsx}",
                            AffineBeta(
                                Vindex(pyro.param("w_mean"))[kdx, ndx, fdx],
                                Vindex(pyro.param("w_size"))[kdx, ndx, fdx],
                                0.75,
                                2.25,
                            ),
                        )
                        pyro.sample(
                            f"x_{kdx}_{fsx}",
                            AffineBeta(
                                Vindex(pyro.param("x_mean"))[kdx, ndx, fdx],
                                Vindex(pyro.param("size"))[kdx, ndx, fdx],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )
                        pyro.sample(
                            f"y_{kdx}_{fsx}",
                            AffineBeta(
                                Vindex(pyro.param("y_mean"))[kdx, ndx, fdx],
                                Vindex(pyro.param("size"))[kdx, ndx, fdx],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )

                z_prev = z_curr