예제 #1
0
파일: hmm.py 프로젝트: zeta1999/pyro
def model_4(sequences, lengths, args, batch_size=None, include_prior=True):
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length
    hidden_dim = int(args.hidden_dim**0.5)  # split between w and x
    with poutine.mask(mask=include_prior):
        probs_w = pyro.sample(
            "probs_w",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).to_event(1))
        probs_x = pyro.sample(
            "probs_x",
            dist.Dirichlet(0.9 * torch.eye(hidden_dim) + 0.1).expand_by(
                [hidden_dim]).to_event(2))
        probs_y = pyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, hidden_dim,
                                        data_dim]).to_event(3))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        # Note the broadcasting tricks here: we declare a hidden torch.arange and
        # ensure that w and x are always tensors so we can unsqueeze them below,
        # thus ensuring that the x sample sites have correct distribution shape.
        w = x = torch.tensor(0, dtype=torch.long)
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                w = pyro.sample("w_{}".format(t),
                                dist.Categorical(probs_w[w]),
                                infer={"enumerate": "parallel"})
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(Vindex(probs_x)[w, x]),
                                infer={"enumerate": "parallel"})
                with tones_plate as tones:
                    pyro.sample("y_{}".format(t),
                                dist.Bernoulli(probs_y[w, x, tones]),
                                obs=sequences[batch, t])
예제 #2
0
def model_1(data, history, vectorized):
    x_dim = 3
    init = pyro.param("init",
                      lambda: torch.rand(x_dim),
                      constraint=constraints.simplex)
    trans = pyro.param("trans",
                       lambda: torch.rand((x_dim, x_dim)),
                       constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(x_dim))

    x_prev = None
    markov_loop = \
        pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \
        else pyro.markov(range(len(data)), history=history)
    for i in markov_loop:
        x_curr = pyro.sample(
            "x_{}".format(i),
            dist.Categorical(
                init if isinstance(i, int) and i < 1 else trans[x_prev]))
        with pyro.plate("tones", data.shape[-1], dim=-1):
            pyro.sample("y_{}".format(i),
                        dist.Normal(Vindex(locs)[..., x_curr], 1.),
                        obs=data[i])
        x_prev = x_curr
예제 #3
0
def model_7(data, history, vectorized):
    w_dim, x_dim, y_dim = 2, 3, 2
    w_init = pyro.param("w_init",
                        lambda: torch.rand(w_dim),
                        constraint=constraints.simplex)
    w_trans = pyro.param("w_trans",
                         lambda: torch.rand((x_dim, w_dim)),
                         constraint=constraints.simplex)
    x_init = pyro.param("x_init",
                        lambda: torch.rand(x_dim),
                        constraint=constraints.simplex)
    x_trans = pyro.param("x_trans",
                         lambda: torch.rand((w_dim, x_dim)),
                         constraint=constraints.simplex)
    y_probs = pyro.param("y_probs",
                         lambda: torch.rand(w_dim, x_dim, y_dim),
                         constraint=constraints.simplex)

    w_prev = x_prev = None
    markov_loop = \
        pyro.vectorized_markov(name="time", size=len(data), dim=-2, history=history) if vectorized \
        else pyro.markov(range(len(data)), history=history)
    for i in markov_loop:
        w_curr = pyro.sample(
            "w_{}".format(i),
            dist.Categorical(
                w_init if isinstance(i, int) and i < 1 else w_trans[x_prev]))
        x_curr = pyro.sample(
            "x_{}".format(i),
            dist.Categorical(
                x_init if isinstance(i, int) and i < 1 else x_trans[w_prev]))
        with pyro.plate("tones", data.shape[-1], dim=-1):
            pyro.sample("y_{}".format(i),
                        dist.Categorical(Vindex(y_probs)[w_curr, x_curr]),
                        obs=data[i])
        x_prev, w_prev = x_curr, w_curr
