Example #1
0
  def forward(self, belief, state, deterministic=False, with_logprob=False,):
    raw_init_std = np.log(np.exp(self.init_std) - 1)
    hidden = self.act_fn(self.fc1(torch.cat([belief, state], dim=-1)))
    hidden = self.act_fn(self.fc2(hidden))
    hidden = self.act_fn(self.fc3(hidden))
    hidden = self.act_fn(self.fc4(hidden))
    hidden = self.fc5(hidden)
    mean, std = torch.chunk(hidden, 2, dim=-1)

    # # ---------
    # mean = self.mean_scale * torch.tanh(mean / self.mean_scale)  # bound the action to [-5, 5] --> to avoid numerical instabilities.  For computing log-probabilities, we need to invert the tanh and this becomes difficult in highly saturated regions.
    # speed = torch.full(mean.shape, 0.3).to("cuda")
    # mean = torch.cat((mean, speed), -1)
    #
    # std = F.softplus(std + raw_init_std) + self.min_std
    #
    # speed = torch.full(std.shape, 0.0).to("cuda")
    # std = torch.cat((std, speed), -1)
    #
    # dist = torch.distributions.Normal(mean, std)
    # transform = [torch.distributions.transforms.TanhTransform()]
    # dist = torch.distributions.TransformedDistribution(dist, transform)
    # dist = torch.distributions.independent.Independent(dist, 1)  # Introduces dependence between actions dimension
    # dist = SampleDist(dist)  # because after transform a distribution, some methods may become invalid, such as entropy, mean and mode, we need SmapleDist to approximate it.
    # return dist  # dist ~ tanh(Normal(mean, std)); remember when sampling, using rsample() to adopt the reparameterization trick


    mean = self.mean_scale * torch.tanh(mean / self.mean_scale)  # bound the action to [-5, 5] --> to avoid numerical instabilities.  For computing log-probabilities, we need to invert the tanh and this becomes difficult in highly saturated regions.
    std = F.softplus(std + raw_init_std) + self.min_std

    dist = torch.distributions.Normal(mean, std)
    # TanhTransform = ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)])
    if self.fix_speed:
      transform = [AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)]

    else:
      transform = [AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.),  # TanhTransform
                   AffineTransform(loc=torch.tensor([0.0, self.throtlle_base]).to("cuda"),
                                  scale=torch.tensor([1.0, 0.2]).to("cuda"))]  # TODO: this is limited at donkeycar env

    dist = TransformedDistribution(dist, transform)
    # dist = torch.distributions.independent.Independent(dist, 1)  # Introduces dependence between actions dimension
    dist = SampleDist(dist)  # because after transform a distribution, some methods may become invalid, such as entropy, mean and mode, we need SmapleDist to approximate it.

    if deterministic:
      action = dist.mean
    else:
      action = dist.rsample()

    # not use logprob now
    if with_logprob:
      logp_pi = dist.log_prob(action).sum(dim=1)
    else:
      logp_pi = None
    # action dim: [batch, act_dim], log_pi dim:[batch]
    return action if not self.fix_speed else torch.cat((action, self.throtlle_base*torch.ones_like(action, requires_grad=False)), dim=-1), logp_pi  # dist ~ tanh(Normal(mean, std)); remember when sampling, using rsample() to adopt the reparameterization trick
