Beispiel #1
0
    def _dynamics(self, features):
        """
        Compute dynamics parameters from time features.
        """
        state_dim = self.args.state_dim
        gate_rate_dim = 2 * self.num_stations**2

        init_loc = torch.zeros(state_dim)
        init_scale_tril = pyro.param(
            "init_scale",
            torch.full((state_dim, ), 10.),
            constraint=constraints.positive).diag_embed()
        init_dist = dist.MultivariateNormal(init_loc,
                                            scale_tril=init_scale_tril)

        trans_matrix = pyro.param("trans_matrix", 0.99 * torch.eye(state_dim))
        trans_loc = torch.zeros(state_dim)
        trans_scale_tril = pyro.param(
            "trans_scale",
            0.1 * torch.ones(state_dim),
            constraint=constraints.positive).diag_embed()
        trans_dist = dist.MultivariateNormal(trans_loc,
                                             scale_tril=trans_scale_tril)

        obs_matrix = pyro.param("obs_matrix",
                                torch.randn(state_dim, gate_rate_dim))
        obs_matrix.data /= obs_matrix.data.norm(dim=-1, keepdim=True)
        loc_scale = self.nn(features)
        loc, scale = loc_scale.reshape(loc_scale.shape[:-1] +
                                       (2, gate_rate_dim)).unbind(-2)
        scale = bounded_exp(scale, bound=10.)
        obs_dist = dist.Normal(loc, scale).to_event(1)

        return init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist
Beispiel #2
0
    def model(self):
        self.set_mode("model")
        N = self.X.size(0)
        Kff = self.kernel(self.X).contiguous()
        Kff.view(-1)[::N + 1] += self.jitter  # add jitter to the diagonal
        Lff = Kff.cholesky()

        zero_loc = self.X.new_zeros(self.f_loc.shape)
        if self.whiten:
            identity = eye_like(self.X, N)
            pyro.sample(
                self._pyro_get_fullname("f"),
                dist.MultivariateNormal(
                    zero_loc,
                    scale_tril=identity).to_event(zero_loc.dim() - 1))
            f_scale_tril = Lff.matmul(self.f_scale_tril)
            f_loc = Lff.matmul(self.f_loc.unsqueeze(-1)).squeeze(-1)
        else:
            pyro.sample(
                self._pyro_get_fullname("f"),
                dist.MultivariateNormal(
                    zero_loc, scale_tril=Lff).to_event(zero_loc.dim() - 1))
            f_scale_tril = self.f_scale_tril
            f_loc = self.f_loc

        f_loc = f_loc + self.mean_function(self.X)
        f_var = f_scale_tril.pow(2).sum(dim=-1)
        if self.y is None:
            return f_loc, f_var
        else:
            return self.likelihood(f_loc, f_var, self.y)
Beispiel #3
0
def model(data):
    # Global variables.
    weights = pyro.param(
        "weights",
        torch.FloatTensor([0.5]),
        constraint=constraints.unit_interval
    )
    scales = pyro.param(
        "scales",
        torch.stack([torch.eye(2), torch.eye(2)]),
        constraint=constraints.positive
    )

    locs = [
        pyro.sample(
            "locs_{}".format(k),
            dist.MultivariateNormal(torch.zeros(2), 2 * torch.eye(2))
        ) for k in range(K)
    ]

    with pyro.iarange("data", data.size(0), 4) as ind:
        # Local variables.
        assignment = pyro.sample(
            "assignment",
            dist.Bernoulli(torch.ones(len(data)) * weights)
        ).to(torch.int64)
        pyro.sample(
            "obs",
            dist.MultivariateNormal(locs[assignment], scales[assignment]),
            obs=data.index_select(ind)
        )
    def model(self):
        # Global variables
        weights = pyro.sample('weights',
                              dist.Dirichlet(0.5 * torch.ones(self.n_comp)))

        with pyro.plate('components', self.n_comp):
            locs = pyro.sample('locs',
                               dist.MultivariateNormal(
                                   torch.zeros(self.shape[1]),
                                   torch.eye(self.shape[1]))
                               )
            scale = pyro.sample('scale', dist.LogNormal(0., 2.))

        lis = []
        for i in range(self.n_comp):
            t = torch.eye(self.shape[1]) * scale[i]
            lis.append(t)
        f = torch.stack(lis)

        with pyro.plate('data', self.shape[0]):
            # Local variables.
            assignment = pyro.sample('assignment', dist.Categorical(weights))
            pyro.sample('obs', dist.MultivariateNormal(locs[assignment],
                                                       f[assignment]),
                        obs=self.tensor_train)
