Пример #1
0
    def __init__(self, params):
        '''
        Pull from passed parameters and initialize numpyro samplers.
        '''
        # if params.R0 is not None:
        #     self.R0 = params.R0
        # self.R0 = ny.sample('r0', dist.Normal(*self.R0))  # NB: Unused right now

        # if params.INCUBATION_PERIOD is not None:
        #     self.INCUBATION_PERIOD = params.INCUBATION_PERIOD
        # self.INCUBATION_PERIOD = ny.sample('iP', dist.Gamma(*self.INCUBATION_PERIOD))  # NB: unused right now
        if params.PROPORTION_PRESYMPTOMATIC_TRANSMISSION is not None:
            self.PROPORTION_PRESYMPTOMATIC_TRANSMISSION = params.PROPORTION_PRESYMPTOMATIC_TRANSMISSION
        self.PROPORTION_PRESYMPTOMATIC_TRANSMISSION = ny.sample(
            'pP',
            dist.TruncatedNormal(*self.PROPORTION_PRESYMPTOMATIC_TRANSMISSION))
        if params.SYMPTOMATIC is not None:
            self.SYMPTOMATIC = params.SYMPTOMATIC
        self.SYMPTOMATIC = ny.sample('rS', dist.Beta(*self.SYMPTOMATIC))

        if params.PRESYMPTOMATIC_CONTAGIOUS_PERIOD is not None:
            self.PRESYMPTOMATIC_CONTAGIOUS_PERIOD = params.PRESYMPTOMATIC_CONTAGIOUS_PERIOD
        self.PRESYMPTOMATIC_CONTAGIOUS_PERIOD = ny.sample(
            'cP', dist.Gamma(*self.PRESYMPTOMATIC_CONTAGIOUS_PERIOD))
        if params.ASYMPTOMATIC_CONTAGIOUS_PERIOD is not None:
            self.ASYMPTOMATIC_CONTAGIOUS_PERIOD = params.ASYMPTOMATIC_CONTAGIOUS_PERIOD
        self.ASYMPTOMATIC_CONTAGIOUS_PERIOD = ny.sample(
            'cA', dist.Gamma(*self.ASYMPTOMATIC_CONTAGIOUS_PERIOD))
        if params.SYMPTOMATIC_CONTAGIOUS_PERIOD is not None:
            self.SYMPTOMATIC_CONTAGIOUS_PERIOD = params.SYMPTOMATIC_CONTAGIOUS_PERIOD
        self.SYMPTOMATIC_CONTAGIOUS_PERIOD = ny.sample(
            'cS', dist.Gamma(*self.SYMPTOMATIC_CONTAGIOUS_PERIOD))

        if params.SYMPTOM_TO_HOSP_PERIOD is not None:
            self.SYMPTOM_TO_HOSP_PERIOD = params.SYMPTOM_TO_HOSP_PERIOD
        self.SYMPTOM_TO_HOSP_PERIOD = ny.sample(
            'pH', dist.TruncatedNormal(*self.SYMPTOM_TO_HOSP_PERIOD))
        if params.HOSP_DEATH_PERIOD is not None:
            self.HOSP_DEATH_PERIOD = params.HOSP_DEATH_PERIOD
        self.HOSP_DEATH_PERIOD = ny.sample('hD',
                                           dist.Gamma(*self.HOSP_DEATH_PERIOD))
        if params.HOSP_RECOVERY_PERIOD is not None:
            self.HOSP_RECOVERY_PERIOD = params.HOSP_RECOVERY_PERIOD
        self.HOSP_RECOVERY_PERIOD = ny.sample(
            'hR', dist.Gamma(*self.HOSP_RECOVERY_PERIOD))

        # Derived variables
        self.EXPOSED_PERIOD = self.INCUBATION_PERIOD - self.PRESYMPTOMATIC_CONTAGIOUS_PERIOD
        self.RECOVERY_PERIOD = self.INCUBATION_PERIOD + np.where(
            self.SYMPTOMATIC > 0, self.SYMPTOMATIC_CONTAGIOUS_PERIOD,
            self.ASYMPTOMATIC_CONTAGIOUS_PERIOD)