Example #2
0
 def __init__(self, loc, scale, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     finfo = _finfo(self.loc)
     if isinstance(loc, Number) and isinstance(scale, Number):
         base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
     else:
         base_dist = Uniform(self.loc.new(self.loc.size()).fill_(finfo.tiny), 1 - finfo.eps)
     transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
                   ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
     super(Gumbel, self).__init__(base_dist, transforms, validate_args=validate_args)
    def __init__(self, w, p, temperature=0.1, validate_args=None):
        relaxed_bernoulli = RelaxedBernoulli(temperature, p)
        affine_transform = AffineTransform(0, w)
        one_minus_p = AffineTransform(1, -1)
        super(BernoulliDropoutDistribution,
              self).__init__(relaxed_bernoulli,
                             ComposeTransform([one_minus_p, affine_transform]),
                             validate_args)

        self.relaxed_bernoulli = relaxed_bernoulli
        self.affine_transform = affine_transform
    def __init__(self, a, b, validate_args=None):
        self.a, self.b = broadcast_all(a, b)
        self.a_reciprocal = self.a.reciprocal()
        self.b_reciprocal = self.b.reciprocal()
        base_dist = Uniform(torch.full_like(self.a, EPS),
                            torch.full_like(self.a, 1. - EPS))
        transforms = [
            AffineTransform(loc=1, scale=-1),
            PowerTransform(self.b_reciprocal),
            AffineTransform(loc=1, scale=-1),
            PowerTransform(self.a_reciprocal)
        ]

        super(Kumaraswamy, self).__init__(base_dist,
                                          transforms,
                                          validate_args=validate_args)
Example #5
0
 def __init__(self, scale, alpha, validate_args=None):
     self.scale, self.alpha = broadcast_all(scale, alpha)
     base_dist = Exponential(self.alpha, validate_args=validate_args)
     transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
     super(Pareto, self).__init__(base_dist,
                                  transforms,
                                  validate_args=validate_args)
Example #6
0
 def __init__(self, concentration1, concentration0, loc, scale, validate_args=None):
     base_dist = Beta(concentration1, concentration0, validate_args=validate_args)
     super(AffineBeta, self).__init__(
         base_dist,
         AffineTransform(loc=loc, scale=scale),
         validate_args=validate_args,
     )
Example #7
0
def get_transforms(cache_size):
    transforms = [
        AbsTransform(cache_size=cache_size),
        ExpTransform(cache_size=cache_size),
        PowerTransform(exponent=2,
                       cache_size=cache_size),
        PowerTransform(exponent=torch.tensor(5.).normal_(),
                       cache_size=cache_size),
        SigmoidTransform(cache_size=cache_size),
        TanhTransform(cache_size=cache_size),
        AffineTransform(0, 1, cache_size=cache_size),
        AffineTransform(1, -2, cache_size=cache_size),
        AffineTransform(torch.randn(5),
                        torch.randn(5),
                        cache_size=cache_size),
        AffineTransform(torch.randn(4, 5),
                        torch.randn(4, 5),
                        cache_size=cache_size),
        SoftmaxTransform(cache_size=cache_size),
        StickBreakingTransform(cache_size=cache_size),
        LowerCholeskyTransform(cache_size=cache_size),
        CorrCholeskyTransform(cache_size=cache_size),
        ComposeTransform([
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
        ]),
        ComposeTransform([
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
            ExpTransform(cache_size=cache_size),
        ]),
        ComposeTransform([
            AffineTransform(0, 1, cache_size=cache_size),
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
            AffineTransform(1, -2, cache_size=cache_size),
            AffineTransform(torch.randn(4, 5),
                            torch.randn(4, 5),
                            cache_size=cache_size),
        ]),
    ]
    transforms += [t.inv for t in transforms]
    return transforms
 def __init__(self, data_dim=28 * 28, device='cpu'):
     self.m = TransformedDistribution(
         Uniform(torch.zeros(data_dim, device=device),
                 torch.ones(data_dim, device=device)), [
                     SigmoidTransform().inv,
                     AffineTransform(torch.zeros(data_dim, device=device),
                                     torch.ones(data_dim, device=device))
                 ])
 def expand(self, batch_shape, _instance=None):
     new = self._get_checked_instance(Kuma, _instance)
     new.a = self.a.expand(batch_shape)
     new.b = self.b.expand(batch_shape)
     new.a_reciprocal = new.a.reciprocal()
     new.b_reciprocal = new.b.reciprocal()
     base_dist = self.base_dist.expand(batch_shape)
     transforms = [
         AffineTransform(loc=1, scale=-1),
         PowerTransform(self.b_reciprocal),
         AffineTransform(loc=1, scale=-1),
         PowerTransform(self.a_reciprocal)
     ]
     super(Kumaraswamy, new).__init__(base_dist,
                                      transforms,
                                      validate_args=False)
     new._validate_args = self._validate_args
     return new
Example #10
0
 def __init__(self, scale, concentration, validate_args=None):
     self.scale, self.concentration = broadcast_all(scale, concentration)
     self.concentration_reciprocal = self.concentration.reciprocal()
     base_dist = Exponential(torch.ones_like(self.scale), validate_args=validate_args)
     transforms = [PowerTransform(exponent=self.concentration_reciprocal),
                   AffineTransform(loc=0, scale=self.scale)]
     super(Weibull, self).__init__(base_dist,
                                   transforms,
                                   validate_args=validate_args)
    def __init__(self, w, p, l, temperature=0.1, validate_args=None):
        relaxed_bernoulli = RelaxedBernoulli(temperature, p)
        affine_transform = AffineTransform(w, l - w)
        super(ToeplitzBernoulliDistribution,
              self).__init__(relaxed_bernoulli, affine_transform,
                             validate_args)

        self.relaxed_bernoulli = relaxed_bernoulli
        self.affine_transform = affine_transform
Example #12
0
 def __init__(self, concentration1, concentration0, validate_args=None):
     self.concentration1, self.concentration0 = broadcast_all(concentration1, concentration0)
     finfo = torch.finfo(self.concentration0.dtype)
     base_dist = Uniform(torch.full_like(self.concentration0, 0),
                         torch.full_like(self.concentration0, 1),
                         validate_args=validate_args)
     transforms = [PowerTransform(exponent=self.concentration0.reciprocal()),
                   AffineTransform(loc=1., scale=-1.),
                   PowerTransform(exponent=self.concentration1.reciprocal())]
     super(Kumaraswamy, self).__init__(base_dist, transforms, validate_args=validate_args)
Example #13
0
def reshape_transform(transform, shape):
    # Needed to squash batch dims for testing jacobian
    if isinstance(transform, AffineTransform):
        if isinstance(transform.loc, Number):
            return transform
        try:
            return AffineTransform(transform.loc.expand(shape), transform.scale.expand(shape), cache_size=transform._cache_size)
        except RuntimeError:
            return AffineTransform(transform.loc.reshape(shape), transform.scale.reshape(shape), cache_size=transform._cache_size)
    if isinstance(transform, ComposeTransform):
        reshaped_parts = []
        for p in transform.parts:
            reshaped_parts.append(reshape_transform(p, shape))
        return ComposeTransform(reshaped_parts, cache_size=transform._cache_size)
    if isinstance(transform.inv, AffineTransform):
        return reshape_transform(transform.inv, shape).inv
    if isinstance(transform.inv, ComposeTransform):
        return reshape_transform(transform.inv, shape).inv
    return transform
Example #14
0
 def __init__(
     self,
     concentration1: Union[float, Tensor],
     concentration0: Union[float, Tensor],
     validate_args: bool = False,
 ):
     self.concentration1, self.concentration0 = broadcast_all(
         concentration1, concentration0)
     base_dist = Uniform(
         torch.full_like(self.concentration0, 0.0),
         torch.full_like(self.concentration0, 1.0),
     )
     transforms = [
         AffineTransform(loc=1.0, scale=-1.0),
         PowerTransform(exponent=self.concentration0.reciprocal()),
         AffineTransform(loc=1.0, scale=-1.0),
         PowerTransform(exponent=self.concentration1.reciprocal()),
     ]
     super().__init__(base_dist, transforms, validate_args=validate_args)
Example #15
0
def test_save_load_transform():
    # Evaluating `log_prob` will create a weakref `_inv` which cannot be pickled. Here, we check
    # that `__getstate__` correctly handles the weakref, and that we can evaluate the density after.
    dist = TransformedDistribution(Normal(0, 1), [AffineTransform(2, 3)])
    x = torch.linspace(0, 1, 10)
    log_prob = dist.log_prob(x)
    stream = io.BytesIO()
    torch.save(dist, stream)
    stream.seek(0)
    other = torch.load(stream)
    assert torch.allclose(log_prob, other.log_prob(x))
Example #16
0
def true_model(design):
    w1 = torch.tensor([-1., 1.])
    w2 = torch.tensor([-.5, .5, -.5, .5, -.5, 2., -2., 2., -2., 0.])
    w = torch.cat([w1, w2], dim=-1)
    k = torch.tensor(.1)
    response_mean = rmv(design, w)

    base_dist = dist.Normal(response_mean, torch.tensor(1.)).to_event(1)
    k = k.expand(response_mean.shape)
    transforms = [AffineTransform(loc=0., scale=k), SigmoidTransform()]
    response_dist = dist.TransformedDistribution(base_dist, transforms)
    return pyro.sample("y", response_dist)
Example #17
0
 def forward(self, state, mean_action=False):
     mu, log_std = self.network(state).chunk(2, dim=-1)
     log_std = torch.clamp(
         log_std, LOG_MIN,
         LOG_MAX)  # to make it not too random/deterministic
     normal = TransformedDistribution(
         Independent(Normal(mu, log_std.exp()), 1),
         [TanhTransform(),
          AffineTransform(loc=self.loc, scale=self.scale)])
     if mean_action:
         return self.loc * torch.tanh(mu) + self.scale
     return normal
Example #18
0
 def expand(self, batch_shape, _instance=None):
     new = self._get_checked_instance(Weibull, _instance)
     new.scale = self.scale.expand(batch_shape)
     new.concentration = self.concentration.expand(batch_shape)
     new.concentration_reciprocal = new.concentration.reciprocal()
     base_dist = self.base_dist.expand(batch_shape)
     transforms = [PowerTransform(exponent=new.concentration_reciprocal),
                   AffineTransform(loc=0, scale=new.scale)]
     super(Weibull, new).__init__(base_dist,
                                  transforms,
                                  validate_args=False)
     new._validate_args = self._validate_args
     return new
Example #19
0
 def forward(self, state):
     policy_mean, policy_log_std = self.policy(state).chunk(2, dim=1)
     policy_log_std = torch.clamp(policy_log_std,
                                  min=self.log_std_min,
                                  max=self.log_std_max)
     policy = TransformedDistribution(
         Independent(Normal(policy_mean, policy_log_std.exp()), 1), [
             TanhTransform(),
             AffineTransform(loc=self.action_loc, scale=self.action_scale)
         ])
     policy.mean_ = self.action_scale * torch.tanh(
         policy.base_dist.mean
     ) + self.action_loc  # TODO: See if mean attr can be overwritten
     return policy
Example #20
0
def test_compose_affine(event_dims):
    transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims]
    transform = ComposeTransform(transforms)
    assert transform.codomain.event_dim == max(event_dims)
    assert transform.domain.event_dim == max(event_dims)

    base_dist = Normal(0, 1)
    if transform.domain.event_dim:
        base_dist = base_dist.expand((1,) * transform.domain.event_dim)
    dist = TransformedDistribution(base_dist, transform.parts)
    assert dist.support.event_dim == max(event_dims)

    base_dist = Dirichlet(torch.ones(5))
    if transform.domain.event_dim > 1:
        base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1))
    dist = TransformedDistribution(base_dist, transforms)
    assert dist.support.event_dim == max(1, max(event_dims))
