def testAssertIntegerForm(self):
    # This should only be detected as an integer.
    x = [1., 5, 10, 15, 20]
    y = [1.1, 5, 10, 15, 20]
    # First component isn't less than float32.eps = 1e-7
    z = [1.0001, 5, 10, 15, 20]
    # This shouldn"t be detected as an integer.
    w = [1e-8, 5, 10, 15, 20]
    with self.test_session():
      with tf.control_dependencies([distribution_util.assert_integer_form(x)]):
        tf.identity(x).eval()

      with self.assertRaisesOpError("x has non-integer components"):
        with tf.control_dependencies([
            distribution_util.assert_integer_form(y)]):
          tf.identity(y).eval()

      with self.assertRaisesOpError("x has non-integer components"):
        with tf.control_dependencies([
            distribution_util.assert_integer_form(z)]):
          tf.identity(z).eval()

      with self.assertRaisesOpError("x has non-integer components"):
        with tf.control_dependencies([
            distribution_util.assert_integer_form(w)]):
          tf.identity(w).eval()
    def testAssertIntegerForm(self):
        # This should only be detected as an integer.
        x = array_ops.placeholder(dtypes.float32)
        y = array_ops.placeholder(dtypes.float32)
        # First component isn't less than float32.eps = 1e-7
        z = array_ops.placeholder(dtypes.float32)
        # This shouldn"t be detected as an integer.
        w = array_ops.placeholder(dtypes.float32)
        feed_dict = {
            x: [1., 5, 10, 15, 20],
            y: [1.1, 5, 10, 15, 20],
            z: [1.0001, 5, 10, 15, 20],
            w: [1e-8, 5, 10, 15, 20]
        }
        with self.test_session():
            with ops.control_dependencies(
                [distribution_util.assert_integer_form(x)]):
                array_ops.identity(x).eval(feed_dict=feed_dict)

            with self.assertRaisesOpError("x has non-integer components"):
                with ops.control_dependencies(
                    [distribution_util.assert_integer_form(y)]):
                    array_ops.identity(y).eval(feed_dict=feed_dict)

            with self.assertRaisesOpError("x has non-integer components"):
                with ops.control_dependencies(
                    [distribution_util.assert_integer_form(z)]):
                    array_ops.identity(z).eval(feed_dict=feed_dict)

            with self.assertRaisesOpError("x has non-integer components"):
                with ops.control_dependencies(
                    [distribution_util.assert_integer_form(w)]):
                    array_ops.identity(w).eval(feed_dict=feed_dict)
    def testAssertIntegerForm(self):
        # This should only be detected as an integer.
        x = [1., 5, 10, 15, 20]
        y = [1.1, 5, 10, 15, 20]
        # First component isn't less than float32.eps = 1e-7
        z = [1.0001, 5, 10, 15, 20]
        # This shouldn"t be detected as an integer.
        w = [1e-8, 5, 10, 15, 20]
        with self.test_session():
            with ops.control_dependencies(
                [distribution_util.assert_integer_form(x)]):
                array_ops.identity(x).eval()

            with self.assertRaisesOpError("x has non-integer components"):
                with ops.control_dependencies(
                    [distribution_util.assert_integer_form(y)]):
                    array_ops.identity(y).eval()

            with self.assertRaisesOpError("x has non-integer components"):
                with ops.control_dependencies(
                    [distribution_util.assert_integer_form(z)]):
                    array_ops.identity(z).eval()

            with self.assertRaisesOpError("x has non-integer components"):
                with ops.control_dependencies(
                    [distribution_util.assert_integer_form(w)]):
                    array_ops.identity(w).eval()
  def testAssertIntegerForm(self):
    # This should only be detected as an integer.
    x = array_ops.placeholder(dtypes.float32)
    y = array_ops.placeholder(dtypes.float32)
    # First component isn't less than float32.eps = 1e-7
    z = array_ops.placeholder(dtypes.float32)
    # This shouldn"t be detected as an integer.
    w = array_ops.placeholder(dtypes.float32)
    feed_dict = {x: [1., 5, 10, 15, 20], y: [1.1, 5, 10, 15, 20],
                 z: [1.0001, 5, 10, 15, 20], w: [1e-8, 5, 10, 15, 20]}
    with self.test_session():
      with ops.control_dependencies([distribution_util.assert_integer_form(x)]):
        array_ops.identity(x).eval(feed_dict=feed_dict)

      with self.assertRaisesOpError("x has non-integer components"):
        with ops.control_dependencies(
            [distribution_util.assert_integer_form(y)]):
          array_ops.identity(y).eval(feed_dict=feed_dict)

      with self.assertRaisesOpError("x has non-integer components"):
        with ops.control_dependencies(
            [distribution_util.assert_integer_form(z)]):
          array_ops.identity(z).eval(feed_dict=feed_dict)

      with self.assertRaisesOpError("x has non-integer components"):
        with ops.control_dependencies(
            [distribution_util.assert_integer_form(w)]):
          array_ops.identity(w).eval(feed_dict=feed_dict)
 def _check_n(self, n):
   n = ops.convert_to_tensor(n, name="n")
   if not self.validate_args:
     return n
   return control_flow_ops.with_dependencies(
       [check_ops.assert_non_negative(n),
        distribution_util.assert_integer_form(n)], n)
 def _assert_valid_n(self, n, validate_args):
   n = ops.convert_to_tensor(n, name="n")
   if not validate_args:
     return n
   return control_flow_ops.with_dependencies(
       [check_ops.assert_non_negative(n),
        distribution_util.assert_integer_form(n)], n)
 def _check_integer(self, value):
     with ops.name_scope("check_integer", values=[value]):
         value = ops.convert_to_tensor(value, name="value")
         if not self.validate_args:
             return value
         dependencies = [distribution_util.assert_integer_form(value, message="value has non-integer components.")]
         return control_flow_ops.with_dependencies(dependencies, value)