Пример #2
0
def model(N, y=None):
    """
    :param int N: number of measurement times
    :param numpy.ndarray y: measured populations with shape (N, 2)
    """
    # initial population
    z_init = numpyro.sample("z_init",
                            dist.LogNormal(jnp.log(10), 1).expand([2]))
    # measurement times
    ts = jnp.arange(float(N))
    # parameters alpha, beta, gamma, delta of dz_dt
    theta = numpyro.sample(
        "theta",
        dist.TruncatedNormal(
            low=0.0,
            loc=jnp.array([1.0, 0.05, 1.0, 0.05]),
            scale=jnp.array([0.5, 0.05, 0.5, 0.05]),
        ),
    )
    # integrate dz/dt, the result will have shape N x 2
    z = odeint(dz_dt, z_init, ts, theta, rtol=1e-6, atol=1e-5, mxstep=1000)
    # measurement errors
    sigma = numpyro.sample("sigma", dist.LogNormal(-1, 1).expand([2]))
    # measured populations
    numpyro.sample("y", dist.LogNormal(jnp.log(z), sigma), obs=y)
def sample_basic_R(nRs, basic_r_prior=None):
    """
    Sample basic r

    :param nRs: number of regions
    :param basic_r_prior: dict contains basic r prior info.
    :return: basic R
    """
    if basic_r_prior is None:
        basic_r_prior = {
            "mean": 1.35,
            "type": "trunc_normal",
            "variability": 0.3
        }

    if basic_r_prior["type"] == "trunc_normal":
        basic_R = numpyro.sample(
            "basic_R",
            dist.TruncatedNormal(
                low=0.1,
                loc=basic_r_prior["mean"],
                scale=basic_r_prior["variability"] * jnp.ones(nRs),
            ),
        )
    else:
        raise ValueError("Basic R prior type must be in [trunc_normal]")

    return basic_R
Пример #4
0
def model_book(leg_left, leg_right, height, br_positive=False):
    a = numpyro.sample("a", dist.Normal(10, 100))
    bl = numpyro.sample("bl", dist.Normal(2, 10))
    if br_positive:
        br = numpyro.sample("br", dist.TruncatedNormal(0, 2, 10))
    else:
        br = numpyro.sample("br", dist.Normal(2, 10))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bl * leg_left + br * leg_right
    numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
Пример #5
0
def model_vague_prior(leg_left, leg_right, height, br_positive=False):
    # we modify the priors to make them less informative
    a = numpyro.sample("a", dist.Normal(0, 100))
    bl = numpyro.sample("bl", dist.Normal(0, 100))
    if br_positive:
        br = numpyro.sample("br", dist.TruncatedNormal(0, 0, 100))
    else:
        br = numpyro.sample("br", dist.Normal(0, 100))
    sigma = numpyro.sample("sigma", dist.Exponential(1))
    mu = a + bl * leg_left + br * leg_right
    numpyro.sample("height", dist.Normal(mu, sigma), obs=height)
Пример #6
0
    def sample(self, key, sample_shape=()):
        # TODO.
        # it is enough to return an arbitrary sample with correct shape
        # return jnp.zeros(sample_shape + self.event_shape)
        key_gamma, key_tn, key_normal = random.split(key, 3)

        k = self.df / 2
        w = random.gamma(key_gamma, k, sample_shape) / k
        z = (dist.TruncatedNormal(loc=0., scale=jnp.sqrt(1/w), low=0.0)
                 .sample(key_tn))
        delta = self.skew / jnp.sqrt(1 + self.skew ** 2)

        _loc = self.loc + self.scale * z * delta
        _scale = self.scale * jnp.sqrt(1 - delta ** 2)
        return random.normal(key_normal, sample_shape) * _scale + _loc
