Example #1
0
class Expm1Transform(ExpTransform):

    codomain = constraints.greater_than_eq(-1.0)

    def _call(self, x):
        return super()._call(x) - 1.0

    def _inverse(self, y):
        return super()._inverse(y + 1.0)
Example #2
0
class ZeroInflatedNegativeBinomial(ZeroInflatedDistribution):
    """
    A Zero Inflated Negative Binomial distribution.

    :param total_count: non-negative number of negative Bernoulli trials.
    :type total_count: float or torch.Tensor
    :param torch.Tensor probs: Event probabilities of success in the half open interval [0, 1).
    :param torch.Tensor logits: Event log-odds for probabilities of success.
    :param torch.Tensor gate: probability of extra zeros.
    :param torch.Tensor gate_logits: logits of extra zeros.
    """

    arg_constraints = {
        "total_count": constraints.greater_than_eq(0),
        "probs": constraints.half_open_interval(0.0, 1.0),
        "logits": constraints.real,
        "gate": constraints.unit_interval,
        "gate_logits": constraints.real,
    }
    support = constraints.nonnegative_integer

    def __init__(self,
                 total_count,
                 *,
                 probs=None,
                 logits=None,
                 gate=None,
                 gate_logits=None,
                 validate_args=None):
        base_dist = NegativeBinomial(
            total_count=total_count,
            probs=probs,
            logits=logits,
            validate_args=False,
        )
        base_dist._validate_args = validate_args

        super().__init__(base_dist,
                         gate=gate,
                         gate_logits=gate_logits,
                         validate_args=validate_args)

    @property
    def total_count(self):
        return self.base_dist.total_count

    @property
    def probs(self):
        return self.base_dist.probs

    @property
    def logits(self):
        return self.base_dist.logits
Example #3
0
 def support(self):
     return constraints.greater_than_eq(self.scale)
Example #4
0
class NegativeBinomial(Distribution):
    r"""
    Creates a Negative Binomial distribution, i.e. distribution
    of the number of successful independent and identical Bernoulli trials
    before :attr:`total_count` failures are achieved. The probability
    of success of each Bernoulli trial is :attr:`probs`.

    Args:
        total_count (float or Tensor): non-negative number of negative Bernoulli
            trials to stop, although the distribution is still valid for real
            valued count
        probs (Tensor): Event probabilities of success in the half open interval [0, 1)
        logits (Tensor): Event log-odds for probabilities of success
    """
    arg_constraints = {'total_count': constraints.greater_than_eq(0),
                       'probs': constraints.half_open_interval(0., 1.),
                       'logits': constraints.real}
    support = constraints.nonnegative_integer

    def __init__(self, total_count, probs=None, logits=None, validate_args=None):
        if (probs is None) == (logits is None):
            raise ValueError("Either `probs` or `logits` must be specified, but not both.")
        if probs is not None:
            self.total_count, self.probs, = broadcast_all(total_count, probs)
            self.total_count = self.total_count.type_as(self.probs)
        else:
            self.total_count, self.logits, = broadcast_all(total_count, logits)
            self.total_count = self.total_count.type_as(self.logits)

        self._param = self.probs if probs is not None else self.logits
        batch_shape = self._param.size()
        super(NegativeBinomial, self).__init__(batch_shape, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(NegativeBinomial, _instance)
        batch_shape = torch.Size(batch_shape)
        new.total_count = self.total_count.expand(batch_shape)
        if 'probs' in self.__dict__:
            new.probs = self.probs.expand(batch_shape)
            new._param = new.probs
        if 'logits' in self.__dict__:
            new.logits = self.logits.expand(batch_shape)
            new._param = new.logits
        super(NegativeBinomial, new).__init__(batch_shape, validate_args=False)
        new._validate_args = self._validate_args
        return new

    def _new(self, *args, **kwargs):
        return self._param.new(*args, **kwargs)

    @property
    def mean(self):
        return self.total_count * torch.exp(self.logits)

    @property
    def variance(self):
        return self.mean / torch.sigmoid(-self.logits)

    @lazy_property
    def logits(self):
        return probs_to_logits(self.probs, is_binary=True)

    @lazy_property
    def probs(self):
        return logits_to_probs(self.logits, is_binary=True)

    @property
    def param_shape(self):
        return self._param.size()

    @lazy_property
    def _gamma(self):
        # Note we avoid validating because self.total_count can be zero.
        return torch.distributions.Gamma(concentration=self.total_count,
                                         rate=torch.exp(-self.logits),
                                         validate_args=False)

    def sample(self, sample_shape=torch.Size()):
        with torch.no_grad():
            rate = self._gamma.sample(sample_shape=sample_shape)
            return torch.poisson(rate)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)

        log_unnormalized_prob = (self.total_count * F.logsigmoid(-self.logits) +
                                 value * F.logsigmoid(self.logits))

        log_normalization = (-torch.lgamma(self.total_count + value) + torch.lgamma(1. + value) +
                             torch.lgamma(self.total_count))

        return log_unnormalized_prob - log_normalization
