Ejemplo n.º 1
0
 def __init__(self,
              loc=0.,
              covariance_matrix=None,
              precision_matrix=None,
              scale_tril=None,
              validate_args=None):
     if np.isscalar(loc):
         loc = np.expand_dims(loc, axis=-1)
     # temporary append a new axis to loc
     loc = loc[..., np.newaxis]
     if covariance_matrix is not None:
         loc, self.covariance_matrix = promote_shapes(
             loc, covariance_matrix)
         self.scale_tril = np.linalg.cholesky(self.covariance_matrix)
     elif precision_matrix is not None:
         loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
         self.scale_tril = cholesky_of_inverse(self.precision_matrix)
     elif scale_tril is not None:
         loc, self.scale_tril = promote_shapes(loc, scale_tril)
     else:
         raise ValueError(
             'One of `covariance_matrix`, `precision_matrix`, `scale_tril`'
             ' must be specified.')
     batch_shape = lax.broadcast_shapes(
         np.shape(loc)[:-2],
         np.shape(self.scale_tril)[:-2])
     event_shape = np.shape(self.scale_tril)[-1:]
     self.loc = np.broadcast_to(np.squeeze(loc, axis=-1),
                                batch_shape + event_shape)
     super(MultivariateNormal, self).__init__(batch_shape=batch_shape,
                                              event_shape=event_shape,
                                              validate_args=validate_args)
Ejemplo n.º 2
0
 def __init__(self, base_gamma, high, validate_args=None):
     assert isinstance(base_gamma, Gamma)
     batch_shape = lax.broadcast_shapes(base_gamma.batch_shape,
                                        jnp.shape(high))
     self.base_gamma = tree_map(
         lambda p: promote_shapes(p, shape=batch_shape)[0], base_gamma)
     (self.high, ) = promote_shapes(high, shape=batch_shape)
     self._support = constraints.interval(0.0, high)
     super().__init__(batch_shape, validate_args=validate_args)
Ejemplo n.º 3
0
 def __init__(self, base_gamma, low, validate_args=None):
     assert isinstance(base_gamma, Gamma)
     batch_shape = lax.broadcast_shapes(base_gamma.batch_shape,
                                        jnp.shape(low))
     self.base_gamma = tree_map(
         lambda p: promote_shapes(p, shape=batch_shape)[0], base_gamma)
     (self.low, ) = promote_shapes(low, shape=batch_shape)
     self._support = constraints.greater_than(low)
     super().__init__(batch_shape, validate_args=validate_args)
Ejemplo n.º 4
0
 def __init__(self, logits, total_count=1, validate_args=None):
     if jnp.ndim(logits) < 1:
         raise ValueError("`logits` parameter must be at least one-dimensional.")
     batch_shape, event_shape = self.infer_shapes(jnp.shape(logits), jnp.shape(total_count))
     self.logits = promote_shapes(logits, shape=batch_shape + jnp.shape(logits)[-1:])[0]
     self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
     super(MultinomialLogits, self).__init__(batch_shape=batch_shape,
                                             event_shape=event_shape,
                                             validate_args=validate_args)
Ejemplo n.º 5
0
 def __init__(self, probs, total_count=1, validate_args=None):
     if jnp.ndim(probs) < 1:
         raise ValueError("`probs` parameter must be at least one-dimensional.")
     batch_shape = lax.broadcast_shapes(jnp.shape(probs)[:-1], jnp.shape(total_count))
     self.probs = promote_shapes(probs, shape=batch_shape + jnp.shape(probs)[-1:])[0]
     self.total_count = promote_shapes(total_count, shape=batch_shape)[0]
     super(MultinomialProbs, self).__init__(batch_shape=batch_shape,
                                            event_shape=jnp.shape(self.probs)[-1:],
                                            validate_args=validate_args)
Ejemplo n.º 6
0
 def __init__(self, predictor, cutpoints, validate_args=None):
     if jnp.ndim(predictor) == 0:
         (predictor, ) = promote_shapes(predictor, shape=(1, ))
     else:
         predictor = predictor[..., None]
     predictor, self.cutpoints = promote_shapes(predictor, cutpoints)
     self.predictor = predictor[..., 0]
     probs = transforms.SimplexToOrderedTransform(self.predictor).inv(
         self.cutpoints)
     super(OrderedLogistic, self).__init__(probs,
                                           validate_args=validate_args)
