예제 #1
0
파일: handlers.py 프로젝트: xidulu/numpyro
    def __exit__(self, *args, **kwargs):
        import funsor

        _coerce = COERCIONS.pop()
        assert _coerce is self._coerce
        super().__exit__(*args, **kwargs)

        # Convert delayed statements to pyro.factor()
        reduced_vars = []
        log_prob_terms = []
        plates = frozenset()
        for name, site in self.trace.items():
            if not site["is_observed"]:
                reduced_vars.append(name)
            dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]}
            fn = funsor.to_funsor(site["fn"], funsor.Real, dim_to_name)
            value = site["value"]
            if not isinstance(value, str):
                value = funsor.to_funsor(site["value"], fn.inputs["value"],
                                         dim_to_name)
            log_prob_terms.append(fn(value=value))
            plates |= frozenset(f.name for f in site["cond_indep_stack"])
        assert log_prob_terms, "nothing to collapse"
        reduced_plates = plates - self.preserved_plates
        log_prob = funsor.sum_product.sum_product(
            funsor.ops.logaddexp,
            funsor.ops.add,
            log_prob_terms,
            eliminate=frozenset(reduced_vars) | reduced_plates,
            plates=plates,
        )
        name = reduced_vars[0]
        numpyro.factor(name, log_prob.data)
예제 #2
0
 def guide():
     with numpyro.handlers.seed(rng_seed=1):
         x = numpyro.sample("x", dist.Normal(0, 1))
         called.add("guide-always")
         if numpyro.get_mask() is not False:
             called.add("guide-sometimes")
             numpyro.factor("g", 2 - x)
예제 #3
0
파일: util.py 프로젝트: jatentaki/numpyro
def _unconstrain_reparam(params, site):
    name = site["name"]
    if name in params:
        p = params[name]
        support = site["fn"].support
        t = biject_to(support)
        # in scan, we might only want to substitute an item at index i, rather than the whole sequence
        i = site["infer"].get("_scan_current_index", None)
        if i is not None:
            event_dim_shift = t.codomain.event_dim - t.domain.event_dim
            expected_unconstrained_dim = len(
                site["fn"].shape()) - event_dim_shift
            # check if p has additional time dimension
            if jnp.ndim(p) > expected_unconstrained_dim:
                p = p[i]

        if support in [constraints.real, constraints.real_vector]:
            return p
        value = t(p)

        log_det = t.log_abs_det_jacobian(p, value)
        log_det = sum_rightmost(
            log_det,
            jnp.ndim(log_det) - jnp.ndim(value) + len(site["fn"].event_shape))
        if site["scale"] is not None:
            log_det = site["scale"] * log_det
        numpyro.factor("_{}_log_det".format(name), log_det)
        return value
예제 #4
0
    def __call__(self, name, fn, obs):
        if name not in self.guide.prototype_trace:
            return fn, obs
        assert obs is None, "NeuTraReparam does not support observe statements"

        log_density = 0.
        if not self._x_unconstrained:  # On first sample site.
            # Sample a shared latent.
            z_unconstrained = numpyro.sample(
                "{}_shared_latent".format(self.guide.prefix),
                self.guide.get_base_dist().mask(False))

            # Differentiably transform.
            x_unconstrained = self.transform(z_unconstrained)
            # TODO: find a way to only compute those log_prob terms when needed
            log_density = self.transform.log_abs_det_jacobian(
                z_unconstrained, x_unconstrained)
            self._x_unconstrained = self.guide._unpack_latent(x_unconstrained)

        # Extract a single site's value from the shared latent.
        unconstrained_value = self._x_unconstrained.pop(name)
        transform = biject_to(fn.support)
        value = transform(unconstrained_value)
        logdet = transform.log_abs_det_jacobian(unconstrained_value, value)
        logdet = sum_rightmost(
            logdet,
            jnp.ndim(logdet) - jnp.ndim(value) + len(fn.event_shape))
        log_density = log_density + fn.log_prob(value) + logdet
        numpyro.factor("_{}_log_prob".format(name), log_density)
        return None, value