def sample_intervention_effects(nCMs, intervention_prior=None):
    """
    Sample interventions from some options

    :param nCMs: number of interventions
    :param intervention_prior: dictionary with relevant keys. usually type and scale
    :return: sample parameters
    """
    if intervention_prior is None:
        intervention_prior = {
            "type": "asymmetric_laplace",
            "scale": 30,
            "asymmetry": 0.5,
        }

    if intervention_prior["type"] == "trunc_normal":
        alpha_i = numpyro.sample(
            "alpha_i",
            dist.TruncatedNormal(low=-0.1,
                                 loc=jnp.zeros(nCMs),
                                 scale=intervention_prior["scale"]),
        )
    elif intervention_prior["type"] == "half_normal":
        alpha_i = numpyro.sample(
            "alpha_i",
            dist.HalfNormal(scale=jnp.ones(nCMs) *
                            intervention_prior["scale"]),
        )
    elif intervention_prior["type"] == "normal":
        alpha_i = numpyro.sample(
            "alpha_i",
            dist.Normal(loc=jnp.zeros(nCMs),
                        scale=intervention_prior["scale"]),
        )
    elif intervention_prior["type"] == "asymmetric_laplace":
        alpha_i = numpyro.sample(
            "alpha_i",
            AsymmetricLaplace(
                asymmetry=intervention_prior["asymmetry"],
                scale=jnp.ones(nCMs) * intervention_prior["scale"],
            ),
        )
    else:
        raise ValueError(
            "Intervention effect prior must take a value in [trunc_normal, normal, asymmetric_laplace, half_normal]"
        )

    return alpha_i
Пример #8
0
def observe_nonrandom(name, latent, det_noise_scale, obs=None):
    mask = True

    if obs is not None:
        mask = np.isfinite(obs) & (obs >= 0)
        obs = np.where(mask, obs, 0.0)

    mean = latent
    scale = det_noise_scale * mean + 1
    d = dist.TruncatedNormal(0., mean, scale)

    numpyro.deterministic("mean_" + name, mean)

    with numpyro.handlers.mask(mask_array=mask):
        y = numpyro.sample(name, d, obs=obs)

    return y
Пример #9
0
def model(N, y=None):
    """
    :param int N: number of measurement times
    :param numpy.ndarray y: measured populations with shape (N, 2)
    """
    # initial population
    z_init = numpyro.sample("z_init", dist.LogNormal(jnp.log(10), 1), sample_shape=(2,))
    # measurement times
    ts = jnp.arange(float(N))
    # parameters alpha, beta, gamma, delta of dz_dt
    theta = numpyro.sample(
        "theta",
        dist.TruncatedNormal(low=0., loc=jnp.array([0.5, 0.05, 1.5, 0.05]),
                             scale=jnp.array([0.5, 0.05, 0.5, 0.05])))
    # integrate dz/dt, the result will have shape N x 2
    z = odeint(dz_dt, z_init, ts, theta, rtol=1e-5, atol=1e-3, mxstep=500)
    # measurement errors, we expect that measured hare has larger error than measured lynx
    sigma = numpyro.sample("sigma", dist.Exponential(jnp.array([1, 2])))
    # measured populations (in log scale)
    numpyro.sample("y", dist.Normal(jnp.log(z), sigma), obs=y)
Пример #10
0
def observe_normal(name, latent, det_rate, det_noise_scale, obs=None):
    mask = True

    reg = 0.
    latent = latent + (reg / det_rate)

    if obs is not None:
        mask = np.isfinite(obs) & (obs >= 0)
        obs = np.where(mask, obs, 0.0)
        obs += reg

    det_rate = np.broadcast_to(det_rate, latent.shape)

    mean = det_rate * latent
    scale = det_noise_scale * mean + 1
    d = dist.TruncatedNormal(0., mean, scale)

    numpyro.deterministic("mean_" + name, mean)

    with numpyro.handlers.mask(mask_array=mask):
        y = numpyro.sample(name, d, obs=obs)

    return y