Ejemplo n.º 7
0
    def __init__(self, concentration, total_count=1, validate_args=None):
        if jnp.ndim(concentration) < 1:
            raise ValueError("`concentration` parameter must be at least one-dimensional.")

        batch_shape = lax.broadcast_shapes(jnp.shape(concentration)[:-1], jnp.shape(total_count))
        concentration_shape = batch_shape + jnp.shape(concentration)[-1:]
        self.concentration, = promote_shapes(concentration, shape=concentration_shape)
        self.total_count, = promote_shapes(total_count, shape=batch_shape)
        concentration = jnp.broadcast_to(self.concentration, concentration_shape)
        self._dirichlet = Dirichlet(concentration)
        super().__init__(
            self._dirichlet.batch_shape, self._dirichlet.event_shape, validate_args=validate_args)
Ejemplo n.º 8
0
 def __init__(self, base_dist, low=0.0, validate_args=None):
     assert isinstance(base_dist, self.supported_types)
     assert (
         base_dist.support is constraints.real
     ), "The base distribution should be univariate and have real support."
     batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(low))
     self.base_dist = tree_map(
         lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
     )
     (self.low,) = promote_shapes(low, shape=batch_shape)
     self._support = constraints.greater_than(low)
     super().__init__(batch_shape, validate_args=validate_args)
Ejemplo n.º 9
0
 def __init__(self, predictor, cutpoints, validate_args=None):
     if jnp.ndim(predictor) == 0:
         predictor, = promote_shapes(predictor, shape=(1,))
     else:
         predictor = predictor[..., None]
     predictor, self.cutpoints = promote_shapes(predictor, cutpoints)
     self.predictor = predictor[..., 0]
     cumulative_probs = expit(cutpoints - predictor)
     # add two boundary points 0 and 1
     pad_width = [(0, 0)] * (jnp.ndim(cumulative_probs) - 1) + [(1, 1)]
     cumulative_probs = jnp.pad(cumulative_probs, pad_width, constant_values=(0, 1))
     probs = cumulative_probs[..., 1:] - cumulative_probs[..., :-1]
     super(OrderedLogistic, self).__init__(probs, validate_args=validate_args)
Ejemplo n.º 10
0
    def __init__(self,
                 v=0.,
                 log_density=0.,
                 event_dim=0,
                 validate_args=None,
                 value=None):
        if value is not None:
            v = value
            warnings.warn(
                "`value` argument has been deprecated in favor of `v` argument.",
                FutureWarning)

        if event_dim > jnp.ndim(v):
            raise ValueError(
                'Expected event_dim <= v.dim(), actual {} vs {}'.format(
                    event_dim, jnp.ndim(v)))
        batch_dim = jnp.ndim(v) - event_dim
        batch_shape = jnp.shape(v)[:batch_dim]
        event_shape = jnp.shape(v)[batch_dim:]
        self.v = lax.convert_element_type(v, canonicalize_dtype(jnp.float64))
        # NB: following Pyro implementation, log_density should be broadcasted to batch_shape
        self.log_density = promote_shapes(log_density, shape=batch_shape)[0]
        super(Delta, self).__init__(batch_shape,
                                    event_shape,
                                    validate_args=validate_args)
Ejemplo n.º 11
0
    def __init__(
        self, loc, scale, D_sparse, W_sparse, eigenvals, alpha, tau, validate_args=None
    ):
        """Connects to the next available port.

        Args:
          loc: loc parameter of the normal RV
          scale: scale parameter of the normal RV
          D_sparse: N length vector containing the number of neighbours of each
            node in the adjacency matrix W
          W_sparse: (N x 2) matrix encoding the neighbour relationship as a
            graph edgeset
          eigenvals: Eigenvalues of D^(1/2)WD^(1/2)
          alpha: alpha parameter (0 < alpha < 1)
          tau: tau parameter

        Returns:
          A SparseCAR distribution.
        """
        self.loc, self.scale = promote_shapes(loc, scale)
        self.D_sparse = D_sparse
        self.W_sparse = W_sparse
        self.eigenvals = eigenvals
        self.alpha = alpha
        self.tau = tau
        batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
        super(SparseCAR, self).__init__(
            batch_shape=batch_shape, validate_args=validate_args
        )