Example #8
0
  def __init__(self,
               n,
               logits=None,
               p=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Binomial"):
    """Initialize a batch of Binomial distributions.

    Args:
      n:  Non-negative floating point tensor with shape broadcastable to
        `[N1,..., Nm]` with `m >= 0` and the same dtype as `p` or `logits`.
        Defines this as a batch of `N1 x ... x Nm` different Binomial
        distributions. Its components should be equal to integer values.
      logits: Floating point tensor representing the log-odds of a
        positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and
        the same dtype as `n`. Each entry represents logits for the probability
        of success for independent Binomial distributions.
      p:  Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm]` `m >= 0`, `p in [0, 1]`. Each entry represents the
        probability of success for independent Binomial distributions.
      validate_args: `Boolean`, default `False`.  Whether to assert valid values
        for parameters `n`, `p`, and `x` in `prob` and `log_prob`.
        If `False` and inputs are invalid, correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prefix Ops created by this distribution class.

    Examples:

    ```python
    # Define 1-batch of a binomial distribution.
    dist = Binomial(n=2., p=.9)

    # Define a 2-batch.
    dist = Binomial(n=[4., 5], p=[.1, .3])
    ```

    """
    self._logits, self._p = distribution_util.get_logits_and_prob(
        name=name, logits=logits, p=p, validate_args=validate_args)
    with ops.name_scope(name, values=[n]) as ns:
      with ops.control_dependencies([
          check_ops.assert_non_negative(
              n, message="n has negative components."),
          distribution_util.assert_integer_form(
              n, message="n has non-integer components."),
      ] if validate_args else []):
        self._n = array_ops.identity(n, name="n")
        super(Binomial, self).__init__(
            dtype=self._p.dtype,
            parameters={"n": self._n, "p": self._p, "logits": self._logits},
            is_continuous=False,
            is_reparameterized=False,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name=ns)
Example #9
0
 def _assert_valid_sample(self, x, check_integer=True):
   if not self.validate_args: return x
   with ops.name_scope("check_x", values=[x]):
     dependencies = [check_ops.assert_non_negative(x)]
     if check_integer:
       dependencies += [distribution_util.assert_integer_form(
           x, message="x has non-integer components.")]
     return control_flow_ops.with_dependencies(dependencies, x)
Example #10
0
 def _assert_valid_sample(self, x, check_integer=True):
   if not self.validate_args:
     return x
   dependencies = [check_ops.assert_non_negative(x)]
   if check_integer:
     dependencies += [distribution_util.assert_integer_form(
         x, message="x has non-integer components.")]
   return control_flow_ops.with_dependencies(dependencies, x)
 def _check_integer(self, value):
   with ops.name_scope("check_integer", values=[value]):
     value = ops.convert_to_tensor(value, name="value")
     if not self.validate_args:
       return value
     dependencies = [distribution_util.assert_integer_form(
         value, message="value has non-integer components.")]
     return control_flow_ops.with_dependencies(dependencies, value)
Example #12
0
 def _check_x(self, x, check_integer=True):
   with ops.name_scope('check_x', values=[x]):
     x = ops.convert_to_tensor(x, name="x")
     if not self.validate_args:
       return x
     dependencies = [check_ops.assert_non_negative(x)]
     if check_integer:
       dependencies += [distribution_util.assert_integer_form(
           x, message="x has non-integer components.")]
     return control_flow_ops.with_dependencies(dependencies, x)
 def _maybe_assert_valid_total_count(self, total_count, validate_args):
     if not validate_args:
         return total_count
     return control_flow_ops.with_dependencies([
         check_ops.assert_non_negative(
             total_count, message="total_count must be non-negative."),
         distribution_util.assert_integer_form(
             total_count,
             message="total_count cannot contain fractional values."),
     ], total_count)
Example #14
0
 def _check_counts(self, counts):
   counts = ops.convert_to_tensor(counts, name="counts_before_deps")
   if not self.validate_args:
     return counts
   return control_flow_ops.with_dependencies([
       check_ops.assert_non_negative(
           counts, message="counts has negative components."),
       check_ops.assert_less_equal(
           counts, self._n, message="counts are not less than or equal to n."),
       distribution_util.assert_integer_form(
           counts, message="counts have non-integer components.")], counts)
Example #15
0
 def _check_counts(self, counts):
   counts = ops.convert_to_tensor(counts, name="counts_before_deps")
   if not self.validate_args:
     return counts
   return control_flow_ops.with_dependencies([
       check_ops.assert_non_negative(
           counts, message="counts has negative components."),
       check_ops.assert_less_equal(
           counts, self._n, message="counts are not less than or equal to n."),
       distribution_util.assert_integer_form(
           counts, message="counts have non-integer components.")], counts)
 def _maybe_assert_valid_total_count(self, total_count, validate_args):
   if not validate_args:
     return total_count
   return control_flow_ops.with_dependencies([
       check_ops.assert_non_negative(
           total_count,
           message="total_count must be non-negative."),
       distribution_util.assert_integer_form(
           total_count,
           message="total_count cannot contain fractional values."),
   ], total_count)
Example #17
0
 def _assert_valid_sample(self, counts):
   """Check counts for proper shape, values, then return tensor version."""
   if not self.validate_args: return counts
   return control_flow_ops.with_dependencies([
       check_ops.assert_non_negative(
           counts, message="counts has negative components."),
       check_ops.assert_equal(
           self.n, math_ops.reduce_sum(counts, reduction_indices=[-1]),
           message="counts do not sum to n."),
       distribution_util.assert_integer_form(
           counts, message="counts have non-integer components.")
   ], counts)
Example #18
0
 def _check_x(self, x, check_integer=True):
     with ops.name_scope('check_x', values=[x]):
         x = ops.convert_to_tensor(x, name="x")
         if not self.validate_args:
             return x
         dependencies = [check_ops.assert_non_negative(x)]
         if check_integer:
             dependencies += [
                 distribution_util.assert_integer_form(
                     x, message="x has non-integer components.")
             ]
         return control_flow_ops.with_dependencies(dependencies, x)
 def _assert_valid_counts(self, counts):
   """Check counts for proper shape, values, then return tensor version."""
   counts = ops.convert_to_tensor(counts, name="counts")
   if not self.validate_args:
     return counts
   candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
   return control_flow_ops.with_dependencies([
       check_ops.assert_non_negative(counts),
       check_ops.assert_equal(
           self._n, candidate_n,
           message="counts do not sum to n"),
       distribution_util.assert_integer_form(counts)], counts)
 def _assert_valid_counts(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
     counts = ops.convert_to_tensor(counts, name="counts")
     if not self.validate_args:
         return counts
     candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1])
     return control_flow_ops.with_dependencies([
         check_ops.assert_non_negative(counts),
         check_ops.assert_equal(
             self._n, candidate_n, message="counts do not sum to n"),
         distribution_util.assert_integer_form(counts)
     ], counts)
