def _maybe_validate_target_accept_prob(target_accept_prob, validate_args): """Validates that target_accept_prob is in (0, 1).""" if not validate_args: return target_accept_prob with tf.control_dependencies([ tf.assert_greater( target_accept_prob, 0., message='`target_accept_prob` must be > 0.' ), tf.assert_less( target_accept_prob, tf.ones_like(target_accept_prob), message='`target_accept_prob` must be < 1.')]): return tf.identity(target_accept_prob)
def _maybe_validate_target_accept_prob(target_accept_prob, validate_args): """Validates that target_accept_prob is in (0, 1).""" if not validate_args: return target_accept_prob assertions = [ tf.assert_greater(target_accept_prob, tf.zeros([], dtype=target_accept_prob.dtype), message='`target_accept_prob` must be > 0.'), tf.assert_less(target_accept_prob, tf.ones([], dtype=target_accept_prob.dtype), message='`target_accept_prob` must be < 1.') ] with tf.control_dependencies(assertions): return tf.identity(target_accept_prob)