Exemplo n.º 1
0
def model(y):
    d = y.shape[1]
    N = y.shape[0]
    options = dict(dtype=y.dtype, device=y.device)
    # Vector of variances for each of the d variables
    theta = pyro.sample("theta", dist.HalfCauchy(torch.ones(d, **options)))
    # Lower cholesky factor of a correlation matrix
    concentration = torch.ones(
        (),
        **options)  # Implies a uniform distribution over correlation matrices
    L_omega = pyro.sample("L_omega", dist.LKJCholesky(d, concentration))
    # Lower cholesky factor of the covariance matrix
    L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega)
    # For inference with SVI, one might prefer to use torch.bmm(theta.sqrt().diag_embed(), L_omega)

    # Vector of expectations
    mu = torch.zeros(d, **options)

    with pyro.plate("observations", N):
        obs = pyro.sample("obs",
                          dist.MultivariateNormal(mu, scale_tril=L_Omega),
                          obs=y)
    return obs
def model(data, params):
    # initialize data
    N = data["N"]
    encouraged = data["encouraged"]
    setting = data["setting"]
    site = data["site"]
    pretest = data["pretest"]
    watched = data["watched"]
    # initialize transformed data
    site2 = data["site2"]
    site3 = data["site3"]
    site4 = data["site4"]
    site5 = data["site5"]

    # init parameters
    beta = params["beta"]
    sigma = pyro.sample('sigma', dist.HalfCauchy(2.5))
    # initialize transformed parameters
    # model block
    with pyro.plate('data', N):
        watched = pyro.sample('obs', dist.Normal(beta[0] + beta[1] * encouraged + \
                              beta[2] * pretest + beta[3] * site2 + beta[4] * site3 + \
                              beta[5] * site4 + beta[6] * site5 + beta[7] * setting, sigma), obs=watched)
