コード例 #1
0
 def _sample_control_dependencies(self, x):
     """Check counts for proper shape and values, then return tensor version."""
     assertions = []
     if not self.validate_args:
         return assertions
     assertions.extend(distribution_util.assert_nonnegative_integer_form(x))
     return assertions
コード例 #2
0
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.extend(distribution_util.assert_nonnegative_integer_form(x))
   assertions.append(assert_util.assert_equal(
       tf.ones([], dtype=x.dtype), tf.reduce_sum(x, axis=[-1]),
       message='Last dimension of sample must sum to 1.'))
   return assertions
コード例 #3
0
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.extend(distribution_util.assert_nonnegative_integer_form(x))
   assertions.append(
       assert_util.assert_less_equal(x, tf.ones([], dtype=x.dtype),
                                     message='Elements cannot exceed 1.'))
   return assertions
コード例 #4
0
ファイル: multinomial.py プロジェクト: xzxzmmnn/probability
 def _parameter_control_dependencies(self, is_init):
   assertions = categorical_lib.maybe_assert_categorical_param_correctness(
       is_init, self.validate_args, self._probs, self._logits)
   if not self.validate_args:
     return assertions
   if is_init != tensor_util.is_ref(self.total_count):
     assertions.extend(distribution_util.assert_nonnegative_integer_form(
         self.total_count))
   return assertions
コード例 #5
0
ファイル: multinomial.py プロジェクト: xzxzmmnn/probability
 def _maybe_assert_valid_sample(self, counts):
   """Check counts for proper shape, values, then return tensor version."""
   if not self.validate_args:
     return []
   assertions = distribution_util.assert_nonnegative_integer_form(counts)
   assertions.append(assert_util.assert_equal(
       self.total_count,
       tf.reduce_sum(counts, axis=-1),
       message='counts must sum to `self.total_count`'))
   return assertions
コード例 #6
0
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.extend(distribution_util.assert_nonnegative_integer_form(x))
   assertions.append(
       assert_util.assert_less_equal(
           x, tf.cast(self._num_categories(), x.dtype),
           message=('Categorical samples must be between `0` and `n-1` '
                    'where `n` is the number of categories.')))
   return assertions
コード例 #7
0
 def _sample_control_dependencies(self, counts):
   """Check counts for proper shape, values, then return tensor version."""
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.extend(distribution_util.assert_nonnegative_integer_form(counts))
   assertions.append(assert_util.assert_equal(
       self.total_count,
       tf.reduce_sum(counts, axis=-1),
       message='counts must sum to `self.total_count`'))
   return assertions
コード例 #8
0
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.extend(distribution_util.assert_nonnegative_integer_form(x))
   assertions.append(
       assert_util.assert_less_equal(
           x, tf.cast(self._num_categories(), x.dtype),
           message=('StoppingRatioLogistic samples must be `>= 0` and `<= K` '
                    'where `K` is the number of cutpoints.')))
   return assertions
コード例 #9
0
 def _sample_control_dependencies(self, x):
   """Checks the validity of a sample."""
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.extend(distribution_util.assert_nonnegative_integer_form(x))
   assertions.append(assert_util.assert_equal(
       self.total_count,
       tf.reduce_sum(x, axis=-1),
       message='counts last-dimension must sum to `self.total_count`'))
   return assertions
コード例 #10
0
 def _sample_control_dependencies(self, counts):
   """Check counts for proper values."""
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.extend(distribution_util.assert_nonnegative_integer_form(counts))
   assertions.append(
       assert_util.assert_less_equal(
           counts, self.total_count,
           message=('Sampled counts must be itemwise less than '
                    'or equal to `total_count` parameter.')))
   return assertions
コード例 #11
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    if is_init and self.validate_args:
      # assert_categorical_event_shape handles both the static and dynamic case.
      assertions.extend(
          distribution_util.assert_categorical_event_shape(self._concentration))

    if is_init != tensor_util.is_ref(self._total_count):
      if self.validate_args:
        assertions.extend(
            distribution_util.assert_nonnegative_integer_form(
                self._total_count))

    if is_init != tensor_util.is_ref(self._concentration):
      if self.validate_args:
        assertions.append(
            assert_util.assert_positive(
                self._concentration,
                message='Concentration parameter must be positive.'))
    return assertions
コード例 #12
0
ファイル: poisson.py プロジェクト: seanmb/probability
 def _sample_control_dependencies(self, x):
   assertions = []
   if not self.validate_args:
     return assertions
   assertions.extend(distribution_util.assert_nonnegative_integer_form(x))
   return assertions
コード例 #13
0
 def _maybe_assert_valid_sample(self, x):
     if not self.validate_args:
         return []
     return distribution_util.assert_nonnegative_integer_form(x)