예제 #1
0
    def testT(self, df, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, df: random.t(key, df, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, df)
        compiled_samples = crand(key, df)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf)
예제 #2
0
def standard_t(df, size=None):
    return JaxArray(jr.t(DEFAULT.split_key(), df=df, shape=_size2shape(size)))
예제 #3
0
 def sample(self, rng_key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     return random.t(rng_key, self.df, shape)