예제 #4
0
def torus_dbn(phis=None,
              psis=None,
              lengths=None,
              num_sequences=None,
              num_states=55,
              prior_conc=0.1,
              prior_loc=0.0,
              prior_length_shape=100.,
              prior_length_rate=100.,
              prior_kappa_min=10.,
              prior_kappa_max=1000.):
    # From https://pyro.ai/examples/hmm.html
    with ignore_jit_warnings():
        if lengths is not None:
            assert num_sequences is None
            num_sequences = int(lengths.shape[0])
        else:
            assert num_sequences is not None
    transition_probs = pyro.sample(
        'transition_probs',
        dist.Dirichlet(
            torch.ones(num_states, num_states, dtype=torch.float) *
            num_states).to_event(1))
    length_shape = pyro.sample('length_shape',
                               dist.HalfCauchy(prior_length_shape))
    length_rate = pyro.sample('length_rate',
                              dist.HalfCauchy(prior_length_rate))
    phi_locs = pyro.sample(
        'phi_locs',
        dist.VonMises(
            torch.ones(num_states, dtype=torch.float) * prior_loc,
            torch.ones(num_states, dtype=torch.float) *
            prior_conc).to_event(1))
    phi_kappas = pyro.sample(
        'phi_kappas',
        dist.Uniform(
            torch.ones(num_states, dtype=torch.float) * prior_kappa_min,
            torch.ones(num_states, dtype=torch.float) *
            prior_kappa_max).to_event(1))
    psi_locs = pyro.sample(
        'psi_locs',
        dist.VonMises(
            torch.ones(num_states, dtype=torch.float) * prior_loc,
            torch.ones(num_states, dtype=torch.float) *
            prior_conc).to_event(1))
    psi_kappas = pyro.sample(
        'psi_kappas',
        dist.Uniform(
            torch.ones(num_states, dtype=torch.float) * prior_kappa_min,
            torch.ones(num_states, dtype=torch.float) *
            prior_kappa_max).to_event(1))
    element_plate = pyro.plate('elements', 1, dim=-1)
    with pyro.plate('sequences', num_sequences, dim=-2) as batch:
        if lengths is not None:
            lengths = lengths[batch]
            obs_length = lengths.float().unsqueeze(-1)
        else:
            obs_length = None
        state = 0
        sam_lengths = pyro.sample('length',
                                  dist.TransformedDistribution(
                                      dist.GammaPoisson(
                                          length_shape, length_rate),
                                      AffineTransform(0., 1.)),
                                  obs=obs_length)
        if lengths is None:
            lengths = sam_lengths.squeeze(-1).long()
        for t in pyro.markov(range(lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                state = pyro.sample(f'state_{t}',
                                    dist.Categorical(transition_probs[state]),
                                    infer={'enumerate': 'parallel'})
                if phis is not None:
                    obs_phi = Vindex(phis)[batch, t].unsqueeze(-1)
                else:
                    obs_phi = None
                if psis is not None:
                    obs_psi = Vindex(psis)[batch, t].unsqueeze(-1)
                else:
                    obs_psi = None
                with element_plate:
                    pyro.sample(f'phi_{t}',
                                dist.VonMises(phi_locs[state],
                                              phi_kappas[state]),
                                obs=obs_phi)
                    pyro.sample(f'psi_{t}',
                                dist.VonMises(psi_locs[state],
                                              psi_kappas[state]),
                                obs=obs_psi)
예제 #5
0
파일: model.py 프로젝트: youngshingjun/pyro
def model_generic(config):
    """Hierarchical mixed-effects hidden markov model"""

    MISSING = config["MISSING"]
    N_v = config["sizes"]["random"]
    N_state = config["sizes"]["state"]

    # initialize group-level random effect parameterss
    if config["group"]["random"] == "discrete":
        probs_e_g = pyro.param("probs_e_group",
                               lambda: torch.randn((N_v, )).abs(),
                               constraint=constraints.simplex)
        theta_g = pyro.param("theta_group", lambda: torch.randn(
            (N_v, N_state**2)))
    elif config["group"]["random"] == "continuous":
        loc_g = torch.zeros((N_state**2, ))
        scale_g = torch.ones((N_state**2, ))

    # initialize individual-level random effect parameters
    N_c = config["sizes"]["group"]
    if config["individual"]["random"] == "discrete":
        probs_e_i = pyro.param("probs_e_individual",
                               lambda: torch.randn((
                                   N_c,
                                   N_v,
                               )).abs(),
                               constraint=constraints.simplex)
        theta_i = pyro.param("theta_individual", lambda: torch.randn(
            (N_c, N_v, N_state**2)))
    elif config["individual"]["random"] == "continuous":
        loc_i = torch.zeros((
            N_c,
            N_state**2,
        ))
        scale_i = torch.ones((
            N_c,
            N_state**2,
        ))

    # initialize likelihood parameters
    # observation 1: step size (step ~ Gamma)
    step_zi_param = pyro.param("step_zi_param", lambda: torch.ones(
        (N_state, 2)))
    step_concentration = pyro.param("step_param_concentration",
                                    lambda: torch.randn((N_state, )).abs(),
                                    constraint=constraints.positive)
    step_rate = pyro.param("step_param_rate",
                           lambda: torch.randn((N_state, )).abs(),
                           constraint=constraints.positive)

    # observation 2: step angle (angle ~ VonMises)
    angle_concentration = pyro.param("angle_param_concentration",
                                     lambda: torch.randn((N_state, )).abs(),
                                     constraint=constraints.positive)
    angle_loc = pyro.param("angle_param_loc", lambda: torch.randn(
        (N_state, )).abs())

    # observation 3: dive activity (omega ~ Beta)
    omega_zi_param = pyro.param("omega_zi_param", lambda: torch.ones(
        (N_state, 2)))
    omega_concentration0 = pyro.param("omega_param_concentration0",
                                      lambda: torch.randn((N_state, )).abs(),
                                      constraint=constraints.positive)
    omega_concentration1 = pyro.param("omega_param_concentration1",
                                      lambda: torch.randn((N_state, )).abs(),
                                      constraint=constraints.positive)

    # initialize gamma to uniform
    gamma = torch.zeros((N_state**2, ))

    N_c = config["sizes"]["group"]
    with pyro.plate("group", N_c, dim=-1):

        # group-level random effects
        if config["group"]["random"] == "discrete":
            # group-level discrete effect
            e_g = pyro.sample("e_g", dist.Categorical(probs_e_g))
            eps_g = Vindex(theta_g)[..., e_g, :]
        elif config["group"]["random"] == "continuous":
            eps_g = pyro.sample(
                "eps_g",
                dist.Normal(loc_g, scale_g).to_event(1),
            )  # infer={"num_samples": 10})
        else:
            eps_g = 0.

        # add group-level random effect to gamma
        gamma = gamma + eps_g

        N_s = config["sizes"]["individual"]
        with pyro.plate(
                "individual", N_s,
                dim=-2), poutine.mask(mask=config["individual"]["mask"]):

            # individual-level random effects
            if config["individual"]["random"] == "discrete":
                # individual-level discrete effect
                e_i = pyro.sample("e_i", dist.Categorical(probs_e_i))
                eps_i = Vindex(theta_i)[..., e_i, :]
                # assert eps_i.shape[-3:] == (1, N_c, N_state ** 2) and eps_i.shape[0] == N_v
            elif config["individual"]["random"] == "continuous":
                eps_i = pyro.sample(
                    "eps_i",
                    dist.Normal(loc_i, scale_i).to_event(1),
                )  # infer={"num_samples": 10})
            else:
                eps_i = 0.

            # add individual-level random effect to gamma
            gamma = gamma + eps_i

            y = torch.tensor(0).long()

            N_t = config["sizes"]["timesteps"]
            for t in pyro.markov(range(N_t)):
                with poutine.mask(mask=config["timestep"]["mask"][..., t]):
                    gamma_t = gamma  # per-timestep variable

                    # finally, reshape gamma as batch of transition matrices
                    gamma_t = gamma_t.reshape(
                        tuple(gamma_t.shape[:-1]) + (N_state, N_state))

                    # we've accounted for all effects, now actually compute gamma_y
                    gamma_y = Vindex(gamma_t)[..., y, :]
                    y = pyro.sample("y_{}".format(t),
                                    dist.Categorical(logits=gamma_y))

                    # observation 1: step size
                    step_dist = dist.Gamma(
                        concentration=Vindex(step_concentration)[..., y],
                        rate=Vindex(step_rate)[..., y])

                    # zero-inflation with MaskedMixture
                    step_zi = Vindex(step_zi_param)[..., y, :]
                    step_zi_mask = pyro.sample(
                        "step_zi_{}".format(t),
                        dist.Categorical(logits=step_zi),
                        obs=(config["observations"]["step"][...,
                                                            t] == MISSING))
                    step_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING))
                    step_zi_dist = dist.MaskedMixture(step_zi_mask, step_dist,
                                                      step_zi_zero_dist)

                    pyro.sample("step_{}".format(t),
                                step_zi_dist,
                                obs=config["observations"]["step"][..., t])

                    # observation 2: step angle
                    angle_dist = dist.VonMises(
                        concentration=Vindex(angle_concentration)[..., y],
                        loc=Vindex(angle_loc)[..., y])
                    pyro.sample("angle_{}".format(t),
                                angle_dist,
                                obs=config["observations"]["angle"][..., t])

                    # observation 3: dive activity
                    omega_dist = dist.Beta(
                        concentration0=Vindex(omega_concentration0)[..., y],
                        concentration1=Vindex(omega_concentration1)[..., y])

                    # zero-inflation with MaskedMixture
                    omega_zi = Vindex(omega_zi_param)[..., y, :]
                    omega_zi_mask = pyro.sample(
                        "omega_zi_{}".format(t),
                        dist.Categorical(logits=omega_zi),
                        obs=(config["observations"]["omega"][...,
                                                             t] == MISSING))

                    omega_zi_zero_dist = dist.Delta(v=torch.tensor(MISSING))
                    omega_zi_dist = dist.MaskedMixture(omega_zi_mask,
                                                       omega_dist,
                                                       omega_zi_zero_dist)

                    pyro.sample("omega_{}".format(t),
                                omega_zi_dist,
                                obs=config["observations"]["omega"][..., t])