Example #5
0
class NegativeBinomialMixture(Distribution):
    """
    Negative binomial mixture distribution.

    See :class:`~scvi.distributions.NegativeBinomial` for further description
    of parameters.

    Parameters
    ----------
    mu1
        Mean of the component 1 distribution.
    mu2
        Mean of the component 2 distribution.
    theta1
        Inverse dispersion for component 1.
    mixture_logits
        Logits scale probability of belonging to component 1.
    theta2
        Inverse dispersion for component 1. If `None`, assumed to be equal to `theta1`.
    validate_args
        Raise ValueError if arguments do not match constraints
    """

    arg_constraints = {
        "mu1": constraints.greater_than_eq(0),
        "mu2": constraints.greater_than_eq(0),
        "theta1": constraints.greater_than_eq(0),
        "mixture_probs": constraints.half_open_interval(0.0, 1.0),
        "mixture_logits": constraints.real,
    }
    support = constraints.nonnegative_integer

    def __init__(
        self,
        mu1: torch.Tensor,
        mu2: torch.Tensor,
        theta1: torch.Tensor,
        mixture_logits: torch.Tensor,
        theta2: Optional[torch.Tensor] = None,
        validate_args: bool = False,
    ):

        (
            self.mu1,
            self.theta1,
            self.mu2,
            self.mixture_logits,
        ) = broadcast_all(mu1, theta1, mu2, mixture_logits)

        super().__init__(validate_args=validate_args)

        if theta2 is not None:
            self.theta2 = broadcast_all(mu1, theta2)
        else:
            self.theta2 = None

    @property
    def mean(self):
        pi = self.mixture_probs
        return pi * self.mu1 + (1 - pi) * self.mu2

    @lazy_property
    def mixture_probs(self) -> torch.Tensor:
        return logits_to_probs(self.mixture_logits, is_binary=True)

    def sample(
        self, sample_shape: Union[torch.Size, Tuple] = torch.Size()
    ) -> torch.Tensor:
        with torch.no_grad():
            pi = self.mixture_probs
            mixing_sample = torch.distributions.Bernoulli(pi).sample()
            mu = self.mu1 * mixing_sample + self.mu2 * (1 - mixing_sample)
            if self.theta2 is None:
                theta = self.theta1
            else:
                theta = self.theta1 * mixing_sample + self.theta2 * (1 - mixing_sample)
            gamma_d = _gamma(mu, theta)
            p_means = gamma_d.sample(sample_shape)

            # Clamping as distributions objects can have buggy behaviors when
            # their parameters are too high
            l_train = torch.clamp(p_means, max=1e8)
            counts = Poisson(
                l_train
            ).sample()  # Shape : (n_samples, n_cells_batch, n_features)
            return counts

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        try:
            self._validate_sample(value)
        except ValueError:
            warnings.warn(
                "The value argument must be within the support of the distribution",
                UserWarning,
            )
        return log_mixture_nb(
            value,
            self.mu1,
            self.mu2,
            self.theta1,
            self.theta2,
            self.mixture_logits,
            eps=1e-08,
        )
