Пример #1
0
 def _cdf(self, positive_counts):
     if self.validate_args:
         positive_counts = math_ops.floor(
             distribution_util.embed_check_nonnegative_discrete(
                 positive_counts, check_integer=False))
     return math_ops.betainc(self.total_count, positive_counts + 1.,
                             math_ops.sigmoid(-self.logits))
Пример #2
0
 def _cdf(self, x):
     if self.validate_args:
         # We set `check_integer=False` since the CDF is defined on whole real
         # line.
         x = distribution_util.embed_check_nonnegative_discrete(
             x, check_integer=False)
     return math_ops.igammac(math_ops.floor(x + 1), self.rate)
Пример #3
0
 def _log_normalization(self, positive_counts):
     if self.validate_args:
         positive_counts = distribution_util.embed_check_nonnegative_discrete(
             positive_counts, check_integer=True)
     return (-math_ops.lgamma(self.total_count + positive_counts) +
             math_ops.lgamma(positive_counts + 1.) +
             math_ops.lgamma(self.total_count))
Пример #4
0
 def _cdf(self, x):
   if self.validate_args:
     # We set `check_integer=False` since the CDF is defined on whole real
     # line.
     x = distribution_util.embed_check_nonnegative_discrete(
         x, check_integer=False)
   return math_ops.igammac(math_ops.floor(x + 1), self.rate)
Пример #5
0
 def _log_normalization(self, positive_counts):
   if self.validate_args:
     positive_counts = distribution_util.embed_check_nonnegative_discrete(
         positive_counts, check_integer=True)
   return (-math_ops.lgamma(self.total_count + positive_counts)
           + math_ops.lgamma(positive_counts + 1.)
           + math_ops.lgamma(self.total_count))
Пример #6
0
 def _cdf(self, positive_counts):
   if self.validate_args:
     positive_counts = math_ops.floor(
         distribution_util.embed_check_nonnegative_discrete(
             positive_counts, check_integer=False))
   return math_ops.betainc(
       self.total_count, positive_counts + 1.,
       math_ops.sigmoid(-self.logits))
Пример #7
0
 def _maybe_assert_valid_sample(self, event, check_integer=True):
   if not self.validate_args:
     return event
   event = distribution_util.embed_check_nonnegative_discrete(
       event, check_integer=check_integer)
   return control_flow_ops.with_dependencies([
       check_ops.assert_less_equal(
           event, array_ops.ones_like(event),
           message="event is not less than or equal to 1."),
   ], event)
Пример #8
0
    def _log_prob(self, counts):
        if self.validate_args:
            counts = distribution_util.embed_check_nonnegative_discrete(
                counts, check_integer=True)
        counts *= array_ops.ones_like(self.probs)
        probs = self.probs * array_ops.ones_like(counts)

        safe_domain = array_ops.where(math_ops.equal(counts, 0.),
                                      array_ops.zeros_like(probs), probs)
        return counts * math_ops.log1p(-safe_domain) + math_ops.log(probs)
Пример #9
0
 def _maybe_assert_valid_sample(self, event, check_integer=True):
     if not self.validate_args:
         return event
     event = distribution_util.embed_check_nonnegative_discrete(
         event, check_integer=check_integer)
     return control_flow_ops.with_dependencies([
         check_ops.assert_less_equal(
             event,
             array_ops.ones_like(event),
             message="event is not less than or equal to 1."),
     ], event)
Пример #10
0
 def _cdf(self, counts):
     if self.validate_args:
         # We set `check_integer=False` since the CDF is defined on whole real
         # line.
         counts = math_ops.floor(
             distribution_util.embed_check_nonnegative_discrete(
                 counts, check_integer=False))
     counts *= array_ops.ones_like(self.probs)
     return array_ops.where(
         counts < 0., array_ops.zeros_like(counts), -math_ops.expm1(
             (counts + 1) * math_ops.log1p(-self.probs)))