Exemplo n.º 3
0
    def __call__(self):
        response = self.response
        num_of_obs = self.num_of_obs
        extra_out = {}

        # smoothing params
        if self.lev_sm_input < 0:
            lev_sm = pyro.sample("lev_sm", dist.Uniform(0, 1))
        else:
            lev_sm = torch.tensor(self.lev_sm_input, dtype=torch.double)
            extra_out['lev_sm'] = lev_sm
        if self.slp_sm_input < 0:
            slp_sm = pyro.sample("slp_sm", dist.Uniform(0, 1))
        else:
            slp_sm = torch.tensor(self.slp_sm_input, dtype=torch.double)
            extra_out['slp_sm'] = slp_sm

        # residual tuning parameters
        nu = pyro.sample("nu", dist.Uniform(self.min_nu, self.max_nu))

        # prior for residuals
        obs_sigma = pyro.sample("obs_sigma", dist.HalfCauchy(self.cauchy_sd))

        # regression parameters
        if self.num_of_pr == 0:
            pr = torch.zeros(num_of_obs)
            pr_beta = pyro.deterministic("pr_beta", torch.zeros(0))
        else:
            with pyro.plate("pr", self.num_of_pr):
                # fixed scale ridge
                if self.reg_penalty_type == 0:
                    pr_sigma = self.pr_sigma_prior
                # auto scale ridge
                elif self.reg_penalty_type == 2:
                    # weak prior for sigma
                    pr_sigma = pyro.sample(
                        "pr_sigma", dist.HalfCauchy(self.auto_ridge_scale))
                # case when it is not lasso
                if self.reg_penalty_type != 1:
                    # weak prior for betas
                    pr_beta = pyro.sample(
                        "pr_beta",
                        dist.FoldedDistribution(
                            dist.Normal(self.pr_beta_prior, pr_sigma)))
                else:
                    pr_beta = pyro.sample(
                        "pr_beta",
                        dist.FoldedDistribution(
                            dist.Laplace(self.pr_beta_prior,
                                         self.lasso_scale)))
            pr = pr_beta @ self.pr_mat.transpose(-1, -2)

        if self.num_of_nr == 0:
            nr = torch.zeros(num_of_obs)
            nr_beta = pyro.deterministic("nr_beta", torch.zeros(0))
        else:
            with pyro.plate("nr", self.num_of_nr):
                # fixed scale ridge
                if self.reg_penalty_type == 0:
                    nr_sigma = self.nr_sigma_prior
                # auto scale ridge
                elif self.reg_penalty_type == 2:
                    # weak prior for sigma
                    nr_sigma = pyro.sample(
                        "nr_sigma", dist.HalfCauchy(self.auto_ridge_scale))
                # case when it is not lasso
                if self.reg_penalty_type != 1:
                    # weak prior for betas
                    nr_beta = pyro.sample(
                        "nr_beta",
                        dist.FoldedDistribution(
                            dist.Normal(self.nr_beta_prior, nr_sigma)))
                else:
                    nr_beta = pyro.sample(
                        "nr_beta",
                        dist.FoldedDistribution(
                            dist.Laplace(self.nr_beta_prior,
                                         self.lasso_scale)))
            nr = nr_beta @ self.nr_mat.transpose(-1, -2)

        if self.num_of_rr == 0:
            rr = torch.zeros(num_of_obs)
            rr_beta = pyro.deterministic("rr_beta", torch.zeros(0))
        else:
            with pyro.plate("rr", self.num_of_rr):
                # fixed scale ridge
                if self.reg_penalty_type == 0:
                    rr_sigma = self.rr_sigma_prior
                # auto scale ridge
                elif self.reg_penalty_type == 2:
                    # weak prior for sigma
                    rr_sigma = pyro.sample(
                        "rr_sigma", dist.HalfCauchy(self.auto_ridge_scale))
                # case when it is not lasso
                if self.reg_penalty_type != 1:
                    # weak prior for betas
                    rr_beta = pyro.sample(
                        "rr_beta", dist.Normal(self.rr_beta_prior, rr_sigma))
                else:
                    rr_beta = pyro.sample(
                        "rr_beta",
                        dist.Laplace(self.rr_beta_prior, self.lasso_scale))
            rr = rr_beta @ self.rr_mat.transpose(-1, -2)

        # a hack to make sure we don't use a dimension "1" due to rr_beta and pr_beta sampling
        r = pr + nr + rr
        if r.dim() > 1:
            r = r.unsqueeze(-2)

        # trend parameters
        # local trend proportion
        lt_coef = pyro.sample("lt_coef", dist.Uniform(0, 1))
        # global trend proportion
        gt_coef = pyro.sample("gt_coef", dist.Uniform(-0.5, 0.5))
        # global trend parameter
        gt_pow = pyro.sample("gt_pow", dist.Uniform(0, 1))

        # seasonal parameters
        if self.is_seasonal:
            # seasonality smoothing parameter
            if self.sea_sm_input < 0:
                sea_sm = pyro.sample("sea_sm", dist.Uniform(0, 1))
            else:
                sea_sm = torch.tensor(self.sea_sm_input, dtype=torch.double)
                extra_out['sea_sm'] = sea_sm

            # initial seasonality
            # 33% lift is with 1 sd prob.
            init_sea = pyro.sample(
                "init_sea",
                dist.Normal(0, 0.33).expand([self.seasonality]).to_event(1))
            init_sea = init_sea - init_sea.mean(-1, keepdim=True)

        b = [None] * num_of_obs  # slope
        l = [None] * num_of_obs  # level
        if self.is_seasonal:
            s = [None] * (self.num_of_obs + self.seasonality)
            for t in range(self.seasonality):
                s[t] = init_sea[..., t]
            s[self.seasonality] = init_sea[..., 0]
        else:
            s = [torch.tensor(0.)] * num_of_obs

        # states initial condition
        b[0] = torch.zeros_like(slp_sm)
        if self.is_seasonal:
            l[0] = response[0] - r[..., 0] - s[0]
        else:
            l[0] = response[0] - r[..., 0]

        # update process
        for t in range(1, num_of_obs):
            # this update equation with l[t-1] ONLY.
            # intentionally different from the Holt-Winter form
            # this change is suggested from Slawek's original SLGT model
            l[t] = lev_sm * (response[t] - s[t] -
                             r[..., t]) + (1 - lev_sm) * l[t - 1]
            b[t] = slp_sm * (l[t] - l[t - 1]) + (1 - slp_sm) * b[t - 1]
            if self.is_seasonal:
                s[t + self.seasonality] = \
                    sea_sm * (response[t] - l[t] - r[..., t]) + (1 - sea_sm) * s[t]

        # evaluation process
        # vectorize as much math as possible
        for lst in [b, l, s]:
            # torch.stack requires all items to have the same shape, but the
            # initial items of our lists may not have batch_shape, so we expand.
            lst[0] = lst[0].expand_as(lst[-1])
        b = torch.stack(b, dim=-1).reshape(b[0].shape[:-1] + (-1, ))
        l = torch.stack(l, dim=-1).reshape(l[0].shape[:-1] + (-1, ))
        s = torch.stack(s, dim=-1).reshape(s[0].shape[:-1] + (-1, ))

        lgt_sum = l + gt_coef * l.abs()**gt_pow + lt_coef * b
        lgt_sum = torch.cat([l[..., :1], lgt_sum[..., :-1]],
                            dim=-1)  # shift by 1
        # a hack here as well to get rid of the extra "1" in r.shape
        if r.dim() >= 2:
            r = r.squeeze(-2)
        yhat = lgt_sum + s[..., :num_of_obs] + r

        with pyro.plate("response_plate", num_of_obs - 1):
            pyro.sample("response",
                        dist.StudentT(nu, yhat[..., 1:], obs_sigma),
                        obs=response[1:])

        # we care beta not the pr_beta, nr_beta, ...
        extra_out['beta'] = torch.cat([pr_beta, nr_beta, rr_beta], dim=-1)

        extra_out.update({'b': b, 'l': l, 's': s, 'lgt_sum': lgt_sum})
        return extra_out