Beispiel #5
0
    def guide(self):
        self.set_mode("guide")

        if self.coord:
            for v in range(self.V_net):
                for h in range(self.H_dim):
                    for sr_i in range(self.sr_dim):
                        for lw_i in range(self.lw_dim):
                            pyro.sample(
                                f'f_coord_v{v}_h{h}_sr{sr_i}_lw{lw_i}',
                                dist.MultivariateNormal(
                                    self.gp.coord.loc[v, h, sr_i, lw_i, :],
                                    scale_tril=self.gp.coord.cov_tril[
                                        v, h, sr_i, lw_i, :, :]).to_event(
                                            self.gp.coord.loc[v, h, sr_i,
                                                              lw_i, :].dim() -
                                            1))

        if self.socpop:
            for v in range(self.V_net):
                for sp_i in range(2):
                    for lw_i in range(self.lw_dim):
                        pyro.sample(
                            f'f_socpop_v{v}_{["soc","pop"][sp_i]}_lw{lw_i}',
                            dist.MultivariateNormal(
                                self.gp.socpop.loc[v, sp_i, lw_i, :],
                                scale_tril=self.gp.socpop.cov_tril[v, sp_i,
                                                                   lw_i, :, :]
                            ).to_event(self.gp.socpop.loc[v, sp_i,
                                                          lw_i, :].dim() - 1))
Beispiel #6
0
    def model(self, seq):
        mu0 = torch.zeros(self.emb_dim).to(self.device)
        tri0 = self.tri0  # create this when initializing. (takes 4ms each time!)

        muV = pyro.sample("muV",
                          dist.MultivariateNormal(loc=mu0, scale_tril=tri0))

        with plate("item_loop", self.num_items):
            V = pyro.sample(f"V", dist.MultivariateNormal(muV,
                                                          scale_tril=tri0))

        # LIFT MODULE:
        prior = {
            'linear.bias': dist.Normal(0, 1),
            'V.weight': Deterministic_distr(V)
        }
        lifted_module = pyro.random_module("net", self, prior=prior)

        lifted_reg_model = lifted_module()
        lifted_reg_model.lstm.flatten_parameters()

        with pyro.plate("data", len(seq),
                        subsample_size=self.batch_size) as ind:
            batch_seq = seq[ind, ]
            x = batch_seq[:, :-1]
            y = batch_seq[:, 1:]
            batch_mask = (y != 0).float()

            lprobs = lifted_reg_model(x)
            data = pyro.sample(
                "obs_x",
                dist.Categorical(logits=lprobs).mask(batch_mask).to_event(2),
                obs=y)
        return lifted_reg_model
Beispiel #7
0
    def model(self):
        self.set_mode("model")

        M = self.Xu.size(0)
        Kuu = self.kernel(self.Xu).contiguous()
        Kuu.view(-1)[::M + 1] += self.jitter  # add jitter to the diagonal
        Luu = Kuu.cholesky()

        zero_loc = self.Xu.new_zeros(self.u_loc.shape)
        if self.whiten:
            identity = eye_like(self.Xu, M)
            pyro.sample(self._pyro_get_fullname("u"),
                        dist.MultivariateNormal(zero_loc, scale_tril=identity)
                            .to_event(zero_loc.dim() - 1))
        else:
            pyro.sample(self._pyro_get_fullname("u"),
                        dist.MultivariateNormal(zero_loc, scale_tril=Luu)
                            .to_event(zero_loc.dim() - 1))

        f_loc, f_var = conditional(self.X, self.Xu, self.kernel, self.u_loc, self.u_scale_tril,
                                   Luu, full_cov=False, whiten=self.whiten, jitter=self.jitter)

        f_loc = f_loc + self.mean_function(self.X)
        if self.y is None:
            return f_loc, f_var
        else:
            # we would like to load likelihood's parameters outside poutine.scale context
            self.likelihood._load_pyro_samples()
            with poutine.scale(scale=self.num_data / self.X.size(0)):
                return self.likelihood(f_loc, f_var, self.y)
Beispiel #8
0
    def model(self, seq):
        bias = dist.Normal(0,1)
        mu0 = torch.zeros(self.emb_dim).to(self.device)
        var0 = torch.diag(torch.ones(self.emb_dim).to(self.device)*2)

        muV = pyro.sample("muV", dist.MultivariateNormal(loc = mu0, covariance_matrix= var0))

        with plate("item_loop", self.num_items):
            V = pyro.sample(f"V", dist.MultivariateNormal(muV, var0))

        # LIFT MODULE:
        prior = {'linear.bias' : bias,
                'V.weight' : Deterministic_distr(V)}
        lifted_module = pyro.random_module("net", self, prior= prior)
        

        lifted_reg_model = lifted_module()
        lifted_reg_model.lstm.flatten_parameters()

        with pyro.plate("data", len(seq), subsample_size = self.batch_size) as ind:
            batch_seq = seq[ind,]
            batch_mask = (batch_seq!=0).float()

            lprobs = lifted_reg_model(batch_seq)
            data = pyro.sample("obs_x", 
                               dist.Categorical(logits=lprobs).mask(batch_mask).to_event(2), 
                               obs = batch_seq)
        return lifted_reg_model
