Exemplo n.º 1
0
def model(y):
    d = y.shape[1]
    N = y.shape[0]
    options = dict(dtype=y.dtype, device=y.device)
    # Vector of variances for each of the d variables
    theta = pyro.sample("theta", dist.HalfCauchy(torch.ones(d, **options)))
    # Lower cholesky factor of a correlation matrix
    eta = torch.ones(1, **options)  # Implies a uniform distribution over correlation matrices
    L_omega = pyro.sample("L_omega", dist.LKJCorrCholesky(d, eta))
    # Lower cholesky factor of the covariance matrix
    L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega)
    # For inference with SVI, one might prefer to use torch.bmm(theta.sqrt().diag_embed(), L_omega)

    # Vector of expectations
    mu = torch.zeros(d, **options)

    with pyro.plate("observations", N):
        obs = pyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)
    return obs
Exemplo n.º 2
0
 def model():
     return pyro.sample('x', dist.LKJCorrCholesky(2, torch.tensor(1.)))
Exemplo n.º 3
0
Arquivo: bart.py Projeto: nwjnwj/pyro
    def model(self, zero_data, covariates):
        period = 24 * 7
        duration, dim = zero_data.shape[-2:]
        assert dim == 2  # Data is bivariate: (arrivals, departures).

        # Sample global parameters.
        noise_scale = pyro.sample(
            "noise_scale",
            dist.LogNormal(torch.full((dim, ), -3.), 1.).to_event(1))
        assert noise_scale.shape[-1:] == (dim, )
        trans_timescale = pyro.sample(
            "trans_timescale",
            dist.LogNormal(torch.zeros(dim), 1).to_event(1))
        assert trans_timescale.shape[-1:] == (dim, )

        trans_loc = pyro.sample("trans_loc", dist.Cauchy(0, 1 / period))
        trans_loc = trans_loc.unsqueeze(-1).expand(trans_loc.shape + (dim, ))
        assert trans_loc.shape[-1:] == (dim, )
        trans_scale = pyro.sample(
            "trans_scale",
            dist.LogNormal(torch.zeros(dim), 0.1).to_event(1))
        trans_corr = pyro.sample("trans_corr",
                                 dist.LKJCorrCholesky(dim, torch.ones(())))
        trans_scale_tril = trans_scale.unsqueeze(-1) * trans_corr
        assert trans_scale_tril.shape[-2:] == (dim, dim)

        obs_scale = pyro.sample(
            "obs_scale",
            dist.LogNormal(torch.zeros(dim), 0.1).to_event(1))
        obs_corr = pyro.sample("obs_corr",
                               dist.LKJCorrCholesky(dim, torch.ones(())))
        obs_scale_tril = obs_scale.unsqueeze(-1) * obs_corr
        assert obs_scale_tril.shape[-2:] == (dim, dim)

        # Note the initial seasonality should be sampled in a plate with the
        # same dim as the time_plate, dim=-1. That way we can repeat the dim
        # below using periodic_repeat().
        with pyro.plate("season_plate", period, dim=-1):
            season_init = pyro.sample(
                "season_init",
                dist.Normal(torch.zeros(dim), 1).to_event(1))
            assert season_init.shape[-2:] == (period, dim)

        # Sample independent noise at each time step.
        with self.time_plate:
            season_noise = pyro.sample("season_noise",
                                       dist.Normal(0, noise_scale).to_event(1))
            assert season_noise.shape[-2:] == (duration, dim)

        # Construct a prediction. This prediction has an exactly repeated
        # seasonal part plus slow seasonal drift. We use two deterministic,
        # linear functions to transform our diagonal Normal noise to nontrivial
        # samples from a Gaussian process.
        prediction = (periodic_repeat(season_init, duration, dim=-2) +
                      periodic_cumsum(season_noise, period, dim=-2))
        assert prediction.shape[-2:] == (duration, dim)

        # Construct a joint noise model. This model is a GaussianHMM, whose
        # .rsample() and .log_prob() methods are parallelized over time; this
        # this entire model is parallelized over time.
        init_dist = dist.Normal(torch.zeros(dim), 100).to_event(1)
        trans_mat = trans_timescale.neg().exp().diag_embed()
        trans_dist = dist.MultivariateNormal(trans_loc,
                                             scale_tril=trans_scale_tril)
        obs_mat = torch.eye(dim)
        obs_dist = dist.MultivariateNormal(torch.zeros(dim),
                                           scale_tril=obs_scale_tril)
        noise_model = dist.GaussianHMM(init_dist,
                                       trans_mat,
                                       trans_dist,
                                       obs_mat,
                                       obs_dist,
                                       duration=duration)
        assert noise_model.event_shape == (duration, dim)

        # The final statement registers our noise model and prediction.
        self.predict(noise_model, prediction)
