def support(self):
     if not self.transforms:
         return self.base_dist.support
     support = self.transforms[-1].codomain
     if len(self.event_shape) > support.event_dim:
         support = constraints.independent(support, len(self.event_shape) - support.event_dim)
     return support
Exemplo n.º 2
0
 def codomain(self):
     if not self.parts:
         return constraints.real
     codomain = self.parts[-1].codomain
     # Adjust event_dim to be maximum among all parts.
     event_dim = self.parts[0].domain.event_dim
     for part in self.parts:
         event_dim += part.codomain.event_dim - part.domain.event_dim
         event_dim = max(event_dim, part.codomain.event_dim)
     assert event_dim >= codomain.event_dim
     if event_dim > codomain.event_dim:
         codomain = constraints.independent(codomain,
                                            event_dim - codomain.event_dim)
     return codomain
Exemplo n.º 3
0
 def domain(self):
     if not self.parts:
         return constraints.real
     domain = self.parts[0].domain
     # Adjust event_dim to be maximum among all parts.
     event_dim = self.parts[-1].codomain.event_dim
     for part in reversed(self.parts):
         event_dim += part.domain.event_dim - part.codomain.event_dim
         event_dim = max(event_dim, part.domain.event_dim)
     assert event_dim >= domain.event_dim
     if event_dim > domain.event_dim:
         domain = constraints.independent(domain,
                                          event_dim - domain.event_dim)
     return domain
Exemplo n.º 4
0
class UnitLowerCholeskyTransform(Transform):
    """
    Transform from unconstrained matrices to lower-triangular matrices with
    all ones diagonals.
    """

    domain = constraints.independent(constraints.real, 2)
    codomain = unit_lower_cholesky

    def __eq__(self, other):
        return isinstance(other, UnitLowerCholeskyTransform)

    def _call(self, x):
        return x.tril(-1) + torch.eye(
            x.size(-1), device=x.device, dtype=x.dtype)

    def _inverse(self, y):
        return y
Exemplo n.º 5
0
class SoftplusLowerCholeskyTransform(Transform):
    """
    Transform from unconstrained matrices to lower-triangular matrices with
    nonnegative diagonal entries. This is useful for parameterizing positive
    definite matrices in terms of their Cholesky factorization.
    """
    domain = constraints.independent(constraints.real, 2)
    codomain = constraints.lower_cholesky

    def __eq__(self, other):
        return isinstance(other, SoftplusLowerCholeskyTransform)

    def _call(self, x):
        diag = softplus(x.diagonal(dim1=-2, dim2=-1))
        return x.tril(-1) + diag.diag_embed()

    def _inverse(self, y):
        diag = softplus_inv(y.diagonal(dim1=-2, dim2=-1))
        return y.tril(-1) + diag.diag_embed()
Exemplo n.º 6
0
 def support(self):
     # First, we remove all `independent` constraints. This applies to e.g.
     # `MultivariateNormal`. An `independent` constraint returns a 1D `[True]`
     # when `.support.check(sample)` is called, whereas distributions that are
     # not `independent` (e.g. `Gamma`), return a 2D `[[True]]`. When such
     # constraints would be combined with the `constraint.cat(..., dim=1)`, it
     # fails because the `independent` constraint returned only a 1D `[True]`.
     supports = []
     for d in self.dists:
         if isinstance(d.support, constraints.independent):
             supports.append(d.support.base_constraint)
         else:
             supports.append(d.support)
     # Wrap as `independent` in order to have the correct shape of the
     # `log_abs_det`, i.e. summed over the parameter dimensions.
     return constraints.independent(
         constraints.cat(supports, dim=1, lengths=self.dims_per_dist),
         reinterpreted_batch_ndims=1,
     )
Exemplo n.º 7
0
 def support(self):
     result = self.base_dist.support
     if self.reinterpreted_batch_ndims:
         result = constraints.independent(result,
                                          self.reinterpreted_batch_ndims)
     return result
