def testLaplace(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.laplace(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)
def testLaplace(self, dtype): if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: raise SkipTest("random.laplace() not supported on TPU for 16-bit types.") key = random.PRNGKey(0) rand = lambda key: random.laplace(key, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)
def sample_momentum(rng, state, n_disc): state = tree_util.tree_map(np.asarray, state) rngn, rngl = random.split(rng) s_cont, s_disc = state[:-n_disc], state[-n_disc:] rngn = utils.split_rng_as(rngn, s_cont) rngl = utils.split_rng_as(rngl, s_disc) p_cont = tree_util.tree_multimap( lambda s, r: random.normal(r, shape=s.shape), s_cont, rngn) p_disc = tree_util.tree_multimap( lambda s, r: random.laplace(r, shape=s.shape), s_disc, rngl) return p_cont + p_disc
def sample(self, rng_key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape return self.loc + random.laplace(rng_key, self.scale, shape)
return data, rand_key # Signal M = 2 ## Dont change N = N_train x_test = np.float32(np.linspace(0, 1., N * M, endpoint=False)) x_train = x_test[::M] search_vals = 2.**np.linspace(-5., 4., 1 * 8 + 1) bval_generators = { 'gaussian': (32, lambda key, sc, N: random.normal(key, [N]) * sc), 'unif': (64, lambda key, sc, N: random.uniform(key, [N]) * sc), 'power1': (80, lambda key, sc, N: (sc**random.uniform(key, [N]))), 'laplace': (20, lambda key, sc, N: random.laplace(key, [N]) * sc), } names = list(bval_generators.keys()) train_fn = lambda s, key, ab: train_model_lite( key, network_size, learning_rate, sgd_iters, (x_train, s[::2]), (x_test[1::2], s[1::2]), optimizers.adam, ab) best_powers = [.4, .75, 1.5] outputs_meta = [] dense_meta = [] s_lists = [] # print(tqdm(zip(data_powers, best_powers))) for p, bp in zip(data_powers, best_powers): s_list, rand_key = data_maker(rand_key, N * M, N_test_signals, p)
def sample(self, key, sample_shape=()): eps = random.laplace(key, shape=sample_shape + self.batch_shape + self.event_shape) return self.loc + eps * self.scale
def laplace(loc=0.0, scale=1.0, size=None): assert loc == 0. assert scale == 1. return JaxArray(jr.laplace(DEFAULT.split_key(), shape=_size2shape(size)))
def lnmax(rng, votes, per_example_epsilon): """LNMax: Discovering frequent patterns in sensitive data, Bhaskar et al.""" votes = votes + (1 / (per_example_epsilon / 2)) * random.laplace(rng, votes.shape) return np.argmax(votes, axis=1)