Beispiel #9
0
    def guide(self):
        """Approximate posterior for the horseshoe prior. We assume posterior in the form
        of the multivariate normal distriburtion for the global mean and standard deviation
        and multivariate normal distribution for the parameters of each subject independently.
        """
        nsub = self.runs  # number of subjects
        npar = self.npar  # number of parameters
        trns = biject_to(constraints.positive)

        m_hyp = param('m_hyp', zeros(2 * npar))
        st_hyp = param('scale_tril_hyp',
                       torch.eye(2 * npar),
                       constraint=constraints.lower_cholesky)
        hyp = sample('hyp',
                     dist.MultivariateNormal(m_hyp, scale_tril=st_hyp),
                     infer={'is_auxiliary': True})

        unc_mu = hyp[..., :npar]
        unc_tau = hyp[..., npar:]

        c_tau = trns(unc_tau)

        ld_tau = trns.inv.log_abs_det_jacobian(c_tau, unc_tau)
        ld_tau = sum_rightmost(ld_tau, ld_tau.dim() - c_tau.dim() + 1)

        sample("mu", dist.Delta(unc_mu, event_dim=1))
        sample("tau", dist.Delta(c_tau, log_density=ld_tau, event_dim=1))

        m_locs = param('m_locs', zeros(nsub, npar))
        st_locs = param('scale_tril_locs',
                        torch.eye(npar).repeat(nsub, 1, 1),
                        constraint=constraints.lower_cholesky)

        with plate('runs', nsub):
            sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
Beispiel #10
0
def model(data):
    _, dim = data.shape
    weights = pyro.sample('weights', dist.Dirichlet(torch.ones(K)))

    with pyro.plate('c1', K):
        mus = pyro.sample(
            'mus',
            dist.MultivariateNormal(torch.zeros(dim),
                                    torch.diag(torch.ones(dim) * 10.0)))
    assert (mus.size() == (K, dim))

    with pyro.plate('dim', dim):
        with pyro.plate('c2', K):
            lambdas = pyro.sample('lambdas', dist.LogNormal(0, 2))
    assert (lambdas.size() == (K, dim))

    scales = []
    for k in range(K):
        scales.append(torch.diag(lambdas[k]))
    scales = torch.stack(scales, dim=0)
    assert ((K, dim, dim) == scales.size())

    with pyro.plate('data', len(data)):
        assignments = pyro.sample('assignments', dist.Categorical(weights))
        pyro.sample('obs',
                    dist.MultivariateNormal(mus[assignments],
                                            scales[assignments]),
                    obs=data)
Beispiel #11
0
def model(X, Y, U, V):
    phi = 1
    if is_cuda:
        d_i = pyro.sample(
            "d_i",
            dist.MultivariateNormal(
                torch.zeros(E).cuda(),
                phi * torch.eye(E).cuda())).cuda()
    else:
        d_i = pyro.sample(
            "d_i", dist.MultivariateNormal(torch.zeros(E), phi * torch.eye(E)))

    with pyro.plate('observations', len(X)):
        if is_cuda:
            logit = torch.sum(torch.bmm(
                U[X[:, 0]].view(U[X[:, 0]].shape[0], 1, E),
                (V[X[:, 1]] + d_i).view((V[X[:, 1]] + d_i).shape[0], E,
                                        1)).cuda(),
                              axis=1).cuda()
        else:
            logit = torch.sum(torch.bmm(
                U[X[:, 0]].view(U[X[:, 0]].shape[0], 1, E),
                (V[X[:, 1]] + d_i).view((V[X[:, 1]] + d_i).shape[0], E, 1)),
                              axis=1)
        target = pyro.sample('obs', dist.Bernoulli(logits=logit), obs=Y)
        if is_cuda:
            target = target.cuda()
    def model(self, X=None, y=None):
        N = X.shape[0]
        D = X.shape[1]
        pyro.module("MDN", self)
        pi, loc, Sigma_tril = self.mdn(X)
        locT = torch.transpose(loc, 0, 1)
        Sigma_trilT = torch.transpose(Sigma_tril, 0, 1)
        assert pi.shape == (N, self.K)
        assert locT.shape == (self.K, N, D)
        assert Sigma_trilT.shape == (self.K, N, D, D)
        with pyro.plate("data", N):
            assignment = pyro.sample("assignment", dist.Categorical(pi))
            if len(assignment.shape) == 1:
                _mu = torch.gather(locT, 0, assignment.view(1, -1, 1))[0]
                _scale_tril = torch.gather(Sigma_trilT, 0,
                                           assignment.view(1, -1, 1, 1))[0]
                sample = pyro.sample('obs',
                                     dist.MultivariateNormal(
                                         _mu, scale_tril=_scale_tril),
                                     obs=y)
            else:
                _mu = locT[assignment][:, 0]
                _scale_tril = Sigma_trilT[assignment][:, 0]
                sample = pyro.sample('obs',
                                     dist.MultivariateNormal(
                                         _mu, scale_tril=_scale_tril),
                                     obs=y)

        return pi, loc, Sigma_tril, sample
