Exemple #1
0
 def model():
     locs = pyro.param("locs", torch.tensor([0.2, 0.3, 0.5]))
     p = torch.tensor([0.2, 0.3, 0.5])
     with pyro.plate("plate", len(data), dim=-1):
         x = pyro.sample("x", dist.Categorical(p))
         pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
Exemple #2
0
 def guide():
     p = pyro.param("p", torch.tensor([0.5, 0.3, 0.2]))
     with pyro.plate("plate", len(data), dim=-1):
         pyro.sample("x", dist.Categorical(p))
Exemple #3
0
 def model(data):
     p = pyro.param("p", torch.tensor(0.5))
     pyro.sample("x", dist.Bernoulli(p), obs=data)
Exemple #4
0
 def model(data):
     loc = pyro.param("loc", torch.tensor(0.0))
     pyro.sample("x", dist.Normal(loc, 1.), obs=data)
Exemple #5
0
 def guide():
     loc = pyro.param("loc", torch.tensor(0.))
     y = pyro.sample("y", dist.Normal(loc, 1.))
     pyro.sample("x", dist.Normal(y, 1.))
Exemple #6
0
def test_elbo_enumerate_plate_7(backend):
    #  Guide    Model
    #    a -----> b
    #    |        |
    #  +-|--------|----------------+
    #  | V        V                |
    #  | c -----> d -----> e   N=2 |
    #  +---------------------------+
    # This tests a mixture of model and guide enumeration.
    with pyro_backend(backend):
        pyro.param("model_probs_a",
                   torch.tensor([0.45, 0.55]),
                   constraint=constraints.simplex)
        pyro.param("model_probs_b",
                   torch.tensor([[0.6, 0.4], [0.4, 0.6]]),
                   constraint=constraints.simplex)
        pyro.param("model_probs_c",
                   torch.tensor([[0.75, 0.25], [0.55, 0.45]]),
                   constraint=constraints.simplex)
        pyro.param("model_probs_d",
                   torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]),
                   constraint=constraints.simplex)
        pyro.param("model_probs_e",
                   torch.tensor([[0.75, 0.25], [0.55, 0.45]]),
                   constraint=constraints.simplex)
        pyro.param("guide_probs_a",
                   torch.tensor([0.35, 0.64]),
                   constraint=constraints.simplex)
        pyro.param("guide_probs_c",
                   torch.tensor([[0., 1.], [1., 0.]]),  # deterministic
                   constraint=constraints.simplex)

        def auto_model(data):
            probs_a = pyro.param("model_probs_a")
            probs_b = pyro.param("model_probs_b")
            probs_c = pyro.param("model_probs_c")
            probs_d = pyro.param("model_probs_d")
            probs_e = pyro.param("model_probs_e")
            a = pyro.sample("a", dist.Categorical(probs_a))
            b = pyro.sample("b", dist.Categorical(probs_b[a]),
                            infer={"enumerate": "parallel"})
            with pyro.plate("data", 2, dim=-1):
                c = pyro.sample("c", dist.Categorical(probs_c[a]))
                d = pyro.sample("d", dist.Categorical(Vindex(probs_d)[b, c]),
                                infer={"enumerate": "parallel"})
                pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data)

        def auto_guide(data):
            probs_a = pyro.param("guide_probs_a")
            probs_c = pyro.param("guide_probs_c")
            a = pyro.sample("a", dist.Categorical(probs_a),
                            infer={"enumerate": "parallel"})
            with pyro.plate("data", 2, dim=-1):
                pyro.sample("c", dist.Categorical(probs_c[a]))

        def hand_model(data):
            probs_a = pyro.param("model_probs_a")
            probs_b = pyro.param("model_probs_b")
            probs_c = pyro.param("model_probs_c")
            probs_d = pyro.param("model_probs_d")
            probs_e = pyro.param("model_probs_e")
            a = pyro.sample("a", dist.Categorical(probs_a))
            b = pyro.sample("b", dist.Categorical(probs_b[a]),
                            infer={"enumerate": "parallel"})
            for i in range(2):
                c = pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a]))
                d = pyro.sample("d_{}".format(i),
                                dist.Categorical(Vindex(probs_d)[b, c]),
                                infer={"enumerate": "parallel"})
                pyro.sample("obs_{}".format(i), dist.Categorical(probs_e[d]), obs=data[i])

        def hand_guide(data):
            probs_a = pyro.param("guide_probs_a")
            probs_c = pyro.param("guide_probs_c")
            a = pyro.sample("a", dist.Categorical(probs_a),
                            infer={"enumerate": "parallel"})
            for i in range(2):
                pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a]))

        data = torch.tensor([0, 0])
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=1)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        auto_loss = elbo(auto_model, auto_guide, data)
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=0)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        hand_loss = elbo(hand_model, hand_guide, data)
        _check_loss_and_grads(hand_loss, auto_loss)
