예제 #1
0
    def linear_model_formula(self, y, design, target_labels):

        if self.use_softplus:
            mu = {l: rmv(self.softplus(self.regressor[l]), y) for l in target_labels}
        else:
            mu = {l: rmv(self.regressor[l], y) for l in target_labels}
        scale_tril = {l: rtril(self.scale_tril[l]) for l in target_labels}

        return mu, scale_tril
예제 #2
0
파일: guides.py 프로젝트: zyxue/pyro
    def linear_model_formula(self, y, design, target_labels):

        tikhonov_diag = torch.diag(self.softplus(self.tikhonov_diag))
        xtx = torch.matmul(design.transpose(-1, -2), design) + tikhonov_diag
        xtxi = rinverse(xtx, sym=True)
        mu = rmv(xtxi, rmv(design.transpose(-1, -2), y))

        # Extract sub-indices
        mu = tensor_to_dict(self.w_sizes, mu, subset=target_labels)
        scale_tril = {l: rtril(self.scale_tril[l]) for l in target_labels}

        return mu, scale_tril
예제 #3
0
def dv_critic(design, trace, observation_labels, target_labels):
    y_dict = {l: trace.nodes[l]["value"] for l in observation_labels}
    theta_dict = {l: trace.nodes[l]["value"] for l in target_labels}
    x = torch.cat(list(theta_dict.values()) + list(y_dict.values()), dim=-1)

    B = pyro.param("B", torch.zeros(5, 5))
    return rvv(x, rmv(B, x))
예제 #4
0
def posterior_guide(y_dict, design, observation_labels, target_labels):

    y = torch.cat(list(y_dict.values()), dim=-1)
    A = pyro.param("A", torch.zeros(2, 3))
    scale_tril = pyro.param("scale_tril", torch.tensor([[1., 0.], [0., 1.5]]),
                            constraint=torch.distributions.constraints.lower_cholesky)
    mu = rmv(A, y)
    pyro.sample("w", dist.MultivariateNormal(mu, scale_tril=scale_tril))
예제 #5
0
def likelihood_guide(theta_dict, design, observation_labels, target_labels):

    theta = torch.cat(list(theta_dict.values()), dim=-1)
    centre = rmv(design, theta)

    # Need to avoid name collision here
    mu = pyro.param("mu_l", torch.zeros(3))
    scale_tril = pyro.param("scale_tril_l", torch.eye(3),
                            constraint=torch.distributions.constraints.lower_cholesky)

    pyro.sample("y", dist.MultivariateNormal(centre + mu, scale_tril=scale_tril))
예제 #6
0
def true_model(design):
    w1 = torch.tensor([-1., 1.])
    w2 = torch.tensor([-.5, .5, -.5, .5, -.5, 2., -2., 2., -2., 0.])
    w = torch.cat([w1, w2], dim=-1)
    k = torch.tensor(.1)
    response_mean = rmv(design, w)

    base_dist = dist.Normal(response_mean, torch.tensor(1.)).to_event(1)
    k = k.expand(response_mean.shape)
    transforms = [AffineTransform(loc=0., scale=k), SigmoidTransform()]
    response_dist = dist.TransformedDistribution(base_dist, transforms)
    return pyro.sample("y", response_dist)
예제 #7
0
    def get_params(self, y_dict, design, target_labels):

        y = torch.cat(list(y_dict.values()), dim=-1)

        coefficient_labels = [label for label in target_labels if label != self.tau_label]
        mu, scale_tril = self.linear_model_formula(y, design, coefficient_labels)
        mu_vec = torch.cat(list(mu.values()), dim=-1)

        yty = rvv(y, y)
        ytxmu = rvv(y, rmv(design, mu_vec))
        beta = self.b0 + .5*(yty - ytxmu)

        return mu, scale_tril, self.alpha, beta
예제 #8
0
    def model(design):
        batch_shape = design.shape[:-2]
        k_shape = batch_shape + (sigmoid_design.shape[-1],)
        k = pyro.sample(
            sigmoid_label,
            dist.Gamma(
                sigmoid_alpha.expand(k_shape), sigmoid_beta.expand(k_shape)
            ).to_event(1),
        )
        k_assigned = rmv(sigmoid_design, k)

        return bayesian_linear_model(
            design,
            w_means=OrderedDict([(coef1_label, coef1_mean), (coef2_label, coef2_mean)]),
            w_sqrtlambdas={
                coef1_label: 1.0 / (observation_sd * coef1_sd),
                coef2_label: 1.0 / (observation_sd * coef2_sd),
            },
            obs_sd=observation_sd,
            response="sigmoid",
            response_label=observation_label,
            k=k_assigned,
        )
