示例#1
0
    def model():
        x_plate = pyro.plate("x_plate",
                             5,
                             subsample_size=2 if subsampling else None,
                             dim=-1)
        y_plate = pyro.plate("y_plate",
                             6,
                             subsample_size=3 if subsampling else None,
                             dim=-2)
        with pyro.plate("num_particles", 50, dim=-3):
            with x_plate:
                b = pyro.sample(
                    "b", dist.Beta(torch.tensor(1.1), torch.tensor(1.1)))
            with y_plate:
                c = pyro.sample("c", dist.Bernoulli(0.5))
            with x_plate, y_plate:
                d = pyro.sample("d", dist.Bernoulli(b))

        # check shapes
        if enumerate_ == "parallel":
            assert b.shape == (50, 1, x_plate.subsample_size)
            assert c.shape == (2, 1, 1, 1)
            assert d.shape == (2, 1, 1, 1, 1)
        elif enumerate_ == "sequential":
            assert b.shape == (50, 1, x_plate.subsample_size)
            assert c.shape in ((), (1, 1, 1))  # both are valid
            assert d.shape in ((), (1, 1, 1))  # both are valid
        else:
            assert b.shape == (50, 1, x_plate.subsample_size)
            assert c.shape == (50, y_plate.subsample_size, 1)
            assert d.shape == (50, y_plate.subsample_size,
                               x_plate.subsample_size)
示例#2
0
    def model():
        pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'})

        with pyro.plate("non_enum", 2):
            a = pyro.sample("a", dist.Bernoulli(0.5), infer={'enumerate': None})

        p = (1.0 + a.sum(-1)) / (2.0 + a.shape[0])  # introduce dependency of b on a

        with pyro.plate("enum_1", 3):
            pyro.sample("b", dist.Bernoulli(p), infer={'enumerate': enumerate_})
示例#3
0
 def model():
     p = torch.tensor(0.5, requires_grad=True)
     with pyro.plate("plate_outer", 5, dim=plate_dims[0]):
         pyro.sample("x", dist.Bernoulli(p))
         with pyro.plate("plate_inner_1", 6, dim=plate_dims[1]):
             pyro.sample("y", dist.Bernoulli(p))
             with pyro.plate("plate_inner_2", 7, dim=plate_dims[2]):
                 pyro.sample("z", dist.Bernoulli(p))
                 with pyro.plate("plate_inner_3", 8, dim=plate_dims[3]):
                     pyro.sample("q", dist.Bernoulli(p))
示例#4
0
 def model():
     p = pyro.param("p", 0.25 * torch.ones(2, 2))
     q = pyro.param("q", 0.25 * torch.ones(2))
     x_prev = torch.tensor(0)
     x_curr = torch.tensor(0)
     for t in pyro.markov(range(10), history=history):
         probs = p[x_prev, x_curr]
         x_prev, x_curr = x_curr, pyro.sample("x_{}".format(t), dist.Bernoulli(probs)).long()
         pyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]),
                     obs=torch.tensor(0.))
 def model():
     pyro.sample("w", dist.Bernoulli(0.5), infer={'enumerate': 'parallel'})
     inner_plate = pyro.plate("plate",
                              10,
                              subsample_size=4 if subsampling else None)
     for i in pyro.plate(
             "iplate", 10,
             subsample=torch.arange(3) if subsampling else None):
         pyro.sample("y_{}".format(i), dist.Bernoulli(0.5))
         with inner_plate:
             pyro.sample("x_{}".format(i),
                         dist.Bernoulli(0.5),
                         infer={'enumerate': enumerate_})
示例#6
0
    def model():
        x_plate = pyro.plate("x_plate", 10, dim=-1)
        y_plate = pyro.plate("y_plate", 11, dim=-2)
        q = pyro.param("q", torch.tensor([0.5, 0.5]))
        pyro.sample("a", dist.Bernoulli(0.5))
        with x_plate:
            b = pyro.sample("b", dist.Bernoulli(0.5)).long()
        with y_plate:
            # Note that it is difficult to check that c does not depend on b.
            c = pyro.sample("c", dist.Bernoulli(0.5)).long()
        with x_plate, y_plate:
            pyro.sample("d", dist.Bernoulli(Vindex(q)[b] if reuse_plate else 0.5))

        assert c.shape != b.shape or enumerate_ == "sequential"
示例#7
0
    def model():
        p = pyro.param("p", torch.ones(3, 3))
        q = pyro.param("q", torch.tensor([0.5, 0.5]))
        plate_x = pyro.plate("plate_x",
                             4,
                             subsample_size=3 if subsampling else None,
                             dim=-1)
        plate_y = pyro.plate("plate_y",
                             5,
                             subsample_size=3 if subsampling else None,
                             dim=-1)
        plate_z = pyro.plate("plate_z",
                             6,
                             subsample_size=3 if subsampling else None,
                             dim=-2)

        a = pyro.sample("a", dist.Bernoulli(q[0])).long()
        w = 0
        for i in pyro.markov(range(4)):
            w = pyro.sample("w_{}".format(i), dist.Categorical(p[w]))

        with plate_x:
            b = pyro.sample("b", dist.Bernoulli(q[a])).long()
            x = 0
            for i in pyro.markov(range(4)):
                x = pyro.sample("x_{}".format(i), dist.Categorical(p[x]))

        with plate_y:
            c = pyro.sample("c", dist.Bernoulli(q[a])).long()
            y = 0
            for i in pyro.markov(range(4)):
                y = pyro.sample("y_{}".format(i), dist.Categorical(p[y]))

        with plate_z:
            d = pyro.sample("d", dist.Bernoulli(q[a])).long()
            z = 0
            for i in pyro.markov(range(4)):
                z = pyro.sample("z_{}".format(i), dist.Categorical(p[z]))

        with plate_x, plate_z:
            # this part is tricky: how do we know to preserve b's dimension?
            # also, how do we know how to make b and d have different dimensions?
            e = pyro.sample("e",
                            dist.Bernoulli(q[b if reuse_plate else a])).long()
            xz = 0
            for i in pyro.markov(range(4)):
                xz = pyro.sample("xz_{}".format(i), dist.Categorical(p[xz]))

        return a, b, c, d, e
