예제 #1
0
    def testAssertIntegerForm(self):
        # This should only be detected as an integer.
        x = array_ops.placeholder(dtypes.float32)
        y = array_ops.placeholder(dtypes.float32)
        # First component isn't less than float32.eps = 1e-7
        z = array_ops.placeholder(dtypes.float32)
        # This shouldn"t be detected as an integer.
        w = array_ops.placeholder(dtypes.float32)
        feed_dict = {
            x: [1., 5, 10, 15, 20],
            y: [1.1, 5, 10, 15, 20],
            z: [1.0001, 5, 10, 15, 20],
            w: [1e-8, 5, 10, 15, 20]
        }
        with self.test_session():
            with ops.control_dependencies(
                [distribution_util.assert_integer_form(x)]):
                array_ops.identity(x).eval(feed_dict=feed_dict)

            with self.assertRaisesOpError("has non-integer components"):
                with ops.control_dependencies(
                    [distribution_util.assert_integer_form(y)]):
                    array_ops.identity(y).eval(feed_dict=feed_dict)

            with self.assertRaisesOpError("has non-integer components"):
                with ops.control_dependencies(
                    [distribution_util.assert_integer_form(z)]):
                    array_ops.identity(z).eval(feed_dict=feed_dict)

            with self.assertRaisesOpError("has non-integer components"):
                with ops.control_dependencies(
                    [distribution_util.assert_integer_form(w)]):
                    array_ops.identity(w).eval(feed_dict=feed_dict)
 def _check_integer(self, value):
   with tf.name_scope("check_integer", values=[value]):
     value = tf.convert_to_tensor(value, name="value")
     if not self.validate_args:
       return value
     dependencies = [distribution_util.assert_integer_form(
         value, message="value has non-integer components.")]
     return control_flow_ops.with_dependencies(dependencies, value)
 def _check_integer(self, value):
   with ops.name_scope("check_integer", values=[value]):
     value = ops.convert_to_tensor(value, name="value")
     if not self.validate_args:
       return value
     dependencies = [distribution_util.assert_integer_form(
         value, message="value has non-integer components.")]
     return control_flow_ops.with_dependencies(dependencies, value)
예제 #4
0
 def _maybe_assert_valid_total_count(self, total_count, validate_args):
   if not validate_args:
     return total_count
   return control_flow_ops.with_dependencies([
       tf.assert_non_negative(
           total_count, message="total_count must be non-negative."),
       distribution_util.assert_integer_form(
           total_count,
           message="total_count cannot contain fractional components."),
   ], total_count)
예제 #5
0
 def _assert_valid_sample(self, x, check_integer=False):
     if not self.validate_args: return x
     with ops.name_scope("check_x", values=[x]):
         dependencies = [check_ops.assert_non_negative(x)]
         if check_integer:
             dependencies += [
                 distribution_util.assert_integer_form(
                     x, message="x has non-integer components.")
             ]
         return control_flow_ops.with_dependencies(dependencies, x)
예제 #6
0
 def _maybe_assert_valid_total_count(self, total_count, validate_args):
   if not validate_args:
     return total_count
   return control_flow_ops.with_dependencies([
       tf.assert_non_negative(
           total_count, message="total_count must be non-negative."),
       distribution_util.assert_integer_form(
           total_count,
           message="total_count cannot contain fractional components."),
   ], total_count)
예제 #7
0
 def _assert_valid_sample(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args: return counts
     return control_flow_ops.with_dependencies([
         check_ops.assert_non_negative(
             counts, message="counts has negative components."),
         check_ops.assert_equal(self.n,
                                math_ops.reduce_sum(counts,
                                                    reduction_indices=[-1]),
                                message="counts do not sum to n."),
         distribution_util.assert_integer_form(
             counts, message="counts have non-integer components.")
     ], counts)
 def _maybe_assert_valid_sample(self, counts):
   """Check counts for proper shape, values, then return tensor version."""
   if not self.validate_args:
     return counts
   return control_flow_ops.with_dependencies([
       check_ops.assert_non_negative(
           counts,
           message="counts must be non-negative."),
       check_ops.assert_equal(
           self.total_count, math_ops.reduce_sum(counts, -1),
           message="counts last-dimension must sum to `self.total_count`"),
       distribution_util.assert_integer_form(
           counts,
           message="counts cannot contain fractional components."),
   ], counts)
예제 #9
0
 def _maybe_assert_valid_sample(self, counts):
     """Check counts for proper shape, values, then return tensor version."""
     if not self.validate_args:
         return counts
     return control_flow_ops.with_dependencies([
         check_ops.assert_non_negative(
             counts, message="counts must be non-negative."),
         check_ops.assert_equal(
             self.total_count,
             math_ops.reduce_sum(counts, -1),
             message="counts last-dimension must sum to `self.total_count`"
         ),
         distribution_util.assert_integer_form(
             counts,
             message="counts cannot contain fractional components."),
     ], counts)
예제 #10
0
    def __init__(self,
                 n,
                 logits=None,
                 p=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Multinomial"):
        """Initialize a batch of Multinomial distributions.

    Args:
      n:  Non-negative floating point tensor with shape broadcastable to
        `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
        `N1 x ... x Nm` different Multinomial distributions.  Its components
        should be equal to integer values.
      logits: Floating point tensor representing the log-odds of a
        positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`,
        and the same dtype as `n`. Defines this as a batch of `N1 x ... x Nm`
        different `k` class Multinomial distributions. Only one of `logits` or
        `p` should be passed in.
      p:  Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm, k]` `m >= 0` and same dtype as `n`.  Defines this as
        a batch of `N1 x ... x Nm` different `k` class Multinomial
        distributions. `p`'s components in the last portion of its shape should
        sum up to 1. Only one of `logits` or `p` should be passed in.
      validate_args: `Boolean`, default `False`.  Whether to assert valid
        values for parameters `n` and `p`, and `x` in `prob` and `log_prob`.
        If `False`, correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prefix Ops created by this distribution class.

    Examples:

    ```python
    # Define 1-batch of 2-class multinomial distribution,
    # also known as a Binomial distribution.
    dist = Multinomial(n=2., p=[.1, .9])

    # Define a 2-batch of 3-class distributions.
    dist = Multinomial(n=[4., 5], p=[[.1, .3, .6], [.4, .05, .55]])
    ```

    """
        parameters = locals()
        parameters.pop("self")
        with ops.name_scope(name, values=[n, p]) as ns:
            with ops.control_dependencies([
                    check_ops.assert_non_negative(
                        n, message="n has negative components."),
                    distribution_util.assert_integer_form(
                        n, message="n has non-integer components.")
            ] if validate_args else []):
                self._logits, self._p = distribution_util.get_logits_and_prob(
                    name=name,
                    logits=logits,
                    p=p,
                    validate_args=validate_args,
                    multidimensional=True)
                self._n = array_ops.identity(n, name="convert_n")
                self._mean_val = n * self._p
                self._broadcast_shape = math_ops.reduce_sum(
                    self._mean_val, reduction_indices=[-1], keepdims=False)
        super(NonPermutedMultinomial,
              self).__init__(dtype=self._p.dtype,
                             is_continuous=False,
                             is_reparameterized=False,
                             validate_args=validate_args,
                             allow_nan_stats=allow_nan_stats,
                             parameters=parameters,
                             graph_parents=[
                                 self._p, self._n, self._mean_val,
                                 self._logits, self._broadcast_shape
                             ],
                             name=ns)