def model(x, y, alt_av, alt_ids_cuda):
    # global parameters in the model
    if diagonal_alpha:
        alpha_mu = pyro.sample(
            "alpha",
            dist.Normal(torch.zeros(len(non_mix_params), device=x.device),
                        1).to_event(1))
    else:
        alpha_mu = pyro.sample(
            "alpha",
            dist.MultivariateNormal(
                torch.zeros(len(non_mix_params), device=x.device),
                scale_tril=torch.tril(
                    1 * torch.eye(len(non_mix_params), device=x.device))))

    if diagonal_beta_mu:
        beta_mu = pyro.sample(
            "beta_mu",
            dist.Normal(torch.zeros(len(mix_params), device=x.device),
                        1.).to_event(1))
    else:
        beta_mu = pyro.sample(
            "beta_mu",
            dist.MultivariateNormal(
                torch.zeros(len(mix_params), device=x.device),
                scale_tril=torch.tril(
                    1 * torch.eye(len(mix_params), device=x.device))))

    # Vector of variances for each of the d variables
    theta = pyro.sample(
        "theta",
        dist.HalfCauchy(
            10. * torch.ones(len(mix_params), device=x.device)).to_event(1))
    # Lower cholesky factor of a correlation matrix
    eta = 1. * torch.ones(
        1, device=x.device
    )  # Implies a uniform distribution over correlation matrices
    L_omega = pyro.sample("L_omega",
                          dist.LKJCorrCholesky(len(mix_params), eta))
    # Lower cholesky factor of the covariance matrix
    L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega)

    # local parameters in the model
    random_params = pyro.sample(
        "beta_resp",
        dist.MultivariateNormal(beta_mu.repeat(num_resp, 1),
                                scale_tril=L_Omega).to_event(1))

    # vector of respondent parameters: global + local (respondent)
    params_resp = torch.cat([alpha_mu.repeat(num_resp, 1), random_params],
                            dim=-1)

    # vector of betas of MXL (may repeat the same learnable parameter multiple times; random + fixed effects)
    beta_resp = torch.cat([
        params_resp[:, beta_to_params_map[i]] for i in range(num_alternatives)
    ],
                          dim=-1)

    with pyro.plate("locals", len(x), subsample_size=BATCH_SIZE) as ind:

        with pyro.plate("data_resp", T):
            # compute utilities for each alternative
            utilities = torch.scatter_add(
                zeros_vec[:, ind, :], 2,
                alt_ids_cuda[ind, :, :].transpose(0, 1),
                torch.mul(x[ind, :, :].transpose(0, 1), beta_resp[ind, :]))

            # adjust utility for unavailable alternatives
            utilities += alt_av[ind, :, :].transpose(0, 1)

            # likelihood
            pyro.sample("obs",
                        dist.Categorical(logits=utilities),
                        obs=y[ind, :].transpose(0, 1))