示例#8
0
def model_2(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    with handlers.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, 2,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x, y = 0, 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                # Note the broadcasting tricks here: to index probs_y on tensors x and y,
                # we also need a final tensor for the tones dimension. This is conveniently
                # provided by the plate associated with that dimension.
                with tones_plate as tones:
                    y = pyro.sample("y_{}".format(t),
                                    dist.Bernoulli(probs_y[x, y, tones]),
                                    obs=sequences[batch, t]).long()
示例#9
0
 def model_leaf(data, state=0, address=""):
     p = pyro.param("p_leaf", torch.ones(10))
     pyro.sample(
         "leaf_{}".format(address),
         dist.Bernoulli(p[state]),
         obs=torch.tensor(1.0 if data else 0.0),
     )
示例#10
0
def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
    assert not torch._C._get_tracing_state()
    num_sequences, max_length, data_dim = sequences.shape
    with handlers.mask(mask=include_prior):
        # Our prior on transition probabilities will be:
        # stay in the same state with 90% probability; uniformly jump to another
        # state with 10% probability.
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        # We put a weak prior on the conditional probability of a tone sounding.
        # We know that on average about 4 of 88 tones are active, so we'll set a
        # rough weak prior of 10% of the notes being active at any one time.
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim,
                                        data_dim]).to_event(2))
    # In this first model we'll sequentially iterate over sequences in a
    # minibatch; this will make it easy to reason about tensor shapes.
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    for i in pyro.plate("sequences", len(sequences), batch_size):
        length = lengths[i]
        sequence = sequences[i, :length]
        x = 0
        for t in pyro.markov(range(length)):
            # On the next line, we'll overwrite the value of x with an updated
            # value. If we wanted to record all x values, we could instead
            # write x[t] = pyro.sample(...x[t-1]...).
            x = pyro.sample("x_{}_{}".format(i, t),
                            dist.Categorical(probs_x[x]),
                            infer={"enumerate": "parallel"})
            with tones_plate:
                pyro.sample("y_{}_{}".format(i, t),
                            dist.Bernoulli(probs_y[x.squeeze(-1)]),
                            obs=sequence[t])
示例#11
0
    def model(data):
        T, N, D = data.shape  # time steps, individuals, features

        # Gaussian initial distribution.
        init_loc = pyro.param("init_loc", torch.zeros(D))
        init_scale = pyro.param("init_scale", 1e-2 * torch.eye(D),
                                constraint=constraints.lower_cholesky)

        # Linear dynamics with Gaussian noise.
        trans_const = pyro.param("trans_const", torch.zeros(D))
        trans_coeff = pyro.param("trans_coeff", torch.eye(D))
        noise = pyro.param("noise", 1e-2 * torch.eye(D),
                           constraint=constraints.lower_cholesky)

        obs_plate = pyro.plate("channel", D, dim=-1)
        with pyro.plate("data", N, dim=-2):
            state = None
            for t in range(T):
                # Transition.
                if t == 0:
                    loc = init_loc
                    scale_tril = init_scale
                else:
                    loc = trans_const + funsor.torch.torch_tensordot(trans_coeff, state, 1)
                    scale_tril = noise
                state = pyro.sample("state_{}".format(t),
                                    dist.MultivariateNormal(loc, scale_tril),
                                    infer={"exact": exact})

                # Factorial probit likelihood model.
                with obs_plate:
                    pyro.sample("obs_{}".format(t),
                                dist.Bernoulli(logits=state["channel"]),
                                obs=data[t])
示例#12
0
def model_5(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length

    # Initialize a global module instance if needed.
    global tones_generator
    if tones_generator is None:
        tones_generator = TonesGenerator(args, data_dim)
    pyro.module("tones_generator", tones_generator)

    with handlers.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        y = torch.zeros(data_dim)
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                # Note that since each tone depends on all tones at a previous time step
                # the tones at different time steps now need to live in separate plates.
                with pyro.plate("tones_{}".format(t), data_dim, dim=-1):
                    y = pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(logits=tones_generator(x, y)),
                        obs=sequences[batch, t])
