def __init__(self,
               total_count,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Multinomial"):
    """Initialize a batch of Multinomial distributions.

    Args:
      total_count: Non-negative floating point tensor with shape broadcastable
        to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
        `N1 x ... x Nm` different Multinomial distributions. Its components
        should be equal to integer values.
      logits: Floating point tensor representing unnormalized log-probabilities
        of a positive event with shape broadcastable to
        `[N1,..., Nm, K]` `m >= 0`, and the same dtype as `total_count`. Defines
        this as a batch of `N1 x ... x Nm` different `K` class Multinomial
        distributions. Only one of `logits` or `probs` should be passed in.
      probs: Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm, K]` `m >= 0` and same dtype as `total_count`. Defines
        this as a batch of `N1 x ... x Nm` different `K` class Multinomial
        distributions. `probs`'s components in the last portion of its shape
        should sum to `1`. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = dict(locals())
    with tf.name_scope(name, values=[total_count, logits, probs]) as name:
      dtype = dtype_util.common_dtype([total_count, logits, probs], tf.float32)
      self._total_count = tf.convert_to_tensor(
          total_count, name="total_count", dtype=dtype)
      if validate_args:
        self._total_count = (
            distribution_util.embed_check_nonnegative_integer_form(
                self._total_count))
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits=logits,
          probs=probs,
          multidimensional=True,
          validate_args=validate_args,
          name=name,
          dtype=dtype)
      self._mean_val = self._total_count[..., tf.newaxis] * self._probs
    super(Multinomial, self).__init__(
        dtype=dtype,
        reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._total_count, self._logits, self._probs],
        name=name)
Exemple #2
0
 def _log_prob(self, x):
   if self.validate_args:
     x = distribution_util.embed_check_nonnegative_integer_form(x)
   else:
     # For consistency with cdf, we take the floor.
     x = tf.floor(x)
   x *= tf.ones_like(self.probs)
   probs = self.probs * tf.ones_like(x)
   safe_domain = tf.where(tf.equal(x, 0.), tf.zeros_like(probs), probs)
   return x * tf.log1p(-safe_domain) + tf.log(probs)
Exemple #3
0
 def _cdf(self, x):
   if self.validate_args:
     x = distribution_util.embed_check_nonnegative_integer_form(x)
   else:
     # Whether or not x is integer-form, the following is well-defined.
     # However, scipy takes the floor, so we do too.
     x = tf.floor(x)
   x *= tf.ones_like(self.probs)
   return tf.where(x < 0., tf.zeros_like(x), -tf.math.expm1(
       (1. + x) * tf.log1p(-self.probs)))
 def _maybe_assert_valid_sample(self, counts):
   """Check counts for proper shape, values, then return tensor version."""
   if not self.validate_args:
     return counts
   counts = distribution_util.embed_check_nonnegative_integer_form(counts)
   return control_flow_ops.with_dependencies([
       tf.assert_equal(
           self.total_count, tf.reduce_sum(counts, -1),
           message="counts must sum to `self.total_count`"),
   ], counts)
Exemple #5
0
 def _maybe_assert_valid_sample(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args:
         return counts
     counts = distribution_util.embed_check_nonnegative_integer_form(counts)
     return distribution_util.with_dependencies([
         tf.compat.v1.assert_equal(
             self.total_count,
             tf.reduce_sum(input_tensor=counts, axis=-1),
             message="counts must sum to `self.total_count`"),
     ], counts)
Exemple #6
0
 def _maybe_assert_valid_sample(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args:
         return counts
     counts = distribution_util.embed_check_nonnegative_integer_form(counts)
     msg = ('Sampled counts must be itemwise less than '
            'or equal to `total_count` parameter.')
     return distribution_util.with_dependencies([
         assert_util.assert_less_equal(
             counts, self.total_count, message=msg),
     ], counts)
Exemple #7
0
 def _maybe_assert_valid_sample(self, counts):
   """Check counts for proper shape, values, then return tensor version."""
   if not self.validate_args:
     return counts
   counts = distribution_util.embed_check_nonnegative_integer_form(counts)
   return distribution_util.with_dependencies([
       assert_util.assert_less_equal(
           counts,
           self.total_count,
           message="counts are not less than or equal to n."),
   ], counts)
Exemple #8
0
 def _maybe_assert_valid_sample(self, counts):
   """Check counts for proper shape, values, then return tensor version."""
   if not self.validate_args:
     return counts
   counts = distribution_util.embed_check_nonnegative_integer_form(counts)
   return control_flow_ops.with_dependencies([
       tf.assert_less_equal(
           counts,
           self.total_count,
           message="counts are not less than or equal to n."),
   ], counts)
 def _maybe_assert_valid_sample(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args:
         return counts
     counts = distribution_util.embed_check_nonnegative_integer_form(counts)
     return control_flow_ops.with_dependencies([
         tf.assert_equal(
             self.total_count,
             tf.reduce_sum(counts, -1),
             message="counts last-dimension must sum to `self.total_count`"
         ),
     ], counts)
Exemple #10
0
    def __init__(self,
                 total_count,
                 concentration,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="DirichletMultinomial"):
        """Initialize a batch of DirichletMultinomial distributions.

    Args:
      total_count:  Non-negative floating point tensor, whose dtype is the same
        as `concentration`. The shape is broadcastable to `[N1,..., Nm]` with
        `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different
        Dirichlet multinomial distributions. Its components should be equal to
        integer values.
      concentration: Positive floating point tensor, whose dtype is the
        same as `n` with shape broadcastable to `[N1,..., Nm, K]` `m >= 0`.
        Defines this as a batch of `N1 x ... x Nm` different `K` class Dirichlet
        multinomial distributions.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        # Broadcasting works because:
        # * The broadcasting convention is to prepend dimensions of size [1], and
        #   we use the last dimension for the distribution, whereas
        #   the batch dimensions are the leading dimensions, which forces the
        #   distribution dimension to be defined explicitly (i.e. it cannot be
        #   created automatically by prepending). This forces enough explicitness.
        # * All calls involving `counts` eventually require a broadcast between
        #  `counts` and concentration.
        # * We broadcast explicitly to include the effect of `counts` on
        #   `concentration` for calls that do not involve `counts`.
        parameters = dict(locals())
        with tf.compat.v2.name_scope(name) as name:
            dtype = dtype_util.common_dtype([total_count, concentration],
                                            tf.float32)
            self._total_count = tf.convert_to_tensor(value=total_count,
                                                     name="total_count",
                                                     dtype=dtype)
            if validate_args:
                self._total_count = (
                    distribution_util.embed_check_nonnegative_integer_form(
                        self._total_count))
            self._concentration = self._maybe_assert_valid_concentration(
                tf.convert_to_tensor(value=concentration,
                                     name="concentration",
                                     dtype=dtype), validate_args)
            self._total_concentration = tf.reduce_sum(
                input_tensor=self._concentration, axis=-1)
            self._broadcasted_concentration = tf.ones_like(
                self._total_count[..., tf.newaxis]) * self._concentration
        super(DirichletMultinomial, self).__init__(
            dtype=dtype,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
            parameters=parameters,
            graph_parents=[self._total_count, self._concentration],
            name=name)
Exemple #11
0
    def __init__(self,
                 total_count,
                 logits=None,
                 probs=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Multinomial"):
        """Initialize a batch of Multinomial distributions.

    Args:
      total_count: Non-negative floating point tensor with shape broadcastable
        to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
        `N1 x ... x Nm` different Multinomial distributions. Its components
        should be equal to integer values.
      logits: Floating point tensor representing unnormalized log-probabilities
        of a positive event with shape broadcastable to
        `[N1,..., Nm, K]` `m >= 0`, and the same dtype as `total_count`. Defines
        this as a batch of `N1 x ... x Nm` different `K` class Multinomial
        distributions. Only one of `logits` or `probs` should be passed in.
      probs: Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm, K]` `m >= 0` and same dtype as `total_count`. Defines
        this as a batch of `N1 x ... x Nm` different `K` class Multinomial
        distributions. `probs`'s components in the last portion of its shape
        should sum to `1`. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        with tf.compat.v2.name_scope(name) as name:
            dtype = dtype_util.common_dtype([total_count, logits, probs],
                                            tf.float32)
            self._total_count = tf.convert_to_tensor(value=total_count,
                                                     name="total_count",
                                                     dtype=dtype)
            if validate_args:
                self._total_count = (
                    distribution_util.embed_check_nonnegative_integer_form(
                        self._total_count))
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits=logits,
                probs=probs,
                multidimensional=True,
                validate_args=validate_args,
                name=name,
                dtype=dtype)
            self._mean_val = self._total_count[..., tf.newaxis] * self._probs
        super(Multinomial, self).__init__(
            dtype=dtype,
            reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._total_count, self._logits, self._probs],
            name=name)
 def _log_normalization(self, x):
   if self.validate_args:
     x = distribution_util.embed_check_nonnegative_integer_form(x)
   return (-tf.lgamma(self.total_count + x) + tf.lgamma(1. + x) + tf.lgamma(
       self.total_count))
 def _log_unnormalized_prob(self, x):
   if self.validate_args:
     x = distribution_util.embed_check_nonnegative_integer_form(x)
   return (self.total_count * tf.log_sigmoid(-self.logits) +
           x * tf.log_sigmoid(self.logits))