Exemplo n.º 5
0
def init_params(data):
    params = {}
    params["sigma_a1"] = pyro.sample("sigma_a1", dist.HalfCauchy(2.5))
    params["sigma_a2"] = pyro.sample("sigma_a2", dist.HalfCauchy(2.5))
    params["sigma_y"] = pyro.sample("sigma_y", dist.HalfCauchy(2.5))
    return params
Exemplo n.º 6
0
def init_params(data):
    params = {}
    params["sigma_alpha"] = pyro.sample("sigma_alpha", dist.HalfCauchy(2.5))
    params["sigma_beta"] = pyro.sample("sigma_beta", dist.HalfCauchy(2.5))

    return params
Exemplo n.º 7
0
def torus_dbn(phis=None,
              psis=None,
              lengths=None,
              num_sequences=None,
              num_states=55,
              prior_conc=0.1,
              prior_loc=0.0,
              prior_length_shape=100.,
              prior_length_rate=100.,
              prior_kappa_min=10.,
              prior_kappa_max=1000.):
    # From https://pyro.ai/examples/hmm.html
    with ignore_jit_warnings():
        if lengths is not None:
            assert num_sequences is None
            num_sequences = int(lengths.shape[0])
        else:
            assert num_sequences is not None
    transition_probs = pyro.sample(
        'transition_probs',
        dist.Dirichlet(
            torch.ones(num_states, num_states, dtype=torch.float) *
            num_states).to_event(1))
    length_shape = pyro.sample('length_shape',
                               dist.HalfCauchy(prior_length_shape))
    length_rate = pyro.sample('length_rate',
                              dist.HalfCauchy(prior_length_rate))
    phi_locs = pyro.sample(
        'phi_locs',
        dist.VonMises(
            torch.ones(num_states, dtype=torch.float) * prior_loc,
            torch.ones(num_states, dtype=torch.float) *
            prior_conc).to_event(1))
    phi_kappas = pyro.sample(
        'phi_kappas',
        dist.Uniform(
            torch.ones(num_states, dtype=torch.float) * prior_kappa_min,
            torch.ones(num_states, dtype=torch.float) *
            prior_kappa_max).to_event(1))
    psi_locs = pyro.sample(
        'psi_locs',
        dist.VonMises(
            torch.ones(num_states, dtype=torch.float) * prior_loc,
            torch.ones(num_states, dtype=torch.float) *
            prior_conc).to_event(1))
    psi_kappas = pyro.sample(
        'psi_kappas',
        dist.Uniform(
            torch.ones(num_states, dtype=torch.float) * prior_kappa_min,
            torch.ones(num_states, dtype=torch.float) *
            prior_kappa_max).to_event(1))
    element_plate = pyro.plate('elements', 1, dim=-1)
    with pyro.plate('sequences', num_sequences, dim=-2) as batch:
        if lengths is not None:
            lengths = lengths[batch]
            obs_length = lengths.float().unsqueeze(-1)
        else:
            obs_length = None
        state = 0
        sam_lengths = pyro.sample('length',
                                  dist.TransformedDistribution(
                                      dist.GammaPoisson(
                                          length_shape, length_rate),
                                      AffineTransform(0., 1.)),
                                  obs=obs_length)
        if lengths is None:
            lengths = sam_lengths.squeeze(-1).long()
        for t in pyro.markov(range(lengths.max())):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                state = pyro.sample(f'state_{t}',
                                    dist.Categorical(transition_probs[state]),
                                    infer={'enumerate': 'parallel'})
                if phis is not None:
                    obs_phi = Vindex(phis)[batch, t].unsqueeze(-1)
                else:
                    obs_phi = None
                if psis is not None:
                    obs_psi = Vindex(psis)[batch, t].unsqueeze(-1)
                else:
                    obs_psi = None
                with element_plate:
                    pyro.sample(f'phi_{t}',
                                dist.VonMises(phi_locs[state],
                                              phi_kappas[state]),
                                obs=obs_phi)
                    pyro.sample(f'psi_{t}',
                                dist.VonMises(psi_locs[state],
                                              psi_kappas[state]),
                                obs=obs_psi)
Exemplo n.º 8
0
 def model(data):
     mu = pyro.sample('mu', dist.Normal(0., 1.))
     sigma = pyro.sample('sigma', dist.HalfCauchy(5.))
     with pyro.plate('observe_data'):
         pyro.sample('obs', dist.Normal(mu, sigma), obs=data)