예제 #6
0
    def model(self):
        r"""
        Generative Model
        """
        # global parameters
        gain = pyro.sample("gain", dist.HalfNormal(self.priors["gain_std"]))
        alpha = pyro.sample(
            "alpha",
            dist.Dirichlet(
                torch.ones((self.Q, self.data.C)) +
                torch.eye(self.Q) * 9).to_event(1),
        )
        pi = pyro.sample(
            "pi",
            dist.Dirichlet(torch.ones(
                (self.Q, self.S + 1)) / (self.S + 1)).to_event(1),
        )
        pi = expand_offtarget(pi)
        lamda = pyro.sample(
            "lamda",
            dist.Exponential(torch.full(
                (self.Q, ), self.priors["lamda_rate"])).to_event(1),
        )
        proximity = pyro.sample(
            "proximity", dist.Exponential(self.priors["proximity_rate"]))
        size = torch.stack(
            (
                torch.full_like(proximity, 2.0),
                (((self.data.P + 1) / (2 * proximity))**2 - 1),
            ),
            dim=-1,
        )

        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-2,
        )
        # time frames
        frames = pyro.plate(
            "frames",
            self.data.F,
            subsample=self.f,
            subsample_size=self.fbatch_size,
            dim=-1,
        )

        with aois as ndx:
            ndx = ndx[:, None]
            mask = Vindex(self.data.mask)[ndx].to(self.device)
            with handlers.mask(mask=mask):
                # background mean and std
                background_mean = pyro.sample(
                    "background_mean",
                    dist.HalfNormal(self.priors["background_mean_std"]).expand(
                        (self.data.C, )).to_event(1),
                )
                background_std = pyro.sample(
                    "background_std",
                    dist.HalfNormal(self.priors["background_std_std"]).expand(
                        (self.data.C, )).to_event(1),
                )
                with frames as fdx:
                    # fetch data
                    obs, target_locs, is_ontarget = self.data.fetch(
                        ndx.unsqueeze(-1), fdx.unsqueeze(-1),
                        torch.arange(self.data.C))
                    # sample background intensity
                    background = pyro.sample(
                        "background",
                        dist.Gamma(
                            (background_mean / background_std)**2,
                            background_mean / background_std**2,
                        ).to_event(1),
                    )

                    ms, heights, widths, xs, ys = [], [], [], [], []
                    is_ontarget = is_ontarget.squeeze(-1)
                    for qdx in range(self.Q):
                        # sample hidden model state (1+S,)
                        z_probs = Vindex(pi)[..., qdx, :, is_ontarget.long()]
                        z = pyro.sample(
                            f"z_q{qdx}",
                            dist.Categorical(z_probs),
                            infer={"enumerate": "parallel"},
                        )
                        theta = pyro.sample(
                            f"theta_q{qdx}",
                            dist.Categorical(
                                Vindex(probs_theta(
                                    self.K, self.device))[torch.clamp(z,
                                                                      min=0,
                                                                      max=1)]),
                            infer={"enumerate": "parallel"},
                        )
                        onehot_theta = one_hot(theta, num_classes=1 + self.K)

                        for kdx in range(self.K):
                            specific = onehot_theta[..., 1 + kdx]
                            # spot presence
                            m = pyro.sample(
                                f"m_k{kdx}_q{qdx}",
                                dist.Bernoulli(
                                    Vindex(probs_m(lamda,
                                                   self.K))[..., qdx, theta,
                                                            kdx]),
                            )
                            with handlers.mask(mask=m > 0):
                                # sample spot variables
                                height = pyro.sample(
                                    f"height_k{kdx}_q{qdx}",
                                    dist.HalfNormal(self.priors["height_std"]),
                                )
                                width = pyro.sample(
                                    f"width_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        1.5,
                                        2,
                                        self.priors["width_min"],
                                        self.priors["width_max"],
                                    ),
                                )
                                x = pyro.sample(
                                    f"x_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        0,
                                        Vindex(size)[..., specific],
                                        -(self.data.P + 1) / 2,
                                        (self.data.P + 1) / 2,
                                    ),
                                )
                                y = pyro.sample(
                                    f"y_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        0,
                                        Vindex(size)[..., specific],
                                        -(self.data.P + 1) / 2,
                                        (self.data.P + 1) / 2,
                                    ),
                                )

                            # append
                            ms.append(m)
                            heights.append(height)
                            widths.append(width)
                            xs.append(x)
                            ys.append(y)

                    heights = torch.stack(
                        [
                            torch.stack(heights[q * self.K:(1 + q) * self.K],
                                        -1) for q in range(self.Q)
                        ],
                        -2,
                    )
                    widths = torch.stack(
                        [
                            torch.stack(widths[q * self.K:(1 + q) * self.K],
                                        -1) for q in range(self.Q)
                        ],
                        -2,
                    )
                    xs = torch.stack(
                        [
                            torch.stack(xs[q * self.K:(1 + q) * self.K], -1)
                            for q in range(self.Q)
                        ],
                        -2,
                    )
                    ys = torch.stack(
                        [
                            torch.stack(ys[q * self.Q:(1 + q) * self.K], -1)
                            for q in range(self.Q)
                        ],
                        -2,
                    )
                    ms = torch.broadcast_tensors(*ms)
                    ms = torch.stack(
                        [
                            torch.stack(ms[q * self.Q:(1 + q) * self.K], -1)
                            for q in range(self.Q)
                        ],
                        -2,
                    )
                    # observed data
                    pyro.sample(
                        "data",
                        KSMOGN(
                            heights,
                            widths,
                            xs,
                            ys,
                            target_locs,
                            background,
                            gain,
                            self.data.offset.samples,
                            self.data.offset.logits.to(self.dtype),
                            self.data.P,
                            ms,
                            alpha,
                            use_pykeops=self.use_pykeops,
                        ),
                        obs=obs,
                    )
예제 #7
0
    def guide(self):
        r"""
        Variational Distribution
        """
        # global parameters
        pyro.sample(
            "gain",
            dist.Gamma(
                pyro.param("gain_loc") * pyro.param("gain_beta"),
                pyro.param("gain_beta"),
            ),
        )
        pyro.sample(
            "alpha",
            dist.Dirichlet(
                pyro.param("alpha_mean") *
                pyro.param("alpha_size")).to_event(1),
        )
        pyro.sample(
            "pi",
            dist.Dirichlet(pyro.param("pi_mean") *
                           pyro.param("pi_size")).to_event(1),
        )
        pyro.sample(
            "lamda",
            dist.Gamma(
                pyro.param("lamda_loc") * pyro.param("lamda_beta"),
                pyro.param("lamda_beta"),
            ).to_event(1),
        )
        pyro.sample(
            "proximity",
            AffineBeta(
                pyro.param("proximity_loc"),
                pyro.param("proximity_size"),
                0,
                (self.data.P + 1) / math.sqrt(12),
            ),
        )

        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-2,
        )
        # time frames
        frames = pyro.plate(
            "frames",
            self.data.F,
            subsample=self.f,
            subsample_size=self.fbatch_size,
            dim=-1,
        )

        with aois as ndx:
            ndx = ndx[:, None]
            mask = Vindex(self.data.mask)[ndx].to(self.device)
            with handlers.mask(mask=mask):
                pyro.sample(
                    "background_mean",
                    dist.Delta(
                        Vindex(
                            pyro.param("background_mean_loc"))[ndx,
                                                               0]).to_event(1),
                )
                pyro.sample(
                    "background_std",
                    dist.Delta(
                        Vindex(
                            pyro.param("background_std_loc"))[ndx,
                                                              0]).to_event(1),
                )
                with frames as fdx:
                    # sample background intensity
                    pyro.sample(
                        "background",
                        dist.Gamma(
                            Vindex(pyro.param("b_loc"))[ndx, fdx] *
                            Vindex(pyro.param("b_beta"))[ndx, fdx],
                            Vindex(pyro.param("b_beta"))[ndx, fdx],
                        ).to_event(1),
                    )

                    for qdx in range(self.Q):
                        for kdx in range(self.K):
                            # sample spot presence m
                            m = pyro.sample(
                                f"m_k{kdx}_q{qdx}",
                                dist.Bernoulli(
                                    Vindex(pyro.param("m_probs"))[kdx, ndx,
                                                                  fdx, qdx]),
                                infer={"enumerate": "parallel"},
                            )
                            with handlers.mask(mask=m > 0):
                                # sample spot variables
                                pyro.sample(
                                    f"height_k{kdx}_q{qdx}",
                                    dist.Gamma(
                                        Vindex(pyro.param("h_loc"))[kdx, ndx,
                                                                    fdx, qdx] *
                                        Vindex(pyro.param("h_beta"))[kdx, ndx,
                                                                     fdx, qdx],
                                        Vindex(pyro.param("h_beta"))[kdx, ndx,
                                                                     fdx, qdx],
                                    ),
                                )
                                pyro.sample(
                                    f"width_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        Vindex(pyro.param("w_mean"))[kdx, ndx,
                                                                     fdx, qdx],
                                        Vindex(pyro.param("w_size"))[kdx, ndx,
                                                                     fdx, qdx],
                                        self.priors["width_min"],
                                        self.priors["width_max"],
                                    ),
                                )
                                pyro.sample(
                                    f"x_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        Vindex(pyro.param("x_mean"))[kdx, ndx,
                                                                     fdx, qdx],
                                        Vindex(pyro.param("size"))[kdx, ndx,
                                                                   fdx, qdx],
                                        -(self.data.P + 1) / 2,
                                        (self.data.P + 1) / 2,
                                    ),
                                )
                                pyro.sample(
                                    f"y_k{kdx}_q{qdx}",
                                    AffineBeta(
                                        Vindex(pyro.param("y_mean"))[kdx, ndx,
                                                                     fdx, qdx],
                                        Vindex(pyro.param("size"))[kdx, ndx,
                                                                   fdx, qdx],
                                        -(self.data.P + 1) / 2,
                                        (self.data.P + 1) / 2,
                                    ),
                                )
