def testAssertIntegerForm(self): # This should only be detected as an integer. x = array_ops.placeholder(dtypes.float32) y = array_ops.placeholder(dtypes.float32) # First component isn't less than float32.eps = 1e-7 z = array_ops.placeholder(dtypes.float32) # This shouldn"t be detected as an integer. w = array_ops.placeholder(dtypes.float32) feed_dict = { x: [1., 5, 10, 15, 20], y: [1.1, 5, 10, 15, 20], z: [1.0001, 5, 10, 15, 20], w: [1e-8, 5, 10, 15, 20] } with self.test_session(): with ops.control_dependencies( [distribution_util.assert_integer_form(x)]): array_ops.identity(x).eval(feed_dict=feed_dict) with self.assertRaisesOpError("has non-integer components"): with ops.control_dependencies( [distribution_util.assert_integer_form(y)]): array_ops.identity(y).eval(feed_dict=feed_dict) with self.assertRaisesOpError("has non-integer components"): with ops.control_dependencies( [distribution_util.assert_integer_form(z)]): array_ops.identity(z).eval(feed_dict=feed_dict) with self.assertRaisesOpError("has non-integer components"): with ops.control_dependencies( [distribution_util.assert_integer_form(w)]): array_ops.identity(w).eval(feed_dict=feed_dict)
def _check_integer(self, value): with tf.name_scope("check_integer", values=[value]): value = tf.convert_to_tensor(value, name="value") if not self.validate_args: return value dependencies = [distribution_util.assert_integer_form( value, message="value has non-integer components.")] return control_flow_ops.with_dependencies(dependencies, value)
def _check_integer(self, value): with ops.name_scope("check_integer", values=[value]): value = ops.convert_to_tensor(value, name="value") if not self.validate_args: return value dependencies = [distribution_util.assert_integer_form( value, message="value has non-integer components.")] return control_flow_ops.with_dependencies(dependencies, value)
def _maybe_assert_valid_total_count(self, total_count, validate_args): if not validate_args: return total_count return control_flow_ops.with_dependencies([ tf.assert_non_negative( total_count, message="total_count must be non-negative."), distribution_util.assert_integer_form( total_count, message="total_count cannot contain fractional components."), ], total_count)
def _assert_valid_sample(self, x, check_integer=False): if not self.validate_args: return x with ops.name_scope("check_x", values=[x]): dependencies = [check_ops.assert_non_negative(x)] if check_integer: dependencies += [ distribution_util.assert_integer_form( x, message="x has non-integer components.") ] return control_flow_ops.with_dependencies(dependencies, x)
def _assert_valid_sample(self, counts): """Check counts for proper shape, values, then return tensor version.""" if not self.validate_args: return counts return control_flow_ops.with_dependencies([ check_ops.assert_non_negative( counts, message="counts has negative components."), check_ops.assert_equal(self.n, math_ops.reduce_sum(counts, reduction_indices=[-1]), message="counts do not sum to n."), distribution_util.assert_integer_form( counts, message="counts have non-integer components.") ], counts)
def _maybe_assert_valid_sample(self, counts): """Check counts for proper shape, values, then return tensor version.""" if not self.validate_args: return counts return control_flow_ops.with_dependencies([ check_ops.assert_non_negative( counts, message="counts must be non-negative."), check_ops.assert_equal( self.total_count, math_ops.reduce_sum(counts, -1), message="counts last-dimension must sum to `self.total_count`"), distribution_util.assert_integer_form( counts, message="counts cannot contain fractional components."), ], counts)
def _maybe_assert_valid_sample(self, counts): """Check counts for proper shape, values, then return tensor version.""" if not self.validate_args: return counts return control_flow_ops.with_dependencies([ check_ops.assert_non_negative( counts, message="counts must be non-negative."), check_ops.assert_equal( self.total_count, math_ops.reduce_sum(counts, -1), message="counts last-dimension must sum to `self.total_count`" ), distribution_util.assert_integer_form( counts, message="counts cannot contain fractional components."), ], counts)
def __init__(self, n, logits=None, p=None, validate_args=False, allow_nan_stats=True, name="Multinomial"): """Initialize a batch of Multinomial distributions. Args: n: Non-negative floating point tensor with shape broadcastable to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different Multinomial distributions. Its components should be equal to integer values. logits: Floating point tensor representing the log-odds of a positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`, and the same dtype as `n`. Defines this as a batch of `N1 x ... x Nm` different `k` class Multinomial distributions. Only one of `logits` or `p` should be passed in. p: Positive floating point tensor with shape broadcastable to `[N1,..., Nm, k]` `m >= 0` and same dtype as `n`. Defines this as a batch of `N1 x ... x Nm` different `k` class Multinomial distributions. `p`'s components in the last portion of its shape should sum up to 1. Only one of `logits` or `p` should be passed in. validate_args: `Boolean`, default `False`. Whether to assert valid values for parameters `n` and `p`, and `x` in `prob` and `log_prob`. If `False`, correct behavior is not guaranteed. allow_nan_stats: `Boolean`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to prefix Ops created by this distribution class. Examples: ```python # Define 1-batch of 2-class multinomial distribution, # also known as a Binomial distribution. dist = Multinomial(n=2., p=[.1, .9]) # Define a 2-batch of 3-class distributions. dist = Multinomial(n=[4., 5], p=[[.1, .3, .6], [.4, .05, .55]]) ``` """ parameters = locals() parameters.pop("self") with ops.name_scope(name, values=[n, p]) as ns: with ops.control_dependencies([ check_ops.assert_non_negative( n, message="n has negative components."), distribution_util.assert_integer_form( n, message="n has non-integer components.") ] if validate_args else []): self._logits, self._p = distribution_util.get_logits_and_prob( name=name, logits=logits, p=p, validate_args=validate_args, multidimensional=True) self._n = array_ops.identity(n, name="convert_n") self._mean_val = n * self._p self._broadcast_shape = math_ops.reduce_sum( self._mean_val, reduction_indices=[-1], keepdims=False) super(NonPermutedMultinomial, self).__init__(dtype=self._p.dtype, is_continuous=False, is_reparameterized=False, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[ self._p, self._n, self._mean_val, self._logits, self._broadcast_shape ], name=ns)