コード例 #1
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        if is_init and self.validate_args:
            # assert_categorical_event_shape handles both the static and dynamic case.
            assertions.extend(
                distribution_util.assert_categorical_event_shape(
                    self._concentration))

        if is_init != tensor_util.is_ref(self._total_count):
            if self.validate_args:
                total_count = tf.convert_to_tensor(self._total_count)
                assertions.append(
                    distribution_util.assert_casting_closed(
                        total_count,
                        target_dtype=tf.int32,
                        message=
                        'total_count cannot contain fractional components.'))
                assertions.append(
                    assert_util.assert_non_negative(
                        total_count,
                        message='total_count must be non-negative'))

        if is_init != tensor_util.is_ref(self._concentration):
            if self.validate_args:
                assertions.append(
                    assert_util.assert_positive(
                        self._concentration,
                        message='Concentration parameter must be positive.'))
        return assertions
コード例 #2
0
 def _parameter_control_dependencies(self, is_init):
   assertions = categorical_lib.maybe_assert_categorical_param_correctness(
       is_init, self.validate_args, self._probs, self._logits)
   if not self.validate_args:
     return assertions
   if is_init != tensor_util.is_ref(self.total_count):
     total_count = tf.convert_to_tensor(self.total_count)
     assertions.append(distribution_util.assert_casting_closed(
         total_count, target_dtype=tf.int32))
     assertions.append(assert_util.assert_non_negative(total_count))
   return assertions
コード例 #3
0
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.append(distribution_util.assert_casting_closed(
       x, target_dtype=tf.int32))
   assertions.append(assert_util.assert_non_negative(x))
   assertions.append(
       assert_util.assert_less_equal(
           x, tf.cast(self._num_categories(), x.dtype),
           message=('OrderedLogistic samples must be `>= 0` and `<= K` '
                    'where `K` is the number of cutpoints.')))
   return assertions
コード例 #4
0
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.append(distribution_util.assert_casting_closed(
       x, target_dtype=tf.int32))
   assertions.append(assert_util.assert_non_negative(
       x, message='Categorical samples must be non-negative.'))
   assertions.append(
       assert_util.assert_less_equal(
           x, tf.cast(self._num_categories(), x.dtype),
           message=('Categorical samples must be between `0` and `n-1` '
                    'where `n` is the number of categories.')))
   return assertions
コード例 #5
0
ファイル: binomial.py プロジェクト: mederrata/probability
 def _sample_control_dependencies(self, counts):
   """Check counts for proper values."""
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.append(distribution_util.assert_casting_closed(
       counts, target_dtype=tf.int32,
       message='counts cannot contain fractional components.'))
   assertions.append(assert_util.assert_non_negative(
       counts, message='counts must be non-negative.'))
   assertions.append(
       assert_util.assert_less_equal(
           counts, self.total_count,
           message=('Sampled counts must be itemwise less than '
                    'or equal to `total_count` parameter.')))
   return assertions
コード例 #6
0
ファイル: bates.py プロジェクト: yfe404/probability
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []

        if is_init:
            try:
                self._batch_shape()
            except ValueError:
                raise ValueError(
                    'Arguments `total_count`, `low` and `high` must have compatible '
                    'shapes; total_count.shape={}, low.shape={}, '
                    'high.shape={}.'.format(tf.shape(self.total_count),
                                            tf.shape(self.low),
                                            tf.shape(self.high)))

        assertions = []

        if is_init != tensor_util.is_ref(self.total_count):
            total_count = tf.convert_to_tensor(self.total_count)
            limit = BATES_TOTAL_COUNT_STABILITY_LIMITS[self.dtype]
            msg = '`total_count` must be representable as a 32-bit integer.'
            assertions.extend([
                assert_util.assert_positive(
                    total_count, message='`total_count` must be positive.'),
                distribution_util.assert_casting_closed(total_count,
                                                        target_dtype=tf.int32,
                                                        message=msg),
                assert_util.assert_less_equal(
                    tf.cast(total_count, self.dtype),
                    tf.cast(limit, self.dtype),
                    message='`total_count` > {} is numerically unstable.'.
                    format(limit))
            ])

        if is_init != (tensor_util.is_ref(self.low)
                       or tensor_util.is_ref(self.high)):
            assertions.append(
                assert_util.assert_less(
                    self.low,
                    self.high,
                    message='`low` must be less than `high`.'))

        return assertions