Beispiel #1
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     dtype = get_dtypes(self.probs)[0]
     value = lax.convert_element_type(value, dtype)
     total_count = lax.convert_element_type(self.total_count, dtype)
     return gammaln(total_count + 1) + np.sum(
         xlogy(value, self.probs) - gammaln(value + 1), axis=-1)
Beispiel #2
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     dtype = get_dtypes(self.logits)[0]
     value = lax.convert_element_type(value, dtype)
     total_count = lax.convert_element_type(self.total_count, dtype)
     normalize_term = total_count * logsumexp(
         self.logits, axis=-1) - gammaln(total_count + 1)
     return np.sum(value * self.logits - gammaln(value + 1),
                   axis=-1) - normalize_term
Beispiel #3
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     dtype = get_dtypes(self.probs)[0]
     value = lax.convert_element_type(value, dtype)
     total_count = lax.convert_element_type(self.total_count, dtype)
     log_factorial_n = gammaln(total_count + 1)
     log_factorial_k = gammaln(value + 1)
     log_factorial_nmk = gammaln(total_count - value + 1)
     return (log_factorial_n - log_factorial_k - log_factorial_nmk +
             xlogy(value, self.probs) +
             xlog1py(total_count - value, -self.probs))
Beispiel #4
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     dtype = get_dtypes(self.logits)[0]
     value = lax.convert_element_type(value, dtype)
     total_count = lax.convert_element_type(self.total_count, dtype)
     log_factorial_n = gammaln(total_count + 1)
     log_factorial_k = gammaln(value + 1)
     log_factorial_nmk = gammaln(total_count - value + 1)
     normalize_term = (total_count * np.clip(self.logits, 0) +
                       xlog1py(total_count, np.exp(-np.abs(self.logits))) -
                       log_factorial_n)
     return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
Beispiel #5
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     dtype = get_dtypes(self.logits)[0]
     value = lax.convert_element_type(value, dtype)
     return -binary_cross_entropy_with_logits(self.logits, value)
Beispiel #6
0
def _to_logits_multinom(probs):
    minval = np.finfo(get_dtypes(probs)[0]).min
    return np.clip(np.log(probs), a_min=minval)
Beispiel #7
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     value = lax.convert_element_type(value, get_dtypes(self.rate)[0])
     return (np.log(self.rate) * value) - gammaln(value + 1) - self.rate