Exemplo n.º 9
0
    def model(self, data):
        '''
        Define the parameters
        '''
        n_ind = data['n_ind']
        n_trt = data['n_trt']
        n_tms = data['n_tms']
        n_mrk = data['n_mrk']

        n_prs = n_trt * n_tms * n_mrk

        plt_ind = pyro.plate('individuals', n_ind, dim=-3)
        plt_trt = pyro.plate('treatments', n_trt, dim=-2)
        plt_tms = pyro.plate('times', n_tms, dim=-1)

        pars = {}
        # covariance factors
        with plt_tms:
            # learning dt time step sizes
            # if k(t1,t2) is independent of time, can instead learn scales and variances for RBF kernels that use data['time_vals']
            pars['dt0'] = pyro.sample('dt0', dist.Normal(0, 1))
            pars['dt1'] = pyro.sample('dt1', dist.Normal(0, 1))

        pars['theta_trt0'] = pyro.sample('theta_trt0',
                                         dist.HalfCauchy(torch.ones(n_trt)))
        pars['theta_mrk0'] = pyro.sample('theta_mrk0',
                                         dist.HalfCauchy(torch.ones(n_mrk)))
        pars['theta_trt1'] = pyro.sample('theta_trt1',
                                         dist.HalfCauchy(torch.ones(n_trt)))
        pars['L_omega_trt1'] = pyro.sample(
            'L_omega_trt1', dist.LKJCorrCholesky(n_trt, torch.ones(1)))
        pars['theta_mrk1'] = pyro.sample('theta_mrk1',
                                         dist.HalfCauchy(torch.ones(n_mrk)))
        pars['L_omega_mrk1'] = pyro.sample(
            'L_omega_mrk1', dist.LKJCorrCholesky(n_mrk, torch.ones(1)))

        times0 = fun.pad(torch.cumsum(pars['dt0'].exp().log1p(), 0), (1, 0),
                         value=0)[:-1].unsqueeze(1)
        times1 = fun.pad(torch.cumsum(pars['dt1'].exp().log1p(), 0), (1, 0),
                         value=0)[:-1].unsqueeze(1)
        cov_t0 = (-torch.cdist(times0, times0)).exp()
        cov_t1 = (-torch.cdist(times1, times1)).exp()

        cov_i0 = pars['theta_trt0'].diag()
        L_Omega_trt = torch.mm(torch.diag(pars['theta_trt1'].sqrt()),
                               pars['L_omega_trt1'])
        cov_i1 = L_Omega_trt.mm(L_Omega_trt.t())

        cov_m0 = pars['theta_mrk0'].diag()
        L_Omega_mrk = torch.mm(torch.diag(pars['theta_mrk1'].sqrt()),
                               pars['L_omega_mrk1'])
        cov_m1 = L_Omega_mrk.mm(L_Omega_mrk.t())

        # kronecker product of the factors
        cov_itm0 = torch.einsum('ij,tu,mn->itmjun',
                                [cov_i0, cov_t0, cov_m0]).view(n_prs, n_prs)
        cov_itm1 = torch.einsum('ij,tu,mn->itmjun',
                                [cov_i1, cov_t1, cov_m1]).view(n_prs, n_prs)

        # global and individual level params of each marker, treatment, and time point
        pars['glb'] = pyro.sample(
            'glb', dist.MultivariateNormal(torch.zeros(n_prs), cov_itm0))
        with plt_ind:
            pars['ind'] = pyro.sample(
                'ind', dist.MultivariateNormal(torch.zeros(n_prs), cov_itm1))

        # observation noise, time series bias and scale
        pars['noise_scale'] = pyro.sample('noise_scale',
                                          dist.HalfCauchy(torch.ones(n_mrk)))
        pars['t0_scale'] = pyro.sample('t0_scale',
                                       dist.HalfCauchy(torch.ones(n_mrk)))
        with plt_ind:
            pars['t0'] = pyro.sample(
                't0',
                dist.MultivariateNormal(torch.zeros(n_mrk),
                                        pars['t0_scale'].diag()))
            with plt_trt, plt_tms:
                pars['noise'] = pyro.sample(
                    'noise',
                    dist.MultivariateNormal(torch.zeros(n_mrk),
                                            pars['noise_scale'].diag()))

        # likelihood of the data
        distr = self.get_distr(data, pars)
        pyro.sample('obs', distr, obs=data['Y'])