Example #6
0
class ZeroInflatedNegativeBinomial(NegativeBinomial):
    r"""
    Zero-inflated negative binomial distribution.

    One of the following parameterizations must be provided:

    (1), (`total_count`, `probs`) where `total_count` is the number of failures until
    the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`)
    parameterization, which is the one used by scvi-tools. These parameters respectively
    control the mean and inverse dispersion of the distribution.

    In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as follows:

    1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, \underbrace{\theta/\mu}_{\text{rate}})`
    2. :math:`x \sim \textrm{Poisson}(w)`

    Parameters
    ----------
    total_count
        Number of failures until the experiment is stopped.
    probs
        The success probability.
    mu
        Mean of the distribution.
    theta
        Inverse dispersion.
    zi_logits
        Logits scale of zero inflation probability.
    validate_args
        Raise ValueError if arguments do not match constraints
    """

    arg_constraints = {
        "mu": constraints.greater_than_eq(0),
        "theta": constraints.greater_than_eq(0),
        "zi_probs": constraints.half_open_interval(0.0, 1.0),
        "zi_logits": constraints.real,
    }
    support = constraints.nonnegative_integer

    def __init__(
        self,
        total_count: Optional[torch.Tensor] = None,
        probs: Optional[torch.Tensor] = None,
        logits: Optional[torch.Tensor] = None,
        mu: Optional[torch.Tensor] = None,
        theta: Optional[torch.Tensor] = None,
        zi_logits: Optional[torch.Tensor] = None,
        validate_args: bool = False,
    ):

        super().__init__(
            total_count=total_count,
            probs=probs,
            logits=logits,
            mu=mu,
            theta=theta,
            validate_args=validate_args,
        )
        self.zi_logits, self.mu, self.theta = broadcast_all(
            zi_logits, self.mu, self.theta
        )

    @property
    def mean(self):
        pi = self.zi_probs
        return (1 - pi) * self.mu

    @property
    def variance(self):
        raise NotImplementedError

    @lazy_property
    def zi_logits(self) -> torch.Tensor:
        return probs_to_logits(self.zi_probs, is_binary=True)

    @lazy_property
    def zi_probs(self) -> torch.Tensor:
        return logits_to_probs(self.zi_logits, is_binary=True)

    def sample(
        self, sample_shape: Union[torch.Size, Tuple] = torch.Size()
    ) -> torch.Tensor:
        with torch.no_grad():
            samp = super().sample(sample_shape=sample_shape)
            is_zero = torch.rand_like(samp) <= self.zi_probs
            samp[is_zero] = 0.0
            return samp

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        try:
            self._validate_sample(value)
        except ValueError:
            warnings.warn(
                "The value argument must be within the support of the distribution",
                UserWarning,
            )
        return log_zinb_positive(value, self.mu, self.theta, self.zi_logits, eps=1e-08)