예제 #5
0
파일: hmm.py 프로젝트: synabreu/numpyro
def semi_supervised_hmm(transition_prior, emission_prior,
                        supervised_categories, supervised_words,
                        unsupervised_words):
    num_categories, num_words = transition_prior.shape[
        0], emission_prior.shape[0]
    transition_prob = numpyro.sample(
        'transition_prob',
        dist.Dirichlet(
            np.broadcast_to(transition_prior,
                            (num_categories, num_categories))))
    emission_prob = numpyro.sample(
        'emission_prob',
        dist.Dirichlet(
            np.broadcast_to(emission_prior, (num_categories, num_words))))

    # models supervised data;
    # here we don't make any assumption about the first supervised category, in other words,
    # we place a flat/uniform prior on it.
    numpyro.sample('supervised_categories',
                   dist.Categorical(
                       transition_prob[supervised_categories[:-1]]),
                   obs=supervised_categories[1:])
    numpyro.sample('supervised_words',
                   dist.Categorical(emission_prob[supervised_categories]),
                   obs=supervised_words)

    # computes log prob of unsupervised data
    transition_log_prob = np.log(transition_prob)
    emission_log_prob = np.log(emission_prob)
    init_log_prob = emission_log_prob[:, unsupervised_words[0]]
    log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:],
                                transition_log_prob, emission_log_prob)
    log_prob = logsumexp(log_prob, axis=0, keepdims=True)
    # inject log_prob to potential function
    numpyro.factor('forward_log_prob', log_prob)
예제 #6
0
 def model():
     with numpyro.handlers.seed(rng_seed=0):
         x = numpyro.sample("x", dist.Normal(0, 1))
         numpyro.sample("y", dist.Normal(x, 1), obs=0.)
         called.add("model-always")
         if numpyro.get_mask() is not False:
             called.add("model-sometimes")
             numpyro.factor("f", x + 1)
예제 #7
0
def _unconstrain_reparam(params, site):
    name = site['name']
    if name in params:
        p = params[name]
        t = biject_to(site['fn'].support)
        value = t(p)

        log_det = t.log_abs_det_jacobian(p, value)
        log_det = sum_rightmost(
            log_det,
            jnp.ndim(log_det) - jnp.ndim(value) + len(site['fn'].event_shape))
        if site['scale'] is not None:
            log_det = site['scale'] * log_det
        numpyro.factor('_{}_log_det'.format(name), log_det)
        return value
예제 #8
0
    def __exit__(self, exc_type, exc_value, traceback):
        # make sure exit trackback is nice if an error happens
        super().__exit__(exc_type, exc_value, traceback)
        if exc_type is not None:
            return

        if self.params is None:
            return

        if numpyro.get_mask() is not False:
            numpyro.factor("_biased_corrected_log_likelihood",
                           self.method(self.likelihoods, self.params, self.gibbs_state))

        # clean up
        self.params = None
        self.likelihoods = {}
        self.subsample_plates = {}
        self.gibbs_state = None
예제 #9
0
    def __exit__(self, exc_type, exc_value, traceback):
        # make sure exit trackback is nice if an error happens
        super().__exit__(exc_type, exc_value, traceback)
        if exc_type is not None:
            return

        if self.params is None:
            return

        # add numpyro.factor; ideally, we will want to skip this computation when making prediction
        # see: https://github.com/pyro-ppl/pyro/issues/2744
        numpyro.factor(
            "_biased_corrected_log_likelihood",
            self.method(self.likelihoods, self.params, self.gibbs_state))

        # clean up
        self.params = None
        self.likelihoods = {}
        self.subsample_plates = {}
        self.gibbs_state = None
예제 #10
0
def semi_supervised_hmm(
    num_categories: int,
    num_words: int,
    supervised_categories: jnp.ndarray,
    supervised_words: jnp.ndarray,
    unsupervised_words: jnp.ndarray,
) -> None:

    transition_prior = jnp.ones(num_categories)
    emission_prior = jnp.repeat(0.1, num_words)

    transition_prob = numpyro.sample(
        "transition_prob",
        dist.Dirichlet(
            jnp.broadcast_to(transition_prior,
                             (num_categories, num_categories))),
    )
    emission_prob = numpyro.sample(
        "emission_prob",
        dist.Dirichlet(
            jnp.broadcast_to(emission_prior, (num_categories, num_words))),
    )

    numpyro.sample(
        "supervised_categories",
        dist.Categorical(transition_prob[supervised_categories[:-1]]),
        obs=supervised_categories[1:],
    )
    numpyro.sample(
        "supervised_words",
        dist.Categorical(emission_prob[supervised_categories]),
        obs=supervised_words,
    )

    transition_log_prob = jnp.log(transition_prob)
    emission_log_prob = jnp.log(emission_prob)
    init_log_prob = emission_log_prob[:, unsupervised_words[0]]
    log_prob = forward_log_prob(init_log_prob, unsupervised_words[1:],
                                transition_log_prob, emission_log_prob)
    log_prob = logsumexp(log_prob, axis=0, keepdims=True)
    numpyro.factor("forward_log_prob", log_prob)