예제 #9
0
파일: glmm.py 프로젝트: jamestwebber/pyro
def bayesian_linear_model(design,
                          w_means={},
                          w_sqrtlambdas={},
                          re_group_sizes={},
                          re_alphas={},
                          re_betas={},
                          obs_sd=None,
                          alpha_0=None,
                          beta_0=None,
                          response="normal",
                          response_label="y",
                          k=None):
    """
    A pyro model for Bayesian linear regression.

    If :param:`response` is `"normal"` this corresponds to a linear regression
    model

        :math:`Y = Xw + \\epsilon`

    with `\\epsilon`` i.i.d. zero-mean Gaussian. The observation standard deviation
    (:param:`obs_sd`) may be known or unknown. If unknown, it is assumed to follow an
    inverse Gamma distribution with parameters :param:`alpha_0` and :param:`beta_0`.

    If the response type is `"bernoulli"` we instead have :math:`Y \\sim Bernoulli(p)`
    with

        :math:`logit(p) = Xw`

    Given parameter groups in :param:`w_means` and :param:`w_sqrtlambda`, the fixed effects
    regression coefficient is taken to be Gaussian with mean `w_mean` and standard deviation
    given by

        :math:`\\sigma / \\sqrt{\\lambda}`

    corresponding to the normal inverse Gamma family.

    The random effects coefficient is constructed as follows. For each random effect
    group, standard deviations for that group are sampled from a normal inverse Gamma
    distribution. For each group, a random effect coefficient is then sampled from a zero
    mean Gaussian with those standard deviations.

    :param torch.Tensor design: a tensor with last two dimensions `n` and `p`
            corresponding to observations and features respectively.
    :param OrderedDict w_means: map from variable names to tensors of fixed effect means.
    :param OrderedDict w_sqrtlambdas: map from variable names to tensors of square root
        :math:`\\lambda` values for fixed effects.
    :param OrderedDict re_group_sizes: map from variable names to int representing the
        group size
    :param OrderedDict re_alphas: map from variable names to `torch.Tensor`, the tensor
        consists of Gamma dist :math:`\\alpha` values
    :param OrderedDict re_betas: map from variable names to `torch.Tensor`, the tensor
        consists of Gamma dist :math:`\\beta` values
    :param torch.Tensor obs_sd: the observation standard deviation (if assumed known).
        This is still relevant in the case of Bernoulli observations when coefficeints
        are sampled using `w_sqrtlambdas`.
    :param torch.Tensor alpha_0: Gamma :math:`\\alpha` parameter for unknown observation
        covariance.
    :param torch.Tensor beta_0: Gamma :math:`\\beta` parameter for unknown observation
        covariance.
    :param str response: Emission distribution. May be `"normal"` or `"bernoulli"`.
    :param str response_label: Variable label for response.
    :param torch.Tensor k: Only used for a sigmoid response. The slope of the sigmoid
        transformation.
    """
    # design is size batch x n x p
    # tau is size batch
    batch_shape = design.shape[:-2]
    with ExitStack() as stack:
        for plate in iter_plates_to_shape(batch_shape):
            stack.enter_context(plate)

        if obs_sd is None:
            # First, sample tau (observation precision)
            tau_prior = dist.Gamma(alpha_0.unsqueeze(-1),
                                   beta_0.unsqueeze(-1)).to_event(1)
            tau = pyro.sample("tau", tau_prior)
            obs_sd = 1. / torch.sqrt(tau)

        elif alpha_0 is not None or beta_0 is not None:
            warnings.warn("Values of `alpha_0` and `beta_0` unused becased"
                          "`obs_sd` was specified already.")

        obs_sd = obs_sd.expand(batch_shape + (1, ))

        # Build the regression coefficient
        w = []
        # Allow different names for different coefficient groups
        # Process fixed effects
        for name, w_sqrtlambda in w_sqrtlambdas.items():
            w_mean = w_means[name]
            # Place a normal prior on the regression coefficient
            w_prior = dist.Normal(w_mean, obs_sd / w_sqrtlambda).to_event(1)
            w.append(pyro.sample(name, w_prior))
        # Process random effects
        for name, group_size in re_group_sizes.items():
            # Sample `G` once for this group
            alpha, beta = re_alphas[name], re_betas[name]
            G_prior = dist.Gamma(alpha, beta).to_event(1)
            G = 1. / torch.sqrt(pyro.sample("G_" + name, G_prior))
            # Repeat `G` for each group
            repeat_shape = tuple(1 for _ in batch_shape) + (group_size, )
            u_prior = dist.Normal(torch.tensor(0.),
                                  G.repeat(repeat_shape)).to_event(1)
            w.append(pyro.sample(name, u_prior))
        # Regression coefficient `w` is batch x p
        w = broadcast_cat(w)

        # Run the regressor forward conditioned on inputs
        prediction_mean = rmv(design, w)
        if response == "normal":
            # y is an n-vector: hence use .to_event(1)
            return pyro.sample(
                response_label,
                dist.Normal(prediction_mean, obs_sd).to_event(1))
        elif response == "bernoulli":
            return pyro.sample(
                response_label,
                dist.Bernoulli(logits=prediction_mean).to_event(1))
        elif response == "sigmoid":
            base_dist = dist.Normal(prediction_mean, obs_sd).to_event(1)
            # You can add loc via the linear model itself
            k = k.expand(prediction_mean.shape)
            transforms = [
                AffineTransform(loc=torch.tensor(0.), scale=k),
                SigmoidTransform()
            ]
            response_dist = dist.TransformedDistribution(base_dist, transforms)
            return pyro.sample(response_label, response_dist)
        else:
            raise ValueError(
                "Unknown response distribution: '{}'".format(response))
예제 #10
0
파일: test_util.py 프로젝트: yufengwa/pyro
def test_rmv(A, b):
    assert_equal(rmv(A, b), A.mv(b), prec=1e-8)
    batched_A = lexpand(A, 5, 4)
    batched_b = lexpand(b, 5, 4)
    expected_Ab = lexpand(A.mv(b), 5, 4)
    assert_equal(rmv(batched_A, batched_b), expected_Ab, prec=1e-8)