Ejemplo n.º 12
0
    def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
        if np.ndim(loc) < 1:
            raise ValueError("`loc` must be at least one-dimensional.")
        event_shape = np.shape(loc)[-1:]
        if np.ndim(cov_factor) < 2:
            raise ValueError("`cov_factor` must be at least two-dimensional, "
                             "with optional leading batch dimensions")
        if np.shape(cov_factor)[-2:-1] != event_shape:
            raise ValueError(
                "`cov_factor` must be a batch of matrices with shape {} x m".
                format(event_shape[0]))
        if np.shape(cov_diag)[-1:] != event_shape:
            raise ValueError(
                "`cov_diag` must be a batch of vectors with shape {}".format(
                    self.event_shape))

        loc, cov_factor, cov_diag = promote_shapes(loc[...,
                                                       np.newaxis], cov_factor,
                                                   cov_diag[..., np.newaxis])
        batch_shape = lax.broadcast_shapes(np.shape(loc), np.shape(cov_factor),
                                           np.shape(cov_diag))[:-2]
        self.loc = np.broadcast_to(loc[..., 0], batch_shape + event_shape)
        self.cov_factor = cov_factor
        cov_diag = cov_diag[..., 0]
        self.cov_diag = cov_diag
        self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
        super(LowRankMultivariateNormal,
              self).__init__(batch_shape=batch_shape,
                             event_shape=event_shape,
                             validate_args=validate_args)
Ejemplo n.º 13
0
 def __init__(self, low=0., loc=0., scale=1., validate_args=None):
     self.low, self.loc, self.scale = promote_shapes(low, loc, scale)
     base_loc = (loc - low) / scale
     base_dist = _BaseTruncatedNormal(base_loc)
     super(TruncatedNormal, self).__init__(base_dist,
                                           AffineTransform(low, scale),
                                           validate_args=validate_args)
Ejemplo n.º 14
0
 def __init__(self, low=0., loc=0., scale=1., validate_args=None):
     self.low, self.loc, self.scale = promote_shapes(low, loc, scale)
     batch_shape = lax.broadcast_shapes(np.shape(low), np.shape(loc),
                                        np.shape(scale))
     self._normal = Normal(self.loc, self.scale)
     super(TruncatedNormal, self).__init__(batch_shape=batch_shape,
                                           validate_args=validate_args)
Ejemplo n.º 15
0
 def tree_flatten(self):
     prepend_ndim = len(self.batch_shape) - len(self.base_dist.batch_shape)
     base_dist = tree_util.tree_map(
         lambda x: promote_shapes(x, shape=(1,) * prepend_ndim + jnp.shape(x))[0],
         self.base_dist)
     base_flatten, base_aux = base_dist.tree_flatten()
     return base_flatten, (type(self.base_dist), base_aux, self.batch_shape)
Ejemplo n.º 16
0
 def __init__(self, low=0., high=1., validate_args=None):
     self.low, self.high = promote_shapes(low, high)
     batch_shape = lax.broadcast_shapes(np.shape(low), np.shape(high))
     base_dist = _BaseUniform(batch_shape)
     super(Uniform, self).__init__(base_dist,
                                   AffineTransform(low, high - low),
                                   validate_args=validate_args)
Ejemplo n.º 17
0
    def __init__(
        self,
        phi_loc,
        psi_loc,
        phi_concentration,
        psi_concentration,
        correlation=None,
        weighted_correlation=None,
        validate_args=None,
    ):
        assert (correlation is None) != (weighted_correlation is None)

        if weighted_correlation is not None:
            correlation = (weighted_correlation *
                           jnp.sqrt(phi_concentration * psi_concentration) +
                           1e-8)

        (
            self.phi_loc,
            self.psi_loc,
            self.phi_concentration,
            self.psi_concentration,
            self.correlation,
        ) = promote_shapes(phi_loc, psi_loc, phi_concentration,
                           psi_concentration, correlation)
        batch_shape = lax.broadcast_shapes(
            jnp.shape(phi_loc),
            jnp.shape(psi_loc),
            jnp.shape(phi_concentration),
            jnp.shape(psi_concentration),
            jnp.shape(correlation),
        )
        super().__init__(batch_shape, (2, ), validate_args)