Beispiel #13
0
def my_local_guide(x, y, alt_av, alt_ids):
    if diagonal_alpha:
        alpha_loc = pyro.param(
            'alpha_loc', torch.randn(len(non_mix_params), device=x.device))
        alpha_scale = pyro.param(
            'alpha_scale',
            1 * torch.ones(len(non_mix_params), device=x.device),
            constraint=constraints.positive)
        alpha = pyro.sample("alpha",
                            dist.Normal(alpha_loc, alpha_scale).to_event(1))
    else:
        alpha_loc = pyro.param(
            'alpha_loc', torch.randn(len(non_mix_params), device=x.device))
        alpha_scale = pyro.param(
            "alpha_scale",
            torch.tril(1 * torch.eye(len(non_mix_params), device=x.device)),
            constraint=constraints.lower_cholesky)
        alpha = pyro.sample(
            "alpha", dist.MultivariateNormal(alpha_loc,
                                             scale_tril=alpha_scale))

    if diagonal_beta_mu:
        beta_mu_loc = pyro.param('beta_mu_loc',
                                 torch.randn(len(mix_params), device=x.device))
        beta_mu_scale = pyro.param(
            'beta_mu_scale',
            1 * torch.ones(len(mix_params), device=x.device),
            constraint=constraints.positive)
        beta_mu = pyro.sample(
            "beta_mu",
            dist.Normal(beta_mu_loc, beta_mu_scale).to_event(1))
    else:
        beta_mu_loc = pyro.param('beta_mu_loc',
                                 torch.randn(len(mix_params), device=x.device))
        beta_mu_scale = pyro.param(
            "beta_mu_scale",
            torch.tril(1 * torch.eye(len(mix_params), device=x.device)),
            constraint=constraints.lower_cholesky)
        beta_mu = pyro.sample(
            "beta_mu",
            dist.MultivariateNormal(beta_mu_loc, scale_tril=beta_mu_scale))

    # Use an amortized guide for local variables.
    pyro.module("predictor", predictor)
    one_hot = torch.zeros(num_resp,
                          T,
                          num_alternatives,
                          device=x.device,
                          dtype=torch.float)
    one_hot = one_hot.scatter(2, y.unsqueeze(2).long(), 1)
    inference_data = torch.cat([one_hot, x, alt_av_cuda.float()], dim=-1)
    beta_loc = predictor.forward(inference_data.flatten(1, 2).unsqueeze(1))
    beta_scale = pyro.param(
        'beta_resp_scale',
        torch.tril(1. * torch.eye(len(mix_params), device=x.device)),
        constraint=constraints.lower_cholesky)
    pyro.sample(
        "beta_resp",
        dist.MultivariateNormal(beta_loc, scale_tril=beta_scale).to_event(1))
