def forward(self, belief, state, deterministic=False, with_logprob=False,): raw_init_std = np.log(np.exp(self.init_std) - 1) hidden = self.act_fn(self.fc1(torch.cat([belief, state], dim=-1))) hidden = self.act_fn(self.fc2(hidden)) hidden = self.act_fn(self.fc3(hidden)) hidden = self.act_fn(self.fc4(hidden)) hidden = self.fc5(hidden) mean, std = torch.chunk(hidden, 2, dim=-1) # # --------- # mean = self.mean_scale * torch.tanh(mean / self.mean_scale) # bound the action to [-5, 5] --> to avoid numerical instabilities. For computing log-probabilities, we need to invert the tanh and this becomes difficult in highly saturated regions. # speed = torch.full(mean.shape, 0.3).to("cuda") # mean = torch.cat((mean, speed), -1) # # std = F.softplus(std + raw_init_std) + self.min_std # # speed = torch.full(std.shape, 0.0).to("cuda") # std = torch.cat((std, speed), -1) # # dist = torch.distributions.Normal(mean, std) # transform = [torch.distributions.transforms.TanhTransform()] # dist = torch.distributions.TransformedDistribution(dist, transform) # dist = torch.distributions.independent.Independent(dist, 1) # Introduces dependence between actions dimension # dist = SampleDist(dist) # because after transform a distribution, some methods may become invalid, such as entropy, mean and mode, we need SmapleDist to approximate it. # return dist # dist ~ tanh(Normal(mean, std)); remember when sampling, using rsample() to adopt the reparameterization trick mean = self.mean_scale * torch.tanh(mean / self.mean_scale) # bound the action to [-5, 5] --> to avoid numerical instabilities. For computing log-probabilities, we need to invert the tanh and this becomes difficult in highly saturated regions. std = F.softplus(std + raw_init_std) + self.min_std dist = torch.distributions.Normal(mean, std) # TanhTransform = ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)]) if self.fix_speed: transform = [AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)] else: transform = [AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.), # TanhTransform AffineTransform(loc=torch.tensor([0.0, self.throtlle_base]).to("cuda"), scale=torch.tensor([1.0, 0.2]).to("cuda"))] # TODO: this is limited at donkeycar env dist = TransformedDistribution(dist, transform) # dist = torch.distributions.independent.Independent(dist, 1) # Introduces dependence between actions dimension dist = SampleDist(dist) # because after transform a distribution, some methods may become invalid, such as entropy, mean and mode, we need SmapleDist to approximate it. if deterministic: action = dist.mean else: action = dist.rsample() # not use logprob now if with_logprob: logp_pi = dist.log_prob(action).sum(dim=1) else: logp_pi = None # action dim: [batch, act_dim], log_pi dim:[batch] return action if not self.fix_speed else torch.cat((action, self.throtlle_base*torch.ones_like(action, requires_grad=False)), dim=-1), logp_pi # dist ~ tanh(Normal(mean, std)); remember when sampling, using rsample() to adopt the reparameterization trick
def __init__(self, loc, scale, validate_args=None): self.loc, self.scale = broadcast_all(loc, scale) finfo = _finfo(self.loc) if isinstance(loc, Number) and isinstance(scale, Number): base_dist = Uniform(finfo.tiny, 1 - finfo.eps) else: base_dist = Uniform(self.loc.new(self.loc.size()).fill_(finfo.tiny), 1 - finfo.eps) transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)), ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)] super(Gumbel, self).__init__(base_dist, transforms, validate_args=validate_args)
def __init__(self, w, p, temperature=0.1, validate_args=None): relaxed_bernoulli = RelaxedBernoulli(temperature, p) affine_transform = AffineTransform(0, w) one_minus_p = AffineTransform(1, -1) super(BernoulliDropoutDistribution, self).__init__(relaxed_bernoulli, ComposeTransform([one_minus_p, affine_transform]), validate_args) self.relaxed_bernoulli = relaxed_bernoulli self.affine_transform = affine_transform
def __init__(self, a, b, validate_args=None): self.a, self.b = broadcast_all(a, b) self.a_reciprocal = self.a.reciprocal() self.b_reciprocal = self.b.reciprocal() base_dist = Uniform(torch.full_like(self.a, EPS), torch.full_like(self.a, 1. - EPS)) transforms = [ AffineTransform(loc=1, scale=-1), PowerTransform(self.b_reciprocal), AffineTransform(loc=1, scale=-1), PowerTransform(self.a_reciprocal) ] super(Kumaraswamy, self).__init__(base_dist, transforms, validate_args=validate_args)
def __init__(self, scale, alpha, validate_args=None): self.scale, self.alpha = broadcast_all(scale, alpha) base_dist = Exponential(self.alpha, validate_args=validate_args) transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)] super(Pareto, self).__init__(base_dist, transforms, validate_args=validate_args)
def __init__(self, concentration1, concentration0, loc, scale, validate_args=None): base_dist = Beta(concentration1, concentration0, validate_args=validate_args) super(AffineBeta, self).__init__( base_dist, AffineTransform(loc=loc, scale=scale), validate_args=validate_args, )
def get_transforms(cache_size): transforms = [ AbsTransform(cache_size=cache_size), ExpTransform(cache_size=cache_size), PowerTransform(exponent=2, cache_size=cache_size), PowerTransform(exponent=torch.tensor(5.).normal_(), cache_size=cache_size), SigmoidTransform(cache_size=cache_size), TanhTransform(cache_size=cache_size), AffineTransform(0, 1, cache_size=cache_size), AffineTransform(1, -2, cache_size=cache_size), AffineTransform(torch.randn(5), torch.randn(5), cache_size=cache_size), AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), SoftmaxTransform(cache_size=cache_size), StickBreakingTransform(cache_size=cache_size), LowerCholeskyTransform(cache_size=cache_size), CorrCholeskyTransform(cache_size=cache_size), ComposeTransform([ AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), ]), ComposeTransform([ AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), ExpTransform(cache_size=cache_size), ]), ComposeTransform([ AffineTransform(0, 1, cache_size=cache_size), AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), AffineTransform(1, -2, cache_size=cache_size), AffineTransform(torch.randn(4, 5), torch.randn(4, 5), cache_size=cache_size), ]), ] transforms += [t.inv for t in transforms] return transforms
def __init__(self, data_dim=28 * 28, device='cpu'): self.m = TransformedDistribution( Uniform(torch.zeros(data_dim, device=device), torch.ones(data_dim, device=device)), [ SigmoidTransform().inv, AffineTransform(torch.zeros(data_dim, device=device), torch.ones(data_dim, device=device)) ])
def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Kuma, _instance) new.a = self.a.expand(batch_shape) new.b = self.b.expand(batch_shape) new.a_reciprocal = new.a.reciprocal() new.b_reciprocal = new.b.reciprocal() base_dist = self.base_dist.expand(batch_shape) transforms = [ AffineTransform(loc=1, scale=-1), PowerTransform(self.b_reciprocal), AffineTransform(loc=1, scale=-1), PowerTransform(self.a_reciprocal) ] super(Kumaraswamy, new).__init__(base_dist, transforms, validate_args=False) new._validate_args = self._validate_args return new
def __init__(self, scale, concentration, validate_args=None): self.scale, self.concentration = broadcast_all(scale, concentration) self.concentration_reciprocal = self.concentration.reciprocal() base_dist = Exponential(torch.ones_like(self.scale), validate_args=validate_args) transforms = [PowerTransform(exponent=self.concentration_reciprocal), AffineTransform(loc=0, scale=self.scale)] super(Weibull, self).__init__(base_dist, transforms, validate_args=validate_args)
def __init__(self, w, p, l, temperature=0.1, validate_args=None): relaxed_bernoulli = RelaxedBernoulli(temperature, p) affine_transform = AffineTransform(w, l - w) super(ToeplitzBernoulliDistribution, self).__init__(relaxed_bernoulli, affine_transform, validate_args) self.relaxed_bernoulli = relaxed_bernoulli self.affine_transform = affine_transform
def __init__(self, concentration1, concentration0, validate_args=None): self.concentration1, self.concentration0 = broadcast_all(concentration1, concentration0) finfo = torch.finfo(self.concentration0.dtype) base_dist = Uniform(torch.full_like(self.concentration0, 0), torch.full_like(self.concentration0, 1), validate_args=validate_args) transforms = [PowerTransform(exponent=self.concentration0.reciprocal()), AffineTransform(loc=1., scale=-1.), PowerTransform(exponent=self.concentration1.reciprocal())] super(Kumaraswamy, self).__init__(base_dist, transforms, validate_args=validate_args)
def reshape_transform(transform, shape): # Needed to squash batch dims for testing jacobian if isinstance(transform, AffineTransform): if isinstance(transform.loc, Number): return transform try: return AffineTransform(transform.loc.expand(shape), transform.scale.expand(shape), cache_size=transform._cache_size) except RuntimeError: return AffineTransform(transform.loc.reshape(shape), transform.scale.reshape(shape), cache_size=transform._cache_size) if isinstance(transform, ComposeTransform): reshaped_parts = [] for p in transform.parts: reshaped_parts.append(reshape_transform(p, shape)) return ComposeTransform(reshaped_parts, cache_size=transform._cache_size) if isinstance(transform.inv, AffineTransform): return reshape_transform(transform.inv, shape).inv if isinstance(transform.inv, ComposeTransform): return reshape_transform(transform.inv, shape).inv return transform
def __init__( self, concentration1: Union[float, Tensor], concentration0: Union[float, Tensor], validate_args: bool = False, ): self.concentration1, self.concentration0 = broadcast_all( concentration1, concentration0) base_dist = Uniform( torch.full_like(self.concentration0, 0.0), torch.full_like(self.concentration0, 1.0), ) transforms = [ AffineTransform(loc=1.0, scale=-1.0), PowerTransform(exponent=self.concentration0.reciprocal()), AffineTransform(loc=1.0, scale=-1.0), PowerTransform(exponent=self.concentration1.reciprocal()), ] super().__init__(base_dist, transforms, validate_args=validate_args)
def test_save_load_transform(): # Evaluating `log_prob` will create a weakref `_inv` which cannot be pickled. Here, we check # that `__getstate__` correctly handles the weakref, and that we can evaluate the density after. dist = TransformedDistribution(Normal(0, 1), [AffineTransform(2, 3)]) x = torch.linspace(0, 1, 10) log_prob = dist.log_prob(x) stream = io.BytesIO() torch.save(dist, stream) stream.seek(0) other = torch.load(stream) assert torch.allclose(log_prob, other.log_prob(x))
def true_model(design): w1 = torch.tensor([-1., 1.]) w2 = torch.tensor([-.5, .5, -.5, .5, -.5, 2., -2., 2., -2., 0.]) w = torch.cat([w1, w2], dim=-1) k = torch.tensor(.1) response_mean = rmv(design, w) base_dist = dist.Normal(response_mean, torch.tensor(1.)).to_event(1) k = k.expand(response_mean.shape) transforms = [AffineTransform(loc=0., scale=k), SigmoidTransform()] response_dist = dist.TransformedDistribution(base_dist, transforms) return pyro.sample("y", response_dist)
def forward(self, state, mean_action=False): mu, log_std = self.network(state).chunk(2, dim=-1) log_std = torch.clamp( log_std, LOG_MIN, LOG_MAX) # to make it not too random/deterministic normal = TransformedDistribution( Independent(Normal(mu, log_std.exp()), 1), [TanhTransform(), AffineTransform(loc=self.loc, scale=self.scale)]) if mean_action: return self.loc * torch.tanh(mu) + self.scale return normal
def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Weibull, _instance) new.scale = self.scale.expand(batch_shape) new.concentration = self.concentration.expand(batch_shape) new.concentration_reciprocal = new.concentration.reciprocal() base_dist = self.base_dist.expand(batch_shape) transforms = [PowerTransform(exponent=new.concentration_reciprocal), AffineTransform(loc=0, scale=new.scale)] super(Weibull, new).__init__(base_dist, transforms, validate_args=False) new._validate_args = self._validate_args return new
def forward(self, state): policy_mean, policy_log_std = self.policy(state).chunk(2, dim=1) policy_log_std = torch.clamp(policy_log_std, min=self.log_std_min, max=self.log_std_max) policy = TransformedDistribution( Independent(Normal(policy_mean, policy_log_std.exp()), 1), [ TanhTransform(), AffineTransform(loc=self.action_loc, scale=self.action_scale) ]) policy.mean_ = self.action_scale * torch.tanh( policy.base_dist.mean ) + self.action_loc # TODO: See if mean attr can be overwritten return policy
def test_compose_affine(event_dims): transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims] transform = ComposeTransform(transforms) assert transform.codomain.event_dim == max(event_dims) assert transform.domain.event_dim == max(event_dims) base_dist = Normal(0, 1) if transform.domain.event_dim: base_dist = base_dist.expand((1,) * transform.domain.event_dim) dist = TransformedDistribution(base_dist, transform.parts) assert dist.support.event_dim == max(event_dims) base_dist = Dirichlet(torch.ones(5)) if transform.domain.event_dim > 1: base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1)) dist = TransformedDistribution(base_dist, transforms) assert dist.support.event_dim == max(1, max(event_dims))
def sigmoid_example(design): n = design.shape[-2] random_effect_k = pyro.sample("k", dist.Gamma(2.*torch.ones(n), torch.tensor(2.))) random_effect_offset = pyro.sample("w2", dist.Normal(torch.tensor(0.), torch.ones(n))) w1 = pyro.sample("w1", dist.Normal(torch.tensor([1., -1.]), torch.tensor([10., 10.])).to_event(1)) mean = torch.matmul(design[..., :-2], w1.unsqueeze(-1)).squeeze(-1) offset_mean = mean + random_effect_offset base_dist = dist.Normal(offset_mean, torch.tensor(1.)).to_event(1) transforms = [ AffineTransform(loc=torch.tensor(0.), scale=random_effect_k), SigmoidTransform() ] response_dist = dist.TransformedDistribution(base_dist, transforms) y = pyro.sample("y", response_dist) return y
def test_logistic(): base_distribution = Uniform(0, 1) transforms = [SigmoidTransform().inv, AffineTransform(loc=torch.tensor([2.]), scale=torch.tensor([1.]))] model = TransformedDistribution(base_distribution, transforms) transform = Logistic(2., 1.) x = model.sample((4,)).reshape(-1, 1) assert torch.all(transform.log_prob(x)- model.log_prob(x).view(-1) < 1e-4) x = transform.sample(4) assert x.shape == (4, 1) assert torch.all(transform.log_prob(x)- model.log_prob(x).view(-1) < 1e-4) x = transform.sample(1) assert x.shape == (1, 1) assert torch.all(transform.log_prob(x)- model.log_prob(x).view(-1) < 1e-4) transform.get_parameters()
def SampleAction(self, mean, std): # mean and ln_var are predicted by the neural network, this function mu = mean sig = std * 0.3 # constraining the standard deviation to at maximum 0.3mu u_range = self.args['U_UB'] - self.args['U_LB'] GPol = norm( mu, sig ) # defining gaussian distribution with mean and std as parameterised scale = AffineTransform(self.args['U_LB'], u_range) GPol = TransformedDistribution(GPol, scale) action = GPol.sample() # drawing randomly from normal distribution assert len(action) == 1 logGP = GPol.log_prob( action) # calculating log probability of action taken return action.cpu(), logGP
def __init__(self, obs_dim, act_dim, act_low, act_high, log_std_min=-20, log_std_max=20, hidden_size=256): super(GaussianActorNetwork, self).__init__(obs_dim, hidden_size=hidden_size) self._mean_layer = nn.Linear(self._hidden_size, act_dim) self._std_layer = nn.Linear(self._hidden_size, act_dim) self._act_dim = act_dim self._log_std_min = log_std_min self._log_std_max = log_std_max act_scale = torch.FloatTensor(act_high - act_low).to(device) act_low = torch.FloatTensor(act_low).to(device) self._transforms = [ SigmoidTransform(), AffineTransform(loc=act_low, scale=act_scale) ]
def forge_distribution(mean, sigma, lower_limit=0.0, upper_limit=5.0): """ Find the required concentration hyperparameters in the canonical Beta distribution that will return the desired mean and deviation after the affine transformation. """ width = upper_limit - lower_limit assert width > 0 assert sigma < EPS + width / 2, f"invalid std: {sigma.item()}" canonical_mean = (mean - lower_limit) / width canonical_sigma = sigma / width**2 alpha_plus_beta = (canonical_mean * (1 - canonical_mean) / canonical_sigma**2) - 1 alpha = canonical_mean * alpha_plus_beta beta = (1 - canonical_mean) * alpha_plus_beta canonical = Beta(alpha, beta) transformation = AffineTransform(loc=lower_limit, scale=width) transformed = TransformedDistribution(canonical, transformation) return transformed
def __init__(self, prior, coupling, in_out_dim, mid_dim, hidden, bottleneck, compress, device, n_layers): """Initialize a NICE. Args: coupling: number of coupling layers. in_out_dim: input/output dimensions. mid_dim: number of units in a hidden layer. hidden: number of hidden layers. device: run on cpu or gpu """ super(NICE, self).__init__() self.device = device if prior == 'gaussian': self.prior = torch.distributions.Normal( torch.tensor(0.).to(device), torch.tensor(1.).to(device)) elif prior == 'logistic': self.prior = TransformedDistribution( Uniform( torch.tensor(0.).to(device), torch.tensor(1.).to(device)), [SigmoidTransform().inv, AffineTransform(loc=0., scale=1.)]) else: raise ValueError('Prior not implemented.') self.in_out_dim = in_out_dim self.coupling = coupling self.n_layers = n_layers layer = AdditiveCoupling if coupling == 'additive' else AffineCoupling self.coupling_layers = nn.ModuleList([ layer(in_out_dim, mid_dim, hidden, i % 2) for i in range(self.n_layers) ]).to(device) self.scale = Scaling(in_out_dim).to(device) self.bottleneck_factor = compress self.bottleneck_loss = nn.MSELoss() self.bottleneck = bottleneck
def test_transformed_distribution(base_batch_dim, base_event_dim, transform_dim, num_transforms, sample_shape): shape = torch.Size([2, 3, 4, 5]) base_dist = Normal(0, 1) base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:]) if base_event_dim: base_dist = Independent(base_dist, base_event_dim) transforms = [ AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1), ReshapeTransform((4, 5), (20, )), ReshapeTransform((3, 20), (6, 10)) ] transforms = transforms[:num_transforms] transform = ComposeTransform(transforms) # Check validation in .__init__(). if base_batch_dim + base_event_dim < transform.domain.event_dim: with pytest.raises(ValueError): TransformedDistribution(base_dist, transforms) return d = TransformedDistribution(base_dist, transforms) # Check sampling is sufficiently expanded. x = d.sample(sample_shape) assert x.shape == sample_shape + d.batch_shape + d.event_shape num_unique = len(set(x.reshape(-1).tolist())) assert num_unique >= 0.9 * x.numel() # Check log_prob shape on full samples. log_prob = d.log_prob(x) assert log_prob.shape == sample_shape + d.batch_shape # Check log_prob shape on partial samples. y = x while y.dim() > len(d.event_shape): y = y[0] log_prob = d.log_prob(y) assert log_prob.shape == d.batch_shape
def __init__(self, a, theta, alpha, beta): """ The Amoroso distribution is a very flexible 4 parameter distribution which contains many important exponential families as special cases. *PDF* ``` Amoroso(x | a, θ, α, β) = 1/gamma(α) * abs(β/θ) * ((x - a)/θ)**(α*β-1) * exp(-((x - a)/θ)**β) for: x, a, θ, α, β \in reals, α > 0 support: x >= a if θ > 0 x <= a if θ < 0 ``` """ self.a, self.theta, self.alpha, self.beta = broadcast_all( a, theta, alpha, beta) base_dist = Gamma(self.alpha, 1.) transform = ComposeTransform([ AffineTransform(-self.a / self.theta, 1 / self.theta), PowerTransform(self.beta), ]).inv super().__init__(base_dist, transform)
def create_distribution(self, scale, shape, shift): wd = Weibull(scale=scale, concentration=shape) transforms = AffineTransform(loc=shift, scale=1.) weibull = TransformedDistribution(wd, transforms) return weibull
def bayesian_linear_model(design, w_means={}, w_sqrtlambdas={}, re_group_sizes={}, re_alphas={}, re_betas={}, obs_sd=None, alpha_0=None, beta_0=None, response="normal", response_label="y", k=None): """ A pyro model for Bayesian linear regression. If :param:`response` is `"normal"` this corresponds to a linear regression model :math:`Y = Xw + \\epsilon` with `\\epsilon`` i.i.d. zero-mean Gaussian. The observation standard deviation (:param:`obs_sd`) may be known or unknown. If unknown, it is assumed to follow an inverse Gamma distribution with parameters :param:`alpha_0` and :param:`beta_0`. If the response type is `"bernoulli"` we instead have :math:`Y \\sim Bernoulli(p)` with :math:`logit(p) = Xw` Given parameter groups in :param:`w_means` and :param:`w_sqrtlambda`, the fixed effects regression coefficient is taken to be Gaussian with mean `w_mean` and standard deviation given by :math:`\\sigma / \\sqrt{\\lambda}` corresponding to the normal inverse Gamma family. The random effects coefficient is constructed as follows. For each random effect group, standard deviations for that group are sampled from a normal inverse Gamma distribution. For each group, a random effect coefficient is then sampled from a zero mean Gaussian with those standard deviations. :param torch.Tensor design: a tensor with last two dimensions `n` and `p` corresponding to observations and features respectively. :param OrderedDict w_means: map from variable names to tensors of fixed effect means. :param OrderedDict w_sqrtlambdas: map from variable names to tensors of square root :math:`\\lambda` values for fixed effects. :param OrderedDict re_group_sizes: map from variable names to int representing the group size :param OrderedDict re_alphas: map from variable names to `torch.Tensor`, the tensor consists of Gamma dist :math:`\\alpha` values :param OrderedDict re_betas: map from variable names to `torch.Tensor`, the tensor consists of Gamma dist :math:`\\beta` values :param torch.Tensor obs_sd: the observation standard deviation (if assumed known). This is still relevant in the case of Bernoulli observations when coefficeints are sampled using `w_sqrtlambdas`. :param torch.Tensor alpha_0: Gamma :math:`\\alpha` parameter for unknown observation covariance. :param torch.Tensor beta_0: Gamma :math:`\\beta` parameter for unknown observation covariance. :param str response: Emission distribution. May be `"normal"` or `"bernoulli"`. :param str response_label: Variable label for response. :param torch.Tensor k: Only used for a sigmoid response. The slope of the sigmoid transformation. """ # design is size batch x n x p # tau is size batch batch_shape = design.shape[:-2] with ExitStack() as stack: for plate in iter_plates_to_shape(batch_shape): stack.enter_context(plate) if obs_sd is None: # First, sample tau (observation precision) tau_prior = dist.Gamma(alpha_0.unsqueeze(-1), beta_0.unsqueeze(-1)).to_event(1) tau = pyro.sample("tau", tau_prior) obs_sd = 1. / torch.sqrt(tau) elif alpha_0 is not None or beta_0 is not None: warnings.warn("Values of `alpha_0` and `beta_0` unused becased" "`obs_sd` was specified already.") obs_sd = obs_sd.expand(batch_shape + (1, )) # Build the regression coefficient w = [] # Allow different names for different coefficient groups # Process fixed effects for name, w_sqrtlambda in w_sqrtlambdas.items(): w_mean = w_means[name] # Place a normal prior on the regression coefficient w_prior = dist.Normal(w_mean, obs_sd / w_sqrtlambda).to_event(1) w.append(pyro.sample(name, w_prior)) # Process random effects for name, group_size in re_group_sizes.items(): # Sample `G` once for this group alpha, beta = re_alphas[name], re_betas[name] G_prior = dist.Gamma(alpha, beta).to_event(1) G = 1. / torch.sqrt(pyro.sample("G_" + name, G_prior)) # Repeat `G` for each group repeat_shape = tuple(1 for _ in batch_shape) + (group_size, ) u_prior = dist.Normal(torch.tensor(0.), G.repeat(repeat_shape)).to_event(1) w.append(pyro.sample(name, u_prior)) # Regression coefficient `w` is batch x p w = broadcast_cat(w) # Run the regressor forward conditioned on inputs prediction_mean = rmv(design, w) if response == "normal": # y is an n-vector: hence use .to_event(1) return pyro.sample( response_label, dist.Normal(prediction_mean, obs_sd).to_event(1)) elif response == "bernoulli": return pyro.sample( response_label, dist.Bernoulli(logits=prediction_mean).to_event(1)) elif response == "sigmoid": base_dist = dist.Normal(prediction_mean, obs_sd).to_event(1) # You can add loc via the linear model itself k = k.expand(prediction_mean.shape) transforms = [ AffineTransform(loc=torch.tensor(0.), scale=k), SigmoidTransform() ] response_dist = dist.TransformedDistribution(base_dist, transforms) return pyro.sample(response_label, response_dist) else: raise ValueError( "Unknown response distribution: '{}'".format(response))