Ejemplo n.º 1
0
 def sample(self, key, sample_shape=()):
     eps = random.normal(key, shape=sample_shape + self.batch_shape)
     return self.loc + eps * self.scale
Ejemplo n.º 2
0
 def __init__(self, dims, scope_var: OrderedDict):
     key = random.PRNGKey(py_random.randrange(9999))
     scope_var["log_signal_variance"] = random.normal(key, shape=[dims])
Ejemplo n.º 3
0
 def testDtypeErrorMessage(self):
   with self.assertRaisesRegex(ValueError, r"dtype argument to.*"):
     random.normal(random.PRNGKey(0), (), dtype=jnp.int32)
Ejemplo n.º 4
0
 def init_params(self):
     params = [
         random.normal(self.rnd_key, (self.vocab_size(), )),
         np.zeros((1, ))
     ]
     return [self.initialization_scale * p for p in params]
Ejemplo n.º 5
0
 def sample(self, rng_key, sample_shape):
     shape = sample_shape + self.batch_shape + self.event_shape
     return np.exp(self.sigma * random.normal(rng_key, shape) + self.mu)
Ejemplo n.º 6
0
Archivo: advi.py Proyecto: zoemcc/jax
def diag_gaussian_sample(rng, mean, log_std):
    # Take a single sample from a diagonal multivariate Gaussian.
    return mean + np.exp(log_std) * random.normal(rng, mean.shape)
Ejemplo n.º 7
0
from modax.training.losses.SBL import loss_fn_SBL
from modax.training import train_max_iter
from sklearn.linear_model import ARDRegression
from modax.linear_model.SBL import SBL

# %% Making data
key = random.PRNGKey(42)