Example #21
0
def sigmoid_example(design):
    n = design.shape[-2]
    random_effect_k = pyro.sample("k", dist.Gamma(2.*torch.ones(n), torch.tensor(2.)))
    random_effect_offset = pyro.sample("w2", dist.Normal(torch.tensor(0.), torch.ones(n)))
    w1 = pyro.sample("w1", dist.Normal(torch.tensor([1., -1.]),
                                       torch.tensor([10., 10.])).to_event(1))
    mean = torch.matmul(design[..., :-2], w1.unsqueeze(-1)).squeeze(-1)
    offset_mean = mean + random_effect_offset

    base_dist = dist.Normal(offset_mean, torch.tensor(1.)).to_event(1)
    transforms = [
        AffineTransform(loc=torch.tensor(0.), scale=random_effect_k),
        SigmoidTransform()
    ]
    response_dist = dist.TransformedDistribution(base_dist, transforms)
    y = pyro.sample("y", response_dist)
    return y
Example #22
0
def test_logistic():
    base_distribution = Uniform(0, 1)
    transforms = [SigmoidTransform().inv, AffineTransform(loc=torch.tensor([2.]), scale=torch.tensor([1.]))]
    model = TransformedDistribution(base_distribution, transforms)
    transform = Logistic(2., 1.)

    x = model.sample((4,)).reshape(-1, 1)
    assert torch.all(transform.log_prob(x)- model.log_prob(x).view(-1) < 1e-4)

    x = transform.sample(4)
    assert x.shape == (4, 1)
    assert torch.all(transform.log_prob(x)- model.log_prob(x).view(-1) < 1e-4)

    x = transform.sample(1)
    assert x.shape == (1, 1)
    assert torch.all(transform.log_prob(x)- model.log_prob(x).view(-1) < 1e-4)

    transform.get_parameters()