示例#13
0
def model_3(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with handlers.mask(mask=include_prior):
        probs_w = pyro.sample(
            "probs_w",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        w, x = 0, 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample("w_{}".format(t),
                                dist.Categorical(probs_w[w]),
                                infer={"enumerate": "parallel"})
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                with tones_plate as tones:
                    pyro.sample("y_{}".format(t),
                                dist.Bernoulli(probs_y[w, x, tones]),
                                obs=sequences[batch, t])
示例#14
0
def double_exp_model(data):
    k1 = pyro.param("k1", lambda: torch.tensor(0.01), constraint=constraints.positive)
    k2 = pyro.param("k2", lambda: torch.tensor(0.05), constraint=constraints.positive)
    A = pyro.param("A", lambda: torch.tensor(0.5), constraint=constraints.unit_interval)
    k = torch.stack([k1, k2])

    with pyro.plate("data", len(data)):
        m = pyro.sample("m", dist.Bernoulli(A), infer={"enumerate": "parallel"})
        pyro.sample("obs", dist.Exponential(k[m.long()]), obs=data)
示例#15
0
 def model():
     with pyro.plate("plate", 10, subsample_size=subsample_size, dim=None):
         p0 = torch.tensor(0.)
         p0 = pyro.subsample(p0, event_dim=0)
         assert p0.shape == ()
         p = 0.5 * torch.ones(10)
         p = pyro.subsample(p, event_dim=0)
         assert len(p) == (subsample_size if subsample_size else 10)
         pyro.sample("x", dist.Bernoulli(p))
示例#16
0
 def model():
     x = pyro.sample("x0", dist.Categorical(pyro.param("q0")))
     with pyro.plate("local", 3):
         for i in range(1, depth):
             x = pyro.sample(
                 "x{}".format(i),
                 dist.Categorical(pyro.param("q{}".format(i))[..., x, :]))
         with pyro.plate("data", 4):
             pyro.sample("y",
                         dist.Bernoulli(pyro.param("qy")[..., x]),
                         obs=data)
示例#17
0
def model_6(sequences, lengths, args, batch_size=None, include_prior=False):
    num_sequences, max_length, data_dim = sequences.shape
    assert lengths.shape == (num_sequences, )
    assert lengths.max() <= max_length
    hidden_dim = args.hidden_dim

    if not args.raftery_parameterization:
        # Explicitly parameterize the full tensor of transition probabilities, which
        # has hidden_dim cubed entries.
        probs_x = pyro.param("probs_x",
                             torch.rand(hidden_dim, hidden_dim, hidden_dim),
                             constraint=constraints.simplex)
    else:
        # Use the more parsimonious "Raftery" parameterization of
        # the tensor of transition probabilities. See reference:
        # Raftery, A. E. A model for high-order markov chains.
        # Journal of the Royal Statistical Society. 1985.
        probs_x1 = pyro.param("probs_x1",
                              torch.rand(hidden_dim, hidden_dim),
                              constraint=constraints.simplex)
        probs_x2 = pyro.param("probs_x2",
                              torch.rand(hidden_dim, hidden_dim),
                              constraint=constraints.simplex)
        mix_lambda = pyro.param("mix_lambda",
                                torch.tensor(0.5),
                                constraint=constraints.unit_interval)
        # we use broadcasting to combine two tensors of shape (hidden_dim, hidden_dim) and
        # (hidden_dim, 1, hidden_dim) to obtain a tensor of shape (hidden_dim, hidden_dim, hidden_dim)
        probs_x = mix_lambda * probs_x1 + (1.0 -
                                           mix_lambda) * probs_x2.unsqueeze(-2)

    probs_y = pyro.param("probs_y",
                         torch.rand(hidden_dim, data_dim),
                         constraint=constraints.unit_interval)
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x_curr, x_prev = torch.tensor(0), torch.tensor(0)
        # we need to pass the argument `history=2' to `pyro.markov()`
        # since our model is now 2-markov
        for t in pyro.markov(range(lengths.max()), history=2):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                probs_x_t = Vindex(probs_x)[x_prev, x_curr]
                x_prev, x_curr = x_curr, pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x_t),
                    infer={"enumerate": "parallel"})
                with tones_plate:
                    probs_y_t = probs_y[x_curr.squeeze(-1)]
                    pyro.sample("y_{}".format(t),
                                dist.Bernoulli(probs_y_t),
                                obs=sequences[batch, t])
示例#18
0
 def model(data, state=0, address=""):
     if isinstance(data, bool):
         p = pyro.param("p_leaf", torch.ones(10))
         pyro.sample("leaf_{}".format(address),
                     dist.Bernoulli(p[state]),
                     obs=torch.tensor(1. if data else 0.))
     else:
         assert isinstance(data, tuple)
         p = pyro.param("p_branch", torch.ones(10, 10))
         for branch, letter in zip(data, "abcdefg"):
             next_state = pyro.sample("branch_{}".format(address + letter),
                                      dist.Categorical(p[state]),
                                      infer={"enumerate": "parallel"})
             model(branch, next_state, address + letter)
示例#19
0
文件: hmm.py 项目: pyro-ppl/pyro
def model_4(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with handlers.mask(mask=include_prior):
        probs_w = pyro.sample(
            "probs_w", dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1)
        )
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1)
            .expand_by([hidden_dim])
            .to_event(2),
        )
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim, data_dim]).to_event(3),
        )
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        # Note the broadcasting tricks here: we declare a hidden torch.arange and
        # ensure that w and x are always tensors so we can unsqueeze them below,
        # thus ensuring that the x sample sites have correct distribution shape.
        w = x = torch.tensor(0, dtype=torch.long)
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample(
                    "w_{}".format(t),
                    dist.Categorical(probs_w[w]),
                    infer={"enumerate": "parallel"},
                )
                x = pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(Vindex(probs_x)[w, x]),
                    infer={"enumerate": "parallel"},
                )
                with tones_plate as tones:
                    pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(probs_y[w, x, tones]),
                        obs=sequences[batch, t],
                    )