Пример #11
0
  def _maybe_assert_valid_sample(self, counts):
    """Check counts for proper shape, values, then return tensor version."""
    if not self.validate_args:
      return counts

    counts = distribution_util.embed_check_nonnegative_discrete(
        counts, check_integer=True)
    return control_flow_ops.with_dependencies([
        check_ops.assert_equal(
            self.total_count, math_ops.reduce_sum(counts, -1),
            message="counts must sum to `self.total_count`"),
    ], counts)
Пример #12
0
  def _maybe_assert_valid_sample(self, counts):
    """Check counts for proper shape, values, then return tensor version."""
    if not self.validate_args:
      return counts

    counts = distribution_util.embed_check_nonnegative_discrete(
        counts, check_integer=True)
    return control_flow_ops.with_dependencies([
        check_ops.assert_equal(
            self.total_count, math_ops.reduce_sum(counts, -1),
            message="counts must sum to `self.total_count`"),
    ], counts)
Пример #13
0
  def _maybe_assert_valid_sample(self, counts, check_integer=True):
    """Check counts for proper shape, values, then return tensor version."""
    if not self.validate_args:
      return counts

    counts = distribution_util.embed_check_nonnegative_discrete(
        counts, check_integer=check_integer)
    return control_flow_ops.with_dependencies([
        check_ops.assert_less_equal(
            counts, self.total_count,
            message="counts are not less than or equal to n."),
    ], counts)
Пример #14
0
  def _log_prob(self, counts):
    if self.validate_args:
      counts = distribution_util.embed_check_nonnegative_discrete(
          counts, check_integer=True)
    counts *= array_ops.ones_like(self.probs)
    probs = self.probs * array_ops.ones_like(counts)

    safe_domain = array_ops.where(
        math_ops.equal(counts, 0.),
        array_ops.zeros_like(probs),
        probs)
    return counts * math_ops.log1p(-safe_domain) + math_ops.log(probs)
Пример #15
0
    def _maybe_assert_valid_sample(self, counts, check_integer=True):
        """Check counts for proper shape, values, then return tensor version."""
        if not self.validate_args:
            return counts

        counts = distribution_util.embed_check_nonnegative_discrete(
            counts, check_integer=check_integer)
        return control_flow_ops.with_dependencies([
            check_ops.assert_less_equal(
                counts,
                self.total_count,
                message="counts are not less than or equal to n."),
        ], counts)
Пример #16
0
 def _cdf(self, counts):
   if self.validate_args:
     # We set `check_integer=False` since the CDF is defined on whole real
     # line.
     counts = math_ops.floor(
         distribution_util.embed_check_nonnegative_discrete(
             counts, check_integer=False))
   counts *= array_ops.ones_like(self.probs)
   return array_ops.where(
       counts < 0.,
       array_ops.zeros_like(counts),
       -math_ops.expm1(
           (counts + 1) * math_ops.log1p(-self.probs)))
Пример #17
0
 def _log_unnormalized_prob(self, positive_counts):
   if self.validate_args:
     positive_counts = distribution_util.embed_check_nonnegative_discrete(
         positive_counts, check_integer=True)
   return self.total_count * math_ops.log1p(
       -self.probs) + positive_counts * math_ops.log(self.probs)
Пример #18
0
 def _log_unnormalized_prob(self, positive_counts):
     if self.validate_args:
         positive_counts = distribution_util.embed_check_nonnegative_discrete(
             positive_counts, check_integer=True)
     return self.total_count * math_ops.log1p(
         -self.probs) + positive_counts * math_ops.log(self.probs)
Пример #19
0
 def _log_unnormalized_prob(self, x):
   if self.validate_args:
     x = distribution_util.embed_check_nonnegative_discrete(
         x, check_integer=True)
   return x * math_ops.log(self.rate) - math_ops.lgamma(x + 1)
Пример #20
0
 def _log_unnormalized_prob(self, x):
     if self.validate_args:
         x = distribution_util.embed_check_nonnegative_discrete(
             x, check_integer=True)
     return x * math_ops.log(self.rate) - math_ops.lgamma(x + 1)
Пример #21
0
 def _log_prob(self, counts):
   if self.validate_args:
     counts = distribution_util.embed_check_nonnegative_discrete(
         counts, check_integer=True)
   return counts * math_ops.log1p(-self.probs) + math_ops.log(self.probs)