예제 #8
0
    def guide(self):
        """
        **Variational Distribution**
        """
        # global parameters
        pyro.sample(
            "gain",
            dist.Gamma(
                pyro.param("gain_loc") * pyro.param("gain_beta"),
                pyro.param("gain_beta"),
            ),
        )
        pyro.sample(
            "init",
            dist.Dirichlet(pyro.param("init_mean") * pyro.param("init_size")))
        pyro.sample(
            "trans",
            dist.Dirichlet(
                pyro.param("trans_mean") *
                pyro.param("trans_size")).to_event(1),
        )
        pyro.sample(
            "lamda",
            dist.Gamma(
                pyro.param("lamda_loc") * pyro.param("lamda_beta"),
                pyro.param("lamda_beta"),
            ),
        )
        pyro.sample(
            "proximity",
            AffineBeta(
                pyro.param("proximity_loc"),
                pyro.param("proximity_size"),
                0,
                (self.data.P + 1) / math.sqrt(12),
            ),
        )

        # spots
        spots = pyro.plate("spots", self.K)
        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-2,
        )
        # time frames
        frames = (pyro.vectorized_markov(
            name="frames", size=self.data.F, dim=-1)
                  if self.vectorized else pyro.markov(range(self.data.F)))

        with aois as ndx:
            ndx = ndx[:, None]
            pyro.sample(
                "background_mean",
                dist.Delta(Vindex(pyro.param("background_mean_loc"))[ndx, 0]),
            )
            pyro.sample(
                "background_std",
                dist.Delta(Vindex(pyro.param("background_std_loc"))[ndx, 0]),
            )
            z_prev = None
            for fdx in frames:
                if self.vectorized:
                    fsx, fdx = fdx
                else:
                    fsx = fdx
                # sample background intensity
                pyro.sample(
                    f"background_{fsx}",
                    dist.Gamma(
                        Vindex(pyro.param("b_loc"))[ndx, fdx] *
                        Vindex(pyro.param("b_beta"))[ndx, fdx],
                        Vindex(pyro.param("b_beta"))[ndx, fdx],
                    ),
                )

                # sample hidden model state
                z_probs = (Vindex(pyro.param("z_trans"))[ndx, fdx, 0]
                           if isinstance(fdx, int) and fdx < 1 else Vindex(
                               pyro.param("z_trans"))[ndx, fdx, z_prev])
                z_curr = pyro.sample(
                    f"z_{fsx}",
                    dist.Categorical(z_probs),
                    infer={"enumerate": "parallel"},
                )

                for kdx in spots:
                    # spot presence
                    m_probs = Vindex(pyro.param("m_probs"))[z_curr, kdx, ndx,
                                                            fdx]
                    m = pyro.sample(
                        f"m_{kdx}_{fsx}",
                        dist.Categorical(
                            torch.stack((1 - m_probs, m_probs), -1)),
                        infer={"enumerate": "parallel"},
                    )
                    with handlers.mask(mask=m > 0):
                        # sample spot variables
                        pyro.sample(
                            f"height_{kdx}_{fsx}",
                            dist.Gamma(
                                Vindex(pyro.param("h_loc"))[kdx, ndx, fdx] *
                                Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                                Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                            ),
                        )
                        pyro.sample(
                            f"width_{kdx}_{fsx}",
                            AffineBeta(
                                Vindex(pyro.param("w_mean"))[kdx, ndx, fdx],
                                Vindex(pyro.param("w_size"))[kdx, ndx, fdx],
                                0.75,
                                2.25,
                            ),
                        )
                        pyro.sample(
                            f"x_{kdx}_{fsx}",
                            AffineBeta(
                                Vindex(pyro.param("x_mean"))[kdx, ndx, fdx],
                                Vindex(pyro.param("size"))[kdx, ndx, fdx],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )
                        pyro.sample(
                            f"y_{kdx}_{fsx}",
                            AffineBeta(
                                Vindex(pyro.param("y_mean"))[kdx, ndx, fdx],
                                Vindex(pyro.param("size"))[kdx, ndx, fdx],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )

                z_prev = z_curr
예제 #9
0
    def model(self):
        """
        **Generative Model**
        """
        # global parameters
        gain = pyro.sample("gain", dist.HalfNormal(self.priors["gain_std"]))
        init = pyro.sample(
            "init",
            dist.Dirichlet(torch.ones(self.Q, self.S + 1) /
                           (self.S + 1)).to_event(1),
        )
        init = expand_offtarget(init)
        trans = pyro.sample(
            "trans",
            dist.Dirichlet(
                torch.ones(self.Q, self.S + 1, self.S + 1) /
                (self.S + 1)).to_event(2),
        )
        trans = expand_offtarget(trans)
        lamda = pyro.sample(
            "lamda",
            dist.Exponential(torch.full(
                (self.Q, ), self.priors["lamda_rate"])).to_event(1),
        )
        proximity = pyro.sample(
            "proximity", dist.Exponential(self.priors["proximity_rate"]))
        size = torch.stack(
            (
                torch.full_like(proximity, 2.0),
                (((self.data.P + 1) / (2 * proximity))**2 - 1),
            ),
            dim=-1,
        )

        # spots
        spots = pyro.plate("spots", self.K)
        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-3,
        )
        # time frames
        frames = (pyro.vectorized_markov(
            name="frames", size=self.data.F, dim=-2)
                  if self.vectorized else pyro.markov(range(self.data.F)))
        # color channels
        channels = pyro.plate(
            "channels",
            self.data.C,
            dim=-1,
        )

        with channels as cdx, aois as ndx:
            ndx = ndx[:, None, None]
            mask = Vindex(self.data.mask)[ndx].to(self.device)
            with handlers.mask(mask=mask):
                # background mean and std
                background_mean = pyro.sample(
                    "background_mean",
                    dist.HalfNormal(self.priors["background_mean_std"]),
                )
                background_std = pyro.sample(
                    "background_std",
                    dist.HalfNormal(self.priors["background_std_std"]))
                z_prev = None
                for fdx in frames:
                    if self.vectorized:
                        fsx, fdx = fdx
                        fdx = torch.as_tensor(fdx)
                        fdx = fdx.unsqueeze(-1)
                    else:
                        fsx = fdx
                    # fetch data
                    obs, target_locs, is_ontarget = self.data.fetch(
                        ndx, fdx, cdx)
                    # sample background intensity
                    background = pyro.sample(
                        f"background_f{fsx}",
                        dist.Gamma(
                            (background_mean / background_std)**2,
                            background_mean / background_std**2,
                        ),
                    )

                    # sample hidden model state (1+S,)
                    z_probs = (Vindex(init)[..., cdx, :,
                                            is_ontarget.long()]
                               if z_prev is None else
                               Vindex(trans)[..., cdx, z_prev, :,
                                             is_ontarget.long()])
                    z_curr = pyro.sample(f"z_f{fsx}",
                                         dist.Categorical(z_probs))

                    theta = pyro.sample(
                        f"theta_f{fsx}",
                        dist.Categorical(
                            Vindex(probs_theta(
                                self.K, self.device))[torch.clamp(z_curr,
                                                                  min=0,
                                                                  max=1)]),
                        infer={"enumerate": "parallel"},
                    )
                    onehot_theta = one_hot(theta, num_classes=1 + self.K)

                    ms, heights, widths, xs, ys = [], [], [], [], []
                    for kdx in spots:
                        specific = onehot_theta[..., 1 + kdx]
                        # spot presence
                        m_probs = Vindex(probs_m(lamda, self.K))[..., cdx,
                                                                 theta, kdx]
                        m = pyro.sample(
                            f"m_k{kdx}_f{fsx}",
                            dist.Categorical(
                                torch.stack((1 - m_probs, m_probs), -1)),
                        )
                        with handlers.mask(mask=m > 0):
                            # sample spot variables
                            height = pyro.sample(
                                f"height_k{kdx}_f{fsx}",
                                dist.HalfNormal(self.priors["height_std"]),
                            )
                            width = pyro.sample(
                                f"width_k{kdx}_f{fsx}",
                                AffineBeta(
                                    1.5,
                                    2,
                                    self.priors["width_min"],
                                    self.priors["width_max"],
                                ),
                            )
                            x = pyro.sample(
                                f"x_k{kdx}_f{fsx}",
                                AffineBeta(
                                    0,
                                    Vindex(size)[..., specific],
                                    -(self.data.P + 1) / 2,
                                    (self.data.P + 1) / 2,
                                ),
                            )
                            y = pyro.sample(
                                f"y_k{kdx}_f{fsx}",
                                AffineBeta(
                                    0,
                                    Vindex(size)[..., specific],
                                    -(self.data.P + 1) / 2,
                                    (self.data.P + 1) / 2,
                                ),
                            )

                        # append
                        ms.append(m)
                        heights.append(height)
                        widths.append(width)
                        xs.append(x)
                        ys.append(y)

                    # observed data
                    pyro.sample(
                        f"data_f{fsx}",
                        KSMOGN(
                            torch.stack(heights, -1),
                            torch.stack(widths, -1),
                            torch.stack(xs, -1),
                            torch.stack(ys, -1),
                            target_locs,
                            background,
                            gain,
                            self.data.offset.samples,
                            self.data.offset.logits.to(self.dtype),
                            self.data.P,
                            torch.stack(torch.broadcast_tensors(*ms), -1),
                            use_pykeops=self.use_pykeops,
                        ),
                        obs=obs,
                    )
                    z_prev = z_curr
예제 #10
0
    def guide(self):
        """
        **Variational Distribution**
        """
        # global parameters
        pyro.sample(
            "gain",
            dist.Gamma(
                pyro.param("gain_loc") * pyro.param("gain_beta"),
                pyro.param("gain_beta"),
            ),
        )
        pyro.sample(
            "init",
            dist.Dirichlet(pyro.param("init_mean") *
                           pyro.param("init_size")).to_event(1),
        )
        pyro.sample(
            "trans",
            dist.Dirichlet(
                pyro.param("trans_mean") *
                pyro.param("trans_size")).to_event(2),
        )
        pyro.sample(
            "lamda",
            dist.Gamma(
                pyro.param("lamda_loc") * pyro.param("lamda_beta"),
                pyro.param("lamda_beta"),
            ).to_event(1),
        )
        pyro.sample(
            "proximity",
            AffineBeta(
                pyro.param("proximity_loc"),
                pyro.param("proximity_size"),
                0,
                (self.data.P + 1) / math.sqrt(12),
            ),
        )

        # spots
        spots = pyro.plate("spots", self.K)
        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-3,
        )
        # time frames
        frames = (pyro.vectorized_markov(
            name="frames", size=self.data.F, dim=-2)
                  if self.vectorized else pyro.markov(range(self.data.F)))
        # color channels
        channels = pyro.plate(
            "channels",
            self.data.C,
            dim=-1,
        )

        with channels as cdx, aois as ndx:
            ndx = ndx[:, None, None]
            mask = Vindex(self.data.mask)[ndx].to(self.device)
            with handlers.mask(mask=mask):
                pyro.sample(
                    "background_mean",
                    dist.Delta(
                        Vindex(pyro.param("background_mean_loc"))[ndx, 0,
                                                                  cdx]),
                )
                pyro.sample(
                    "background_std",
                    dist.Delta(
                        Vindex(pyro.param("background_std_loc"))[ndx, 0, cdx]),
                )
                z_prev = None
                for fdx in frames:
                    if self.vectorized:
                        fsx, fdx = fdx
                        fdx = torch.as_tensor(fdx)
                        fdx = fdx.unsqueeze(-1)
                    else:
                        fsx = fdx
                    # sample background intensity
                    pyro.sample(
                        f"background_f{fsx}",
                        dist.Gamma(
                            Vindex(pyro.param("b_loc"))[ndx, fdx, cdx] *
                            Vindex(pyro.param("b_beta"))[ndx, fdx, cdx],
                            Vindex(pyro.param("b_beta"))[ndx, fdx, cdx],
                        ),
                    )

                    # sample hidden model state
                    z_probs = (Vindex(pyro.param("z_trans"))[ndx, fdx, cdx, 0]
                               if z_prev is None else Vindex(
                                   pyro.param("z_trans"))[ndx, fdx, cdx,
                                                          z_prev])
                    z_curr = pyro.sample(
                        f"z_f{fsx}",
                        dist.Categorical(z_probs),
                        infer={"enumerate": "parallel"},
                    )

                    for kdx in spots:
                        # spot presence
                        m_probs = Vindex(pyro.param("m_probs"))[z_curr, kdx,
                                                                ndx, fdx, cdx]
                        m = pyro.sample(
                            f"m_k{kdx}_f{fsx}",
                            dist.Categorical(
                                torch.stack((1 - m_probs, m_probs), -1)),
                            infer={"enumerate": "parallel"},
                        )
                        with handlers.mask(mask=m > 0):
                            # sample spot variables
                            pyro.sample(
                                f"height_k{kdx}_f{fsx}",
                                dist.Gamma(
                                    Vindex(pyro.param("h_loc"))[kdx, ndx, fdx,
                                                                cdx] *
                                    Vindex(pyro.param("h_beta"))[kdx, ndx, fdx,
                                                                 cdx],
                                    Vindex(pyro.param("h_beta"))[kdx, ndx, fdx,
                                                                 cdx],
                                ),
                            )
                            pyro.sample(
                                f"width_k{kdx}_f{fsx}",
                                AffineBeta(
                                    Vindex(pyro.param("w_mean"))[kdx, ndx, fdx,
                                                                 cdx],
                                    Vindex(pyro.param("w_size"))[kdx, ndx, fdx,
                                                                 cdx],
                                    self.priors["width_min"],
                                    self.priors["width_max"],
                                ),
                            )
                            pyro.sample(
                                f"x_k{kdx}_f{fsx}",
                                AffineBeta(
                                    Vindex(pyro.param("x_mean"))[kdx, ndx, fdx,
                                                                 cdx],
                                    Vindex(pyro.param("size"))[kdx, ndx, fdx,
                                                               cdx],
                                    -(self.data.P + 1) / 2,
                                    (self.data.P + 1) / 2,
                                ),
                            )
                            pyro.sample(
                                f"y_k{kdx}_f{fsx}",
                                AffineBeta(
                                    Vindex(pyro.param("y_mean"))[kdx, ndx, fdx,
                                                                 cdx],
                                    Vindex(pyro.param("size"))[kdx, ndx, fdx,
                                                               cdx],
                                    -(self.data.P + 1) / 2,
                                    (self.data.P + 1) / 2,
                                ),
                            )

                    z_prev = z_curr