Exemple #7
0
    def _init_parameters(self):
        """
        Parameters shared between different models.
        """
        device = self.device
        data = self.data

        pyro.param(
            "proximity_loc",
            lambda: torch.tensor(0.5, device=device),
            constraint=constraints.interval(
                0,
                (self.data.P + 1) / math.sqrt(12) - torch.finfo(self.dtype).eps,
            ),
        )
        pyro.param(
            "proximity_size",
            lambda: torch.tensor(100, device=device),
            constraint=constraints.greater_than(2.0),
        )
        pyro.param(
            "lamda_loc",
            lambda: torch.full((self.Q,), 0.5, device=device),
            constraint=constraints.positive,
        )
        pyro.param(
            "lamda_beta",
            lambda: torch.full((self.Q,), 100, device=device),
            constraint=constraints.positive,
        )
        pyro.param(
            "gain_loc",
            lambda: torch.tensor(5, device=device),
            constraint=constraints.positive,
        )
        pyro.param(
            "gain_beta",
            lambda: torch.tensor(100, device=device),
            constraint=constraints.positive,
        )

        pyro.param(
            "background_mean_loc",
            lambda: (data.median.to(device) - data.offset.mean).expand(
                data.Nt, 1, data.C
            ),
            constraint=constraints.positive,
        )
        pyro.param(
            "background_std_loc",
            lambda: torch.ones(data.Nt, 1, data.C, device=device),
            constraint=constraints.positive,
        )

        pyro.param(
            "b_loc",
            lambda: (data.median.to(device) - self.data.offset.mean).expand(
                data.Nt, data.F, data.C
            ),
            constraint=constraints.positive,
        )
        pyro.param(
            "b_beta",
            lambda: torch.ones(data.Nt, data.F, data.C, device=device),
            constraint=constraints.positive,
        )
        pyro.param(
            "h_loc",
            lambda: torch.full((self.K, data.Nt, data.F, self.Q), 2000, device=device),
            constraint=constraints.positive,
        )
        pyro.param(
            "h_beta",
            lambda: torch.full((self.K, data.Nt, data.F, self.Q), 0.001, device=device),
            constraint=constraints.positive,
        )
        pyro.param(
            "w_mean",
            lambda: torch.full((self.K, data.Nt, data.F, self.Q), 1.5, device=device),
            constraint=constraints.interval(
                0.75 + torch.finfo(self.dtype).eps,
                2.25 - torch.finfo(self.dtype).eps,
            ),
        )
        pyro.param(
            "w_size",
            lambda: torch.full((self.K, data.Nt, data.F, self.Q), 100, device=device),
            constraint=constraints.greater_than(2.0),
        )
        pyro.param(
            "x_mean",
            lambda: torch.zeros(self.K, data.Nt, data.F, self.Q, device=device),
            constraint=constraints.interval(
                -(data.P + 1) / 2 + torch.finfo(self.dtype).eps,
                (data.P + 1) / 2 - torch.finfo(self.dtype).eps,
            ),
        )
        pyro.param(
            "y_mean",
            lambda: torch.zeros(self.K, data.Nt, data.F, self.Q, device=device),
            constraint=constraints.interval(
                -(data.P + 1) / 2 + torch.finfo(self.dtype).eps,
                (data.P + 1) / 2 - torch.finfo(self.dtype).eps,
            ),
        )
        pyro.param(
            "size",
            lambda: torch.full((self.K, data.Nt, data.F, self.Q), 200, device=device),
            constraint=constraints.greater_than(2.0),
        )