Example #23
0
    def SampleAction(self, mean, std):
        # mean and ln_var are predicted by the neural network, this function
        mu = mean
        sig = std * 0.3  # constraining the standard deviation to at maximum 0.3mu
        u_range = self.args['U_UB'] - self.args['U_LB']

        GPol = norm(
            mu, sig
        )  # defining gaussian distribution with mean and std as parameterised
        scale = AffineTransform(self.args['U_LB'], u_range)
        GPol = TransformedDistribution(GPol, scale)

        action = GPol.sample()  # drawing randomly from normal distribution
        assert len(action) == 1
        logGP = GPol.log_prob(
            action)  # calculating log probability of action taken

        return action.cpu(), logGP
Example #24
0
 def __init__(self,
              obs_dim,
              act_dim,
              act_low,
              act_high,
              log_std_min=-20,
              log_std_max=20,
              hidden_size=256):
     super(GaussianActorNetwork, self).__init__(obs_dim,
                                                hidden_size=hidden_size)
     self._mean_layer = nn.Linear(self._hidden_size, act_dim)
     self._std_layer = nn.Linear(self._hidden_size, act_dim)
     self._act_dim = act_dim
     self._log_std_min = log_std_min
     self._log_std_max = log_std_max
     act_scale = torch.FloatTensor(act_high - act_low).to(device)
     act_low = torch.FloatTensor(act_low).to(device)
     self._transforms = [
         SigmoidTransform(),
         AffineTransform(loc=act_low, scale=act_scale)
     ]