Example #7
0
class NegativeBinomial(Distribution):
    r"""
    Negative binomial distribution.

    One of the following parameterizations must be provided:

    (1), (`total_count`, `probs`) where `total_count` is the number of failures until
    the experiment is stopped and `probs` the success probability. (2), (`mu`, `theta`)
    parameterization, which is the one used by scvi-tools. These parameters respectively
    control the mean and inverse dispersion of the distribution.

    In the (`mu`, `theta`) parameterization, samples from the negative binomial are generated as follows:

    1. :math:`w \sim \textrm{Gamma}(\underbrace{\theta}_{\text{shape}}, \underbrace{\theta/\mu}_{\text{rate}})`
    2. :math:`x \sim \textrm{Poisson}(w)`

    Parameters
    ----------
    total_count
        Number of failures until the experiment is stopped.
    probs
        The success probability.
    mu
        Mean of the distribution.
    theta
        Inverse dispersion.
    validate_args
        Raise ValueError if arguments do not match constraints
    """

    arg_constraints = {
        "mu": constraints.greater_than_eq(0),
        "theta": constraints.greater_than_eq(0),
    }
    support = constraints.nonnegative_integer

    def __init__(
        self,
        total_count: Optional[torch.Tensor] = None,
        probs: Optional[torch.Tensor] = None,
        logits: Optional[torch.Tensor] = None,
        mu: Optional[torch.Tensor] = None,
        theta: Optional[torch.Tensor] = None,
        validate_args: bool = False,
    ):
        self._eps = 1e-8
        if (mu is None) == (total_count is None):
            raise ValueError(
                "Please use one of the two possible parameterizations. Refer to the documentation for more information."
            )

        using_param_1 = total_count is not None and (
            logits is not None or probs is not None
        )
        if using_param_1:
            logits = logits if logits is not None else probs_to_logits(probs)
            total_count = total_count.type_as(logits)
            total_count, logits = broadcast_all(total_count, logits)
            mu, theta = _convert_counts_logits_to_mean_disp(total_count, logits)
        else:
            mu, theta = broadcast_all(mu, theta)
        self.mu = mu
        self.theta = theta
        super().__init__(validate_args=validate_args)

    @property
    def mean(self):
        return self.mu

    @property
    def variance(self):
        return self.mean + (self.mean ** 2) / self.theta

    def sample(
        self, sample_shape: Union[torch.Size, Tuple] = torch.Size()
    ) -> torch.Tensor:
        with torch.no_grad():
            gamma_d = self._gamma()
            p_means = gamma_d.sample(sample_shape)

            # Clamping as distributions objects can have buggy behaviors when
            # their parameters are too high
            l_train = torch.clamp(p_means, max=1e8)
            counts = Poisson(
                l_train
            ).sample()  # Shape : (n_samples, n_cells_batch, n_vars)
            return counts

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        if self._validate_args:
            try:
                self._validate_sample(value)
            except ValueError:
                warnings.warn(
                    "The value argument must be within the support of the distribution",
                    UserWarning,
                )
        return log_nb_positive(value, mu=self.mu, theta=self.theta, eps=self._eps)

    def _gamma(self):
        return _gamma(self.theta, self.mu)
Example #8
0
class NegativeBinomial(Distribution):
    r"""Negative Binomial(NB) distribution using two parameterizations:

    - (`total_count`, `probs`) where `total_count` is the number of failures
        until the experiment is stopped
        and `probs` the success probability.
    - The (`mu`, `theta`) parameterization is the one used by scVI. These parameters respectively
    control the mean and overdispersion of the distribution.

    `_convert_mean_disp_to_counts_logits` and `_convert_counts_logits_to_mean_disp` provide ways to convert
    one parameterization to another.

    Parameters
    ----------

    Returns
    -------
    """
    arg_constraints = {
        "mu": constraints.greater_than_eq(0),
        "theta": constraints.greater_than_eq(0),
    }
    support = constraints.nonnegative_integer

    def __init__(
        self,
        total_count: torch.Tensor = None,
        probs: torch.Tensor = None,
        logits: torch.Tensor = None,
        mu: torch.Tensor = None,
        theta: torch.Tensor = None,
        validate_args=True,
    ):
        self._eps = 1e-8
        if (mu is None) == (total_count is None):
            raise ValueError(
                "Please use one of the two possible parameterizations. Refer to the documentation for more information."
            )

        using_param_1 = total_count is not None and (logits is not None
                                                     or probs is not None)
        if using_param_1:
            logits = logits if logits is not None else probs_to_logits(probs)
            total_count = total_count.type_as(logits)
            total_count, logits = broadcast_all(total_count, logits)
            mu, theta = _convert_counts_logits_to_mean_disp(
                total_count, logits)
        else:
            mu, theta = broadcast_all(mu, theta)
        self.mu = mu
        self.theta = theta
        super().__init__(validate_args=validate_args)

    def sample(self, sample_shape=torch.Size()):
        gamma_d = self._gamma()
        p_means = gamma_d.sample(sample_shape)

        # Clamping as distributions objects can have buggy behaviors when
        # their parameters are too high
        l_train = torch.clamp(p_means, max=1e8)
        counts = Poisson(
            l_train).sample()  # Shape : (n_samples, n_cells_batch, n_genes)
        return counts

    def log_prob(self, value):
        if self._validate_args:
            try:
                self._validate_sample(value)
            except ValueError:
                warnings.warn(
                    "The value argument must be within the support of the distribution",
                    UserWarning,
                )
        return log_nb_positive(value,
                               mu=self.mu,
                               theta=self.theta,
                               eps=self._eps)

    def _gamma(self):
        concentration = self.theta
        rate = self.theta / self.mu
        # Important remark: Gamma is parametrized by the rate = 1/scale!
        gamma_d = Gamma(concentration=concentration, rate=rate)
        return gamma_d