Ejemplo n.º 18
0
 def __init__(self, df, loc=0., scale=1., validate_args=None):
     batch_shape = lax.broadcast_shapes(np.shape(df), np.shape(loc),
                                        np.shape(scale))
     self.df = np.broadcast_to(df, batch_shape)
     self.loc, self.scale = promote_shapes(loc, scale, shape=batch_shape)
     self._chi2 = Chi2(self.df)
     super(StudentT, self).__init__(batch_shape,
                                    validate_args=validate_args)
Ejemplo n.º 19
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     value = np.expand_dims(value, -1)
     log_pmf = self.logits - logsumexp(self.logits, axis=-1, keepdims=True)
     value, log_pmf = promote_shapes(value, log_pmf)
     value = value[..., :1]
     return np.take_along_axis(log_pmf, value, -1)[..., 0]
Ejemplo n.º 20
0
 def __init__(self, mu, tau, validate_args=None):
     self.mu, self.tau = promote_shapes(mu, tau)
     # converts mean var parametrisation to r and p
     self.r = tau
     self.var = mu + 1 / self.r * mu**2
     self.p = (self.var - mu) / self.var
     self._gamma = dist.Gamma(self.r, (1 - self.p) / self.p)
     super(NegativeBinomial, self).__init__(self._gamma.batch_shape,
                                            validate_args=validate_args)
Ejemplo n.º 21
0
 def __init__(self, concentration1, concentration0, total_count=1, validate_args=None):
     self.concentration1, self.concentration0, self.total_count = promote_shapes(
         concentration1, concentration0, total_count
     )
     batch_shape = lax.broadcast_shapes(jnp.shape(concentration1), jnp.shape(concentration0),
                                        jnp.shape(total_count))
     concentration1 = jnp.broadcast_to(concentration1, batch_shape)
     concentration0 = jnp.broadcast_to(concentration0, batch_shape)
     self._beta = Beta(concentration1, concentration0)
     super(BetaBinomial, self).__init__(batch_shape, validate_args=validate_args)
Ejemplo n.º 22
0
 def __init__(self, v=0., log_density=0., event_dim=0, validate_args=None):
     if event_dim > jnp.ndim(v):
         raise ValueError('Expected event_dim <= v.dim(), actual {} vs {}'
                          .format(event_dim, jnp.ndim(v)))
     batch_dim = jnp.ndim(v) - event_dim
     batch_shape = jnp.shape(v)[:batch_dim]
     event_shape = jnp.shape(v)[batch_dim:]
     self.v = v
     # NB: following Pyro implementation, log_density should be broadcasted to batch_shape
     self.log_density = promote_shapes(log_density, shape=batch_shape)[0]
     super(Delta, self).__init__(batch_shape, event_shape, validate_args=validate_args)
Ejemplo n.º 23
0
 def __init__(self, value=0., log_density=0., event_ndim=0, validate_args=None):
     if event_ndim > np.ndim(value):
         raise ValueError('Expected event_dim <= v.dim(), actual {} vs {}'
                          .format(event_ndim, np.ndim(value)))
     batch_dim = np.ndim(value) - event_ndim
     batch_shape = np.shape(value)[:batch_dim]
     event_shape = np.shape(value)[batch_dim:]
     self.value = lax.convert_element_type(value, xla_bridge.canonicalize_dtype(np.float64))
     # NB: following Pyro implementation, log_density should be broadcasted to batch_shape
     self.log_density = promote_shapes(log_density, shape=batch_shape)[0]
     super(Delta, self).__init__(batch_shape, event_shape, validate_args=validate_args)