Example #25
0
def forge_distribution(mean, sigma, lower_limit=0.0, upper_limit=5.0):
    """
    Find the required concentration hyperparameters in the canonical Beta distribution
    that will return the desired mean and deviation after the affine transformation.
    """
    width = upper_limit - lower_limit
    assert width > 0
    assert sigma < EPS + width / 2, f"invalid std: {sigma.item()}"

    canonical_mean = (mean - lower_limit) / width
    canonical_sigma = sigma / width**2

    alpha_plus_beta = (canonical_mean *
                       (1 - canonical_mean) / canonical_sigma**2) - 1
    alpha = canonical_mean * alpha_plus_beta
    beta = (1 - canonical_mean) * alpha_plus_beta

    canonical = Beta(alpha, beta)
    transformation = AffineTransform(loc=lower_limit, scale=width)
    transformed = TransformedDistribution(canonical, transformation)

    return transformed
Example #26
0
    def __init__(self, prior, coupling, in_out_dim, mid_dim, hidden,
                 bottleneck, compress, device, n_layers):
        """Initialize a NICE.

        Args:
            coupling: number of coupling layers.
            in_out_dim: input/output dimensions.
            mid_dim: number of units in a hidden layer.
            hidden: number of hidden layers.
            device: run on cpu or gpu
        """
        super(NICE, self).__init__()
        self.device = device
        if prior == 'gaussian':
            self.prior = torch.distributions.Normal(
                torch.tensor(0.).to(device),
                torch.tensor(1.).to(device))
        elif prior == 'logistic':
            self.prior = TransformedDistribution(
                Uniform(
                    torch.tensor(0.).to(device),
                    torch.tensor(1.).to(device)),
                [SigmoidTransform().inv,
                 AffineTransform(loc=0., scale=1.)])
        else:
            raise ValueError('Prior not implemented.')

        self.in_out_dim = in_out_dim
        self.coupling = coupling
        self.n_layers = n_layers
        layer = AdditiveCoupling if coupling == 'additive' else AffineCoupling
        self.coupling_layers = nn.ModuleList([
            layer(in_out_dim, mid_dim, hidden, i % 2)
            for i in range(self.n_layers)
        ]).to(device)
        self.scale = Scaling(in_out_dim).to(device)
        self.bottleneck_factor = compress
        self.bottleneck_loss = nn.MSELoss()
        self.bottleneck = bottleneck