Beispiel #14
0
def normal_inv_gamma_family_guide(design, obs_sd, w_sizes, mf=False):
    """Normal inverse Gamma family guide.

    If `obs_sd` is known, this is a multivariate Normal family with separate
    parameters for each batch. `w` is sampled from a Gaussian with mean `mw_param` and
    covariance matrix derived from  `obs_sd * lambda_param` and the two parameters `mw_param` and `lambda_param`
    are learned.

    If `obs_sd=None`, this is a four-parameter family. The observation precision
    `tau` is sampled from a Gamma distribution with parameters `alpha`, `beta`
    (separate for each batch). We let `obs_sd = 1./torch.sqrt(tau)` and then
    proceed as above.

    :param torch.Tensor design: a tensor with last two dimensions `n` and `p`
        corresponding to observations and features respectively.
    :param torch.Tensor obs_sd: observation standard deviation, or `None` to use
        inverse Gamma
    :param OrderedDict w_sizes: map from variable names to torch.Size
    """
    # design is size batch x n x p
    # tau is size batch
    tau_shape = design.shape[:-2]
    with ExitStack() as stack:
        for plate in iter_plates_to_shape(tau_shape):
            stack.enter_context(plate)

        if obs_sd is None:
            # First, sample tau (observation precision)
            alpha = softplus(
                pyro.param("invsoftplus_alpha", 20.0 * torch.ones(tau_shape))
            )
            beta = softplus(
                pyro.param("invsoftplus_beta", 20.0 * torch.ones(tau_shape))
            )
            # Global variable
            tau_prior = dist.Gamma(alpha, beta)
            tau = pyro.sample("tau", tau_prior)
            obs_sd = 1.0 / torch.sqrt(tau)

        # response will be shape batch x n
        obs_sd = obs_sd.expand(tau_shape).unsqueeze(-1)

        for name, size in w_sizes.items():
            w_shape = tau_shape + size
            # Set up mu and lambda
            mw_param = pyro.param("{}_guide_mean".format(name), torch.zeros(w_shape))
            scale_tril = pyro.param(
                "{}_guide_scale_tril".format(name),
                torch.eye(*size).expand(tau_shape + size + size),
                constraint=constraints.lower_cholesky,
            )
            # guide distributions for w
            if mf:
                w_dist = dist.MultivariateNormal(mw_param, scale_tril=scale_tril)
            else:
                w_dist = dist.MultivariateNormal(
                    mw_param, scale_tril=obs_sd.unsqueeze(-1) * scale_tril
                )
            pyro.sample(name, w_dist)
