Beispiel #1
0
 def _stats(self, b, moments='mv'):
     mu, mu2, g1, g2 = None, None, None, None
     if 'm' in moments:
         mask = b > 1
         bt = np.extract(mask, b)
         mu = np.where(mask, bt / (bt - 1.0), np.inf)
     if 'v' in moments:
         mask = b > 2
         bt = np.extract(mask, b)
         mu2 = np.where(mask, bt / (bt - 2.0) / (bt - 1.0) ** 2, np.inf)
     if 's' in moments:
         mask = b > 3
         bt = np.extract(mask, b)
         vals = 2 * (bt + 1.0) * np.sqrt(bt - 2.0) / ((bt - 3.0) * np.sqrt(bt))
         g1 = np.where(mask, vals, np.nan)
     if 'k' in moments:
         mask = b > 4
         bt = np.extract(mask, b)
         vals = (6.0 * np.polyval([1.0, 1.0, -6, -2], bt)
                 / np.polyval([1.0, -7.0, 12.0, 0.0], bt))
         g2 = np.where(mask, vals, np.nan)
     return mu, mu2, g1, g2
Beispiel #2
0
def extract(condition, arr):
  if isinstance(condition, JaxArray): condition = condition.value
  if isinstance(arr, JaxArray): arr = arr.value
  return JaxArray(jnp.extract(condition, arr))