def model(x, y, alt_av, alt_ids_cuda):
    # global parameters in the model
    if diagonal_alpha:
        alpha_mu = pyro.sample(
            "alpha",
            dist.Normal(torch.zeros(len(non_mix_params), device=x.device),
                        1).to_event(1))
    else:
        alpha_mu = pyro.sample(
            "alpha",
            dist.MultivariateNormal(
                torch.zeros(len(non_mix_params), device=x.device),
                scale_tril=torch.tril(
                    1 * torch.eye(len(non_mix_params), device=x.device))))

    if diagonal_beta_mu:
        beta_mu = pyro.sample(
            "beta_mu",
            dist.Normal(torch.zeros(len(mix_params), device=x.device),
                        1.).to_event(1))
    else:
        beta_mu = pyro.sample(
            "beta_mu",
            dist.MultivariateNormal(
                torch.zeros(len(mix_params), device=x.device),
                scale_tril=torch.tril(
                    1 * torch.eye(len(mix_params), device=x.device))))

    # Vector of variances for each of the d variables
    theta = pyro.sample(
        "theta",
        dist.HalfCauchy(
            10. * torch.ones(len(mix_params), device=x.device)).to_event(1))
    # Lower cholesky factor of a correlation matrix
    eta = 1. * torch.ones(
        1, device=x.device
    )  # Implies a uniform distribution over correlation matrices
    L_omega = pyro.sample("L_omega",
                          dist.LKJCorrCholesky(len(mix_params), eta))
    # Lower cholesky factor of the covariance matrix
    L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega)

    # local parameters in the model
    random_params = pyro.sample(
        "beta_resp",
        dist.MultivariateNormal(beta_mu.repeat(num_resp, 1),
                                scale_tril=L_Omega).to_event(1))

    # vector of respondent parameters: global + local (respondent)
    params_resp = torch.cat([alpha_mu.repeat(num_resp, 1), random_params],
                            dim=-1)

    # vector of betas of MXL (may repeat the same learnable parameter multiple times; random + fixed effects)
    beta_resp = torch.cat([
        params_resp[:, beta_to_params_map[i]] for i in range(num_alternatives)
    ],
                          dim=-1)

    with pyro.plate("locals", len(x), subsample_size=BATCH_SIZE) as ind:

        with pyro.plate("data_resp", T):
            # compute utilities for each alternative
            utilities = torch.scatter_add(
                zeros_vec[:, ind, :], 2,
                alt_ids_cuda[ind, :, :].transpose(0, 1),
                torch.mul(x[ind, :, :].transpose(0, 1), beta_resp[ind, :]))

            # adjust utility for unavailable alternatives
            utilities += alt_av[ind, :, :].transpose(0, 1)

            # likelihood
            pyro.sample("obs",
                        dist.Categorical(logits=utilities),
                        obs=y[ind, :].transpose(0, 1))