Example #9
0
class ZeroInflatedNegativeBinomial(NegativeBinomial):
    r"""Zero Inflated Negative Binomial distribution.

    zi_logits correspond to the zero-inflation logits
        mu + mu ** 2 / theta
    The negative binomial component parameters can follow two two parameterizations:
    - The first one corresponds to the parameterization NB(`total_count`, `probs`)
        where `total_count` is the number of failures until the experiment is stopped
        and `probs` the success probability.
    - The (`mu`, `theta`) parameterization is the one used by scVI. These parameters respectively
    control the mean and overdispersion of the distribution.

    `_convert_mean_disp_to_counts_logits` and `_convert_counts_logits_to_mean_disp`
    provide ways to convert one parameterization to another.

    Parameters
    ----------

    Returns
    -------
    """
    arg_constraints = {
        "mu": constraints.greater_than_eq(0),
        "theta": constraints.greater_than_eq(0),
        "zi_probs": constraints.half_open_interval(0.0, 1.0),
        "zi_logits": constraints.real,
    }
    support = constraints.nonnegative_integer

    def __init__(
        self,
        total_count: torch.Tensor = None,
        probs: torch.Tensor = None,
        logits: torch.Tensor = None,
        mu: torch.Tensor = None,
        theta: torch.Tensor = None,
        zi_logits: torch.Tensor = None,
        validate_args=True,
    ):

        super().__init__(
            total_count=total_count,
            probs=probs,
            logits=logits,
            mu=mu,
            theta=theta,
            validate_args=validate_args,
        )
        self.zi_logits, self.mu, self.theta = broadcast_all(
            zi_logits, self.mu, self.theta)

    @lazy_property
    def zi_logits(self) -> torch.Tensor:
        return probs_to_logits(self.zi_probs, is_binary=True)

    @lazy_property
    def zi_probs(self) -> torch.Tensor:
        return logits_to_probs(self.zi_logits, is_binary=True)

    def sample(
        self, sample_shape: Union[torch.Size, Tuple] = torch.Size()
    ) -> torch.Tensor:
        with torch.no_grad():
            samp = super().sample(sample_shape=sample_shape)
            is_zero = torch.rand_like(samp) <= self.zi_probs
            samp[is_zero] = 0.0
            return samp

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        try:
            self._validate_sample(value)
        except ValueError:
            warnings.warn(
                "The value argument must be within the support of the distribution",
                UserWarning,
            )
        return log_zinb_positive(value,
                                 self.mu,
                                 self.theta,
                                 self.zi_logits,
                                 eps=1e-08)
Example #10
0
    return x


def _constraint_hash(constraint: constraints.Constraint) -> int:
    assert isinstance(constraint, constraints.Constraint)
    out = hash(type(constraint))
    out ^= hash(frozenset(constraint.__dict__.items()))
    return out


# useful if the distribution's init has multiple ways of specifying (e.g. both logits or probs)
_distribution_to_param_names = {NegativeBinomial: ['probs', 'total_count']}