x = jnp.linspace(-3, 4, 50)
t = jnp.linspace(0.5, 5.0, 20)
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X = jnp.concatenate([t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y = u.reshape(-1, 1)
y += 0.10 * jnp.std(y) * random.normal(key, y.shape)

# %% Building model and params
model = Deepmod([30, 30, 30, 1])
variables = model.init(key, X)

optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
state, params = variables.pop("params")
optimizer = optimizer.create(params)

state = (state, {"prior_init": None})  # adding prior to state
update_fn = create_update(loss_fn_SBL, (model, X, y, False))

# optimizer, state = train_max_iter(update_fn, optimizer, state, 10000)

grad_fn = jax.value_and_grad(loss_fn_SBL, has_aux=True)
Ejemplo n.º 8
0
 def transition_sample(self, x_previous: jnp.ndarray, t_previous: float,
                       t_new: float,
                       random_key: jnp.ndarray) -> jnp.ndarray:
     return self.transition_function(x_previous, t_previous, t_new) \
            + random.normal(random_key, shape=x_previous.shape) @ self.transition_covariance_sqrt
Ejemplo n.º 9
0
 def initial_sample(self, t: float,
                    random_key: jnp.ndarray) -> Union[float, jnp.ndarray]:
     return random.normal(random_key, shape=(self.dim,)) @ self.initial_covariance_sqrt.T \
            + self.initial_mean
Ejemplo n.º 10
0
 def init(key, shape, dtype=dtype):
     dtype = dtypes.canonicalize_dtype(dtype)
     return random.normal(key, shape, dtype) * stddev
Ejemplo n.º 11
0
 def initializer(key, shape, dtype=jnp.float32):
     x = random.normal(key, shape, dtype) * (-random_sign_init) + 1.0
     return x.astype(dtype)
Ejemplo n.º 12
0
def random_orthonormal(rng, n):
    u, _, vh = jnp.linalg.svd(random.normal(rng, shape=(n, n)))
    return u @ vh
Ejemplo n.º 13
0
def random_psd(rng, n):
    x = random.normal(rng, shape=(n, n))
    return x.T @ x
Ejemplo n.º 14
0
def reparameterize(rng, mean, logvar):
    std = jnp.exp(0.5 * logvar)
    eps = random.normal(rng, logvar.shape)
    return mean + eps * std
Ejemplo n.º 15
0
    def testGpInference(self):
        reg = 1e-5
        key = random.PRNGKey(1)
        x_train = random.normal(key, (4, 2))
        init_fn, apply_fn, kernel_fn_analytic = stax.serial(
            stax.Dense(32, 2., 0.5), stax.Relu(), stax.Dense(10, 2., 0.5))
        y_train = random.normal(key, (4, 10))
        for kernel_fn_is_analytic in [True, False]:
            if kernel_fn_is_analytic:
                kernel_fn = kernel_fn_analytic
            else:
                _, params = init_fn(key, x_train.shape)
                kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn)

                def kernel_fn(x1, x2, get):
                    return kernel_fn_empirical(x1, x2, get, params)

            for get in [
                    None, 'nngp', 'ntk', ('nngp', ), ('ntk', ),
                ('nngp', 'ntk'), ('ntk', 'nngp')
            ]:
                k_dd = kernel_fn(x_train, None, get)

                gp_inference = predict.gp_inference(k_dd,
                                                    y_train,
                                                    diag_reg=reg)
                gd_ensemble = predict.gradient_descent_mse_ensemble(
                    kernel_fn, x_train, y_train, diag_reg=reg)
                for x_test in [None, 'x_test']:
                    x_test = None if x_test is None else random.normal(
                        key, (8, 2))
                    k_td = None if x_test is None else kernel_fn(
                        x_test, x_train, get)

                    for compute_cov in [True, False]:
                        with self.subTest(
                                kernel_fn_is_analytic=kernel_fn_is_analytic,
                                get=get,
                                x_test=x_test if x_test is None else 'x_test',
                                compute_cov=compute_cov):
                            if compute_cov:
                                nngp_tt = (True if x_test is None else
                                           kernel_fn(x_test, None, 'nngp'))
                            else:
                                nngp_tt = None

                            out_ens = gd_ensemble(None, x_test, get,
                                                  compute_cov)
                            out_ens_inf = gd_ensemble(np.inf, x_test, get,
                                                      compute_cov)
                            self._assertAllClose(out_ens_inf, out_ens, 0.08)

                            if (get is not None and 'nngp' not in get
                                    and compute_cov and k_td is not None):
                                with self.assertRaises(ValueError):
                                    out_gp_inf = gp_inference(
                                        get=get,
                                        k_test_train=k_td,
                                        nngp_test_test=nngp_tt)
                            else:
                                out_gp_inf = gp_inference(
                                    get=get,
                                    k_test_train=k_td,
                                    nngp_test_test=nngp_tt)
                                self.assertAllClose(out_ens, out_gp_inf)
Ejemplo n.º 16
0
  def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in):
    if axis in (None, 0) and branch_in == 'dense_after_branch_in':
      raise jtu.SkipTest('`FanInSum` and `FanInConcat(0)` '
                         'require `is_gaussian`.')

    if axis == 1 and branch_in == 'dense_before_branch_in':
      raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer'
                         'after concatenation.')

    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (10, 20))
    X0_2 = None if same_inputs else random.normal(key, (8, 20))

    if xla_bridge.get_backend().platform == 'tpu':
      width = 2048
      n_samples = 1024
      tol = 0.02
    else:
      width = 1024
      n_samples = 256
      tol = 0.01

    dense = stax.Dense(width, 1.25, 0.1)
    input_layers = [dense,
                    stax.FanOut(n_branches)]

    branches = []
    for b in range(n_branches):
      branch_layers = [FanInTest._get_phi(b)]
      for i in range(b):
        branch_layers += [
            stax.Dense(width, 1. + 2 * i, 0.5 + i),
            FanInTest._get_phi(i)]

      if branch_in == 'dense_before_branch_in':
        branch_layers += [dense]
      branches += [stax.serial(*branch_layers)]

    output_layers = [
        stax.FanInSum() if axis is None else stax.FanInConcat(axis),
        stax.Relu()
    ]
    if branch_in == 'dense_after_branch_in':
      output_layers.insert(1, dense)

    nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                       output_layers))

    if get == 'nngp':
      init_fn, apply_fn, kernel_fn = nn
    elif get == 'ntk':
      init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1.25, 0.5))
    else:
      raise ValueError(get)

    kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn(
        init_fn, apply_fn, key, n_samples, device_count=0)

    exact = kernel_fn(X0_1, X0_2, get=get)
    empirical = kernel_fn_mc(X0_1, X0_2, get=get)
    empirical = empirical.reshape(exact.shape)
    utils.assert_close_matrices(self, empirical, exact, tol)