示例#20
0
文件: hmm.py 项目: pyro-ppl/pyro
def model_1(sequences, lengths, args, batch_size=None, include_prior=True):
    # Sometimes it is safe to ignore jit warnings. Here we use the
    # pyro.util.ignore_jit_warnings context manager to silence warnings about
    # conversion to integer, since we know all three numbers will be the same
    # across all invocations to the model.
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    with handlers.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1),
        )
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),
        )
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    # We subsample batch_size items out of num_sequences items. Note that since
    # we're using dim=-1 for the notes plate, we need to batch over a different
    # dimension, here dim=-2.
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        # If we are not using the jit, then we can vary the program structure
        # each call by running for a dynamically determined number of time
        # steps, lengths.max(). However if we are using the jit, then we try to
        # keep a single program structure for all minibatches; the fixed
        # structure ends up being faster since each program structure would
        # need to trigger a new jit compile stage.
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x[x]),
                    infer={"enumerate": "parallel"},
                )
                with tones_plate:
                    pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(probs_y[x.squeeze(-1)]),
                        obs=sequences[batch, t],
                    )
示例#21
0
文件: hmm.py 项目: pyro-ppl/pyro
def model_7(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences,)
        assert lengths.max() <= max_length
    with handlers.mask(mask=include_prior):
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1),
        )
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2),
        )
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    # Note that since we're using dim=-2 for the time dimension, we need
    # to batch sequences over a different dimension, here dim=-3.
    with pyro.plate("sequences", num_sequences, batch_size, dim=-3) as batch:
        lengths = lengths[batch]
        batch = batch[:, None]
        x_prev = 0
        # To vectorize time dimension we use pyro.vectorized_markov(name=...).
        # With the help of Vindex and additional unsqueezes we can ensure that
        # dimensions line up properly.
        for t in pyro.vectorized_markov(
            name="time", size=int(max_length if args.jit else lengths.max()), dim=-2
        ):
            with handlers.mask(mask=(t < lengths.unsqueeze(-1)).unsqueeze(-1)):
                x_curr = pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x[x_prev]),
                    infer={"enumerate": "parallel"},
                )
                with tones_plate:
                    pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(probs_y[x_curr.squeeze(-1)]),
                        obs=Vindex(sequences)[batch, t],
                    )
示例#22
0
    def guide(self):
        r"""
        Variational Distribution
        """
        # global parameters
        pyro.sample(
            "gain",
            dist.Gamma(
                pyro.param("gain_loc") * pyro.param("gain_beta"),
                pyro.param("gain_beta"),
            ),
        )
        pyro.sample(
            "alpha",
            dist.Dirichlet(
                pyro.param("alpha_mean") *
                pyro.param("alpha_size")).to_event(1),
        )
        pyro.sample(
            "pi",
            dist.Dirichlet(pyro.param("pi_mean") *
                           pyro.param("pi_size")).to_event(1),
        )
        pyro.sample(
            "lamda",
            dist.Gamma(
                pyro.param("lamda_loc") * pyro.param("lamda_beta"),
                pyro.param("lamda_beta"),
            ).to_event(1),
        )
        pyro.sample(
            "proximity",
            AffineBeta(
                pyro.param("proximity_loc"),
                pyro.param("proximity_size"),
                0,
                (self.data.P + 1) / math.sqrt(12),
            ),
        )

        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-2,
        )
        # time frames
        frames = pyro.plate(
            "frames",
            self.data.F,
            subsample=self.f,
            subsample_size=self.fbatch_size,
            dim=-1,
        )

        with aois as ndx:
            ndx = ndx[:, None]
            mask = Vindex(self.data.mask)[ndx].to(self.device)
            with handlers.mask(mask=mask):
                pyro.sample(
                    "background_mean",
                    dist.Delta(
                        Vindex(
                            pyro.param("background_mean_loc"))[ndx,
                                                               0]).to_event(1),
                )
                pyro.sample(
                    "background_std",
                    dist.Delta(
                        Vindex(
                            pyro.param("background_std_loc"))[ndx,
                                                              0]).to_event(1),
                )
                with frames as fdx:
                    # sample background intensity
                    pyro.sample(
                        "background",
                        dist.Gamma(
                            Vindex(pyro.param("b_loc"))[ndx, fdx] *
                            Vindex(pyro.param("b_beta"))[ndx, fdx],
                            Vindex(pyro.param("b_beta"))[ndx, fdx],
                        ).to_event(1),
                    )

                    for qdx in range(self.Q):
                        for kdx in range(self.K):
                            # sample spot presence m
                            m = pyro.sample(
                                f"m_k{kdx}_q{qdx}",
                                dist.Bernoulli(
                                    Vindex(pyro.param("m_probs"))[kdx, ndx,
                                                                  fdx, qdx]),
                                infer={"enumerate": "parallel"},
                            )
                            with handlers.mask(mask=m > 0):
                                # sample spot variables
                                pyro.sample(
                                    f"height_k{kdx}_q{qdx}",
                                    dist.Gamma(
                                        Vindex(pyro.param("h_loc"))[kdx, ndx,
                                                                    fdx, qdx] *
                                        Vindex(pyro.param("h_beta"))[kdx, ndx,
                                                                     fdx, qdx],
                                        Vindex(pyro.param("h_beta"))[kdx, ndx,
                                                                     fdx, qdx],
                                    ),
                                )
                                pyro.sample(
                                    f"width_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        Vindex(pyro.param("w_mean"))[kdx, ndx,
                                                                     fdx, qdx],
                                        Vindex(pyro.param("w_size"))[kdx, ndx,
                                                                     fdx, qdx],
                                        self.priors["width_min"],
                                        self.priors["width_max"],
                                    ),
                                )
                                pyro.sample(
                                    f"x_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        Vindex(pyro.param("x_mean"))[kdx, ndx,
                                                                     fdx, qdx],
                                        Vindex(pyro.param("size"))[kdx, ndx,
                                                                   fdx, qdx],
                                        -(self.data.P + 1) / 2,
                                        (self.data.P + 1) / 2,
                                    ),
                                )
                                pyro.sample(
                                    f"y_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        Vindex(pyro.param("y_mean"))[kdx, ndx,
                                                                     fdx, qdx],
                                        Vindex(pyro.param("size"))[kdx, ndx,
                                                                   fdx, qdx],
                                        -(self.data.P + 1) / 2,
                                        (self.data.P + 1) / 2,
                                    ),
                                )