# torch.distributions has a whole 'transforms' module but I don't know if they provide a mapping
_constraint_to_ilink = {
    _constraint_hash(constraints.positive):
    torch.exp,
    _constraint_hash(constraints.greater_than(0)):
    torch.exp,
    _constraint_hash(constraints.unit_interval):
    torch.sigmoid,
    _constraint_hash(constraints.real):
    identity,
    # TODO: is there a way to make these work?
    _constraint_hash(constraints.greater_than_eq(0)):
    torch.exp,
    _constraint_hash(constraints.half_open_interval(0, 1)):
    torch.sigmoid,
    # TODO: constraints.interval
}
Example #11
0
class PiecewiseConstantBirthDeath(Distribution):
    r"""Piecewise constant birth death model

    :param lambda_: birth rates
    :param mu: death rates
    :param psi: sampling rates
    :param rho: sampling effort
    :param origin: time at which the process starts (i.e. t_0)
    :param times: times of rate shift events
    :param relative_times: times are relative to origin
    :param survival: condition on observing at least one sample
    :param validate_args:
    """
    arg_constraints = {
        'lambda_': constraints.greater_than_eq(0.0),
        'mu': constraints.positive,
        'psi': constraints.greater_than_eq(0.0),
        'rho': constraints.unit_interval,
    }

    def __init__(
        self,
        lambda_: Tensor,
        mu: Tensor,
        psi: Tensor,
        rho: Tensor,
        origin: Tensor,
        times: Tensor = None,
        relative_times=False,
        survival: bool = True,
        validate_args=None,
    ):
        self.lambda_ = lambda_
        self.mu = mu
        self.psi = psi
        self.rho = rho
        self.times = times
        self.origin = origin
        self.relative_times = relative_times
        self.survival = survival
        batch_shape, event_shape = self.mu.shape[:-1], self.mu.shape[-1:]
        super().__init__(batch_shape, event_shape, validate_args=validate_args)

    def log_q(self, A, B, t, t_i):
        """Probability density of lineage alive between time t and t_i gives
        rise to observed clade."""
        e = torch.exp(-A * (t - t_i))
        return torch.log(4.0 * e / torch.pow(
            e * (1.0 + B) + (1.0 - B),
            2,
        ))

    def log_p(self, t, t_i):
        """Probability density of lineage alive between time t and t_i has no
        descendant at time t_m."""
        sum_term = self.lambda_ + self.mu + self.psi
        m = self.mu.shape[-1]

        A = torch.sqrt(
            torch.pow(self.lambda_ - self.mu - self.psi, 2.0) +
            4.0 * self.lambda_ * self.psi)
        B = torch.zeros_like(self.mu, dtype=self.mu.dtype)
        p = torch.ones(self.mu.shape[:-1] + (m + 1, ), dtype=self.mu.dtype)
        exp_A_term = torch.exp(A * (t - t_i))
        inv_2lambda = 1.0 / (2.0 * self.lambda_)

        for i in torch.arange(m - 1, -1, step=-1):
            B[..., i] += ((1.0 - 2.0 *
                           (1.0 - self.rho[..., i]) * p[..., i + 1].clone()) *
                          self.lambda_[..., i] + self.mu[..., i] +
                          self.psi[..., i]) / A[..., i]
            term = exp_A_term[..., i] * (1.0 + B[..., i])
            one_minus_Bi = 1.0 - B[..., i]
            p[...,
              i] *= (sum_term[..., i] - A[..., i] * (term - one_minus_Bi) /
                     (term + one_minus_Bi)) * inv_2lambda[..., i]
        return p, A, B

    def log_prob(self, node_heights: torch.Tensor):
        taxa_shape = node_heights.shape[:-1] + (int(
            (node_heights.shape[-1] + 1) / 2), )
        tip_heights = node_heights[..., :taxa_shape[-1]]
        serially_sampled = torch.any(tip_heights > 0.0)

        m = self.mu.shape[-1]

        if self.times is None:
            dtimes = (self.origin / m).expand(self.origin.shape[:-1] + (m, ))
            times = torch.cat(
                (torch.zeros(dtimes.shape[:-1] +
                             (1, ), dtype=dtimes.dtype), dtimes),
                -1).cumsum(-1)
        else:
            times = self.times

        times = torch.broadcast_to(times,
                                   self.mu.shape[:-1] + times.shape[-1:])

        if self.relative_times and self.times is not None:
            times = times * self.origin

        p, A, B = self.log_p(times[..., 1:], times[..., :-1])

        # first term
        e = torch.exp(-A[..., 0] * times[..., 1])
        q0 = 4.0 * e / torch.pow(e * (1.0 - B[..., 0]) + (1.0 + B[..., 0]), 2)

        log_p = torch.log(q0)
        # condition on sampling at least one individual
        if self.survival:
            log_p -= torch.log(1.0 - p[..., 0])

        # calculate l(x) with l(t)=1 iff t_{i-1} <= t < t_i
        x = times[..., -1:] - node_heights[..., taxa_shape[-1]:]
        indices_x = torch.max(times.unsqueeze(-2) >= x.unsqueeze(-1),
                              dim=-1)[1] - 1
        log_p += (torch.log(self.lambda_.gather(-1, indices_x)) + self.log_q(
            A.gather(-1, indices_x),
            B.gather(-1, indices_x),
            x,
            torch.gather(times[..., 1:], -1, indices_x),
        )).sum(-1)

        y = times[..., -1:] - tip_heights
        if serially_sampled:
            indices_y = torch.max(times.unsqueeze(-2) >= y.unsqueeze(-1),
                                  dim=-1)[1] - 1
            log_p += (torch.log(self.psi.gather(-1, indices_y)) - self.log_q(
                A.gather(-1, indices_y),
                B.gather(-1, indices_y),
                y,
                times.gather(-1, indices_y + 1),
            )).sum(-1)

        # last term
        if m > 1:
            # number of degree 2 vertices
            ni = (
                torch.sum(x.unsqueeze(-2) < times[..., 1:].unsqueeze(-1), -1) -
                torch.sum(y.unsqueeze(-2) <= times[..., 1:].unsqueeze(-1),
                          -1))[..., :-1] + 1.0

            # contemporenaous term
            log_p += (ni * (self.log_q(A[..., 1:], B[..., 1:],
                                       times[..., 1:-1], times[..., 2:]) +
                            torch.log(1.0 - self.rho[..., :-1]))).sum(-1)

        N = torch.sum(
            times[..., 1:].unsqueeze(-2) == torch.unsqueeze(
                times[..., -1:] - tip_heights, -1),
            -2,
        )
        mask = (N > 0).logical_and(self.rho > 0.0)
        if torch.any(mask):
            p = torch.masked_select(N, mask) * torch.masked_select(
                self.rho, mask).log()
            log_p += p.squeeze() if log_p.dim() == 0 else p
        return log_p