Exemple #8
0
 def guide():
     with pyro.plate("plate", len(data), dim=-1):
         p = pyro.param("p", torch.ones(len(data), 3) / 3, event_dim=1)
         pyro.sample("x", dist.Categorical(p))
     return p
Exemple #9
0
 def model(z=None):
     p = pyro.param("p", torch.tensor([0.75, 0.25]))
     z = pyro.sample("z", dist.Categorical(p), obs=z)
     logger.info("z.shape = {}".format(z.shape))
     with pyro.plate("data", 3), handlers.mask(mask=mask):
         pyro.sample("x", dist.Normal(z.type_as(data), 1.0), obs=data)
Exemple #10
0
    def guide(self):
        r"""
        **Variational Distribution**

        .. math::
            \begin{aligned}
                q(\phi \setminus \{z, \theta\}) =~&q(g) q(\sigma^{xy}) q(\pi) q(\lambda) \cdot \\
                &\prod_{\mathsf{AOI}} \left[ q(\mu^b) q(\sigma^b) \prod_{\mathsf{frame}}
                \left[ \vphantom{\prod_{F}} q(b) \prod_{\mathsf{spot}}
                q(m) q(h | m) q(w | m) q(x | m) q(y | m) \right] \right]
            \end{aligned}
        """
        # global parameters
        pyro.sample(
            "gain",
            dist.Gamma(
                pyro.param("gain_loc") * pyro.param("gain_beta"),
                pyro.param("gain_beta"),
            ),
        )
        pyro.sample(
            "pi",
            dist.Dirichlet(pyro.param("pi_mean") * pyro.param("pi_size")).to_event(1),
        )
        pyro.sample(
            "lamda",
            dist.Gamma(
                pyro.param("lamda_loc") * pyro.param("lamda_beta"),
                pyro.param("lamda_beta"),
            ).to_event(1),
        )
        pyro.sample(
            "proximity",
            AffineBeta(
                pyro.param("proximity_loc"),
                pyro.param("proximity_size"),
                0,
                (self.data.P + 1) / math.sqrt(12),
            ),
        )

        # spots
        spots = pyro.plate("spots", self.K)
        # aoi sites
        aois = pyro.plate(
            "aois",
            self.data.Nt,
            subsample=self.n,
            subsample_size=self.nbatch_size,
            dim=-3,
        )
        # time frames
        frames = pyro.plate(
            "frames",
            self.data.F,
            subsample=self.f,
            subsample_size=self.fbatch_size,
            dim=-2,
        )
        # color channels
        channels = pyro.plate(
            "channels",
            self.data.C,
            dim=-1,
        )

        with channels as cdx, aois as ndx:
            ndx = ndx[:, None, None]
            mask = Vindex(self.data.mask)[ndx].to(self.device)
            with handlers.mask(mask=mask):
                pyro.sample(
                    "background_mean",
                    dist.Delta(Vindex(pyro.param("background_mean_loc"))[ndx, 0, cdx]),
                )
                pyro.sample(
                    "background_std",
                    dist.Delta(Vindex(pyro.param("background_std_loc"))[ndx, 0, cdx]),
                )
                with frames as fdx:
                    fdx = fdx[:, None]
                    # sample background intensity
                    pyro.sample(
                        "background",
                        dist.Gamma(
                            Vindex(pyro.param("b_loc"))[ndx, fdx, cdx]
                            * Vindex(pyro.param("b_beta"))[ndx, fdx, cdx],
                            Vindex(pyro.param("b_beta"))[ndx, fdx, cdx],
                        ),
                    )

                    for kdx in spots:
                        # sample spot presence m
                        m = pyro.sample(
                            f"m_k{kdx}",
                            dist.Bernoulli(
                                Vindex(pyro.param("m_probs"))[kdx, ndx, fdx, cdx]
                            ),
                            infer={"enumerate": "parallel"},
                        )
                        with handlers.mask(mask=m > 0):
                            # sample spot variables
                            pyro.sample(
                                f"height_k{kdx}",
                                dist.Gamma(
                                    Vindex(pyro.param("h_loc"))[kdx, ndx, fdx, cdx]
                                    * Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, cdx],
                                    Vindex(pyro.param("h_beta"))[kdx, ndx, fdx, cdx],
                                ),
                            )
                            pyro.sample(
                                f"width_k{kdx}",
                                AffineBeta(
                                    Vindex(pyro.param("w_mean"))[kdx, ndx, fdx, cdx],
                                    Vindex(pyro.param("w_size"))[kdx, ndx, fdx, cdx],
                                    self.priors["width_min"],
                                    self.priors["width_max"],
                                ),
                            )
                            pyro.sample(
                                f"x_k{kdx}",
                                AffineBeta(
                                    Vindex(pyro.param("x_mean"))[kdx, ndx, fdx, cdx],
                                    Vindex(pyro.param("size"))[kdx, ndx, fdx, cdx],
                                    -(self.data.P + 1) / 2,
                                    (self.data.P + 1) / 2,
                                ),
                            )
                            pyro.sample(
                                f"y_k{kdx}",
                                AffineBeta(
                                    Vindex(pyro.param("y_mean"))[kdx, ndx, fdx, cdx],
                                    Vindex(pyro.param("size"))[kdx, ndx, fdx, cdx],
                                    -(self.data.P + 1) / 2,
                                    (self.data.P + 1) / 2,
                                ),
                            )
