예제 #1
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   low = None
   high = None
   if is_init != tensor_util.is_ref(self.low):
     low = tf.convert_to_tensor(self.low)
     assertions.append(
         assert_util.assert_finite(low, message='`low` is not finite'))
   if is_init != tensor_util.is_ref(self.high):
     high = tf.convert_to_tensor(self.high)
     assertions.append(
         assert_util.assert_finite(high, message='`high` is not finite'))
   if is_init != tensor_util.is_ref(self.loc):
     assertions.append(
         assert_util.assert_finite(self.loc, message='`loc` is not finite'))
   if is_init != tensor_util.is_ref(self.scale):
     scale = tf.convert_to_tensor(self.scale)
     assertions.extend([
         assert_util.assert_positive(
             scale, message='`scale` must be positive'),
         assert_util.assert_finite(scale, message='`scale` is not finite'),
     ])
   if (is_init != tensor_util.is_ref(self.low) or
       is_init != tensor_util.is_ref(self.high)):
     low = tf.convert_to_tensor(self.low) if low is None else low
     high = tf.convert_to_tensor(self.high) if high is None else high
     assertions.append(
         assert_util.assert_greater(
             high,
             low,
             message='TruncatedCauchy not defined when `low >= high`.'))
   return assertions
예제 #2
0
 def _validate(self):
     vops = [
         assert_util.assert_positive(self._scale),
         assert_util.assert_positive(self._high - self._low),
         assert_util.assert_finite(self._low,
                                   message="Lower bound not finite"),
         assert_util.assert_finite(self._high,
                                   message="Upper bound not finite"),
         assert_util.assert_finite(self._loc, message="Loc not finite"),
         assert_util.assert_finite(self._scale, message="scale not finite"),
     ]
     return tf.group(*vops, name="ValidationOps")
 def mapper(x):
     result = assert_util.assert_finite(constraint_fn(
         tf.convert_to_tensor(value=x)),
                                        message='param non-finite')
     if tf.executing_eagerly():
         return result.numpy()
     return result
 def mapper(x):
     result = assert_util.assert_finite(constraint_fn(
         tf.convert_to_tensor(value=x)),
                                        message='param non-finite')
     if tf.executing_eagerly():
         # TODO(b/128974935): Eager segfault when Tensors retained by hypothesis?
         return result.numpy()
     return result
예제 #5
0
 def mapper(x):
   return assert_util.assert_finite(
       constraint_fn(tf.convert_to_tensor(x)), message='param non-finite')