Example #12
0
class LogNormalNegativeBinomial(TorchDistribution):
    r"""
    A three-parameter generalization of the Negative Binomial distribution [1].
    It can be understood as a continuous mixture of Negative Binomial distributions
    in which we inject Normally-distributed noise into the logits of the Negative
    Binomial distribution:

    .. math::

        \begin{eqnarray}
        &\rm{LNNB}(y | \rm{total\_count}=\nu, \rm{logits}=\ell, \rm{multiplicative\_noise\_scale}=sigma) = \\
        &\int d\epsilon \mathcal{N}(\epsilon | 0, \sigma)
        \rm{NB}(y | \rm{total\_count}=\nu, \rm{logits}=\ell + \epsilon)
        \end{eqnarray}

    where :math:`y \ge 0` is a non-negative integer. Thus while a Negative Binomial distribution
    can be formulated as a Poisson distribution with a Gamma-distributed rate, this distribution
    adds an additional level of variability by also modulating the rate by Log Normally-distributed
    multiplicative noise.

    This distribution has a mean given by

    .. math::
        \mathbb{E}[y] = \nu e^{\ell} = e^{\ell + \log \nu + \tfrac{1}{2}\sigma^2}

    and a variance given by

    .. math::
        \rm{Var}[y] = \mathbb{E}[y] + \left( e^{\sigma^2} (1 + 1/\nu) - 1 \right) \left( \mathbb{E}[y] \right)^2

    Thus while a given mean and variance together uniquely characterize a Negative Binomial distribution, there is a
    one-dimensional family of Log Normal Negative Binomial distributions with a given mean and variance.

    Note that in some applications it may be useful to parameterize the logits as

    .. math::
        \ell = \ell^\prime - \log \nu - \tfrac{1}{2}\sigma^2

    so that the mean is given by :math:`\mathbb{E}[y] = e^{\ell^\prime}` and does not depend on :math:`\nu`
    and :math:`\sigma`, which serve to determine the higher moments.

    References:

    [1] "Lognormal and Gamma Mixed Negative Binomial Regression,"
    Mingyuan Zhou, Lingbo Li, David Dunson, and Lawrence Carin.

    :param total_count: non-negative number of negative Bernoulli trials. The variance decreases
        as `total_count` increases.
    :type total_count: float or torch.Tensor
    :param torch.Tensor logits: Event log-odds for probabilities of success for underlying
        Negative Binomial distribution.
    :param torch.Tensor multiplicative_noise_scale: Controls the level of the injected Normal logit noise.
    :param int num_quad_points: Number of quadrature points used to compute the (approximate) `log_prob`.
        Defaults to 8.
    """
    arg_constraints = {
        "total_count": constraints.greater_than_eq(0),
        "logits": constraints.real,
        "multiplicative_noise_scale": constraints.positive,
    }
    support = constraints.nonnegative_integer

    def __init__(
        self,
        total_count,
        logits,
        multiplicative_noise_scale,
        *,
        num_quad_points=8,
        validate_args=None,
    ):
        if num_quad_points < 1:
            raise ValueError("num_quad_points must be positive.")

        total_count, logits, multiplicative_noise_scale = broadcast_all(
            total_count, logits, multiplicative_noise_scale)

        self.quad_points, self.log_weights = get_quad_rule(
            num_quad_points, logits)
        quad_logits = (
            logits.unsqueeze(-1) +
            multiplicative_noise_scale.unsqueeze(-1) * self.quad_points)
        self.nb_dist = NegativeBinomial(total_count=total_count.unsqueeze(-1),
                                        logits=quad_logits)

        self.multiplicative_noise_scale = multiplicative_noise_scale
        self.total_count = total_count
        self.logits = logits
        self.num_quad_points = num_quad_points

        batch_shape = broadcast_shape(multiplicative_noise_scale.shape,
                                      self.nb_dist.batch_shape[:-1])
        event_shape = torch.Size()

        super().__init__(batch_shape, event_shape, validate_args)

    def log_prob(self, value):
        nb_log_prob = self.nb_dist.log_prob(value.unsqueeze(-1))
        return torch.logsumexp(self.log_weights + nb_log_prob, axis=-1)

    def sample(self, sample_shape=torch.Size()):
        raise NotImplementedError

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(type(self), _instance)
        batch_shape = torch.Size(batch_shape)
        total_count = self.total_count.expand(batch_shape)
        logits = self.logits.expand(batch_shape)
        multiplicative_noise_scale = self.multiplicative_noise_scale.expand(
            batch_shape)
        LogNormalNegativeBinomial.__init__(
            new,
            total_count,
            logits,
            multiplicative_noise_scale,
            num_quad_points=self.num_quad_points,
            validate_args=False,
        )
        new._validate_args = self._validate_args
        return new

    @lazy_property
    def mean(self):
        return torch.exp(self.logits + self.total_count.log() +
                         0.5 * self.multiplicative_noise_scale.pow(2.0))

    @lazy_property
    def variance(self):
        kappa = (torch.exp(self.multiplicative_noise_scale.pow(2.0)) *
                 (1 + 1 / self.total_count) - 1)
        return self.mean + kappa * self.mean.pow(2.0)