Exemplo n.º 10
0
def spire_model(priors, sub=1):

    if len(priors) != 3:
        raise ValueError
    band_plate = pyro.plate('bands', len(priors), dim=-2)
    src_plate = pyro.plate('nsrc', priors[0].nsrc, dim=-1)
    psw_plate = pyro.plate('psw_pixels',
                           priors[0].sim.size,
                           dim=-3,
                           subsample_size=np.rint(
                               sub * priors[0].sim.size).astype(int))
    pmw_plate = pyro.plate('pmw_pixels',
                           priors[1].sim.size,
                           dim=-3,
                           subsample_size=np.rint(
                               sub * priors[1].sim.size).astype(int))
    plw_plate = pyro.plate('plw_pixels',
                           priors[2].sim.size,
                           dim=-3,
                           subsample_size=np.rint(
                               sub * priors[2].sim.size).astype(int))
    pointing_matrices = [
        torch.sparse.FloatTensor(torch.LongTensor([p.amat_row, p.amat_col]),
                                 torch.Tensor(p.amat_data),
                                 torch.Size([p.snpix, p.nsrc])) for p in priors
    ]

    bkg_prior = torch.tensor([p.bkg[0] for p in priors])
    bkg_prior_sig = torch.tensor([p.bkg[1] for p in priors])
    nsrc = priors[0].nsrc

    f_low_lim = torch.tensor([p.prior_flux_lower for p in priors],
                             dtype=torch.float)
    f_up_lim = torch.tensor([p.prior_flux_upper for p in priors],
                            dtype=torch.float)

    with band_plate as ind_band:
        sigma_conf = pyro.sample(
            'sigma_conf',
            dist.HalfCauchy(torch.tensor([1.0]), torch.tensor([0.5])).expand(
                [1]).to_event(1)).squeeze(-1)
        bkg = pyro.sample('bkg',
                          dist.Normal(-5,
                                      0.5).expand([1]).to_event(1)).squeeze(-1)
        with src_plate as ind_src:
            src_f = pyro.sample('src_f',
                                dist.Uniform(0, 1).expand(
                                    [1]).to_event(1)).squeeze(-1)
    f_vec = (f_up_lim - f_low_lim) * src_f + f_low_lim
    db_hat_psw = torch.sparse.mm(pointing_matrices[0],
                                 f_vec[0, ...].unsqueeze(-1)) + bkg[0]
    db_hat_pmw = torch.sparse.mm(pointing_matrices[1].to_dense(),
                                 f_vec[1, ...].unsqueeze(-1)) + bkg[1]
    db_hat_plw = torch.sparse.mm(pointing_matrices[2].to_dense(),
                                 f_vec[2, ...].unsqueeze(-1)) + bkg[2]
    sigma_tot_psw = torch.sqrt(
        torch.pow(torch.tensor(priors[0].snim), 2) +
        torch.pow(sigma_conf[0], 2))
    sigma_tot_pmw = torch.sqrt(
        torch.pow(torch.tensor(priors[1].snim), 2) +
        torch.pow(sigma_conf[1], 2))
    sigma_tot_plw = torch.sqrt(
        torch.pow(torch.tensor(priors[2].snim), 2) +
        torch.pow(sigma_conf[2], 2))
    with psw_plate as ind_psw:
        psw_map = pyro.sample("obs_psw",
                              dist.Normal(db_hat_psw.squeeze()[ind_psw],
                                          sigma_tot_psw[ind_psw]),
                              obs=torch.tensor(priors[0].sim[ind_psw]))
    with pmw_plate as ind_pmw:
        pmw_map = pyro.sample("obs_pmw",
                              dist.Normal(db_hat_pmw.squeeze()[ind_pmw],
                                          sigma_tot_pmw[ind_pmw]),
                              obs=torch.tensor(priors[1].sim[ind_pmw]))
    with plw_plate as ind_plw:
        plw_map = pyro.sample("obs_plw",
                              dist.Normal(db_hat_plw.squeeze()[ind_plw],
                                          sigma_tot_plw[ind_plw]),
                              obs=torch.tensor(priors[2].sim[ind_plw]))
    return psw_map, pmw_map, plw_map