Example #27
0
def test_transformed_distribution(base_batch_dim, base_event_dim,
                                  transform_dim, num_transforms, sample_shape):
    shape = torch.Size([2, 3, 4, 5])
    base_dist = Normal(0, 1)
    base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:])
    if base_event_dim:
        base_dist = Independent(base_dist, base_event_dim)
    transforms = [
        AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1),
        ReshapeTransform((4, 5), (20, )),
        ReshapeTransform((3, 20), (6, 10))
    ]
    transforms = transforms[:num_transforms]
    transform = ComposeTransform(transforms)

    # Check validation in .__init__().
    if base_batch_dim + base_event_dim < transform.domain.event_dim:
        with pytest.raises(ValueError):
            TransformedDistribution(base_dist, transforms)
        return
    d = TransformedDistribution(base_dist, transforms)

    # Check sampling is sufficiently expanded.
    x = d.sample(sample_shape)
    assert x.shape == sample_shape + d.batch_shape + d.event_shape
    num_unique = len(set(x.reshape(-1).tolist()))
    assert num_unique >= 0.9 * x.numel()

    # Check log_prob shape on full samples.
    log_prob = d.log_prob(x)
    assert log_prob.shape == sample_shape + d.batch_shape

    # Check log_prob shape on partial samples.
    y = x
    while y.dim() > len(d.event_shape):
        y = y[0]
    log_prob = d.log_prob(y)
    assert log_prob.shape == d.batch_shape
Example #28
0
    def __init__(self, a, theta, alpha, beta):
        """
        The Amoroso distribution is a very flexible 4 parameter distribution which 
        contains many important exponential families as special cases. 

        *PDF*
        ```
        Amoroso(x | a, θ, α, β) = 1/gamma(α) * abs(β/θ) * ((x - a)/θ)**(α*β-1) * exp(-((x - a)/θ)**β)
        for:
            x, a, θ, α, β \in reals, α > 0
        support:
            x >= a if θ > 0
            x <= a if θ < 0
        ```
        """
        self.a, self.theta, self.alpha, self.beta = broadcast_all(
            a, theta, alpha, beta)

        base_dist = Gamma(self.alpha, 1.)
        transform = ComposeTransform([
            AffineTransform(-self.a / self.theta, 1 / self.theta),
            PowerTransform(self.beta),
        ]).inv
        super().__init__(base_dist, transform)
 def create_distribution(self, scale, shape, shift):
     wd = Weibull(scale=scale, concentration=shape)
     transforms = AffineTransform(loc=shift, scale=1.)
     weibull = TransformedDistribution(wd, transforms)
     return weibull