Exemplo n.º 8
0
class Dirichlet(ExponentialFamily):
    r"""
    Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`.

    Example::

        >>> m = Dirichlet(torch.tensor([0.5, 0.5]))
        >>> m.sample()  # Dirichlet distributed with concentrarion concentration
        tensor([ 0.1046,  0.8954])

    Args:
        concentration (Tensor): concentration parameter of the distribution
            (often referred to as alpha)
    """
    arg_constraints = {
        'concentration': constraints.independent(constraints.positive, 1)
    }
    support = constraints.simplex
    has_rsample = True

    def __init__(self, concentration, validate_args=None):
        if concentration.dim() < 1:
            raise ValueError(
                "`concentration` parameter must be at least one-dimensional.")
        self.concentration = concentration
        batch_shape, event_shape = concentration.shape[:
                                                       -1], concentration.shape[
                                                           -1:]
        super(Dirichlet, self).__init__(batch_shape,
                                        event_shape,
                                        validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(Dirichlet, _instance)
        batch_shape = torch.Size(batch_shape)
        new.concentration = self.concentration.expand(batch_shape +
                                                      self.event_shape)
        super(Dirichlet, new).__init__(batch_shape,
                                       self.event_shape,
                                       validate_args=False)
        new._validate_args = self._validate_args
        return new

    def rsample(self, sample_shape=()):
        shape = self._extended_shape(sample_shape)
        concentration = self.concentration.expand(shape)
        return _Dirichlet.apply(concentration)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        return ((torch.log(value) * (self.concentration - 1.0)).sum(-1) +
                torch.lgamma(self.concentration.sum(-1)) -
                torch.lgamma(self.concentration).sum(-1))

    @property
    def mean(self):
        return self.concentration / self.concentration.sum(-1, True)

    @property
    def mode(self):
        concentrationm1 = (self.concentration - 1).clamp(min=0.)
        mode = concentrationm1 / concentrationm1.sum(-1, True)
        mask = (self.concentration < 1).all(axis=-1)
        mode[mask] = torch.nn.functional.one_hot(
            mode[mask].argmax(axis=-1), concentrationm1.shape[-1]).to(mode)
        return mode

    @property
    def variance(self):
        con0 = self.concentration.sum(-1, True)
        return self.concentration * (con0 -
                                     self.concentration) / (con0.pow(2) *
                                                            (con0 + 1))

    def entropy(self):
        k = self.concentration.size(-1)
        a0 = self.concentration.sum(-1)
        return (torch.lgamma(self.concentration).sum(-1) - torch.lgamma(a0) -
                (k - a0) * torch.digamma(a0) -
                ((self.concentration - 1.0) *
                 torch.digamma(self.concentration)).sum(-1))

    @property
    def _natural_params(self):
        return (self.concentration, )

    def _log_normalizer(self, x):
        return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
Exemplo n.º 9
0
 def check(self, value):
     return ordered_vector.check(value) & independent(positive, 1).check(value)
Exemplo n.º 10
0
class LowRankMultivariateNormal(Distribution):
    r"""
        Creates a multivariate normal distribution with covariance matrix having a low-rank form
        parameterized by :attr:`cov_factor`, :attr:`cov_diag`, and :attr:`cov_factor_inner`::

            covariance_matrix = cov_diag + cov_factor @ cov_factor_inner @ cov_factor.T

    """
    arg_constraints = {
        "loc": constraints.real_vector,
        "cov_factor": constraints.independent(constraints.real, 2),
        "cov_diag": constraints.independent(constraints.positive, 1),
        "cov_factor_inner": constraints.positive_definite
    }
    support = constraints.real_vector
    has_rsample = False

    def __init__(self,
                 loc: Tensor,
                 cov_factor: Tensor,
                 cov_diag: Tensor,
                 cov_factor_inner: Optional[Tensor] = None,
                 validate_args: Optional[bool] = None):

        if loc.dim() < 1:
            raise ValueError("loc must be at least one-dimensional.")
        event_shape = loc.shape[-1:]

        if cov_factor.dim() < 2:
            raise ValueError("cov_factor must be at least two-dimensional, "
                             "with optional leading batch dimensions")
        if cov_factor.shape[-2:-1] != event_shape:
            raise ValueError(
                "cov_factor must be a batch of matrices with shape {} x m".
                format(event_shape[0]))
        if cov_diag.shape[-1:] != event_shape:
            raise ValueError(
                "cov_diag must be a batch of vectors with shape {}".format(
                    event_shape))
        if cov_factor_inner is None:
            raise NotImplementedError("TODO: identity matrix")
        else:
            pass  # TODO: validate

        loc_ = loc.unsqueeze(-1)
        cov_diag_ = cov_diag.unsqueeze(-1)
        try:
            loc_, self.cov_factor, cov_diag_, self.cov_factor_inner = torch.broadcast_tensors(
                loc_, cov_factor, cov_diag_, cov_factor_inner)
        except RuntimeError as e:
            raise ValueError(
                "Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}, cov_factor_inner_ {}"
                .format(loc.shape, cov_factor.shape, cov_diag.shape,
                        self.cov_factor_inner.shape)) from e
        self.loc = loc_[..., 0]
        self.cov_diag = cov_diag_[..., 0]
        batch_shape = self.loc.shape[:-1]

        super().__init__(batch_shape, event_shape, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
        batch_shape = torch.Size(batch_shape)
        loc_shape = batch_shape + self.event_shape
        new.loc = self.loc.expand(loc_shape)
        new.cov_diag = self.cov_diag.expand(loc_shape)
        new.cov_factor = self.cov_factor.expand(loc_shape +
                                                self.cov_factor.shape[-1:])
        new.cov_factor_inner = self.cov_factor_inner.expand(
            loc_shape + self.cov_factor_inner.shape[-1:])
        super(LowRankMultivariateNormal, new).__init__(batch_shape,
                                                       self.event_shape,
                                                       validate_args=False)
        new._validate_args = self._validate_args
        return new

    @property
    def mean(self):
        return self.loc

    @lazy_property
    def variance(self):
        raise NotImplementedError("TODO")

    @lazy_property
    def covariance_matrix(self):
        raise NotImplementedError("TODO")
        # covariance_matrix = (torch.matmul(self._unbroadcasted_cov_factor,
        #                                   self._unbroadcasted_cov_factor.transpose(-1, -2))
        #                      + torch.diag_embed(self._unbroadcasted_cov_diag))
        # return covariance_matrix.expand(self._batch_shape + self._event_shape +
        #                                 self._event_shape)

    @lazy_property
    def precision_matrix(self):
        raise NotImplementedError("TODO")

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        diff = value - self.loc
        raise NotImplementedError("TODO")

    # implemented in original ---
    def rsample(self, sample_shape=torch.Size()):
        raise NotImplementedError

    def entropy(self):
        raise NotImplementedError

    # not implemented in original ---
    def cdf(self, value):
        raise NotImplementedError

    def icdf(self, value):
        raise NotImplementedError

    def enumerate_support(self, expand=True):
        raise NotImplementedError
Exemplo n.º 11
0
 def codomain(self):
     if self.event_dim == 0:
         return constraints.real
     return constraints.independent(constraints.real, self.event_dim)
Exemplo n.º 12
0
 def codomain(self):
     return constraints.independent(constraints.real, len(self.out_shape))
Exemplo n.º 13
0
 def codomain(self):
     return constraints.independent(self.base_transform.codomain,
                                    self.reinterpreted_batch_ndims)
class LowRankMultivariateNormal(Distribution):
    r"""
    Creates a multivariate normal distribution with covariance matrix having a low-rank form
    parameterized by :attr:`cov_factor` and :attr:`cov_diag`::

        covariance_matrix = cov_factor @ cov_factor.T + cov_diag

    Example:

        >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2))
        >>> m.sample()  # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]`
        tensor([-0.2102, -0.5429])

    Args:
        loc (Tensor): mean of the distribution with shape `batch_shape + event_shape`
        cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape
            `batch_shape + event_shape + (rank,)`
        cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape
            `batch_shape + event_shape`

    Note:
        The computation for determinant and inverse of covariance matrix is avoided when
        `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity
        <https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and
        `matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_.
        Thanks to these formulas, we just need to compute the determinant and inverse of
        the small size "capacitance" matrix::

            capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
    """
    arg_constraints = {
        "loc": constraints.real_vector,
        "cov_factor": constraints.independent(constraints.real, 2),
        "cov_diag": constraints.independent(constraints.positive, 1)
    }
    support = constraints.real_vector
    has_rsample = True

    def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
        if loc.dim() < 1:
            raise ValueError("loc must be at least one-dimensional.")
        event_shape = loc.shape[-1:]
        if cov_factor.dim() < 2:
            raise ValueError("cov_factor must be at least two-dimensional, "
                             "with optional leading batch dimensions")
        if cov_factor.shape[-2:-1] != event_shape:
            raise ValueError(
                "cov_factor must be a batch of matrices with shape {} x m".
                format(event_shape[0]))
        if cov_diag.shape[-1:] != event_shape:
            raise ValueError(
                "cov_diag must be a batch of vectors with shape {}".format(
                    event_shape))

        loc_ = loc.unsqueeze(-1)
        cov_diag_ = cov_diag.unsqueeze(-1)
        try:
            loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(
                loc_, cov_factor, cov_diag_)
        except RuntimeError as e:
            raise ValueError(
                "Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}"
                .format(loc.shape, cov_factor.shape, cov_diag.shape)) from e
        self.loc = loc_[..., 0]
        self.cov_diag = cov_diag_[..., 0]
        batch_shape = self.loc.shape[:-1]

        self._unbroadcasted_cov_factor = cov_factor
        self._unbroadcasted_cov_diag = cov_diag
        self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
        super(LowRankMultivariateNormal,
              self).__init__(batch_shape,
                             event_shape,
                             validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
        batch_shape = torch.Size(batch_shape)
        loc_shape = batch_shape + self.event_shape
        new.loc = self.loc.expand(loc_shape)
        new.cov_diag = self.cov_diag.expand(loc_shape)
        new.cov_factor = self.cov_factor.expand(loc_shape +
                                                self.cov_factor.shape[-1:])
        new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
        new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
        new._capacitance_tril = self._capacitance_tril
        super(LowRankMultivariateNormal, new).__init__(batch_shape,
                                                       self.event_shape,
                                                       validate_args=False)
        new._validate_args = self._validate_args
        return new

    @property
    def mean(self):
        return self.loc

    @lazy_property
    def variance(self):
        return (self._unbroadcasted_cov_factor.pow(2).sum(-1) +
                self._unbroadcasted_cov_diag).expand(self._batch_shape +
                                                     self._event_shape)

    @lazy_property
    def scale_tril(self):
        # The following identity is used to increase the numerically computation stability
        # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
        #     W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
        # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
        # hence it is well-conditioned and safe to take Cholesky decomposition.
        n = self._event_shape[0]
        cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt(
        ).unsqueeze(-1)
        Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
        K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.transpose(-1, -2)).contiguous()
        K.view(-1, n * n)[:, ::n + 1] += 1  # add identity matrix to K
        scale_tril = cov_diag_sqrt_unsqueeze * torch.cholesky(K)
        return scale_tril.expand(self._batch_shape + self._event_shape +
                                 self._event_shape)

    @lazy_property
    def covariance_matrix(self):
        covariance_matrix = (
            torch.matmul(self._unbroadcasted_cov_factor,
                         self._unbroadcasted_cov_factor.transpose(-1, -2)) +
            torch.diag_embed(self._unbroadcasted_cov_diag))
        return covariance_matrix.expand(self._batch_shape + self._event_shape +
                                        self._event_shape)

    @lazy_property
    def precision_matrix(self):
        # We use "Woodbury matrix identity" to take advantage of low rank form::
        #     inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
        # where :math:`C` is the capacitance matrix.
        Wt_Dinv = (self._unbroadcasted_cov_factor.transpose(-1, -2) /
                   self._unbroadcasted_cov_diag.unsqueeze(-2))
        A = torch.triangular_solve(Wt_Dinv,
                                   self._capacitance_tril,
                                   upper=False)[0]
        precision_matrix = (
            torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) -
            torch.matmul(A.transpose(-1, -2), A))
        return precision_matrix.expand(self._batch_shape + self._event_shape +
                                       self._event_shape)

    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        W_shape = shape[:-1] + self.cov_factor.shape[-1:]
        eps_W = _standard_normal(W_shape,
                                 dtype=self.loc.dtype,
                                 device=self.loc.device)
        eps_D = _standard_normal(shape,
                                 dtype=self.loc.dtype,
                                 device=self.loc.device)
        return (self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W) +
                self._unbroadcasted_cov_diag.sqrt() * eps_D)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        diff = value - self.loc
        M = _batch_lowrank_mahalanobis(self._unbroadcasted_cov_factor,
                                       self._unbroadcasted_cov_diag, diff,
                                       self._capacitance_tril)
        log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
                                        self._unbroadcasted_cov_diag,
                                        self._capacitance_tril)
        return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det +
                       M)

    def entropy(self):
        log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
                                        self._unbroadcasted_cov_diag,
                                        self._capacitance_tril)
        H = 0.5 * (self._event_shape[0] *
                   (1.0 + math.log(2 * math.pi)) + log_det)
        if len(self._batch_shape) == 0:
            return H
        else:
            return H.expand(self._batch_shape)
Exemplo n.º 15
0
 def support(self):
     return constraints.independent(constraints.real, len(self.event_shape))
Exemplo n.º 16
0
 def support(self):
     return constraints.independent(self.base_dist.support,
                                    self.reinterpreted_batch_ndims)