Пример #11
0
    def __call__(self,
                 T=50,
                 N=1e5,
                 T_future=0,
                 E_duration_est=4.0,
                 I_duration_est=2.0,
                 H_duration_est=10.0,
                 R0_est=3.0,
                 beta_shape=1.,
                 sigma_shape=100.,
                 gamma_shape=100.,
                 det_prob_est=0.3,
                 det_prob_conc=50.,
                 confirmed_dispersion=0.3,
                 death_dispersion=0.3,
                 rw_scale=2e-1,
                 forecast_rw_scale=0.,
                 drift_scale=None,
                 num_frozen=0,
                 rw_use_last=1,
                 confirmed=None,
                 death=None):
        '''
        Stochastic SEIR model. Draws random parameters and runs dynamics.
        '''

        # Sample initial number of infected individuals
        I0 = numpyro.sample("I0", dist.Uniform(0, 0.02 * N))
        E0 = numpyro.sample("E0", dist.Uniform(0, 0.02 * N))
        H0 = numpyro.sample("H0", dist.Uniform(0, 1e-3 * N))
        D0 = numpyro.sample("D0", dist.Uniform(0, 1e-3 * N))

        # Sample dispersion parameters around specified values
        death_dispersion = numpyro.sample(
            "death_dispersion",
            dist.TruncatedNormal(low=0.1, loc=death_dispersion, scale=0.15))

        confirmed_dispersion = numpyro.sample(
            "confirmed_dispersion",
            dist.TruncatedNormal(low=0.1, loc=confirmed_dispersion,
                                 scale=0.15))

        # Sample parameters
        sigma = numpyro.sample(
            "sigma", dist.Gamma(sigma_shape, sigma_shape * E_duration_est))

        gamma = numpyro.sample(
            "gamma", dist.Gamma(gamma_shape, gamma_shape * I_duration_est))

        beta0 = numpyro.sample(
            "beta0",
            dist.Gamma(beta_shape, beta_shape * I_duration_est / R0_est))

        det_prob0 = numpyro.sample(
            "det_prob0",
            dist.Beta(det_prob_est * det_prob_conc,
                      (1 - det_prob_est) * det_prob_conc))

        det_prob_d = numpyro.sample("det_prob_d",
                                    dist.Beta(.9 * 100, (1 - .9) * 100))

        death_prob = numpyro.sample("death_prob",
                                    dist.Beta(0.01 * 100, (1 - 0.01) * 100))
        #dist.Beta(0.02 * 1000, (1-0.02) * 1000))

        death_rate = numpyro.sample("death_rate",
                                    dist.Gamma(10, 10 * H_duration_est))

        if drift_scale is not None:
            drift = numpyro.sample("drift",
                                   dist.Normal(loc=0, scale=drift_scale))
        else:
            drift = 0

        x0 = SEIRDModel.seed(N=N, I=I0, E=E0, H=H0, D=D0)
        numpyro.deterministic("x0", x0)

        # Split observations into first and rest
        if confirmed is None:
            confirmed0, confirmed = (None, None)
        else:
            confirmed0 = confirmed[0]
            confirmed = clean_daily_obs(onp.diff(confirmed))

        if death is None:
            death0, death = (None, None)
        else:
            death0 = death[0]
            death = clean_daily_obs(onp.diff(death))

        # First observation
        with numpyro.handlers.scale(scale=0.5):
            y0 = observe_nb2("dy0",
                             x0[6],
                             det_prob0,
                             confirmed_dispersion,
                             obs=confirmed0)

        with numpyro.handlers.scale(scale=2.0):
            z0 = observe_nb2("dz0",
                             x0[5],
                             det_prob_d,
                             death_dispersion,
                             obs=death0)

        params = (beta0, sigma, gamma, rw_scale, drift, det_prob0,
                  confirmed_dispersion, death_dispersion, death_prob,
                  death_rate, det_prob_d)

        beta, det_prob, x, y, z = self.dynamics(T,
                                                params,
                                                x0,
                                                num_frozen=num_frozen,
                                                confirmed=confirmed,
                                                death=death)

        x = np.vstack((x0, x))
        y = np.append(y0, y)
        z = np.append(z0, z)

        if T_future > 0:

            params = (beta[-rw_use_last:].mean(), sigma, gamma,
                      forecast_rw_scale, drift, det_prob[-rw_use_last:].mean(),
                      confirmed_dispersion, death_dispersion, death_prob,
                      death_rate, det_prob_d)

            beta_f, det_rate_rw_f, x_f, y_f, z_f = self.dynamics(
                T_future + 1, params, x[-1, :], suffix="_future")

            x = np.vstack((x, x_f))
            y = np.append(y, y_f)
            z = np.append(z, z_f)

        return beta, x, y, z, det_prob, death_prob