Example #30
0
def bayesian_linear_model(design,
                          w_means={},
                          w_sqrtlambdas={},
                          re_group_sizes={},
                          re_alphas={},
                          re_betas={},
                          obs_sd=None,
                          alpha_0=None,
                          beta_0=None,
                          response="normal",
                          response_label="y",
                          k=None):
    """
    A pyro model for Bayesian linear regression.

    If :param:`response` is `"normal"` this corresponds to a linear regression
    model

        :math:`Y = Xw + \\epsilon`

    with `\\epsilon`` i.i.d. zero-mean Gaussian. The observation standard deviation
    (:param:`obs_sd`) may be known or unknown. If unknown, it is assumed to follow an
    inverse Gamma distribution with parameters :param:`alpha_0` and :param:`beta_0`.

    If the response type is `"bernoulli"` we instead have :math:`Y \\sim Bernoulli(p)`
    with

        :math:`logit(p) = Xw`

    Given parameter groups in :param:`w_means` and :param:`w_sqrtlambda`, the fixed effects
    regression coefficient is taken to be Gaussian with mean `w_mean` and standard deviation
    given by

        :math:`\\sigma / \\sqrt{\\lambda}`

    corresponding to the normal inverse Gamma family.

    The random effects coefficient is constructed as follows. For each random effect
    group, standard deviations for that group are sampled from a normal inverse Gamma
    distribution. For each group, a random effect coefficient is then sampled from a zero
    mean Gaussian with those standard deviations.

    :param torch.Tensor design: a tensor with last two dimensions `n` and `p`
            corresponding to observations and features respectively.
    :param OrderedDict w_means: map from variable names to tensors of fixed effect means.
    :param OrderedDict w_sqrtlambdas: map from variable names to tensors of square root
        :math:`\\lambda` values for fixed effects.
    :param OrderedDict re_group_sizes: map from variable names to int representing the
        group size
    :param OrderedDict re_alphas: map from variable names to `torch.Tensor`, the tensor
        consists of Gamma dist :math:`\\alpha` values
    :param OrderedDict re_betas: map from variable names to `torch.Tensor`, the tensor
        consists of Gamma dist :math:`\\beta` values
    :param torch.Tensor obs_sd: the observation standard deviation (if assumed known).
        This is still relevant in the case of Bernoulli observations when coefficeints
        are sampled using `w_sqrtlambdas`.
    :param torch.Tensor alpha_0: Gamma :math:`\\alpha` parameter for unknown observation
        covariance.
    :param torch.Tensor beta_0: Gamma :math:`\\beta` parameter for unknown observation
        covariance.
    :param str response: Emission distribution. May be `"normal"` or `"bernoulli"`.
    :param str response_label: Variable label for response.
    :param torch.Tensor k: Only used for a sigmoid response. The slope of the sigmoid
        transformation.
    """
    # design is size batch x n x p
    # tau is size batch
    batch_shape = design.shape[:-2]
    with ExitStack() as stack:
        for plate in iter_plates_to_shape(batch_shape):
            stack.enter_context(plate)

        if obs_sd is None:
            # First, sample tau (observation precision)
            tau_prior = dist.Gamma(alpha_0.unsqueeze(-1),
                                   beta_0.unsqueeze(-1)).to_event(1)
            tau = pyro.sample("tau", tau_prior)
            obs_sd = 1. / torch.sqrt(tau)

        elif alpha_0 is not None or beta_0 is not None:
            warnings.warn("Values of `alpha_0` and `beta_0` unused becased"
                          "`obs_sd` was specified already.")

        obs_sd = obs_sd.expand(batch_shape + (1, ))

        # Build the regression coefficient
        w = []
        # Allow different names for different coefficient groups
        # Process fixed effects
        for name, w_sqrtlambda in w_sqrtlambdas.items():
            w_mean = w_means[name]
            # Place a normal prior on the regression coefficient
            w_prior = dist.Normal(w_mean, obs_sd / w_sqrtlambda).to_event(1)
            w.append(pyro.sample(name, w_prior))
        # Process random effects
        for name, group_size in re_group_sizes.items():
            # Sample `G` once for this group
            alpha, beta = re_alphas[name], re_betas[name]
            G_prior = dist.Gamma(alpha, beta).to_event(1)
            G = 1. / torch.sqrt(pyro.sample("G_" + name, G_prior))
            # Repeat `G` for each group
            repeat_shape = tuple(1 for _ in batch_shape) + (group_size, )
            u_prior = dist.Normal(torch.tensor(0.),
                                  G.repeat(repeat_shape)).to_event(1)
            w.append(pyro.sample(name, u_prior))
        # Regression coefficient `w` is batch x p
        w = broadcast_cat(w)

        # Run the regressor forward conditioned on inputs
        prediction_mean = rmv(design, w)
        if response == "normal":
            # y is an n-vector: hence use .to_event(1)
            return pyro.sample(
                response_label,
                dist.Normal(prediction_mean, obs_sd).to_event(1))
        elif response == "bernoulli":
            return pyro.sample(
                response_label,
                dist.Bernoulli(logits=prediction_mean).to_event(1))
        elif response == "sigmoid":
            base_dist = dist.Normal(prediction_mean, obs_sd).to_event(1)
            # You can add loc via the linear model itself
            k = k.expand(prediction_mean.shape)
            transforms = [
                AffineTransform(loc=torch.tensor(0.), scale=k),
                SigmoidTransform()
            ]
            response_dist = dist.TransformedDistribution(base_dist, transforms)
            return pyro.sample(response_label, response_dist)
        else:
            raise ValueError(
                "Unknown response distribution: '{}'".format(response))