Ejemplo n.º 24
0
    def __init__(self, loc, concentration, validate_args=None):
        """  von Mises distribution for sampling directions.

        :param loc: center of distribution
        :param concentration: concentration of distribution
        """
        self.loc, self.concentration = promote_shapes(loc, concentration)

        batch_shape = lax.broadcast_shapes(jnp.shape(concentration), jnp.shape(loc))

        super(VonMises, self).__init__(batch_shape=batch_shape,
                                       validate_args=validate_args)
Ejemplo n.º 25
0
    def _process_parameters(self, n, p):
        p_ = 1. - p[..., :-1].sum(axis=-1)
        p, p_ = promote_shapes(p, p_)
        lax.dynamic_update_slice_in_dim(p, p_, 0, axis=-1)

        # true for bad p
        pcond = np.any(p < 0, axis=-1) | np.any(p > 1, axis=-1)

        # true for bad n
        n = np.array(n, dtype=np.int32)
        ncond = n <= 0

        return n, p, ncond | pcond
Ejemplo n.º 26
0
 def __init__(self, base_dist, gate, *, validate_args=None):
     batch_shape = lax.broadcast_shapes(jnp.shape(gate),
                                        base_dist.batch_shape)
     (self.gate, ) = promote_shapes(gate, shape=batch_shape)
     assert base_dist.support.is_discrete
     if base_dist.event_shape:
         raise ValueError(
             "ZeroInflatedProbs expected empty base_dist.event_shape but got {}"
             .format(base_dist.event_shape))
     # XXX: we might need to promote parameters of base_dist but let's keep
     # this simplified for now
     self.base_dist = base_dist.expand(batch_shape)
     super(ZeroInflatedProbs, self).__init__(batch_shape,
                                             validate_args=validate_args)
Ejemplo n.º 27
0
    def __init__(self, loc, scale, W_sparse, validate_args=None):
        """Connects to the next available port.

        Args:
          loc: loc parameter of the normal RV
          scale: scale parameter of the normal RV
          W_sparse: (N x 2) matrix encoding the neighbour relationship as a
            graph edgeset
        """
        self.loc, self.scale = promote_shapes(loc, scale)
        batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
        self.W_sparse = W_sparse
        super(SparseICAR, self).__init__(
            batch_shape=batch_shape, validate_args=validate_args
        )
Ejemplo n.º 28
0
    def __init__(
        self,
        phi_loc,
        psi_loc,
        phi_concentration,
        psi_concentration,
        correlation=None,
        weighted_correlation=None,
        validate_args=None,
    ):

        assert (correlation is None) != (weighted_correlation is None)

        if weighted_correlation is not None:
            correlation = (weighted_correlation *
                           jnp.sqrt(phi_concentration * psi_concentration) +
                           1e-8)

        (
            self.phi_loc,
            self.psi_loc,
            self.phi_concentration,
            self.psi_concentration,
            self.correlation,
        ) = promote_shapes(phi_loc, psi_loc, phi_concentration,
                           psi_concentration, correlation)
        batch_shape = lax.broadcast_shapes(
            phi_loc.shape,
            psi_loc.shape,
            phi_concentration.shape,
            psi_concentration.shape,
            correlation.shape,
        )
        super().__init__(batch_shape, (2, ), validate_args)

        if self._validate_args and jnp.any(
                phi_concentration * psi_concentration <= correlation**2):
            warnings.warn(
                f"{self.__class__.__name__} bimodal due to concentration-correlation relation, "
                f"sampling will likely fail.",
                UserWarning,
            )
Ejemplo n.º 29
0
 def __init__(self, loc=0., scale=1., validate_args=None):
     self.loc, self.scale = promote_shapes(loc, scale)
     batch_shape = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
     super(Cauchy, self).__init__(batch_shape=batch_shape,
                                  validate_args=validate_args)
Ejemplo n.º 30
0
 def __init__(self, concentration, rate=1., validate_args=None):
     self.concentration, self.rate = promote_shapes(concentration, rate)
     batch_shape = lax.broadcast_shapes(np.shape(concentration),
                                        np.shape(rate))
     super(Gamma, self).__init__(batch_shape=batch_shape,
                                 validate_args=validate_args)