Пример #12
0
    def __call__(self,
                 T=50,
                 N=1e5,
                 T_future=0,
                 E_duration_est=4.0,
                 I_duration_est=2.0,
                 R0_est=3.0,
                 beta_shape=1.,
                 sigma_shape=100.,
                 gamma_shape=100.,
                 det_prob_est=0.3,
                 det_prob_conc=50.,
                 confirmed_dispersion=0.3,
                 death_dispersion=0.3,
                 rw_scale=2e-1,
                 forecast_rw_scale=0.,
                 drift_scale=None,
                 num_frozen=0,
                 rw_use_last=1,
                 confirmed=None,
                 death=None,
                 place_data=None):
        '''
        Stochastic SEIR model. Draws random parameters and runs dynamics.
        '''

        # Sample initial number of infected individuals
        I0 = numpyro.sample("I0", dist.Uniform(0, 0.02 * N))
        E0 = numpyro.sample("E0", dist.Uniform(0, 0.02 * N))
        H0 = numpyro.sample("H0", dist.Uniform(0, 1e-3 * N))
        D0 = numpyro.sample("D0", dist.Uniform(0, 1e-3 * N))

        # Sample dispersion parameters around specified values

        death_dispersion = numpyro.sample(
            "death_dispersion",
            dist.TruncatedNormal(low=0.1, loc=death_dispersion, scale=0.15))

        confirmed_dispersion = numpyro.sample(
            "confirmed_dispersion",
            dist.TruncatedNormal(low=0.1, loc=confirmed_dispersion,
                                 scale=0.15))
        if confirmed is None:
            confirmed0, confirmed = (None, None)
            d = {'t': [0]}
        else:
            confirmed0 = confirmed[0]
            confirmed = clean_daily_obs(onp.diff(confirmed))
            d = {'t': onp.arange(len(confirmed))}

        if death is None:
            death0, death = (None, None)
        else:
            death0 = death[0]
            death = clean_daily_obs(onp.diff(death))
        place_data = pd.DataFrame()
        R0_glm = GLM("1 + cr(t,df=3)",
                     d,
                     log_link,
                     partial(Gamma, var=1),
                     prior=dist.Normal(0, 1),
                     guess=3.5,
                     name="R0")
        R0 = R0_glm.sample(shape=(-1))[0]

        # Sample parameters
        sigma = numpyro.sample(
            "sigma", dist.Gamma(sigma_shape, sigma_shape * E_duration_est))

        gamma = numpyro.sample(
            "gamma", dist.Gamma(gamma_shape, gamma_shape * I_duration_est))

        beta0 = R0 * gamma  #numpyro.sample("beta0",
        #              dist.Gamma(beta_shape, beta_shape * I_duration_est/R0_est))

        det_prob0 = numpyro.sample(
            "det_prob0",
            dist.Beta(det_prob_est * det_prob_conc,
                      (1 - det_prob_est) * det_prob_conc))

        det_prob_d = numpyro.sample("det_prob_d",
                                    dist.Beta(.9 * 100, (1 - .9) * 100))

        death_prob = numpyro.sample("death_prob",
                                    dist.Beta(.01 * 100, (1 - .01) * 100))

        death_rate = numpyro.sample("death_rate", dist.Gamma(10, 10 * 10))

        if drift_scale is not None:
            drift = numpyro.sample("drift",
                                   dist.Normal(loc=0, scale=drift_scale))
        else:
            drift = 0

        x0 = SEIRDModel.seed(N=N, I=I0, E=E0, H=H0, D=D0)
        numpyro.deterministic("x0", x0)

        # Split observations into first and rest

        # First observation
        with numpyro.handlers.scale(scale_factor=0.5):
            y0 = observe_normal("dy0",
                                x0[6],
                                det_prob0,
                                confirmed_dispersion,
                                obs=confirmed0)

        with numpyro.handlers.scale(scale_factor=2.0):
            z0 = observe_normal("dz0",
                                x0[5],
                                det_prob_d,
                                death_dispersion,
                                obs=death0)
        params = (beta0, sigma, gamma, rw_scale, drift, det_prob0,
                  confirmed_dispersion, death_dispersion, death_prob,
                  death_rate, det_prob_d)

        beta, det_prob, x, y = self.dynamics(T,
                                             params,
                                             x0,
                                             num_frozen=num_frozen,
                                             confirmed=confirmed,
                                             death=death)

        x = np.vstack((x0, x))
        y = np.append(y0, y)

        if T_future > 0:
            d_future = {'t': onp.arange(T, T + T_future)}
            R0_future = R0_glm.sample(d_future, name="R0_future",
                                      shape=(-1))[0]
            beta_future = R0_future * gamma
            #beta_future = np.append(beta[-1],beta_future)
            params = (beta_future, sigma, gamma, forecast_rw_scale, drift,
                      det_prob[-rw_use_last:].mean(), confirmed_dispersion,
                      death_dispersion, death_prob, death_rate, det_prob_d)

            beta_f, det_rate_rw_f, x_f, y_f = self.dynamics(T_future + 1,
                                                            params,
                                                            x[-1, :],
                                                            suffix="_future")

            x = np.vstack((x, x_f))
            y = np.append(y, y_f)

        return beta, x, y, det_prob, death_prob
