def testT(self, df, dtype): key = random.PRNGKey(0) rand = lambda key, df: random.t(key, df, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, df) compiled_samples = crand(key, df) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)
def standard_t(df, size=None): return JaxArray(jr.t(DEFAULT.split_key(), df=df, shape=_size2shape(size)))
def sample(self, rng_key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape return random.t(rng_key, self.df, shape)