Exemple #14
0
 def _log_unnormalized_prob(self, x):
   if self.validate_args:
     x = distribution_util.embed_check_nonnegative_integer_form(x)
   return (self.total_count * tf.math.log_sigmoid(-self.logits) +
           x * tf.math.log_sigmoid(self.logits))
Exemple #15
0
 def _log_normalization(self, x):
   if self.validate_args:
     x = distribution_util.embed_check_nonnegative_integer_form(x)
   return (-tf.math.lgamma(self.total_count + x) + tf.math.lgamma(1. + x) +
           tf.math.lgamma(self.total_count))
Exemple #16
0
 def _cdf(self, x):
   if self.validate_args:
     x = distribution_util.embed_check_nonnegative_integer_form(x)
   return tf.math.betainc(self.total_count, 1. + x, tf.sigmoid(-self.logits))
Exemple #17
0
 def _cdf(self, x):
   if self.validate_args:
     x = distribution_util.embed_check_nonnegative_integer_form(x)
   raise NotImplementedError()
 def _maybe_assert_valid_sample(self, x):
   """Check counts for proper shape and values, then return tensor version."""
   if not self.validate_args:
     return x
   return distribution_util.embed_check_nonnegative_integer_form(x)
 def _cdf(self, x):
   if self.validate_args:
     x = distribution_util.embed_check_nonnegative_integer_form(x)
   return tf.betainc(self.total_count, 1. + x, tf.sigmoid(-self.logits))
Exemple #20
0
 def _log_normalization(self, x):
   if self.validate_args:
     x = distribution_util.embed_check_nonnegative_integer_form(x)
   return tf.math.lgamma(self.disp) \
     + tf.math.lgamma(x + 1) \
       - tf.math.lgamma(x + self.disp)