예제 #11
0
파일: reparam.py 프로젝트: pyro-ppl/numpyro
    def __call__(self, name, fn, obs):
        # Support must be circular
        support = fn.support
        if isinstance(support, constraints.independent):
            support = fn.support.base_constraint
        assert support is constraints.circular

        # Draw parameter-free noise.
        new_fn = dist.ImproperUniform(constraints.real, fn.batch_shape, fn.event_shape)
        value = numpyro.sample(
            f"{name}_unwrapped",
            new_fn,
            obs=obs,
        )

        # Differentiably transform.
        value = jnp.remainder(value + math.pi, 2 * math.pi) - math.pi

        # Simulate a pyro.deterministic() site.
        numpyro.factor(f"{name}_factor", fn.log_prob(value))
        return None, value
예제 #12
0
    def to_numpyro(self, y=None):

        f_loc = self.mean(self.X)

        _, W, D = vfe_precompute(
            self.X, self.X_u, self.obs_noise, self.kernel, jitter=self.jitter
        )

        numpyro.factor("trace_term", self.trace_term(self.X, W, self.obs_noise))
        # Sample y according SGP
        if y is not None:

            return numpyro.sample(
                "y",
                dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D)
                .expand_by(self.y.shape[:-1])
                .to_event(self.y.ndim - 1),
                obs=self.y,
            )
        else:

            return numpyro.sample(
                "y", dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D)
            )
예제 #13
0
 def model():
     with handlers.mask(mask=jnp.zeros(10, dtype=bool)):
         numpyro.factor('inf', -jnp.inf)
예제 #14
0
 def _amplitude_prior(self, he_amp, cz_amp):
     # Prior that log(he_amp) == log(cz_amp) is a 2-sigma event
     # and the He amplitude is ~ 4 times the BCZ amplitude (log10(4) ~ 0.6)
     delta = 2.0 * (jnp.log10(he_amp) - jnp.log10(cz_amp) - 0.6) / 0.6
     logp = numpyro.distributions.Normal().log_prob(delta)
     numpyro.factor("amp", logp)
예제 #15
0
    def __call__(
        self,
        n: ArrayLike,
        nu: Optional[ArrayLike] = None,
        nu_err: Optional[ArrayLike] = None,
        n_pred: Optional[ArrayLike] = None,
    ):
        """Sample the model for given observables.

        Args:
            nu (:term:`array_like`, optional): Observed radial mode
                frequencies.
            nu_err (:term:`array_like`, optional): Gaussian observational
                uncertainties (sigma) for nu.
            pred (bool): If True, make predictions nu and nu_pred from n and
                num_pred.
        """
        # Same kernel function for both models
        var = numpyro.param("kernel_var", self._kernel_var)
        length = numpyro.param("kernel_length", self._kernel_length)
        kernel = var * kernels.ExpSquared(length)

        diag = 1e-6 if nu_err is None else nu_err**2  # No need for jitter

        args = ("models", 2)
        with dimension(*args):
            with numpyro.plate(*args):
                # Broadcast background function to both models
                bkg_func = self.background()

        # MODEL 0
        with numpyro.handlers.scope(prefix=self._prefix, divider="."):
            # Contain null model parameters in the null scope
            def mean0(n):
                return jnp.squeeze(bkg_func(n)[0])

            # gp0 = GP(kernel, mean=mean0)
            # dist0 = gp0.distribution(n, noise=nu_err)
            gp0 = GaussianProcess(kernel, n, mean=mean0, diag=diag)
            dist0 = gp0.numpyro_dist()

            with dimension("n", n.shape[-1], coords=n):
                nu0 = numpyro.sample("nu_obs", dist0, obs=nu)
                numpyro.deterministic("nu_bkg", bkg_func(n)[0])

            if n_pred is not None:
                with dimension("n", n.shape[-1], coords=n):
                    # gp0.predict("nu", n)
                    numpyro.sample(
                        "nu",
                        gp0.condition(nu0, n).gp.numpyro_dist(),
                    )
                with dimension("n_pred", n_pred.shape[-1], coords=n_pred):
                    # gp0.predict("nu_pred", n_pred)
                    numpyro.sample(
                        "nu_pred",
                        gp0.condition(nu0, n_pred).gp.numpyro_dist())
                    numpyro.deterministic("nu_bkg_pred", bkg_func(n_pred)[0])

        # MODEL 1
        he_glitch_func = self.he_glitch()
        cz_glitch_func = self.cz_glitch()

        def mean(n):
            nu_bkg = jnp.squeeze(bkg_func(n)[1])
            return nu_bkg + he_glitch_func(nu_bkg) + cz_glitch_func(nu_bkg)

        # gp = GP(kernel, mean=mean)
        # dist = gp.distribution(n, noise=nu_err)
        gp = GaussianProcess(kernel, n, mean=mean, diag=diag)
        dist = gp.numpyro_dist()

        with dimension("n", n.shape[-1], coords=n):

            nu = numpyro.sample("nu_obs", dist, obs=nu)  # redefines nu!
            nu_bkg = numpyro.deterministic("nu_bkg", bkg_func(n)[1])
            numpyro.deterministic("dnu_he", he_glitch_func(nu_bkg))
            numpyro.deterministic("dnu_cz", cz_glitch_func(nu_bkg))

        if n_pred is not None:
            with dimension("n", n.shape[-1], coords=n):
                # gp.predict("nu", n)
                numpyro.sample("nu", gp.condition(nu, n).gp.numpyro_dist())
            with dimension("n_pred", n_pred.shape[-1], coords=n_pred):
                # gp.predict("nu_pred", n_pred)
                numpyro.sample("nu_pred",
                               gp.condition(nu, n_pred).gp.numpyro_dist())

                nu_bkg = numpyro.deterministic("nu_bkg_pred",
                                               bkg_func(n_pred)[1])
                numpyro.deterministic("dnu_he_pred", he_glitch_func(nu_bkg))
                numpyro.deterministic("dnu_cz_pred", cz_glitch_func(nu_bkg))

        # Other deterministics and priors
        self._amplitude_prior(*self._glitch_amplitudes(nu))

        # LIKELIHOOD
        # Model comparison - if nu is not None, then nu0 == nu
        logL0 = dist0.log_prob(nu0)
        logL = dist.log_prob(nu)

        numpyro.factor("obs", (logL0 + logL).sum())

        # Log10 Bayes factor
        numpyro.deterministic("log_k", (logL - logL0).sum() / np.log(10.0))