示例#23
0
def ttfb_model(data, control, Tmax):
    r"""
    Eq. 4 and Eq. 7 in::

      @article{friedman2015multi,
        title={Multi-wavelength single-molecule fluorescence analysis of transcription mechanisms},
        author={Friedman, Larry J and Gelles, Jeff},
        journal={Methods},
        volume={86},
        pages={27--36},
        year={2015},
        publisher={Elsevier}
      }

    :param data: time prior to the first binding at the target location
    :param control: time prior to the first binding at the control location
    :param Tmax: entire observation interval
    """
    ka = pyro.param(
        "ka",
        lambda: torch.full((data.shape[0], 1), 0.001),
        constraint=constraints.positive,
    )
    kns = pyro.param(
        "kns",
        lambda: torch.full((data.shape[0], 1), 0.001),
        constraint=constraints.positive,
    )
    Af = pyro.param(
        "Af",
        lambda: torch.full((data.shape[0], 1), 0.9),
        constraint=constraints.unit_interval,
    )
    k = torch.stack([kns, ka + kns])

    # on-target data
    mask = (data < Tmax) & (data > 0)
    tau = data.masked_fill(~mask, 1.0)
    with pyro.plate("bootstrap", data.shape[0], dim=-2) as bdx:
        with pyro.plate("N", data.shape[1], dim=-1):
            active = pyro.sample(
                "active", dist.Bernoulli(Af), infer={"enumerate": "parallel"}
            )
            with handlers.mask(mask=(data == Tmax)):
                pyro.factor("Tmax", -Vindex(k)[active.long().squeeze(-1), bdx] * Tmax)
                # pyro.factor("Tmax", -k * Tmax)
            with handlers.mask(mask=mask):
                pyro.sample(
                    "tau",
                    dist.Exponential(Vindex(k)[active.long().squeeze(-1), bdx]),
                    obs=tau,
                )
                # pyro.sample("tau", dist.Exponential(k), obs=tau)

    # negative control data
    if control is not None:
        mask = (control < Tmax) & (control > 0)
        tauc = control.masked_fill(~mask, 1.0)
        with pyro.plate("bootstrapc", control.shape[0], dim=-2):
            with pyro.plate("Nc", control.shape[1], dim=-1):
                with handlers.mask(mask=(control == Tmax)):
                    pyro.factor("Tmaxc", -kns * Tmax)
                with handlers.mask(mask=mask):
                    pyro.sample("tauc", dist.Exponential(kns), obs=tauc)
