def _stats(self, b, moments='mv'): mu, mu2, g1, g2 = None, None, None, None if 'm' in moments: mask = b > 1 bt = np.extract(mask, b) mu = np.where(mask, bt / (bt - 1.0), np.inf) if 'v' in moments: mask = b > 2 bt = np.extract(mask, b) mu2 = np.where(mask, bt / (bt - 2.0) / (bt - 1.0) ** 2, np.inf) if 's' in moments: mask = b > 3 bt = np.extract(mask, b) vals = 2 * (bt + 1.0) * np.sqrt(bt - 2.0) / ((bt - 3.0) * np.sqrt(bt)) g1 = np.where(mask, vals, np.nan) if 'k' in moments: mask = b > 4 bt = np.extract(mask, b) vals = (6.0 * np.polyval([1.0, 1.0, -6, -2], bt) / np.polyval([1.0, -7.0, 12.0, 0.0], bt)) g2 = np.where(mask, vals, np.nan) return mu, mu2, g1, g2
def extract(condition, arr): if isinstance(condition, JaxArray): condition = condition.value if isinstance(arr, JaxArray): arr = arr.value return JaxArray(jnp.extract(condition, arr))