Ejemplo n.º 17
0
    def testPredictND(self):
        n_chan = 6
        key = random.PRNGKey(1)
        im_shape = (5, 4, 3)
        n_train = 2
        n_test = 2
        x_train = random.normal(key, (n_train, ) + im_shape)
        y_train = random.uniform(key, (n_train, 3, 2, n_chan))
        init_fn, apply_fn, _ = stax.Conv(n_chan, (3, 2), (1, 2))
        _, params = init_fn(key, x_train.shape)
        fx_train_0 = apply_fn(params, x_train)

        for trace_axes in [(), (-1, ), (-2, ), (-3, ), (0, 1), (2, 3), (2, ),
                           (1, 3), (0, -1), (0, 0, -3), (0, 1, 2, 3),
                           (0, 1, -1, 2)]:
            for ts in [None, np.arange(6).reshape((2, 3))]:
                for x in [None, 'x_test']:
                    with self.subTest(trace_axes=trace_axes, ts=ts, x=x):
                        t_shape = ts.shape if ts is not None else ()
                        y_test_shape = t_shape + (n_test, ) + y_train.shape[1:]
                        y_train_shape = t_shape + y_train.shape
                        x = x if x is None else random.normal(
                            key, (n_test, ) + im_shape)
                        fx_test_0 = None if x is None else apply_fn(params, x)

                        kernel_fn = empirical.empirical_kernel_fn(
                            apply_fn, trace_axes=trace_axes)

                        # TODO(romann): investigate the SIGTERM error on CPU.
                        # kernel_fn = jit(kernel_fn, static_argnums=(2,))
                        ntk_train_train = kernel_fn(x_train, None, 'ntk',
                                                    params)
                        if x is not None:
                            ntk_test_train = kernel_fn(x, x_train, 'ntk',
                                                       params)

                        loss = lambda x, y: 0.5 * np.mean(x - y)**2
                        predict_fn_mse = predict.gradient_descent_mse(
                            ntk_train_train, y_train, trace_axes=trace_axes)

                        predict_fn_mse_ensemble = predict.gradient_descent_mse_ensemble(
                            kernel_fn,
                            x_train,
                            y_train,
                            trace_axes=trace_axes,
                            params=params)

                        if x is None:
                            p_train_mse = predict_fn_mse(ts, fx_train_0)
                        else:
                            p_train_mse, p_test_mse = predict_fn_mse(
                                ts, fx_train_0, fx_test_0, ntk_test_train)
                            self.assertAllClose(y_test_shape, p_test_mse.shape)
                        self.assertAllClose(y_train_shape, p_train_mse.shape)

                        p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble(
                            ts, x, ('nngp', 'ntk'), compute_cov=True)
                        ref_shape = y_train_shape if x is None else y_test_shape
                        self.assertAllClose(ref_shape,
                                            p_ntk_mse_ens.mean.shape)
                        self.assertAllClose(ref_shape,
                                            p_nngp_mse_ens.mean.shape)

                        if ts is not None:
                            predict_fn = predict.gradient_descent(
                                loss,
                                ntk_train_train,
                                y_train,
                                trace_axes=trace_axes)

                            if x is None:
                                p_train = predict_fn(ts, fx_train_0)
                            else:
                                p_train, p_test = predict_fn(
                                    ts, fx_train_0, fx_test_0, ntk_test_train)
                                self.assertAllClose(y_test_shape, p_test.shape)
                            self.assertAllClose(y_train_shape, p_train.shape)
