def support(self): # First, we remove all `independent` constraints. This applies to e.g. # `MultivariateNormal`. An `independent` constraint returns a 1D `[True]` # when `.support.check(sample)` is called, whereas distributions that are # not `independent` (e.g. `Gamma`), return a 2D `[[True]]`. When such # constraints would be combined with the `constraint.cat(..., dim=1)`, it # fails because the `independent` constraint returned only a 1D `[True]`. supports = [] for d in self.dists: if isinstance(d.support, constraints.independent): supports.append(d.support.base_constraint) else: supports.append(d.support) # Wrap as `independent` in order to have the correct shape of the # `log_abs_det`, i.e. summed over the parameter dimensions. return constraints.independent( constraints.cat(supports, dim=1, lengths=self.dims_per_dist), reinterpreted_batch_ndims=1, )
def codomain(self): return constraints.cat([t.codomain for t in self.transforms], self.dim, self.lengths)