示例#24
0
    def model(self):
        r"""
        **Generative Model**

        Model parameters:

        +-----------------+-----------+-------------------------------------+
        | Parameter       | Shape     | Description                         |
        +=================+===========+=====================================+
        | |g| - :math:`g` | (1,)      | camera gain                         |
        +-----------------+-----------+-------------------------------------+
        | |sigma| - |prox|| (1,)      | proximity                           |
        +-----------------+-----------+-------------------------------------+
        | ``lamda`` - |ld|| (1,)      | average rate of target-nonspecific  |
        |                 |           | binding                             |
        +-----------------+-----------+-------------------------------------+
        | ``pi`` - |pi|   | (1,)      | average binding probability of      |
        |                 |           | target-specific binding             |
        +-----------------+-----------+-------------------------------------+
        | |bg| - |b|      | (N, F)    | background intensity                |
        +-----------------+-----------+-------------------------------------+
        | |z| - :math:`z` | (N, F)    | target-specific spot presence       |
        +-----------------+-----------+-------------------------------------+
        | |t| - |theta|   | (N, F)    | target-specific spot index          |
        +-----------------+-----------+-------------------------------------+
        | |m| - :math:`m` | (K, N, F) | spot presence indicator             |
        +-----------------+-----------+-------------------------------------+
        | |h| - :math:`h` | (K, N, F) | spot intensity                      |
        +-----------------+-----------+-------------------------------------+
        | |w| - :math:`w` | (K, N, F) | spot width                          |
        +-----------------+-----------+-------------------------------------+
        | |x| - :math:`x` | (K, N, F) | spot position on x-axis             |
        +-----------------+-----------+-------------------------------------+
        | |y| - :math:`y` | (K, N, F) | spot position on y-axis             |
        +-----------------+-----------+-------------------------------------+
        | |D| - :math:`D` | |shape|   | observed images                     |
        +-----------------+-----------+-------------------------------------+

        .. |ps| replace:: :math:`p(\mathsf{specific})`
        .. |theta| replace:: :math:`\theta`
        .. |prox| replace:: :math:`\sigma^{xy}`
        .. |ld| replace:: :math:`\lambda`
        .. |b| replace:: :math:`b`
        .. |shape| replace:: (N, F, P, P)
        .. |sigma| replace:: ``proximity``
        .. |bg| replace:: ``background``
        .. |h| replace:: ``height``
        .. |w| replace:: ``width``
        .. |D| replace:: ``data``
        .. |m| replace:: ``m``
        .. |z| replace:: ``z``
        .. |t| replace:: ``theta``
        .. |x| replace:: ``x``
        .. |y| replace:: ``y``
        .. |pi| replace:: :math:`\pi`
        .. |g| replace:: ``gain``

        Full joint distribution:

        .. math::

            \begin{aligned}
                p(D, \phi) =~&p(g) p(\sigma^{xy}) p(\pi) p(\lambda)
                \prod_{\mathsf{AOI}} \left[ p(\mu^b) p(\sigma^b) \prod_{\mathsf{frame}}
                \left[ \vphantom{\prod_{F}} p(b | \mu^b, \sigma^b) p(z | \pi) p(\theta | z)
                \vphantom{\prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}}} \cdot \right. \right. \\
                &\prod_{\mathsf{spot}} \left[ \vphantom{\prod_{F}} p(m | \theta, \lambda)
                p(h) p(w) p(x | \sigma^{xy}, \theta) p(y | \sigma^{xy}, \theta) \right] \left. \left.
                \prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}} \sum_{\delta} p(\delta)
                p(D | \mu^I, g, \delta) \right] \right]
            \end{aligned}

        :math:`z` and :math:`\theta` marginalized joint distribution:

        .. math::

            \begin{aligned}
                \sum_{z, \theta} p(D, \phi) =~&p(g) p(\sigma^{xy}) p(\pi) p(\lambda)
                \prod_{\mathsf{AOI}} \left[ p(\mu^b) p(\sigma^b) \prod_{\mathsf{frame}}
                \left[ \vphantom{\prod_{F}} p(b | \mu^b, \sigma^b) \sum_{z} p(z | \pi) \sum_{\theta} p(\theta | z)
                \vphantom{\prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}}} \cdot \right. \right. \\
                &\prod_{\mathsf{spot}} \left[ \vphantom{\prod_{F}} p(m | \theta, \lambda)
                p(h) p(w) p(x | \sigma^{xy}, \theta) p(y | \sigma^{xy}, \theta) \right] \left. \left.
                \prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}} \sum_{\delta} p(\delta)
                p(D | \mu^I, g, \delta) \right] \right]
            \end{aligned}
        """
        # global parameters
        gain = pyro.sample("gain", dist.HalfNormal(self.gain_std))
        pi = pyro.sample("pi",
                         dist.Dirichlet(torch.ones(self.S + 1) / (self.S + 1)))
        pi = expand_offtarget(pi)
        lamda = pyro.sample("lamda", dist.Exponential(self.lamda_rate))
        proximity = pyro.sample("proximity",
                                dist.Exponential(self.proximity_rate))
        size = torch.stack(
            (
                torch.full_like(proximity, 2.0),
                (((self.data.P + 1) / (2 * proximity))**2 - 1),
            ),
            dim=-1,
        )

        # spots
        spots = pyro.plate("spots", self.K)
        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-2,
        )
        # time frames
        frames = pyro.plate(
            "frames",
            self.data.F,
            subsample=self.f,
            subsample_size=self.fbatch_size,
            dim=-1,
        )

        with aois as ndx:
            ndx = ndx[:, None]
            # background mean and std
            background_mean = pyro.sample(
                "background_mean", dist.HalfNormal(self.background_mean_std))
            background_std = pyro.sample(
                "background_std", dist.HalfNormal(self.background_std_std))
            with frames as fdx:
                # fetch data
                obs, target_locs, is_ontarget = self.data.fetch(
                    ndx, fdx, self.cdx)
                # sample background intensity
                background = pyro.sample(
                    "background",
                    dist.Gamma(
                        (background_mean / background_std)**2,
                        background_mean / background_std**2,
                    ),
                )

                # sample hidden model state (1+S,)
                z = pyro.sample(
                    "z",
                    dist.Categorical(Vindex(pi)[..., :,
                                                is_ontarget.long()]),
                    infer={"enumerate": "parallel"},
                )
                theta = pyro.sample(
                    "theta",
                    dist.Categorical(
                        Vindex(probs_theta(self.K,
                                           self.device))[torch.clamp(z,
                                                                     min=0,
                                                                     max=1)]),
                    infer={"enumerate": "parallel"},
                )
                onehot_theta = one_hot(theta, num_classes=1 + self.K)

                ms, heights, widths, xs, ys = [], [], [], [], []
                for kdx in spots:
                    specific = onehot_theta[..., 1 + kdx]
                    # spot presence
                    m = pyro.sample(
                        f"m_{kdx}",
                        dist.Bernoulli(
                            Vindex(probs_m(lamda, self.K))[..., theta, kdx]),
                    )
                    with handlers.mask(mask=m > 0):
                        # sample spot variables
                        height = pyro.sample(
                            f"height_{kdx}",
                            dist.HalfNormal(self.height_std),
                        )
                        width = pyro.sample(
                            f"width_{kdx}",
                            AffineBeta(
                                1.5,
                                2,
                                self.width_min,
                                self.width_max,
                            ),
                        )
                        x = pyro.sample(
                            f"x_{kdx}",
                            AffineBeta(
                                0,
                                Vindex(size)[..., specific],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )
                        y = pyro.sample(
                            f"y_{kdx}",
                            AffineBeta(
                                0,
                                Vindex(size)[..., specific],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )

                    # append
                    ms.append(m)
                    heights.append(height)
                    widths.append(width)
                    xs.append(x)
                    ys.append(y)

                # observed data
                pyro.sample(
                    "data",
                    KSMOGN(
                        torch.stack(heights, -1),
                        torch.stack(widths, -1),
                        torch.stack(xs, -1),
                        torch.stack(ys, -1),
                        target_locs,
                        background,
                        gain,
                        self.data.offset.samples,
                        self.data.offset.logits.to(self.dtype),
                        self.data.P,
                        torch.stack(torch.broadcast_tensors(*ms), -1),
                        self.use_pykeops,
                    ),
                    obs=obs,
                )
示例#25
0
    def guide(self):
        r"""
        **Variational Distribution**

        .. math::
            \begin{aligned}
                q(\phi \setminus \{z, \theta\}) =~&q(g) q(\sigma^{xy}) q(\pi) q(\lambda) \cdot \\
                &\prod_{\mathsf{AOI}} \left[ q(\mu^b) q(\sigma^b) \prod_{\mathsf{frame}}
                \left[ \vphantom{\prod_{F}} q(b) \prod_{\mathsf{spot}}
                q(m) q(h | m) q(w | m) q(x | m) q(y | m) \right] \right]
            \end{aligned}
        """
        # global parameters
        pyro.sample(
            "gain",
            dist.Gamma(
                pyro.param("gain_loc") * pyro.param("gain_beta"),
                pyro.param("gain_beta"),
            ),
        )
        pyro.sample(
            "pi",
            dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size")))
        pyro.sample(
            "lamda",
            dist.Gamma(
                pyro.param("lamda_loc") * pyro.param("lamda_beta"),
                pyro.param("lamda_beta"),
            ),
        )
        pyro.sample(
            "proximity",
            AffineBeta(
                pyro.param("proximity_loc"),
                pyro.param("proximity_size"),
                0,
                (self.data.P + 1) / math.sqrt(12),
            ),
        )

        # spots
        spots = pyro.plate("spots", self.K)
        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-2,
        )
        # time frames
        frames = pyro.plate(
            "frames",
            self.data.F,
            subsample=self.f,
            subsample_size=self.fbatch_size,
            dim=-1,
        )

        with aois as ndx:
            ndx = ndx[:, None]
            pyro.sample(
                "background_mean",
                dist.Delta(Vindex(pyro.param("background_mean_loc"))[ndx, 0]),
            )
            pyro.sample(
                "background_std",
                dist.Delta(Vindex(pyro.param("background_std_loc"))[ndx, 0]),
            )
            with frames as fdx:
                # sample background intensity
                pyro.sample(
                    "background",
                    dist.Gamma(
                        Vindex(pyro.param("b_loc"))[ndx, fdx] *
                        Vindex(pyro.param("b_beta"))[ndx, fdx],
                        Vindex(pyro.param("b_beta"))[ndx, fdx],
                    ),
                )

                for kdx in spots:
                    # sample spot presence m
                    m = pyro.sample(
                        f"m_{kdx}",
                        dist.Bernoulli(
                            Vindex(pyro.param("m_probs"))[kdx, ndx, fdx]),
                        infer={"enumerate": "parallel"},
                    )
                    with handlers.mask(mask=m > 0):
                        # sample spot variables
                        pyro.sample(
                            f"height_{kdx}",
                            dist.Gamma(
                                Vindex(pyro.param("h_loc"))[kdx, ndx, fdx] *
                                Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                                Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                            ),
                        )
                        pyro.sample(
                            f"width_{kdx}",
                            AffineBeta(
                                Vindex(pyro.param("w_mean"))[kdx, ndx, fdx],
                                Vindex(pyro.param("w_size"))[kdx, ndx, fdx],
                                0.75,
                                2.25,
                            ),
                        )
                        pyro.sample(
                            f"x_{kdx}",
                            AffineBeta(
                                Vindex(pyro.param("x_mean"))[kdx, ndx, fdx],
                                Vindex(pyro.param("size"))[kdx, ndx, fdx],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )
                        pyro.sample(
                            f"y_{kdx}",
                            AffineBeta(
                                Vindex(pyro.param("y_mean"))[kdx, ndx, fdx],
                                Vindex(pyro.param("size"))[kdx, ndx, fdx],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )
示例#26
0
 def model(data):
     p = pyro.param("p", torch.tensor(0.5))
     pyro.sample("x", dist.Bernoulli(p), obs=data)
示例#27
0
    def model(self):
        r"""
        Generative Model
        """
        # global parameters
        gain = pyro.sample("gain", dist.HalfNormal(self.priors["gain_std"]))
        alpha = pyro.sample(
            "alpha",
            dist.Dirichlet(
                torch.ones((self.Q, self.data.C)) +
                torch.eye(self.Q) * 9).to_event(1),
        )
        pi = pyro.sample(
            "pi",
            dist.Dirichlet(torch.ones(
                (self.Q, self.S + 1)) / (self.S + 1)).to_event(1),
        )
        pi = expand_offtarget(pi)
        lamda = pyro.sample(
            "lamda",
            dist.Exponential(torch.full(
                (self.Q, ), self.priors["lamda_rate"])).to_event(1),
        )
        proximity = pyro.sample(
            "proximity", dist.Exponential(self.priors["proximity_rate"]))
        size = torch.stack(
            (
                torch.full_like(proximity, 2.0),
                (((self.data.P + 1) / (2 * proximity))**2 - 1),
            ),
            dim=-1,
        )

        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-2,
        )
        # time frames
        frames = pyro.plate(
            "frames",
            self.data.F,
            subsample=self.f,
            subsample_size=self.fbatch_size,
            dim=-1,
        )

        with aois as ndx:
            ndx = ndx[:, None]
            mask = Vindex(self.data.mask)[ndx].to(self.device)
            with handlers.mask(mask=mask):
                # background mean and std
                background_mean = pyro.sample(
                    "background_mean",
                    dist.HalfNormal(self.priors["background_mean_std"]).expand(
                        (self.data.C, )).to_event(1),
                )
                background_std = pyro.sample(
                    "background_std",
                    dist.HalfNormal(self.priors["background_std_std"]).expand(
                        (self.data.C, )).to_event(1),
                )
                with frames as fdx:
                    # fetch data
                    obs, target_locs, is_ontarget = self.data.fetch(
                        ndx.unsqueeze(-1), fdx.unsqueeze(-1),
                        torch.arange(self.data.C))
                    # sample background intensity
                    background = pyro.sample(
                        "background",
                        dist.Gamma(
                            (background_mean / background_std)**2,
                            background_mean / background_std**2,
                        ).to_event(1),
                    )

                    ms, heights, widths, xs, ys = [], [], [], [], []
                    is_ontarget = is_ontarget.squeeze(-1)
                    for qdx in range(self.Q):
                        # sample hidden model state (1+S,)
                        z_probs = Vindex(pi)[..., qdx, :, is_ontarget.long()]
                        z = pyro.sample(
                            f"z_q{qdx}",
                            dist.Categorical(z_probs),
                            infer={"enumerate": "parallel"},
                        )
                        theta = pyro.sample(
                            f"theta_q{qdx}",
                            dist.Categorical(
                                Vindex(probs_theta(
                                    self.K, self.device))[torch.clamp(z,
                                                                      min=0,
                                                                      max=1)]),
                            infer={"enumerate": "parallel"},
                        )
                        onehot_theta = one_hot(theta, num_classes=1 + self.K)

                        for kdx in range(self.K):
                            specific = onehot_theta[..., 1 + kdx]
                            # spot presence
                            m = pyro.sample(
                                f"m_k{kdx}_q{qdx}",
                                dist.Bernoulli(
                                    Vindex(probs_m(lamda,
                                                   self.K))[..., qdx, theta,
                                                            kdx]),
                            )
                            with handlers.mask(mask=m > 0):
                                # sample spot variables
                                height = pyro.sample(
                                    f"height_k{kdx}_q{qdx}",
                                    dist.HalfNormal(self.priors["height_std"]),
                                )
                                width = pyro.sample(
                                    f"width_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        1.5,
                                        2,
                                        self.priors["width_min"],
                                        self.priors["width_max"],
                                    ),
                                )
                                x = pyro.sample(
                                    f"x_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        0,
                                        Vindex(size)[..., specific],
                                        -(self.data.P + 1) / 2,
                                        (self.data.P + 1) / 2,
                                    ),
                                )
                                y = pyro.sample(
                                    f"y_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        0,
                                        Vindex(size)[..., specific],
                                        -(self.data.P + 1) / 2,
                                        (self.data.P + 1) / 2,
                                    ),
                                )

                            # append
                            ms.append(m)
                            heights.append(height)
                            widths.append(width)
                            xs.append(x)
                            ys.append(y)

                    heights = torch.stack(
                        [
                            torch.stack(heights[q * self.K:(1 + q) * self.K],
                                        -1) for q in range(self.Q)
                        ],
                        -2,
                    )
                    widths = torch.stack(
                        [
                            torch.stack(widths[q * self.K:(1 + q) * self.K],
                                        -1) for q in range(self.Q)
                        ],
                        -2,
                    )
                    xs = torch.stack(
                        [
                            torch.stack(xs[q * self.K:(1 + q) * self.K], -1)
                            for q in range(self.Q)
                        ],
                        -2,
                    )
                    ys = torch.stack(
                        [
                            torch.stack(ys[q * self.Q:(1 + q) * self.K], -1)
                            for q in range(self.Q)
                        ],
                        -2,
                    )
                    ms = torch.broadcast_tensors(*ms)
                    ms = torch.stack(
                        [
                            torch.stack(ms[q * self.Q:(1 + q) * self.K], -1)
                            for q in range(self.Q)
                        ],
                        -2,
                    )
                    # observed data
                    pyro.sample(
                        "data",
                        KSMOGN(
                            heights,
                            widths,
                            xs,
                            ys,
                            target_locs,
                            background,
                            gain,
                            self.data.offset.samples,
                            self.data.offset.logits.to(self.dtype),
                            self.data.P,
                            ms,
                            alpha,
                            use_pykeops=self.use_pykeops,
                        ),
                        obs=obs,
                    )