Ejemplo n.º 18
0
  def test_fan_in_conv(self,
                       same_inputs,
                       axis,
                       n_branches,
                       get,
                       branch_in,
                       readout):
    if xla_bridge.get_backend().platform == 'cpu':
      raise jtu.SkipTest('Not running CNNs on CPU to save time.')

    if axis in (None, 0, 1, 2) and branch_in == 'dense_after_branch_in':
      raise jtu.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` '
                         'require `is_gaussian`.')

    if axis == 3 and branch_in == 'dense_before_branch_in':
      raise jtu.SkipTest('`FanInConcat` on feature axis requires a dense layer '
                         'after concatenation.')

    key = random.PRNGKey(1)
    X0_1 = random.normal(key, (2, 5, 6, 3))
    X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3))

    if xla_bridge.get_backend().platform == 'tpu':
      width = 2048
      n_samples = 1024
      tol = 0.02
    else:
      width = 1024
      n_samples = 512
      tol = 0.01

    conv = stax.Conv(out_chan=width,
                     filter_shape=(3, 3),
                     padding='SAME',
                     W_std=1.25,
                     b_std=0.1)

    input_layers = [conv,
                    stax.FanOut(n_branches)]

    branches = []
    for b in range(n_branches):
      branch_layers = [FanInTest._get_phi(b)]
      for i in range(b):
        branch_layers += [
            stax.Conv(
                out_chan=width,
                filter_shape=(i + 1, 4 - i),
                padding='SAME',
                W_std=1.25 + i,
                b_std=0.1 + i),
            FanInTest._get_phi(i)]

      if branch_in == 'dense_before_branch_in':
        branch_layers += [conv]
      branches += [stax.serial(*branch_layers)]

    output_layers = [
        stax.FanInSum() if axis is None else stax.FanInConcat(axis),
        stax.Relu(),
        stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten()
    ]
    if branch_in == 'dense_after_branch_in':
      output_layers.insert(1, conv)

    nn = stax.serial(*(input_layers + [stax.parallel(*branches)] +
                       output_layers))

    init_fn, apply_fn, kernel_fn = stax.serial(
        nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5))

    kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn(
        init_fn,
        apply_fn,
        key,
        n_samples,
        device_count=0 if axis in (0, -4) else -1)

    exact = kernel_fn(X0_1, X0_2, get=get)
    empirical = kernel_fn_mc(X0_1, X0_2, get=get)
    empirical = empirical.reshape(exact.shape)
    utils.assert_close_matrices(self, empirical, exact, tol)
Ejemplo n.º 19
0
 def test_create_model(self):
     variables = train.initialized(random.PRNGKey(0), 224)
     x = random.normal(random.PRNGKey(1), (8, 224, 224, 3))
     y = train.model(train=False).apply(variables, x)
     self.assertEqual(y.shape, (8, 1000))
Ejemplo n.º 20
0
        ll = -0.5 * jnp.square(target - p).sum(axis=(-1,
                                                     -2)) / jnp.square(scale)
        return ll

    return log_density


rng = random.PRNGKey(args.seed)
rng, rng_data, rng_ortho, rng_noise = random.split(rng, 4)
rng, rng_haar, rng_acc = random.split(rng, 3)
rng, rng_deq, rng_bij = random.split(rng, 3)
rng, rng_train = random.split(rng, 2)
rng, rng_amb, rng_mse, rng_kl = random.split(rng, 4)

num_dims = 3
data = random.normal(rng_data, [10, num_dims])
O = pd.orthogonal.haar.rvs(rng_ortho, 1, num_dims)[0]
noise = args.noise_scale * random.normal(rng_noise, data.shape)
target = data @ O.T + noise
log_density = log_density_factory(data, target, args.noise_scale)
U, _, VT = jnp.linalg.svd(data.T @ target)
Oml = (U @ VT).T

xhaar = pd.orthogonal.haar.rvs(rng_haar, 10000000, num_dims)
lprop = pd.orthogonal.haar.logpdf(xhaar)
ld = log_density(xhaar)
lm = -lprop[0] + log_density(Oml)
la = ld - lprop - lm
logu = jnp.log(random.uniform(rng_acc, [len(xhaar)]))
xobs = xhaar[logu < la]
print('number of rejection samples: {}'.format(len(xobs)))
Ejemplo n.º 21
0
    def test_vmap_axes(self, same_inputs):
        n1, n2 = 3, 4
        c1, c2, c3 = 9, 5, 7
        h2, h3, w3 = 6, 8, 2

        def get_x(n, k):
            k1, k2, k3 = random.split(k, 3)
            x1 = random.normal(k1, (n, c1))
            x2 = random.normal(k2, (h2, n, c2))
            x3 = random.normal(k3, (c3, w3, n, h3))
            x = [(x1, x2), x3]
            return x

        x1 = get_x(n1, random.PRNGKey(1))
        x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None

        p1 = random.normal(random.PRNGKey(5), (n1, h2, h2))
        p2 = None if same_inputs else random.normal(random.PRNGKey(6),
                                                    (n2, h2, h2))

        init_fn, apply_fn, _ = stax.serial(
            stax.parallel(
                stax.parallel(
                    stax.serial(stax.Dense(4, 2., 0.1), stax.Relu(),
                                stax.Dense(3, 1., 0.15)),  # 1
                    stax.serial(
                        stax.Conv(7, (2, ),
                                  padding='SAME',
                                  dimension_numbers=('HNC', 'OIH', 'NHC')),
                        stax.Erf(), stax.Aggregate(1, 0, -1),
                        stax.GlobalAvgPool(), stax.Dense(3, 0.5, 0.2)),  # 2
                ),
                stax.serial(
                    stax.Conv(5, (2, 3),
                              padding='SAME',
                              dimension_numbers=('CWNH', 'IOHW', 'HWCN')),
                    stax.Sin(),
                )  # 3
            ),
            stax.parallel(
                stax.FanInSum(),
                stax.Conv(2, (2, 1),
                          dimension_numbers=('HWCN', 'OIHW', 'HNWC'))))

        _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1))
        implicit = jit(nt.empirical_ntk_fn(apply_fn, implementation=2))
        direct = jit(nt.empirical_ntk_fn(apply_fn, implementation=1))

        implicit_batched = jit(
            nt.empirical_ntk_fn(apply_fn,
                                vmap_axes=([(0, 1), 2], [-2,
                                                         -3], dict(pattern=0)),
                                implementation=2))
        direct_batched = jit(
            nt.empirical_ntk_fn(apply_fn,
                                vmap_axes=([(-2, -2),
                                            -2], [0, 1], dict(pattern=-3)),
                                implementation=1))

        k = direct(x1, x2, params, pattern=(p1, p2))

        self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2)))
        self.assertAllClose(k, direct_batched(x1, x2, params,
                                              pattern=(p1, p2)))
        self.assertAllClose(k,
                            implicit_batched(x1, x2, params, pattern=(p1, p2)))
Ejemplo n.º 22
0
 def get_2d_array(self):
   return random.normal(self.key, (self.nb_rows, self.nb_columns))
Ejemplo n.º 23
0
 def random_normal(self, shape):
     self.rnd_key, subkey = jax.random.split(self.rnd_key)
     return random.normal(subkey,
                          shape,
                          dtype=to_numpy_dtype(self.float_type))
Ejemplo n.º 24
0
 def update_fn(self, rng, x, t):
     f, G = self.rsde.discretize(x, t)
     z = random.normal(rng, x.shape)
     x_mean = x - f
     x = x_mean + batch_mul(G, z)
     return x, x_mean
def sample_diag_gaussian(mu, log_std, subkey):
    """Reparameterization trick for getting z from x.
  """
    return random.normal(subkey, mu.shape) * np.exp(log_std) + mu
Ejemplo n.º 26
0
Archivo: stax.py Proyecto: tonyduan/jax
 def init(rng, shape):
     std = lax.convert_element_type(stddev, np.float32)
     return std * random.normal(rng, shape, dtype=np.float32)
Ejemplo n.º 27
0
 def testNormalBfloat16(self):
   # Passing bfloat16 as dtype string.
   # https://github.com/google/jax/issues/6813
   res_bfloat16_str = random.normal(random.PRNGKey(0), dtype='bfloat16')
   res_bfloat16 = random.normal(random.PRNGKey(0), dtype=jnp.bfloat16)
   self.assertAllClose(res_bfloat16, res_bfloat16_str)
Ejemplo n.º 28
0
Archivo: stax.py Proyecto: tonyduan/jax
 def init(rng, shape):
     fan_in, fan_out = shape[in_axis], shape[out_axis]
     size = onp.prod(onp.delete(shape, [in_axis, out_axis]))
     std = scale / np.sqrt((fan_in + fan_out) / 2. * size)
     std = lax.convert_element_type(std, np.float32)
     return std * random.normal(rng, shape, dtype=np.float32)
Ejemplo n.º 29
0
 def f(x):
   return random.normal(random.PRNGKey(x), (int(1e12),))
Ejemplo n.º 30
0
 def sample(self, key, sample_shape=()):
     eps = random.normal(key,
                         shape=sample_shape + self.batch_shape +
                         self.event_shape)
     return self.loc + np.squeeze(
         np.matmul(self.scale_tril, eps[..., np.newaxis]), axis=-1)