Ejemplo n.º 1
0
 def _assert_valid_sample(self, x):
   if not self.validate_args:
     return x
   return distribution_util.with_dependencies([
       assert_util.assert_non_positive(x),
       assert_util.assert_near(
           tf.zeros([], dtype=self.dtype), tf.reduce_logsumexp(x, axis=[-1])),
   ], x)
Ejemplo n.º 2
0
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.append(assert_util.assert_non_positive(
       x,
       message=('Samples must be less than or equal to `0` for '
                '`ExpRelaxedOneHotCategorical` or `1` for '
                '`RelaxedOneHotCategorical`.')))
   assertions.append(assert_util.assert_near(
       tf.zeros([], dtype=self.dtype), tf.reduce_logsumexp(x, axis=[-1]),
       message=('Final dimension of samples must sum to `0` for ''.'
                '`ExpRelaxedOneHotCategorical` or `1` '
                'for `RelaxedOneHotCategorical`.')))
   return assertions