Exemple #11
0
 def model(z2=None):
     p = pyro.param("p", torch.tensor([0.25, 0.75]))
     loc = pyro.param("loc", torch.tensor([-1.0, 1.0]))
     with pyro.plate("data", 2):
         z2 = pyro.sample("z2", dist.Categorical(p), obs=z2)
         pyro.sample("x2", dist.Normal(loc[z2], 1.0), obs=data)
Exemple #12
0
def ttfb_model(data, control, Tmax):
    r"""
    Eq. 4 and Eq. 7 in::

      @article{friedman2015multi,
        title={Multi-wavelength single-molecule fluorescence analysis of transcription mechanisms},
        author={Friedman, Larry J and Gelles, Jeff},
        journal={Methods},
        volume={86},
        pages={27--36},
        year={2015},
        publisher={Elsevier}
      }

    :param data: time prior to the first binding at the target location
    :param control: time prior to the first binding at the control location
    :param Tmax: entire observation interval
    """
    ka = pyro.param(
        "ka",
        lambda: torch.full((data.shape[0], 1), 0.001),
        constraint=constraints.positive,
    )
    kns = pyro.param(
        "kns",
        lambda: torch.full((data.shape[0], 1), 0.001),
        constraint=constraints.positive,
    )
    Af = pyro.param(
        "Af",
        lambda: torch.full((data.shape[0], 1), 0.9),
        constraint=constraints.unit_interval,
    )
    k = torch.stack([kns, ka + kns])

    # on-target data
    mask = (data < Tmax) & (data > 0)
    tau = data.masked_fill(~mask, 1.0)
    with pyro.plate("bootstrap", data.shape[0], dim=-2) as bdx:
        with pyro.plate("N", data.shape[1], dim=-1):
            active = pyro.sample(
                "active", dist.Bernoulli(Af), infer={"enumerate": "parallel"}
            )
            with handlers.mask(mask=(data == Tmax)):
                pyro.factor("Tmax", -Vindex(k)[active.long().squeeze(-1), bdx] * Tmax)
                # pyro.factor("Tmax", -k * Tmax)
            with handlers.mask(mask=mask):
                pyro.sample(
                    "tau",
                    dist.Exponential(Vindex(k)[active.long().squeeze(-1), bdx]),
                    obs=tau,
                )
                # pyro.sample("tau", dist.Exponential(k), obs=tau)

    # negative control data
    if control is not None:
        mask = (control < Tmax) & (control > 0)
        tauc = control.masked_fill(~mask, 1.0)
        with pyro.plate("bootstrapc", control.shape[0], dim=-2):
            with pyro.plate("Nc", control.shape[1], dim=-1):
                with handlers.mask(mask=(control == Tmax)):
                    pyro.factor("Tmaxc", -kns * Tmax)
                with handlers.mask(mask=mask):
                    pyro.sample("tauc", dist.Exponential(kns), obs=tauc)