예제 #11
0
    def model(self):
        r"""
        **Generative Model**

        Model parameters:

        +-----------------+-----------+-------------------------------------+
        | Parameter       | Shape     | Description                         |
        +=================+===========+=====================================+
        | |g| - :math:`g` | (1,)      | camera gain                         |
        +-----------------+-----------+-------------------------------------+
        | |sigma| - |prox|| (1,)      | proximity                           |
        +-----------------+-----------+-------------------------------------+
        | ``lamda`` - |ld|| (1,)      | average rate of target-nonspecific  |
        |                 |           | binding                             |
        +-----------------+-----------+-------------------------------------+
        | ``pi`` - |pi|   | (1,)      | average binding probability of      |
        |                 |           | target-specific binding             |
        +-----------------+-----------+-------------------------------------+
        | |bg| - |b|      | (N, F)    | background intensity                |
        +-----------------+-----------+-------------------------------------+
        | |z| - :math:`z` | (N, F)    | target-specific spot presence       |
        +-----------------+-----------+-------------------------------------+
        | |t| - |theta|   | (N, F)    | target-specific spot index          |
        +-----------------+-----------+-------------------------------------+
        | |m| - :math:`m` | (K, N, F) | spot presence indicator             |
        +-----------------+-----------+-------------------------------------+
        | |h| - :math:`h` | (K, N, F) | spot intensity                      |
        +-----------------+-----------+-------------------------------------+
        | |w| - :math:`w` | (K, N, F) | spot width                          |
        +-----------------+-----------+-------------------------------------+
        | |x| - :math:`x` | (K, N, F) | spot position on x-axis             |
        +-----------------+-----------+-------------------------------------+
        | |y| - :math:`y` | (K, N, F) | spot position on y-axis             |
        +-----------------+-----------+-------------------------------------+
        | |D| - :math:`D` | |shape|   | observed images                     |
        +-----------------+-----------+-------------------------------------+

        .. |ps| replace:: :math:`p(\mathsf{specific})`
        .. |theta| replace:: :math:`\theta`
        .. |prox| replace:: :math:`\sigma^{xy}`
        .. |ld| replace:: :math:`\lambda`
        .. |b| replace:: :math:`b`
        .. |shape| replace:: (N, F, P, P)
        .. |sigma| replace:: ``proximity``
        .. |bg| replace:: ``background``
        .. |h| replace:: ``height``
        .. |w| replace:: ``width``
        .. |D| replace:: ``data``
        .. |m| replace:: ``m``
        .. |z| replace:: ``z``
        .. |t| replace:: ``theta``
        .. |x| replace:: ``x``
        .. |y| replace:: ``y``
        .. |pi| replace:: :math:`\pi`
        .. |g| replace:: ``gain``

        Full joint distribution:

        .. math::

            \begin{aligned}
                p(D, \phi) =~&p(g) p(\sigma^{xy}) p(\pi) p(\lambda)
                \prod_{\mathsf{AOI}} \left[ p(\mu^b) p(\sigma^b) \prod_{\mathsf{frame}}
                \left[ \vphantom{\prod_{F}} p(b | \mu^b, \sigma^b) p(z | \pi) p(\theta | z)
                \vphantom{\prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}}} \cdot \right. \right. \\
                &\prod_{\mathsf{spot}} \left[ \vphantom{\prod_{F}} p(m | \theta, \lambda)
                p(h) p(w) p(x | \sigma^{xy}, \theta) p(y | \sigma^{xy}, \theta) \right] \left. \left.
                \prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}} \sum_{\delta} p(\delta)
                p(D | \mu^I, g, \delta) \right] \right]
            \end{aligned}

        :math:`z` and :math:`\theta` marginalized joint distribution:

        .. math::

            \begin{aligned}
                \sum_{z, \theta} p(D, \phi) =~&p(g) p(\sigma^{xy}) p(\pi) p(\lambda)
                \prod_{\mathsf{AOI}} \left[ p(\mu^b) p(\sigma^b) \prod_{\mathsf{frame}}
                \left[ \vphantom{\prod_{F}} p(b | \mu^b, \sigma^b) \sum_{z} p(z | \pi) \sum_{\theta} p(\theta | z)
                \vphantom{\prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}}} \cdot \right. \right. \\
                &\prod_{\mathsf{spot}} \left[ \vphantom{\prod_{F}} p(m | \theta, \lambda)
                p(h) p(w) p(x | \sigma^{xy}, \theta) p(y | \sigma^{xy}, \theta) \right] \left. \left.
                \prod_{\substack{\mathsf{pixelX} \\ \mathsf{pixelY}}} \sum_{\delta} p(\delta)
                p(D | \mu^I, g, \delta) \right] \right]
            \end{aligned}
        """
        # global parameters
        gain = pyro.sample("gain", dist.HalfNormal(self.gain_std))
        pi = pyro.sample("pi",
                         dist.Dirichlet(torch.ones(self.S + 1) / (self.S + 1)))
        pi = expand_offtarget(pi)
        lamda = pyro.sample("lamda", dist.Exponential(self.lamda_rate))
        proximity = pyro.sample("proximity",
                                dist.Exponential(self.proximity_rate))
        size = torch.stack(
            (
                torch.full_like(proximity, 2.0),
                (((self.data.P + 1) / (2 * proximity))**2 - 1),
            ),
            dim=-1,
        )

        # spots
        spots = pyro.plate("spots", self.K)
        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-2,
        )
        # time frames
        frames = pyro.plate(
            "frames",
            self.data.F,
            subsample=self.f,
            subsample_size=self.fbatch_size,
            dim=-1,
        )

        with aois as ndx:
            ndx = ndx[:, None]
            # background mean and std
            background_mean = pyro.sample(
                "background_mean", dist.HalfNormal(self.background_mean_std))
            background_std = pyro.sample(
                "background_std", dist.HalfNormal(self.background_std_std))
            with frames as fdx:
                # fetch data
                obs, target_locs, is_ontarget = self.data.fetch(
                    ndx, fdx, self.cdx)
                # sample background intensity
                background = pyro.sample(
                    "background",
                    dist.Gamma(
                        (background_mean / background_std)**2,
                        background_mean / background_std**2,
                    ),
                )

                # sample hidden model state (1+S,)
                z = pyro.sample(
                    "z",
                    dist.Categorical(Vindex(pi)[..., :,
                                                is_ontarget.long()]),
                    infer={"enumerate": "parallel"},
                )
                theta = pyro.sample(
                    "theta",
                    dist.Categorical(
                        Vindex(probs_theta(self.K,
                                           self.device))[torch.clamp(z,
                                                                     min=0,
                                                                     max=1)]),
                    infer={"enumerate": "parallel"},
                )
                onehot_theta = one_hot(theta, num_classes=1 + self.K)

                ms, heights, widths, xs, ys = [], [], [], [], []
                for kdx in spots:
                    specific = onehot_theta[..., 1 + kdx]
                    # spot presence
                    m = pyro.sample(
                        f"m_{kdx}",
                        dist.Bernoulli(
                            Vindex(probs_m(lamda, self.K))[..., theta, kdx]),
                    )
                    with handlers.mask(mask=m > 0):
                        # sample spot variables
                        height = pyro.sample(
                            f"height_{kdx}",
                            dist.HalfNormal(self.height_std),
                        )
                        width = pyro.sample(
                            f"width_{kdx}",
                            AffineBeta(
                                1.5,
                                2,
                                self.width_min,
                                self.width_max,
                            ),
                        )
                        x = pyro.sample(
                            f"x_{kdx}",
                            AffineBeta(
                                0,
                                Vindex(size)[..., specific],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )
                        y = pyro.sample(
                            f"y_{kdx}",
                            AffineBeta(
                                0,
                                Vindex(size)[..., specific],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )

                    # append
                    ms.append(m)
                    heights.append(height)
                    widths.append(width)
                    xs.append(x)
                    ys.append(y)

                # observed data
                pyro.sample(
                    "data",
                    KSMOGN(
                        torch.stack(heights, -1),
                        torch.stack(widths, -1),
                        torch.stack(xs, -1),
                        torch.stack(ys, -1),
                        target_locs,
                        background,
                        gain,
                        self.data.offset.samples,
                        self.data.offset.logits.to(self.dtype),
                        self.data.P,
                        torch.stack(torch.broadcast_tensors(*ms), -1),
                        self.use_pykeops,
                    ),
                    obs=obs,
                )
예제 #12
0
    def guide(self):
        r"""
        **Variational Distribution**

        .. math::
            \begin{aligned}
                q(\phi \setminus \{z, \theta\}) =~&q(g) q(\sigma^{xy}) q(\pi) q(\lambda) \cdot \\
                &\prod_{\mathsf{AOI}} \left[ q(\mu^b) q(\sigma^b) \prod_{\mathsf{frame}}
                \left[ \vphantom{\prod_{F}} q(b) \prod_{\mathsf{spot}}
                q(m) q(h | m) q(w | m) q(x | m) q(y | m) \right] \right]
            \end{aligned}
        """
        # global parameters
        pyro.sample(
            "gain",
            dist.Gamma(
                pyro.param("gain_loc") * pyro.param("gain_beta"),
                pyro.param("gain_beta"),
            ),
        )
        pyro.sample(
            "pi",
            dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size")))
        pyro.sample(
            "lamda",
            dist.Gamma(
                pyro.param("lamda_loc") * pyro.param("lamda_beta"),
                pyro.param("lamda_beta"),
            ),
        )
        pyro.sample(
            "proximity",
            AffineBeta(
                pyro.param("proximity_loc"),
                pyro.param("proximity_size"),
                0,
                (self.data.P + 1) / math.sqrt(12),
            ),
        )

        # spots
        spots = pyro.plate("spots", self.K)
        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-2,
        )
        # time frames
        frames = pyro.plate(
            "frames",
            self.data.F,
            subsample=self.f,
            subsample_size=self.fbatch_size,
            dim=-1,
        )

        with aois as ndx:
            ndx = ndx[:, None]
            pyro.sample(
                "background_mean",
                dist.Delta(Vindex(pyro.param("background_mean_loc"))[ndx, 0]),
            )
            pyro.sample(
                "background_std",
                dist.Delta(Vindex(pyro.param("background_std_loc"))[ndx, 0]),
            )
            with frames as fdx:
                # sample background intensity
                pyro.sample(
                    "background",
                    dist.Gamma(
                        Vindex(pyro.param("b_loc"))[ndx, fdx] *
                        Vindex(pyro.param("b_beta"))[ndx, fdx],
                        Vindex(pyro.param("b_beta"))[ndx, fdx],
                    ),
                )

                for kdx in spots:
                    # sample spot presence m
                    m = pyro.sample(
                        f"m_{kdx}",
                        dist.Bernoulli(
                            Vindex(pyro.param("m_probs"))[kdx, ndx, fdx]),
                        infer={"enumerate": "parallel"},
                    )
                    with handlers.mask(mask=m > 0):
                        # sample spot variables
                        pyro.sample(
                            f"height_{kdx}",
                            dist.Gamma(
                                Vindex(pyro.param("h_loc"))[kdx, ndx, fdx] *
                                Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                                Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                            ),
                        )
                        pyro.sample(
                            f"width_{kdx}",
                            AffineBeta(
                                Vindex(pyro.param("w_mean"))[kdx, ndx, fdx],
                                Vindex(pyro.param("w_size"))[kdx, ndx, fdx],
                                0.75,
                                2.25,
                            ),
                        )
                        pyro.sample(
                            f"x_{kdx}",
                            AffineBeta(
                                Vindex(pyro.param("x_mean"))[kdx, ndx, fdx],
                                Vindex(pyro.param("size"))[kdx, ndx, fdx],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )
                        pyro.sample(
                            f"y_{kdx}",
                            AffineBeta(
                                Vindex(pyro.param("y_mean"))[kdx, ndx, fdx],
                                Vindex(pyro.param("size"))[kdx, ndx, fdx],
                                -(self.data.P + 1) / 2,
                                (self.data.P + 1) / 2,
                            ),
                        )
예제 #13
0
 def m_probs(self) -> torch.Tensor:
     r"""
     Posterior spot presence probability :math:`q(m=1, z=z_\mathsf{MAP})`.
     """
     return Vindex(torch.permute(pyro.param("m_probs").data,
                                 (1, 2, 3, 0)))[..., self.z_map.long()]
예제 #14
0
 def fetch(self, ndx, fdx, cdx):
     return (
         Vindex(self.images)[ndx, fdx, cdx].to(self.device),
         Vindex(self.xy)[ndx, fdx, cdx].to(self.device),
         Vindex(self.is_ontarget)[ndx].to(self.device),
     )
예제 #15
0
def save_stats(model, path, CI=0.95, save_matlab=False):
    # global parameters
    global_params = model._global_params
    summary = pd.DataFrame(
        index=global_params,
        columns=["Mean", f"{int(100*CI)}% LL", f"{int(100*CI)}% UL"],
    )
    # local parameters
    local_params = [
        "height",
        "width",
        "x",
        "y",
        "background",
    ]

    ci_stats = defaultdict(partial(defaultdict, list))
    num_samples = 10000
    for param in global_params:
        if param == "gain":
            fn = dist.Gamma(
                pyro.param("gain_loc") * pyro.param("gain_beta"),
                pyro.param("gain_beta"),
            )
        elif param == "pi":
            fn = dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size"))
        elif param == "lamda":
            fn = dist.Gamma(
                pyro.param("lamda_loc") * pyro.param("lamda_beta"),
                pyro.param("lamda_beta"),
            )
        elif param == "proximity":
            fn = AffineBeta(
                pyro.param("proximity_loc"),
                pyro.param("proximity_size"),
                0,
                (model.data.P + 1) / math.sqrt(12),
            )
        elif param == "trans":
            fn = dist.Dirichlet(
                pyro.param("trans_mean") * pyro.param("trans_size")
            ).to_event(1)
        else:
            raise NotImplementedError
        samples = fn.sample((num_samples,)).data.squeeze()
        ci_stats[param] = {}
        LL, UL = hpdi(
            samples,
            CI,
            dim=0,
        )
        ci_stats[param]["LL"] = LL.cpu()
        ci_stats[param]["UL"] = UL.cpu()
        ci_stats[param]["Mean"] = fn.mean.data.squeeze().cpu()

        # calculate Keq
        if param == "pi":
            ci_stats["Keq"] = {}
            LL, UL = hpdi(samples[:, 1] / (1 - samples[:, 1]), CI, dim=0)
            ci_stats["Keq"]["LL"] = LL.cpu()
            ci_stats["Keq"]["UL"] = UL.cpu()
            ci_stats["Keq"]["Mean"] = (samples[:, 1] / (1 - samples[:, 1])).mean().cpu()

    # this does not need to be very accurate
    num_samples = 1000
    for param in local_params:
        LL, UL, Mean = [], [], []
        for ndx in torch.split(torch.arange(model.data.Nt), model.nbatch_size):
            ndx = ndx[:, None]
            kdx = torch.arange(model.K)[:, None, None]
            ll, ul, mean = [], [], []
            for fdx in torch.split(torch.arange(model.data.F), model.fbatch_size):
                if param == "background":
                    fn = dist.Gamma(
                        Vindex(pyro.param("b_loc"))[ndx, fdx]
                        * Vindex(pyro.param("b_beta"))[ndx, fdx],
                        Vindex(pyro.param("b_beta"))[ndx, fdx],
                    )
                elif param == "height":
                    fn = dist.Gamma(
                        Vindex(pyro.param("h_loc"))[kdx, ndx, fdx]
                        * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                        Vindex(pyro.param("h_beta"))[kdx, ndx, fdx],
                    )
                elif param == "width":
                    fn = AffineBeta(
                        Vindex(pyro.param("w_mean"))[kdx, ndx, fdx],
                        Vindex(pyro.param("w_size"))[kdx, ndx, fdx],
                        0.75,
                        2.25,
                    )
                elif param == "x":
                    fn = AffineBeta(
                        Vindex(pyro.param("x_mean"))[kdx, ndx, fdx],
                        Vindex(pyro.param("size"))[kdx, ndx, fdx],
                        -(model.data.P + 1) / 2,
                        (model.data.P + 1) / 2,
                    )
                elif param == "y":
                    fn = AffineBeta(
                        Vindex(pyro.param("y_mean"))[kdx, ndx, fdx],
                        Vindex(pyro.param("size"))[kdx, ndx, fdx],
                        -(model.data.P + 1) / 2,
                        (model.data.P + 1) / 2,
                    )
                else:
                    raise NotImplementedError
                samples = fn.sample((num_samples,)).data
                l, u = hpdi(
                    samples,
                    CI,
                    dim=0,
                )
                m = fn.mean.data
                ll.append(l)
                ul.append(u)
                mean.append(m)
            else:
                LL.append(torch.cat(ll, -1))
                UL.append(torch.cat(ul, -1))
                Mean.append(torch.cat(mean, -1))
        else:
            ci_stats[param]["LL"] = torch.cat(LL, -2).cpu()
            ci_stats[param]["UL"] = torch.cat(UL, -2).cpu()
            ci_stats[param]["Mean"] = torch.cat(Mean, -2).cpu()

    for param in global_params:
        if param == "pi":
            summary.loc[param, "Mean"] = ci_stats[param]["Mean"][1].item()
            summary.loc[param, "95% LL"] = ci_stats[param]["LL"][1].item()
            summary.loc[param, "95% UL"] = ci_stats[param]["UL"][1].item()
            # Keq
            summary.loc["Keq", "Mean"] = ci_stats["Keq"]["Mean"].item()
            summary.loc["Keq", "95% LL"] = ci_stats["Keq"]["LL"].item()
            summary.loc["Keq", "95% UL"] = ci_stats["Keq"]["UL"].item()
        elif param == "trans":
            summary.loc["kon", "Mean"] = ci_stats[param]["Mean"][0, 1].item()
            summary.loc["kon", "95% LL"] = ci_stats[param]["LL"][0, 1].item()
            summary.loc["kon", "95% UL"] = ci_stats[param]["UL"][0, 1].item()
            summary.loc["koff", "Mean"] = ci_stats[param]["Mean"][1, 0].item()
            summary.loc["koff", "95% LL"] = ci_stats[param]["LL"][1, 0].item()
            summary.loc["koff", "95% UL"] = ci_stats[param]["UL"][1, 0].item()
        else:
            summary.loc[param, "Mean"] = ci_stats[param]["Mean"].item()
            summary.loc[param, "95% LL"] = ci_stats[param]["LL"].item()
            summary.loc[param, "95% UL"] = ci_stats[param]["UL"].item()
    ci_stats["m_probs"] = model.m_probs.data.cpu()
    ci_stats["theta_probs"] = model.theta_probs.data.cpu()
    ci_stats["z_probs"] = model.z_probs.data.cpu()
    ci_stats["z_map"] = model.z_map.data.cpu()

    # timestamps
    if model.data.time1 is not None:
        ci_stats["time1"] = model.data.time1
    if model.data.ttb is not None:
        ci_stats["ttb"] = model.data.ttb

    model.params = ci_stats

    # snr
    summary.loc["SNR", "Mean"] = (
        snr(
            model.data.images[:, :, model.cdx],
            ci_stats["width"]["Mean"],
            ci_stats["x"]["Mean"],
            ci_stats["y"]["Mean"],
            model.data.xy[:, :, model.cdx],
            ci_stats["background"]["Mean"],
            ci_stats["gain"]["Mean"],
            model.data.offset.mean,
            model.data.offset.var,
            model.data.P,
            model.theta_probs,
        )
        .mean()
        .item()
    )

    # classification statistics
    if model.data.labels is not None:
        pred_labels = model.z_map[model.data.is_ontarget].cpu().numpy().ravel()
        true_labels = model.data.labels["z"][: model.data.N, :, model.cdx].ravel()

        with np.errstate(divide="ignore", invalid="ignore"):
            summary.loc["MCC", "Mean"] = matthews_corrcoef(true_labels, pred_labels)
        summary.loc["Recall", "Mean"] = recall_score(
            true_labels, pred_labels, zero_division=0
        )
        summary.loc["Precision", "Mean"] = precision_score(
            true_labels, pred_labels, zero_division=0
        )

        (
            summary.loc["TN", "Mean"],
            summary.loc["FP", "Mean"],
            summary.loc["FN", "Mean"],
            summary.loc["TP", "Mean"],
        ) = confusion_matrix(true_labels, pred_labels, labels=(0, 1)).ravel()

        mask = torch.from_numpy(model.data.labels["z"][: model.data.N, :, model.cdx])
        samples = torch.masked_select(model.z_probs[model.data.is_ontarget].cpu(), mask)
        if len(samples):
            z_ll, z_ul = hpdi(samples, CI)
            summary.loc["p(specific)", "Mean"] = quantile(samples, 0.5).item()
            summary.loc["p(specific)", "95% LL"] = z_ll.item()
            summary.loc["p(specific)", "95% UL"] = z_ul.item()
        else:
            summary.loc["p(specific)", "Mean"] = 0.0
            summary.loc["p(specific)", "95% LL"] = 0.0
            summary.loc["p(specific)", "95% UL"] = 0.0

    model.summary = summary

    if path is not None:
        path = Path(path)
        torch.save(ci_stats, path / f"{model.full_name}-params.tpqr")
        if save_matlab:
            from scipy.io import savemat

            for param, field in ci_stats.items():
                if param in (
                    "m_probs",
                    "theta_probs",
                    "z_probs",
                    "z_map",
                    "time1",
                    "ttb",
                ):
                    ci_stats[param] = field.numpy()
                    continue
                for stat, value in field.items():
                    ci_stats[param][stat] = value.cpu().numpy()
            savemat(path / f"{model.full_name}-params.mat", ci_stats)
        summary.to_csv(
            path / f"{model.full_name}-summary.csv",
        )
예제 #16
0
파일: hmm.py 프로젝트: pyro-ppl/pyro
def model_6(sequences, lengths, args, batch_size=None, include_prior=False):
    num_sequences, max_length, data_dim = sequences.shape
    assert lengths.shape == (num_sequences, )
    assert lengths.max() <= max_length
    hidden_dim = args.hidden_dim

    if not args.raftery_parameterization:
        # Explicitly parameterize the full tensor of transition probabilities, which
        # has hidden_dim cubed entries.
        probs_x = pyro.param(
            "probs_x",
            torch.rand(hidden_dim, hidden_dim, hidden_dim),
            constraint=constraints.simplex,
        )
    else:
        # Use the more parsimonious "Raftery" parameterization of
        # the tensor of transition probabilities. See reference:
        # Raftery, A. E. A model for high-order markov chains.
        # Journal of the Royal Statistical Society. 1985.
        probs_x1 = pyro.param(
            "probs_x1",
            torch.rand(hidden_dim, hidden_dim),
            constraint=constraints.simplex,
        )
        probs_x2 = pyro.param(
            "probs_x2",
            torch.rand(hidden_dim, hidden_dim),
            constraint=constraints.simplex,
        )
        mix_lambda = pyro.param("mix_lambda",
                                torch.tensor(0.5),
                                constraint=constraints.unit_interval)
        # we use broadcasting to combine two tensors of shape (hidden_dim, hidden_dim) and
        # (hidden_dim, 1, hidden_dim) to obtain a tensor of shape (hidden_dim, hidden_dim, hidden_dim)
        probs_x = mix_lambda * probs_x1 + (1.0 -
                                           mix_lambda) * probs_x2.unsqueeze(-2)

    probs_y = pyro.param(
        "probs_y",
        torch.rand(hidden_dim, data_dim),
        constraint=constraints.unit_interval,
    )
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x_curr, x_prev = torch.tensor(0), torch.tensor(0)
        # we need to pass the argument `history=2' to `pyro.markov()`
        # since our model is now 2-markov
        for t in pyro.markov(range(lengths.max()), history=2):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                probs_x_t = Vindex(probs_x)[x_prev, x_curr]
                x_prev, x_curr = x_curr, pyro.sample(
                    "x_{}".format(t),
                    dist.Categorical(probs_x_t),
                    infer={"enumerate": "parallel"},
                )
                with tones_plate:
                    probs_y_t = probs_y[x_curr.squeeze(-1)]
                    pyro.sample(
                        "y_{}".format(t),
                        dist.Bernoulli(probs_y_t),
                        obs=sequences[batch, t],
                    )
예제 #17
0
def ttfb_model(data, control, Tmax):
    r"""
    Eq. 4 and Eq. 7 in::

      @article{friedman2015multi,
        title={Multi-wavelength single-molecule fluorescence analysis of transcription mechanisms},
        author={Friedman, Larry J and Gelles, Jeff},
        journal={Methods},
        volume={86},
        pages={27--36},
        year={2015},
        publisher={Elsevier}
      }

    :param data: time prior to the first binding at the target location
    :param control: time prior to the first binding at the control location
    :param Tmax: entire observation interval
    """
    ka = pyro.param(
        "ka",
        lambda: torch.full((data.shape[0], 1), 0.001),
        constraint=constraints.positive,
    )
    kns = pyro.param(
        "kns",
        lambda: torch.full((data.shape[0], 1), 0.001),
        constraint=constraints.positive,
    )
    Af = pyro.param(
        "Af",
        lambda: torch.full((data.shape[0], 1), 0.9),
        constraint=constraints.unit_interval,
    )
    k = torch.stack([kns, ka + kns])

    # on-target data
    mask = (data < Tmax) & (data > 0)
    tau = data.masked_fill(~mask, 1.0)
    with pyro.plate("bootstrap", data.shape[0], dim=-2) as bdx:
        with pyro.plate("N", data.shape[1], dim=-1):
            active = pyro.sample(
                "active", dist.Bernoulli(Af), infer={"enumerate": "parallel"}
            )
            with handlers.mask(mask=(data == Tmax)):
                pyro.factor("Tmax", -Vindex(k)[active.long().squeeze(-1), bdx] * Tmax)
                # pyro.factor("Tmax", -k * Tmax)
            with handlers.mask(mask=mask):
                pyro.sample(
                    "tau",
                    dist.Exponential(Vindex(k)[active.long().squeeze(-1), bdx]),
                    obs=tau,
                )
                # pyro.sample("tau", dist.Exponential(k), obs=tau)

    # negative control data
    if control is not None:
        mask = (control < Tmax) & (control > 0)
        tauc = control.masked_fill(~mask, 1.0)
        with pyro.plate("bootstrapc", control.shape[0], dim=-2):
            with pyro.plate("Nc", control.shape[1], dim=-1):
                with handlers.mask(mask=(control == Tmax)):
                    pyro.factor("Tmaxc", -kns * Tmax)
                with handlers.mask(mask=mask):
                    pyro.sample("tauc", dist.Exponential(kns), obs=tauc)