Beispiel #15
0
 def model(loc, cov):
     x = pyro.param("x", torch.randn(2))
     y = pyro.param("y", torch.randn(3, 2))
     z = pyro.param("z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1))
     pyro.sample("obs_x", dist.MultivariateNormal(loc, cov), obs=x)
     with pyro.plate("y_plate", 3):
         pyro.sample("obs_y", dist.MultivariateNormal(loc, cov), obs=y)
     with pyro.plate("z_plate", 4):
         pyro.sample("obs_z", dist.MultivariateNormal(loc, cov), obs=z)
Beispiel #16
0
    def model(N):

        with pyro.plate("x_plate", N):
            z1 = pyro.sample(
                "z1", dist.MultivariateNormal(torch.zeros(2), torch.eye(2)))
            z2 = pyro.sample(
                "z2", dist.MultivariateNormal(torch.zeros(2), torch.eye(2)))
            return pyro.sample("x",
                               dist.MultivariateNormal(z1 + z2, torch.eye(2)))
def my_local_guide(x, y, alt_av, alt_ids):
    if diagonal_alpha:
        alpha_loc = pyro.param(
            'alpha_loc', torch.randn(len(non_mix_params), device=x.device))
        alpha_scale = pyro.param(
            'alpha_scale',
            1. * torch.ones(len(non_mix_params), device=x.device),
            constraint=constraints.positive)
        alpha = pyro.sample("alpha",
                            dist.Normal(alpha_loc, alpha_scale).to_event(1))
    else:
        alpha_loc = pyro.param(
            'alpha_loc', torch.randn(len(non_mix_params), device=x.device))
        alpha_scale = pyro.param(
            "alpha_scale",
            torch.tril(1. * torch.eye(len(non_mix_params), device=x.device)),
            constraint=constraints.lower_cholesky)
        alpha = pyro.sample(
            "alpha", dist.MultivariateNormal(alpha_loc,
                                             scale_tril=alpha_scale))

    if diagonal_beta_mu:
        beta_mu_loc = pyro.param('beta_mu_loc',
                                 torch.randn(len(mix_params), device=x.device))
        beta_mu_scale = pyro.param(
            'beta_mu_scale',
            1. * torch.ones(len(mix_params), device=x.device),
            constraint=constraints.positive)
        beta_mu = pyro.sample(
            "beta_mu",
            dist.Normal(beta_mu_loc, beta_mu_scale).to_event(1))
    else:
        beta_mu_loc = pyro.param('beta_mu_loc',
                                 torch.randn(len(mix_params), device=x.device))
        beta_mu_scale = pyro.param(
            "beta_mu_scale",
            torch.tril(1. * torch.eye(len(mix_params), device=x.device)),
            constraint=constraints.lower_cholesky)
        beta_mu = pyro.sample(
            "beta_mu",
            dist.MultivariateNormal(beta_mu_loc, scale_tril=beta_mu_scale))

    beta_loc = pyro.param(
        'beta_resp_loc', torch.randn(num_resp,
                                     len(mix_params),
                                     device=x.device))
    beta_scale = pyro.param(
        'beta_resp_scale',
        torch.tril(
            1. * torch.eye(len(mix_params), len(mix_params), device=x.device)),
        constraint=constraints.lower_cholesky)
    pyro.sample(
        "beta_resp",
        dist.MultivariateNormal(beta_loc, scale_tril=beta_scale).to_event(1))
Beispiel #18
0
 def step(self, state, datum=None):
     state["z"] = pyro.sample(
         "z_{}".format(self.t),
         dist.MultivariateNormal(state["z"],
                                 scale_tril=trans_dist.scale_tril))
     datum = pyro.sample(
         "obs_{}".format(self.t),
         dist.MultivariateNormal(state["z"],
                                 scale_tril=obs_dist.scale_tril),
         obs=datum)
     self.t += 1
     return datum
Beispiel #19
0
def get_replicated_data(data, mu, cov, pi):
    data_rep = []
    for i in range(len(data)):
        cluster = pyro.sample('category', dist.Categorical(torch.tensor(pi)))
        idx = cluster.item()
        sample = pyro.sample("obs", dist.MultivariateNormal(mu[idx], cov[idx]))
        while sample[0] < min(data[:, 0]) or sample[1] < min(data[:, 1]):
            # Only sample valid points
            sample = pyro.sample("obs",
                                 dist.MultivariateNormal(mu[idx], cov[idx]))
        data_rep.append(sample.tolist())
    data_rep = torch.tensor(data_rep)
    return data_rep
Beispiel #20
0
    def guide(self):
        self.set_mode("guide")
        self._load_pyro_samples()

        pyro.sample(
            self._pyro_get_fullname("f"),
            dist.MultivariateNormal(
                self.f_loc,
                scale_tril=self.f_scale_tril).to_event(self.f_loc.dim() - 1))
        pyro.sample(
            self._pyro_get_fullname("g"),
            dist.MultivariateNormal(
                self.g_loc,
                scale_tril=self.g_scale_tril).to_event(self.g_loc.dim() - 1))
Beispiel #21
0
def get_samples(num_samples=100):
    # underlying parameters
    mu1 = torch.tensor([0., 5.])
    sig1 = torch.tensor([[2., 0.], [0., 3.]])
    mu2 = torch.tensor([5., 0.])
    sig2 = torch.tensor([[4., 0.], [0., 1.]])

    # generate samples
    dist1 = dist.MultivariateNormal(mu1, sig1)
    samples1 = [pyro.sample("samples1", dist1) for _ in range(num_samples)]
    dist2 = dist.MultivariateNormal(mu2, sig2)
    samples2 = [pyro.sample("samples2", dist2) for _ in range(num_samples)]

    return torch.cat((torch.stack(samples1), torch.stack(samples2)))
Beispiel #22
0
def test_kl_independent_normal_mvn(batch_shape, size):
    loc = torch.randn(batch_shape + (size, ))
    scale = torch.randn(batch_shape + (size, )).exp()
    p1 = dist.Normal(loc, scale).to_event(1)
    p2 = dist.MultivariateNormal(loc, scale_tril=scale.diag_embed())

    loc = torch.randn(batch_shape + (size, ))
    cov = torch.randn(batch_shape + (size, size))
    cov = cov @ cov.transpose(-1, -2) + 0.01 * torch.eye(size)
    q = dist.MultivariateNormal(loc, covariance_matrix=cov)

    actual = kl_divergence(p1, q)
    expected = kl_divergence(p2, q)
    assert_close(actual, expected)
Beispiel #23
0
    def guide(self):
        """Approximate posterior for the Dirichlet process prior.
        """
        nsub = self.runs  # number of subjects
        npar = self.npar  # number of parameters
        kmax = self.kmax  # maximum number of components

        gaa = param("ga_a", ones(1), constraint=constraints.positive)
        gar = param("ga_r", .1 * ones(1), constraint=constraints.positive)
        sample('alpha', dist.Gamma(gaa, gaa / gar))

        gba = param("gb_beta_a",
                    ones(kmax - 1),
                    constraint=constraints.positive)
        gbb = param("gb_beta_b",
                    ones(kmax - 1),
                    constraint=constraints.positive)
        beta = sample("beta", dist.Beta(gba, gbb).to_event(1))

        with plate('classes', kmax, dim=-2):
            m_mu = param('m_mu', zeros(kmax, npar))
            st_mu = param('scale_tril_mu',
                          torch.eye(npar).repeat(kmax, 1, 1),
                          constraint=constraints.lower_cholesky)
            mu = sample("mu", dist.MultivariateNormal(m_mu, scale_tril=st_mu))

            m_tau = param('m_hyp', zeros(kmax, npar))
            st_tau = param('scale_tril_hyp',
                           torch.eye(npar).repeat(kmax, 1, 1),
                           constraint=constraints.lower_cholesky)
            mn = dist.MultivariateNormal(m_tau, scale_tril=st_tau)
            sample(
                "tau",
                dist.TransformedDistribution(mn,
                                             [dist.transforms.ExpTransform()]))

        m_locs = param('m_locs', zeros(nsub, npar))
        st_locs = param('scale_tril_locs',
                        torch.eye(npar).repeat(nsub, 1, 1),
                        constraint=constraints.lower_cholesky)

        class_probs = param('class_probs',
                            ones(nsub, kmax) / kmax,
                            constraint=constraints.simplex)
        with plate('subjects', nsub, dim=-2):
            sample('class',
                   dist.Categorical(class_probs),
                   infer={"enumerate": "parallel"})
            sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
Beispiel #24
0
    def guide_horseshoe_plus(self):
        
        npar = self.npars  # number of parameters
        nsub = self.runs  # number of subjects
        trns = biject_to(constraints.positive)

        
        m_hyp = param('m_hyp', zeros(2*npar))
        st_hyp = param('scale_tril_hyp', 
                              torch.eye(2*npar), 
                              constraint=constraints.lower_cholesky)
        hyp = sample('hyp', dist.MultivariateNormal(m_hyp, 
                                                  scale_tril=st_hyp), 
                            infer={'is_auxiliary': True})
        
        unc_mu = hyp[:npar]
        unc_sigma = hyp[npar:]
    
    
        c_sigma = trns(unc_sigma)
    
        ld_sigma = trns.inv.log_abs_det_jacobian(c_sigma, unc_sigma)
        ld_sigma = sum_rightmost(ld_sigma, ld_sigma.dim() - c_sigma.dim() + 1)
    
        mu_g = sample("mu_g", dist.Delta(unc_mu, event_dim=1))
        sigma_g = sample("sigma_g", dist.Delta(c_sigma, log_density=ld_sigma, event_dim=1))
        
        m_tmp = param('m_tmp', zeros(nsub, 2*npar))
        st_tmp = param('s_tmp', torch.eye(2*npar).repeat(nsub, 1, 1), 
                   constraint=constraints.lower_cholesky)

        with plate('subjects', nsub):
            tmp = sample('tmp', dist.MultivariateNormal(m_tmp, 
                                                  scale_tril=st_tmp), 
                            infer={'is_auxiliary': True})
            
            unc_locs = tmp[..., :npar]
            unc_scale = tmp[..., npar:]
            
            c_scale = trns(unc_scale)
            
            ld_scale = trns.inv.log_abs_det_jacobian(c_scale, unc_scale)
            ld_scale = sum_rightmost(ld_scale, ld_scale.dim() - c_scale.dim() + 1)
            
            x = sample("x", dist.Delta(unc_locs, event_dim=1))
            sigma_x = sample("sigma_x", dist.Delta(c_scale, log_density=ld_scale, event_dim=1))
        
        return {'mu_g': mu_g, 'sigma_g': sigma_g, 'sigma_x': sigma_x, 'x': x}
Beispiel #25
0
 def guide(self, data):
     sample_size = self.sample_size
     subsample_size = self.subsample_size
     pyro.module('encoder', self.encoder)
     if self.x_feature == 1:
         with pyro.plate("data",
                         sample_size,
                         subsample_size=subsample_size,
                         dim=-2) as idx:
             data_ = data[idx]
             data_nan = torch.isnan(data_)
             if data_nan.any():
                 data_ = torch.where(data_nan, torch.full_like(data_, -1),
                                     data_)
             x_local, x_scale = self.encoder.forward(data_)
             pyro.sample('x', dist.Normal(x_local, x_scale))
     else:
         transform = LowerCholeskyTransform()
         with pyro.plate("data", sample_size,
                         subsample_size=subsample_size) as idx:
             data_ = data[idx]
             data_nan = torch.isnan(data_)
             if data_nan.any():
                 data_ = torch.where(data_nan, torch.full_like(data_, -1),
                                     data_)
             x_local, x_scale = self.encoder.forward(data_)
             pyro.sample(
                 'x',
                 dist.MultivariateNormal(x_local,
                                         scale_tril=transform(x_scale)))
Beispiel #26
0
 def __init__(self,
              mdisc_log_local=0,
              mdisc_log_scale=0.5,
              mdiff_local=0.5,
              mdiff_scale=1,
              x_feature=2,
              x_local=None,
              x_cov=None,
              D=1,
              *args,
              **kwargs):
     super().__init__(*args, **kwargs)
     mdisc = torch.FloatTensor(self.item_size).log_normal_(
         mdisc_log_local, mdisc_log_scale)
     mdiff = torch.FloatTensor(self.item_size).normal_(
         mdiff_local, mdiff_scale)
     self.a = self.gen_a(self.item_size, mdisc, x_feature)
     b = -mdiff * mdisc
     self.b = b.view(1, -1)
     self.x_feature = x_feature
     if x_local is None:
         x_local = torch.zeros((x_feature, ))
     if x_cov is None:
         x_cov = torch.eye(x_feature)
     self.x = dist.MultivariateNormal(x_local, x_cov).sample(
         (self.sample_size, ))
     self.D = D
Beispiel #27
0
def test_masked_mixture_multivariate(sample_shape, batch_shape):
    event_shape = torch.Size((8,))
    component0 = dist.MultivariateNormal(
        torch.zeros(event_shape), torch.eye(event_shape[0])
    )
    component1 = dist.Uniform(
        torch.zeros(event_shape), torch.ones(event_shape)
    ).to_event(1)
    if batch_shape:
        component0 = component0.expand_by(batch_shape)
        component1 = component1.expand_by(batch_shape)
    mask = torch.empty(batch_shape).bernoulli_(0.5).bool()
    d = dist.MaskedMixture(mask, component0, component1)
    assert d.batch_shape == batch_shape
    assert d.event_shape == event_shape

    assert d.sample().shape == batch_shape + event_shape
    assert d.mean.shape == batch_shape + event_shape
    assert d.variance.shape == batch_shape + event_shape
    x = d.sample(sample_shape)
    assert x.shape == sample_shape + batch_shape + event_shape

    log_prob = d.log_prob(x)
    assert log_prob.shape == sample_shape + batch_shape
    assert not torch_isnan(log_prob)
    log_prob_0 = component0.log_prob(x)
    log_prob_1 = component1.log_prob(x)
    mask = mask.expand(sample_shape + batch_shape)
    assert_equal(log_prob[mask], log_prob_1[mask])
    assert_equal(log_prob[~mask], log_prob_0[~mask])
Beispiel #28
0
 def model(cov):
     w = pyro.sample("w", dist.Normal(0, 1000).expand([2]).to_event(1))
     x = pyro.sample("x", dist.Normal(0, 1000).expand([1]).to_event(1))
     y = pyro.sample("y", dist.Normal(0, 1000).expand([1]).to_event(1))
     z = pyro.sample("z", dist.Normal(0, 1000).expand([1]).to_event(1))
     wxyz = torch.cat([w, x, y, z])
     pyro.sample("obs", dist.MultivariateNormal(torch.zeros(5), cov), obs=wxyz)
    def model(self, bows, embeddings, article_ids):
        pyro.module("topic_recognition_net", self.topic_recognition_net)
        with pyro.plate("articles", bows.shape[0]):
            # instead of a Dirichlet prior, we use a log-normal distribution
            prop_mu = bows.new_zeros((bows.shape[0], self.nav_topics))
            prop_sigma = bows.new_ones((bows.shape[0], self.nav_topics))
            props = pyro.sample(
                "theta",
                dist.LogNormal(prop_mu, prop_sigma).to_event(1))

            topics_mu, topics_sigma = self.topic_recognition_net(props)

            for batch_article_id, article_id in enumerate(article_ids):
                nav_embeddings = torch.tensor(
                    embeddings[self.article_navs[article_id]],
                    dtype=torch.float32).to(device)
                for article_nav_id in pyro.plate(
                        "navs_{}".format(article_id),
                        len(self.article_navs[article_id])):
                    pyro.sample("nav_{}_{}".format(article_id, article_nav_id),
                                dist.MultivariateNormal(
                                    topics_mu[batch_article_id],
                                    scale_tril=torch.diag(
                                        topics_sigma[batch_article_id])),
                                obs=nav_embeddings[article_nav_id])
Beispiel #30
0
def program_arbitrary(nn_model, p_tgt, std=0.05):
    '''
	a probabilistic model for enforcing p = p_tgt
	sample u_i ~ No(p_i, std)
	then the posterior is p( z | u_i=p_tgt[i] )
	'''
    if nn_model.device == 'cuda':
        typ = torch.cuda.FloatTensor
    elif nn_model.device == 'cpu':
        typ = torch.FloatTensor
    torch.set_default_tensor_type(typ)

    std = torch.tensor(std).float()
    latent_dim = nn_model.latent_dim
    loc = torch.zeros(latent_dim)
    cov = torch.eye(latent_dim)
    z = pyro.sample('z', dist.MultivariateNormal(loc, cov))
    prob = nn_model.predict_from_latent(z)
    N = len(prob)
    us = []
    for i in range(N):
        us.append(
            pyro.sample('u_%i' % i,
                        dist.Normal(prob[i], std),
                        obs=torch.tensor(p_tgt[i])))