Exemplo n.º 5
0
    def model(self, data):
        '''
        Define the parameters
        '''
        n_ind = data['n_ind']
        n_trt = data['n_trt']
        n_tms = data['n_tms']
        n_mrk = data['n_mrk']

        n_prs = n_trt * n_tms * n_mrk

        plt_ind = pyro.plate('individuals', n_ind, dim=-3)
        plt_trt = pyro.plate('treatments', n_trt, dim=-2)
        plt_tms = pyro.plate('times', n_tms, dim=-1)

        pars = {}
        # covariance factors
        with plt_tms:
            # learning dt time step sizes
            # if k(t1,t2) is independent of time, can instead learn scales and variances for RBF kernels that use data['time_vals']
            pars['dt0'] = pyro.sample('dt0', dist.Normal(0, 1))
            pars['dt1'] = pyro.sample('dt1', dist.Normal(0, 1))

        pars['theta_trt0'] = pyro.sample('theta_trt0',
                                         dist.HalfCauchy(torch.ones(n_trt)))
        pars['theta_mrk0'] = pyro.sample('theta_mrk0',
                                         dist.HalfCauchy(torch.ones(n_mrk)))
        pars['theta_trt1'] = pyro.sample('theta_trt1',
                                         dist.HalfCauchy(torch.ones(n_trt)))
        pars['L_omega_trt1'] = pyro.sample(
            'L_omega_trt1', dist.LKJCorrCholesky(n_trt, torch.ones(1)))
        pars['theta_mrk1'] = pyro.sample('theta_mrk1',
                                         dist.HalfCauchy(torch.ones(n_mrk)))
        pars['L_omega_mrk1'] = pyro.sample(
            'L_omega_mrk1', dist.LKJCorrCholesky(n_mrk, torch.ones(1)))

        times0 = fun.pad(torch.cumsum(pars['dt0'].exp().log1p(), 0), (1, 0),
                         value=0)[:-1].unsqueeze(1)
        times1 = fun.pad(torch.cumsum(pars['dt1'].exp().log1p(), 0), (1, 0),
                         value=0)[:-1].unsqueeze(1)
        cov_t0 = (-torch.cdist(times0, times0)).exp()
        cov_t1 = (-torch.cdist(times1, times1)).exp()

        cov_i0 = pars['theta_trt0'].diag()
        L_Omega_trt = torch.mm(torch.diag(pars['theta_trt1'].sqrt()),
                               pars['L_omega_trt1'])
        cov_i1 = L_Omega_trt.mm(L_Omega_trt.t())

        cov_m0 = pars['theta_mrk0'].diag()
        L_Omega_mrk = torch.mm(torch.diag(pars['theta_mrk1'].sqrt()),
                               pars['L_omega_mrk1'])
        cov_m1 = L_Omega_mrk.mm(L_Omega_mrk.t())

        # kronecker product of the factors
        cov_itm0 = torch.einsum('ij,tu,mn->itmjun',
                                [cov_i0, cov_t0, cov_m0]).view(n_prs, n_prs)
        cov_itm1 = torch.einsum('ij,tu,mn->itmjun',
                                [cov_i1, cov_t1, cov_m1]).view(n_prs, n_prs)

        # global and individual level params of each marker, treatment, and time point
        pars['glb'] = pyro.sample(
            'glb', dist.MultivariateNormal(torch.zeros(n_prs), cov_itm0))
        with plt_ind:
            pars['ind'] = pyro.sample(
                'ind', dist.MultivariateNormal(torch.zeros(n_prs), cov_itm1))

        # observation noise, time series bias and scale
        pars['noise_scale'] = pyro.sample('noise_scale',
                                          dist.HalfCauchy(torch.ones(n_mrk)))
        pars['t0_scale'] = pyro.sample('t0_scale',
                                       dist.HalfCauchy(torch.ones(n_mrk)))
        with plt_ind:
            pars['t0'] = pyro.sample(
                't0',
                dist.MultivariateNormal(torch.zeros(n_mrk),
                                        pars['t0_scale'].diag()))
            with plt_trt, plt_tms:
                pars['noise'] = pyro.sample(
                    'noise',
                    dist.MultivariateNormal(torch.zeros(n_mrk),
                                            pars['noise_scale'].diag()))

        # likelihood of the data
        distr = self.get_distr(data, pars)
        pyro.sample('obs', distr, obs=data['Y'])