Example #21
0
 def _maybe_assert_valid_sample(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args:
         return counts
     return control_flow_ops.with_dependencies([
         check_ops.assert_non_negative(
             counts, message="counts must be non-negative."),
         check_ops.assert_equal(
             self.total_count,
             math_ops.reduce_sum(counts, -1),
             message="counts must sum to `self.total_count`"),
         distribution_util.assert_integer_form(
             counts,
             message="counts cannot contain fractional components."),
     ], counts)
 def _maybe_assert_valid_sample(self, counts):
   """Check counts for proper shape, values, then return tensor version."""
   if not self.validate_args:
     return counts
   return control_flow_ops.with_dependencies([
       check_ops.assert_non_negative(
           counts,
           message="counts must be non-negative."),
       check_ops.assert_equal(
           self.total_count, math_ops.reduce_sum(counts, -1),
           message="counts last-dimension must sum to `self.total_count`"),
       distribution_util.assert_integer_form(
           counts,
           message="counts cannot contain fractional components."),
   ], counts)
Example #23
0
  def __init__(self,
               n,
               logits=None,
               p=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Multinomial"):
    """Initialize a batch of Multinomial distributions.

    Args:
      n:  Non-negative floating point tensor with shape broadcastable to
        `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
        `N1 x ... x Nm` different Multinomial distributions.  Its components
        should be equal to integer values.
      logits: Floating point tensor representing the log-odds of a
        positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`,
        and the same dtype as `n`. Defines this as a batch of `N1 x ... x Nm`
        different `k` class Multinomial distributions.
      p:  Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm, k]` `m >= 0` and same dtype as `n`.  Defines this as
        a batch of `N1 x ... x Nm` different `k` class Multinomial
        distributions. `p`'s components in the last portion of its shape should
        sum up to 1.
      validate_args: `Boolean`, default `False`.  Whether to assert valid
        values for parameters `n` and `p`, and `x` in `prob` and `log_prob`.
        If `False`, correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prefix Ops created by this distribution class.

    Examples:

    ```python
    # Define 1-batch of 2-class multinomial distribution,
    # also known as a Binomial distribution.
    dist = Multinomial(n=2., p=[.1, .9])

    # Define a 2-batch of 3-class distributions.
    dist = Multinomial(n=[4., 5], p=[[.1, .3, .6], [.4, .05, .55]])
    ```

    """

    self._logits, self._p = distribution_util.get_logits_and_prob(
        name=name, logits=logits, p=p, validate_args=validate_args,
        multidimensional=True)
    with ops.name_scope(name, values=[n, self._p]) as ns:
      with ops.control_dependencies([
          check_ops.assert_non_negative(
              n, message="n has negative components."),
          distribution_util.assert_integer_form(
              n, message="n has non-integer components.")
      ] if validate_args else []):
        self._n = array_ops.identity(n, name="convert_n")
        self._mean_val = array_ops.expand_dims(n, -1) * self._p
        self._broadcast_shape = math_ops.reduce_sum(
            self._mean_val, reduction_indices=[-1], keep_dims=False)
        super(Multinomial, self).__init__(
            dtype=self._p.dtype,
            parameters={"p": self._p,
                        "n": self._n,
                        "mean": self._mean,
                        "logits": self._logits,
                        "broadcast_shape": self._broadcast_shape},
            is_continuous=False,
            is_reparameterized=False,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name=ns)