예제 #16
0
def SparseGP(X, y):

    n_samples = X.shape[0]
    X = numpyro.deterministic("X", X)
    # Set priors on kernel hyperparameters.
    η = numpyro.sample("variance", dist.HalfCauchy(scale=5.0))
    ℓ = numpyro.sample("length_scale", dist.Gamma(2.0, 1.0))
    σ = numpyro.sample("obs_noise", dist.HalfCauchy(scale=5.0))

    x_u = numpyro.param("x_u", init_value=X_u_init)

    # η = numpyro.param("kernel_var", init_value=1.0, constraints=dist.constraints.positive)
    # ℓ = numpyro.param("kernel_length", init_value=0.1,  constraints=dist.constraints.positive)
    # σ = numpyro.param("sigma", init_value=0.1, onstraints=dist.constraints.positive)

    # ================================
    # Mean Function
    # ================================
    f_loc = np.zeros(n_samples)

    # ================================
    # Qff Term
    # ================================
    # W   = (inv(Luu) @ Kuf).T
    # Qff = Kfu @ inv(Kuu) @ Kuf
    # Qff = W @ W.T
    # ================================
    Kuu = rbf_kernel(x_u, x_u, η, ℓ)
    Kuf = rbf_kernel(x_u, X, η, ℓ)
    # Kuu += jnp.eye(Ninducing) * jitter
    # add jitter
    Kuu = add_to_diagonal(Kuu, jitter)

    # cholesky factorization
    Luu = cholesky(Kuu, lower=True)
    Luu = numpyro.deterministic("Luu", Luu)

    # W matrix
    W = solve_triangular(Luu, Kuf, lower=True)
    W = numpyro.deterministic("W", W).T

    # ================================
    # Likelihood Noise Term
    # ================================
    # D = noise
    # ================================
    D = numpyro.deterministic("G", jnp.ones(n_samples) * σ)

    # ================================
    # trace term
    # ================================
    # t = tr(Kff - Qff) / noise
    # t /= - 2.0
    # ================================
    Kffdiag = jnp.diag(rbf_kernel(X, X, η, ℓ))
    Qffdiag = jnp.power(W, 2).sum(axis=1)
    trace_term = (Kffdiag - Qffdiag).sum() / σ
    trace_term = jnp.clip(trace_term, a_min=0.0)  # numerical errors

    # add trace term to the log probability loss
    numpyro.factor("trace_term", -trace_term / 2.0)

    # Sample y according SGP
    return numpyro.sample(
        "y",
        dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W,
                                       cov_diag=D).expand_by(
                                           y.shape[:-1]).to_event(y.ndim - 1),
        obs=y,
    )
예제 #17
0
 def model():
     a = numpyro.sample("a", dist.Normal(0, 1))
     numpyro.factor("b", 0.0)
     numpyro.factor("c", a)