Exemple #13
0
 def model():
     loc = pyro.param("loc", torch.tensor(2.0))
     scale = pyro.param("scale", torch.tensor(1.0))
     x = pyro.sample("x", dist.Normal(loc, scale))
     return x
Exemple #14
0
 def guide():
     loc = pyro.param("loc", torch.tensor(0.))
     scale = pyro.param("scale", torch.tensor(1.))
     with pyro.plate("plate_outer", data.size(-1), dim=-1):
         pyro.sample("x", dist.Normal(loc, scale))
Exemple #15
0
 def m_probs(self) -> torch.Tensor:
     r"""
     Posterior spot presence probability :math:`q(m=1)`.
     """
     return pyro.param("m_probs").data
Exemple #16
0
 def model():
     locs = pyro.param("locs", torch.tensor([-1., 0., 1.]))
     with pyro.plate("plate", len(data), dim=-1):
         x = pyro.sample("x", dist.Categorical(torch.ones(3) / 3))
         pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)
Exemple #17
0
    def save_checkpoint(self, writer: SummaryWriter = None):
        """
        Save checkpoint.

        :param writer: SummaryWriter object.
        """
        # save only if no NaN values
        for k, v in pyro.get_param_store().items():
            if torch.isnan(v).any() or torch.isinf(v).any():
                raise ValueError(
                    "Iteration #{}. Detected NaN values in {}".format(self.iter, k)
                )

        # update convergence criteria parameters
        for name in self.conv_params:
            if name == "-ELBO":
                self._rolling["-ELBO"].append(self.iter_loss)
            elif pyro.param(name).ndim == 1:
                for i in range(len(pyro.param(name))):
                    self._rolling[f"{name}_{i}"].append(pyro.param(name)[i].item())
            else:
                self._rolling[name].append(pyro.param(name).item())

        # check convergence status
        self.converged = False
        if len(self._rolling["-ELBO"]) == self._rolling["-ELBO"].maxlen:
            crit = all(
                torch.tensor(value).std() / torch.tensor(value)[-50:].std() < 1.05
                for value in self._rolling.values()
            )
            if crit:
                self.converged = True

        # save the model state
        torch.save(
            {
                "iter": self.iter,
                "params": pyro.get_param_store().get_state(),
                "optimizer": self.optim.get_state(),
                "rolling": dict(self._rolling),
                "convergence_status": self.converged,
            },
            self.run_path / f"{self.name}-model.tpqr",
        )

        # save global parameters for tensorboard
        writer.add_scalar("-ELBO", self.iter_loss, self.iter)
        for name, val in pyro.get_param_store().items():
            if val.dim() == 0:
                writer.add_scalar(name, val.item(), self.iter)
            elif val.dim() == 1 and len(val) <= self.S + 1:
                scalars = {str(i): v.item() for i, v in enumerate(val)}
                writer.add_scalars(name, scalars, self.iter)

        if False and self.data.labels is not None:
            pred_labels = (
                self.pspecific_map[self.data.is_ontarget].cpu().numpy().ravel()
            )
            true_labels = self.data.labels["z"].ravel()

            metrics = {}
            with np.errstate(divide="ignore", invalid="ignore"):
                metrics["MCC"] = matthews_corrcoef(true_labels, pred_labels)
            metrics["Recall"] = recall_score(true_labels, pred_labels, zero_division=0)
            metrics["Precision"] = precision_score(
                true_labels, pred_labels, zero_division=0
            )

            neg, pos = {}, {}
            neg["TN"], neg["FP"], pos["FN"], pos["TP"] = confusion_matrix(
                true_labels, pred_labels, labels=(0, 1)
            ).ravel()

            writer.add_scalars("ACCURACY", metrics, self.iter)
            writer.add_scalars("NEGATIVES", neg, self.iter)
            writer.add_scalars("POSITIVES", pos, self.iter)

        logger.debug(f"Iteration #{self.iter}: Successful.")
