示例#1
0
    def testLaplace(self, dtype):
        key = random.PRNGKey(0)
        rand = lambda key: random.laplace(key, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key)
        compiled_samples = crand(key)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)
示例#2
0
  def testLaplace(self, dtype):
    if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
      raise SkipTest("random.laplace() not supported on TPU for 16-bit types.")
    key = random.PRNGKey(0)
    rand = lambda key: random.laplace(key, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key)
    compiled_samples = crand(key)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf)
示例#3
0
def sample_momentum(rng, state, n_disc):
    state = tree_util.tree_map(np.asarray, state)

    rngn, rngl = random.split(rng)
    s_cont, s_disc = state[:-n_disc], state[-n_disc:]
    rngn = utils.split_rng_as(rngn, s_cont)
    rngl = utils.split_rng_as(rngl, s_disc)

    p_cont = tree_util.tree_multimap(
        lambda s, r: random.normal(r, shape=s.shape), s_cont, rngn)
    p_disc = tree_util.tree_multimap(
        lambda s, r: random.laplace(r, shape=s.shape), s_disc, rngl)

    return p_cont + p_disc
示例#4
0
文件: laplace.py 项目: riversdark/mcx
 def sample(self, rng_key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     return self.loc + random.laplace(rng_key, self.scale, shape)
示例#5
0
    return data, rand_key


# Signal
M = 2  ## Dont change
N = N_train
x_test = np.float32(np.linspace(0, 1., N * M, endpoint=False))
x_train = x_test[::M]

search_vals = 2.**np.linspace(-5., 4., 1 * 8 + 1)

bval_generators = {
    'gaussian': (32, lambda key, sc, N: random.normal(key, [N]) * sc),
    'unif': (64, lambda key, sc, N: random.uniform(key, [N]) * sc),
    'power1': (80, lambda key, sc, N: (sc**random.uniform(key, [N]))),
    'laplace': (20, lambda key, sc, N: random.laplace(key, [N]) * sc),
}

names = list(bval_generators.keys())
train_fn = lambda s, key, ab: train_model_lite(
    key, network_size, learning_rate, sgd_iters, (x_train, s[::2]),
    (x_test[1::2], s[1::2]), optimizers.adam, ab)

best_powers = [.4, .75, 1.5]
outputs_meta = []
dense_meta = []
s_lists = []
# print(tqdm(zip(data_powers, best_powers)))

for p, bp in zip(data_powers, best_powers):
    s_list, rand_key = data_maker(rand_key, N * M, N_test_signals, p)
示例#6
0
 def sample(self, key, sample_shape=()):
     eps = random.laplace(key,
                          shape=sample_shape + self.batch_shape +
                          self.event_shape)
     return self.loc + eps * self.scale
示例#7
0
def laplace(loc=0.0, scale=1.0, size=None):
    assert loc == 0.
    assert scale == 1.
    return JaxArray(jr.laplace(DEFAULT.split_key(), shape=_size2shape(size)))
示例#8
0
def lnmax(rng, votes, per_example_epsilon):
  """LNMax: Discovering frequent patterns in sensitive data, Bhaskar et al."""
  votes = votes + (1 / (per_example_epsilon / 2)) * random.laplace(rng, votes.shape)
  return np.argmax(votes, axis=1)