Пример #13
0
def spire_model_CIGALE(priors, sed_prior, params):
    """
    numpyro model for SPIRE maps using cigale emulator
    :param priors: list of xid+ SPIRE prior objects
    :type priors: list
    :param sed_prior: xid+ SED prior class
    :type sed_prior:
    :return:
    :rtype:
    """

    # get pointing matices in useable format
    pointing_matrices = [([p.amat_row, p.amat_col], p.amat_data)
                         for p in priors]

    # get background priors for maps
    bkg_mu = np.asarray([p.bkg[0] for p in priors]).T
    bkg_sig = np.asarray([p.bkg[1] for p in priors]).T

    # background priors
    with numpyro.plate('bands', len(priors)):
        bkg = numpyro.sample('bkg', dist.Normal(bkg_mu, bkg_sig))

    # redshift-sfr relation parameters
    m = numpyro.sample('m', dist.Normal(params['m_mu'], params['m_sig']))
    c = numpyro.sample('c', dist.Normal(params['c_mu'], params['c_sig']))

    # sfr dispersion parameter
    sfr_sig = numpyro.sample('sfr_sig', dist.HalfNormal(params['sfr_disp']))

    # sample parameters for each source (treat as conditionaly independent hence plate)
    with numpyro.plate('nsrc', priors[0].nsrc):
        # use truncated normal for redshift, with mean and sigma from prior
        redshift = numpyro.sample(
            'redshift',
            dist.TruncatedNormal(0.01, sed_prior.params_mu[:, 1],
                                 sed_prior.params_sig[:, 1]))
        # use beta distribution for AGN as a fraction
        agn = numpyro.sample('agn', dist.Beta(1.0, 3.0))

        # use handlers.reparam as sampling from standard normal and transforming by redshift relation
        # with numpyro.handlers.reparam(config={"params": TransformReparam()}):
        # sfr = numpyro.sample('sfr', dist.TransformedDistribution(dist.Normal(0.0, 1.0),
        # dist.transforms.AffineTransform(redshift * m + c,
        # jnp.full(
        # priors[0].nsrc,
        # sfr_sig))))
        sfr = numpyro.sample(
            'sfr',
            dist.Normal(redshift * m + c, jnp.full(priors[0].nsrc, sfr_sig)))

    # stack params and make vector ready to be used by emualator
    params = jnp.vstack((sfr[None, :], agn[None, :], redshift[None, :])).T
    # Use emulator to get fluxes. As emulator provides log flux, convert.
    src_f = jnp.exp(sed_prior.emulator['net_apply'](
        sed_prior.emulator['params'], params))

    # create model map by multiplying fluxes by pointing matrix and adding background
    db_hat_psw = sp_matmul(pointing_matrices[0], src_f[:, 0][:, None],
                           priors[0].snpix).reshape(-1) + bkg[0]
    db_hat_pmw = sp_matmul(pointing_matrices[1], src_f[:, 1][:, None],
                           priors[1].snpix).reshape(-1) + bkg[1]
    db_hat_plw = sp_matmul(pointing_matrices[2], src_f[:, 2][:, None],
                           priors[2].snpix).reshape(-1) + bkg[2]

    # for each band, condition on data
    with numpyro.plate('psw_pixels', priors[0].snim.size):  # as ind_psw:
        numpyro.sample("obs_psw",
                       dist.Normal(db_hat_psw, priors[0].snim),
                       obs=priors[0].sim)
    with numpyro.plate('pmw_pixels', priors[1].snim.size):  # as ind_pmw:
        numpyro.sample("obs_pmw",
                       dist.Normal(db_hat_pmw, priors[1].snim),
                       obs=priors[1].sim)
    with numpyro.plate('plw_pixels', priors[2].snim.size):  # as ind_plw:
        numpyro.sample("obs_plw",
                       dist.Normal(db_hat_plw, priors[2].snim),
                       obs=priors[2].sim)