Exemple #18
0
 def guide():
     q = pyro.param("q",
                    torch.randn(3).exp(),
                    constraint=constraints.simplex)
     pyro.sample("x", dist.Categorical(q))
Exemple #19
0
 def model(data=None):
     loc = pyro.param("loc", torch.tensor(2.0))
     scale = pyro.param("scale", torch.tensor(1.0))
     with pyro.plate("data", 1000, dim=-1):
         x = pyro.sample("x", dist.Normal(loc, scale), obs=data)
     return x
Exemple #20
0
def test_elbo_enumerate_plates_1(backend):
    #  +-----------------+
    #  | a ----> b   M=2 |
    #  +-----------------+
    #  +-----------------+
    #  | c ----> d   N=3 |
    #  +-----------------+
    # This tests two unrelated plates.
    # Each should remain uncontracted.
    with pyro_backend(backend):
        pyro.param("probs_a",
                   torch.tensor([0.45, 0.55]),
                   constraint=constraints.simplex)
        pyro.param("probs_b",
                   torch.tensor([[0.6, 0.4], [0.4, 0.6]]),
                   constraint=constraints.simplex)
        pyro.param("probs_c",
                   torch.tensor([0.75, 0.25]),
                   constraint=constraints.simplex)
        pyro.param("probs_d",
                   torch.tensor([[0.4, 0.6], [0.3, 0.7]]),
                   constraint=constraints.simplex)
        b_data = torch.tensor([0, 1])
        d_data = torch.tensor([0, 0, 1])

        def auto_model():
            probs_a = pyro.param("probs_a")
            probs_b = pyro.param("probs_b")
            probs_c = pyro.param("probs_c")
            probs_d = pyro.param("probs_d")
            with pyro.plate("a_axis", 2, dim=-1):
                a = pyro.sample("a",
                                dist.Categorical(probs_a),
                                infer={"enumerate": "parallel"})
                pyro.sample("b", dist.Categorical(probs_b[a]), obs=b_data)
            with pyro.plate("c_axis", 3, dim=-1):
                c = pyro.sample("c",
                                dist.Categorical(probs_c),
                                infer={"enumerate": "parallel"})
                pyro.sample("d", dist.Categorical(probs_d[c]), obs=d_data)

        def hand_model():
            probs_a = pyro.param("probs_a")
            probs_b = pyro.param("probs_b")
            probs_c = pyro.param("probs_c")
            probs_d = pyro.param("probs_d")
            for i in range(2):
                a = pyro.sample("a_{}".format(i),
                                dist.Categorical(probs_a),
                                infer={"enumerate": "parallel"})
                pyro.sample("b_{}".format(i),
                            dist.Categorical(probs_b[a]),
                            obs=b_data[i])
            for j in range(3):
                c = pyro.sample("c_{}".format(j),
                                dist.Categorical(probs_c),
                                infer={"enumerate": "parallel"})
                pyro.sample("d_{}".format(j),
                            dist.Categorical(probs_d[c]),
                            obs=d_data[j])

        def guide():
            pass

        elbo = infer.TraceEnum_ELBO(max_plate_nesting=1)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        auto_loss = elbo(auto_model, guide)
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=0)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        hand_loss = elbo(hand_model, guide)
        _check_loss_and_grads(hand_loss, auto_loss)
Exemple #21
0
 def model():
     locs = pyro.param("locs", torch.randn(3), constraint=constraints.real)
     scales = pyro.param("scales", torch.randn(3).exp(), constraint=constraints.positive)
     p = torch.tensor([0.5, 0.3, 0.2])
     x = pyro.sample("x", dist.Categorical(p))
     pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data)