def test_binary(self, primitive: Optional[Primitive], shape1, shape2, dtype, params): # TODO(romann): revisit when bugs below are fixed. if primitive == lax.conv_general_dilated_p: if jax.default_backend() == 'tpu': raise absltest.SkipTest('http://b/235167364') elif jax.default_backend( ) == 'gpu' and params['batch_group_count'] != 1: raise absltest.SkipTest('http://b/235485533') if len(shape1) > 3 or len(shape2) > 3: test_utils.skip_test(self) self._test_primitive(primitive, [shape1, shape2], dtype, params)
def test_parallel_in(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) input_key1, input_key2, mc_key = random.split(rng, 3) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 2)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 3)) x1 = (x1_1, x1_2) x2 = (x2_1, x2_2) N = 2 ** 7 def net(logits): return stax.serial( stax.parallel(stax.Dense(N), stax.Dense(N)), stax.serial(stax.FanInSum(), stax.Dense(logits))) init_fn, apply_fn, kernel_fn = net(N if kernel_type == 'nngp' else 1) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,), implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=((0, 0), 0, {}) ) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, kernel_type), kernel_fn_empirical(x1, x2, kernel_type), rtol)
def main(unused_argv): key1, key2, key3 = random.split(random.PRNGKey(1), 3) x1 = random.normal(key1, (2, 8, 8, 3)) x2 = random.normal(key2, (3, 8, 8, 3)) # A vanilla CNN. init_fn, f, _ = stax.serial( stax.Conv(8, (3, 3)), stax.Relu(), stax.Conv(8, (3, 3)), stax.Relu(), stax.Conv(8, (3, 3)), stax.Flatten(), stax.Dense(10) ) _, params = init_fn(key3, x1.shape) kwargs = dict( f=f, trace_axes=(), vmap_axes=0, ) # Default, baseline Jacobian contraction. jacobian_contraction = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION) # (6, 3, 10, 10) full `np.ndarray` test-train NTK ntk_jc = jacobian_contraction(x2, x1, params) # NTK-vector products-based implementation. ntk_vector_products = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.NTK_VECTOR_PRODUCTS) ntk_vp = ntk_vector_products(x2, x1, params) # Structured derivatives-based implementation. structured_derivatives = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.STRUCTURED_DERIVATIVES) ntk_sd = structured_derivatives(x2, x1, params) # Auto-FLOPs-selecting implementation. Doesn't work correctly on CPU/GPU. auto = nt.empirical_ntk_fn( **kwargs, implementation=nt.NtkImplementation.AUTO) ntk_auto = auto(x2, x1, params) # Check that implementations match for ntk1 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]: for ntk2 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]: diff = np.max(np.abs(ntk1 - ntk2)) print(f'NTK implementation diff {diff}.') assert diff < (1e-4 if jax.default_backend() != 'tpu' else 0.1), diff print('All NTK implementations match.')
def assert_close(x, y, tol=3e-5): if default_backend() == 'tpu': # TODO(romann): understand why TPUs have high errors. tol = 0.21 self.assertLess( np.max(np.abs(x - y)) / (np.mean(np.abs(x)) + np.mean(np.abs(y))), tol)
def test_is_on_cpu(self): dtypes = [np.float16, np.float32] float64 = jax.dtypes.canonicalize_dtype(np.float64) if float64 != np.float32: dtypes += [float64] for dtype in dtypes: with self.subTest(dtype=dtype): def x(): return random.normal(random.PRNGKey(1), (2, 3), dtype) def x_cpu(): return device_get( random.normal(random.PRNGKey(1), (2, 3), dtype)) x_jit = jit(x) # x_cpu_jit = jit(x_cpu) x_cpu_jit_cpu = jit(x_cpu, backend='cpu') self.assertTrue(utils.is_on_cpu(x_cpu())) # TODO(mattjj): re-enable this when device_put under jit works # self.assertTrue(utils.is_on_cpu(x_cpu_jit())) self.assertTrue(utils.is_on_cpu(x_cpu_jit_cpu())) if jax.default_backend() == 'cpu': self.assertTrue(utils.is_on_cpu(x())) self.assertTrue(utils.is_on_cpu(x_jit())) else: self.assertFalse(utils.is_on_cpu(x())) self.assertFalse(utils.is_on_cpu(x_jit()))
def _optimize() -> str: """Return contraction order for `np.einsum` based on platform. Introduced after https://github.com/google/jax/pull/7512 since TPU seems to be more precise in `greeedy` mode. """ return 'greedy' if jax.default_backend() == 'tpu' else 'optimal'
def skip_test( self, msg: str = 'Skipping large tests for speed.', platforms: Tuple[str, ...] = ('cpu',) ): if jax.default_backend() in platforms: raise parameterized.TestCase.skipTest(self, msg)
def double_buffer_on_gpu(ds): if jax.default_backend() == "gpu": # This keeps two batches per-device in memory at all times, allowing # h2d transfers to overlap with execution (see b/173483287 for details). return double_buffer(ds) else: return ds
def test_nested_parallel(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) (input_key1, input_key2, input_key3, input_key4, mask_key, mc_key) = random.split(rng, 6) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2)) x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3)) x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4)) m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4) x1_1 = test_utils.mask( x1_1, mask_constant=-1, mask_axis=(1,), key=m1_key, p=0.5) x1_2 = test_utils.mask( x1_2, mask_constant=-1, mask_axis=(2, 3,), key=m2_key, p=0.5) if not same_inputs: x2_3 = test_utils.mask( x2_3, mask_constant=-1, mask_axis=(1, 3,), key=m3_key, p=0.5) x2_4 = test_utils.mask( x2_4, mask_constant=-1, mask_axis=(2,), key=m4_key, p=0.5) x1 = (((x1_1, x1_2), x1_3), x1_4) x2 = (((x2_1, x2_2), x2_3), x2_4) if not same_inputs else None N_in = 2 ** 7 # We only include dropout on non-TPU backends, because it takes large N to # converge on TPU. dropout_or_id = stax.Dropout(0.9) if platform != 'tpu' else stax.Identity() init_fn, apply_fn, kernel_fn = stax.parallel( stax.parallel( stax.parallel(stax.Dense(N_in), stax.serial(stax.Conv(N_in + 1, (2, 2)), stax.Flatten())), stax.serial(stax.Conv(N_in + 2, (2, 2)), dropout_or_id, stax.GlobalAvgPool())), stax.Conv(N_in + 3, (2,))) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {}) if platform == 'tpu' else None) ) test_utils.assert_close_matrices( self, kernel_fn(x1, x2, get=kernel_type, mask_constant=-1), kernel_fn_empirical(x1, x2, get=kernel_type, mask_constant=-1), rtol)
def test_parallel_out(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) input_key1, mc_key = random.split(rng, 2) x1, x2 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 1)) N = 2 ** 10 def net(logits): return stax.serial( stax.Dense(N), stax.FanOut(2), stax.parallel(stax.Dense(logits), stax.Dense(logits))) init_fn, apply_fn, kernel_fn = net(N if kernel_type == 'nngp' else 1) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,), implementation=2, vmap_axes=(0, [0, 0], {})) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, kernel_type), kernel_fn_empirical(x1, x2, kernel_type), rtol)
def _test_activation(self, activation_fn, same_inputs, model, get, rbf_gamma=None): if 'conv' in model: test_utils.skip_test(self) key = random.PRNGKey(1) key, split = random.split(key) output_dim = 1024 if get == 'nngp' else 1 b_std = 0.5 W_std = 2.0 if activation_fn[2].__name__ == 'Sin': W_std = 0.9 if activation_fn[2].__name__ == 'Rbf': W_std = 1.0 b_std = 0.0 if model == 'fc': rtol = 0.04 X0_1 = random.normal(key, (4, 2)) X0_2 = None if same_inputs else random.normal(split, (2, 2)) affine = stax.Dense(1024, W_std, b_std) readout = stax.Dense(output_dim) depth = 1 else: rtol = 0.05 X0_1 = random.normal(key, (2, 4, 4, 3)) X0_2 = None if same_inputs else random.normal(split, (4, 4, 4, 3)) affine = stax.Conv(512, (3, 2), W_std=W_std, b_std=b_std, padding='SAME') readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else stax.Flatten(), stax.Dense(output_dim)) depth = 2 if default_backend() == 'cpu': num_samplings = 200 rtol *= 2 else: num_samplings = (500 if activation_fn[2].__name__ in ('Sin', 'Rbf') else 300) init_fn, apply_fn, kernel_fn = stax.serial( *[affine, activation_fn]*depth, readout) analytic_kernel = kernel_fn(X0_1, X0_2, get) mc_kernel_fn = nt.monte_carlo_kernel_fn( init_fn, apply_fn, split, num_samplings, implementation=2, vmap_axes=0 ) empirical_kernel = mc_kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, empirical_kernel, rtol) # Check match with explicit RBF if rbf_gamma is not None and get == 'nngp' and model == 'fc': input_dim = X0_1.shape[1] _, _, kernel_fn = self._RBF(rbf_gamma / input_dim) direct_rbf_kernel = kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, direct_rbf_kernel, rtol)
def _binomial(key, p, n, shape): shape = shape or lax.broadcast_shapes(jnp.shape(p), jnp.shape(n)) # reshape to map over axis 0 p = jnp.reshape(jnp.broadcast_to(p, shape), -1) n = jnp.reshape(jnp.broadcast_to(n, shape), -1) key = random.split(key, jnp.size(p)) if jax.default_backend() == "cpu": ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n)) else: ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n) return jnp.reshape(ret, shape)
def test_pjit_inherits_effects(self): if jax.default_backend() not in {'gpu', 'tpu'}: raise unittest.SkipTest("pjit only supports GPU and TPU backends") def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='bar') return x f = pjit.pjit(f, in_axis_resources=pjit.PartitionSpec('x'), out_axis_resources=pjit.PartitionSpec('x')) with self.assertRaisesRegex(NotImplementedError, 'Effects not supported'): with maps.Mesh(np.array(jax.devices()), ['x']): jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
def stub_out_pmap(batch: ModuleType, count: int): # If we are using GPU or CPU stub out pmap with vmap to simulate multi-core. if count > 0: class xla_bridge_stub: def device_count(self) -> int: return count platform = jax.default_backend() if platform == 'gpu' or platform == 'cpu': batch.pmap = _jit_vmap batch.xla_bridge = xla_bridge_stub()
def testJaxRoundTrip(self, shape, dtype, take_ownership, gpu): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) if gpu and jax.default_backend() == "cpu": raise unittest.SkipTest("Skipping GPU test case on CPU") if (not gpu and jax.default_backend() == "gpu" and jax.lib._xla_extension_version < 25): raise unittest.SkipTest( "Mixed CPU/GPU dlpack support requires jaxlib " "0.1.68 or newer") device = jax.devices("gpu" if gpu else "cpu")[0] x = jax.device_put(np, device) dlpack = jax.dlpack.to_dlpack(x, take_ownership=take_ownership) self.assertEqual(take_ownership, x.device_buffer.is_deleted()) y = jax.dlpack.from_dlpack(dlpack) self.assertEqual(y.device(), device) self.assertAllClose(np.astype(x.dtype), y) self.assertRaisesRegex(RuntimeError, "DLPack tensor may be consumed at most once", lambda: jax.dlpack.from_dlpack(dlpack))
def test_sample_vs_analytic_nngp(self, batch_size, device_count, store_on_device): test_utils.stub_out_pmap(batching, device_count) x1, x2, init_fn, apply_fn, stax_kernel_fn, key = _get_inputs_and_model( WIDTH, 256, jax.default_backend() == 'tpu') sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 200, batch_size, device_count, store_on_device) ker_empirical = sample(x1, x2, 'nngp') ker_analytic = stax_kernel_fn(x1, x2, 'nngp') test_utils.assert_close_matrices(self, ker_analytic, ker_empirical, 2e-2)
def test_sparse_inputs(self, act, kernel, do_stabilize): if do_stabilize and act != 'relu': raise absltest.SkipTest('Stabilization possible only in Relu.') key = random.PRNGKey(1) input_count = 4 sparse_count = 2 input_size = 3 width = 1024 # NOTE(schsam): It seems that convergence is slower when inputs are sparse. samples = N_SAMPLES if default_backend() == 'gpu': tol = 5e-4 samples = 100 * N_SAMPLES else: tol = {onp.dtype(onp.float32): 5e-2, onp.dtype(onp.float64): 5e-3} # a batch of dense inputs x_dense = random.normal(key, (input_count, input_size)) x_sparse = x_dense.at[:sparse_count, :].set(0.) activation = (stax.Relu(do_stabilize=do_stabilize) if act == 'relu' else stax.Erf()) init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(width), activation, stax.Dense(1 if kernel == 'ntk' else width)) exact = kernel_fn(x_sparse, None, kernel) mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, random.split(key, 2)[0], samples, vmap_axes=0, device_count=-1, implementation=2 )(x_sparse, None, kernel) mc = np.reshape(mc, exact.shape) assert not np.any(np.isnan(exact)) self.assertAllClose(exact[sparse_count:, sparse_count:], mc[sparse_count:, sparse_count:], rtol=tol, atol=tol)
def _check_agreement_with_empirical( self, net, same_inputs, use_dropout, is_ntk, rtol=RTOL, atol=ATOL ): ((init_fn, apply_fn, kernel_fn), input_shape, device_count, channel_axis) = net num_samples = N_SAMPLES * 5 if use_dropout else N_SAMPLES key = random.PRNGKey(1) x1, x2 = _get_inputs(key, same_inputs, input_shape) if default_backend() == 'tpu' and use_dropout: # including a test case for tpu + dropout with (parallel + batching) batch_size = 2 else: batch_size = 0 x1_out_shape, params = init_fn(key, x1.shape) if same_inputs: assert x2 is None if x2 is None: x2_out_shape = x1_out_shape else: x2_out_shape, params = init_fn(key, x2.shape) del params def _get_empirical(n_samples, get): kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=device_count, trace_axes=(channel_axis,), batch_size=batch_size, implementation=2 ) if same_inputs: assert x2 is None return kernel_fn_empirical(x1, x2, get) if is_ntk: exact, shape1, shape2 = kernel_fn(x1, x2, ('ntk', 'shape1', 'shape2')) empirical = _get_empirical(num_samples, 'ntk') else: exact, shape1, shape2 = kernel_fn(x1, x2, ('nngp', 'shape1', 'shape2')) empirical = _get_empirical(num_samples, 'nngp') test_utils.assert_close_matrices(self, exact, empirical, rtol, atol) self.assertEqual(shape1, x1_out_shape) self.assertEqual(shape2, x2_out_shape)
def _subsample_fn(size, subsample_size, rng_key=None): assert rng_key is not None, "Missing random key to generate subsample indices." if jax.default_backend() == "cpu": # ref: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm rng_keys = random.split(rng_key, subsample_size) def body_fn(val, idx): i_p1 = size - idx i = i_p1 - 1 j = random.randint(rng_keys[idx], (), 0, i_p1) val = val.at[jnp.array([i, j])].set(val[jnp.array([j, i])]) return val, None val, _ = lax.scan(body_fn, jnp.arange(size), jnp.arange(subsample_size)) return val[-subsample_size:] else: return random.choice(rng_key, size, (subsample_size, ), replace=False)
def test_parallel_in_out(self, same_inputs, kernel_type): platform = default_backend() rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) input_key1, input_key2, mc_key = random.split(rng, 3) x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 1)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2)) x1 = (x1_1, x1_2) x2 = (x2_1, x2_2) N_in = 2 ** 10 N_out = N_in if kernel_type == 'nngp' else 1 readin = stax.serial(stax.parallel(stax.Dense(N_in), stax.Dense(N_in)), stax.FanInSum()) readout = stax.serial(stax.FanOut(3), stax.parallel(stax.Dense(N_out), stax.Dense(N_out + 1), stax.Dense(N_out + 2))) init_fn, apply_fn, _ = stax.serial(readin, readout) K_readin_fn = jit(readin[2]) K_readout_fn = jit(functools.partial(readout[2], get=kernel_type)) kernel_fn_empirical = nt.monte_carlo_kernel_fn( init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,), implementation=2, vmap_axes=((0, 0), [0, 0, 0], {}) ) test_utils.assert_close_matrices( self, K_readout_fn(K_readin_fn(x1, x2)), kernel_fn_empirical(x1, x2, get=kernel_type), rtol) # Check Both (here we just want to make sure we _can_ compute the output). K_readin_fn = jit(readin[2]) K_readout_fn = jit(functools.partial(readout[2], get=('nngp', 'ntk'))) K_readout_fn(K_readin_fn(x1, x2))
def test_double_buffer(self): if jax.default_backend() != "gpu": self.skipTest("Only necessary on GPU.") n = jax.local_device_count() dataset = it.repeat(np.ones([n])) iterator = iter(utils.double_buffer(dataset)) batch_ptrs = [] while len(batch_ptrs) < 4: batch = next(iterator) ptrs = [b.unsafe_buffer_pointer() for b in batch.device_buffers] batch_ptrs.append(ptrs) del batch self.assertEqual(batch_ptrs[0], batch_ptrs[2]) self.assertEqual(batch_ptrs[1], batch_ptrs[3]) self.assertNotEqual(batch_ptrs[0], batch_ptrs[1]) self.assertNotEqual(batch_ptrs[2], batch_ptrs[3])
def x1_is_x2(x1: np.ndarray, x2: Optional[np.ndarray] = None, eps: float = 1e-12) -> Union[bool, np.ndarray]: if not isinstance(x1, (onp.ndarray, np.ndarray)): raise TypeError('`x1` must be an ndarray. A {} is found.'.format( type(x1))) if x2 is None: return True if x1 is x2: return True if x1.shape != x2.shape: return False if jax.default_backend() == 'tpu': eps = 1e-4 return np.all(np.abs(x1 - x2) < eps)
def test_elementwise_numerical(self, same_inputs, model, phi, get): if 'conv' in model: test_utils.skip_test(self) key, split = random.split(random.PRNGKey(1)) output_dim = 1 b_std = 0.01 W_std = 1.0 rtol = 2e-3 deg = 25 if get == 'ntk': rtol *= 2 if default_backend() == 'tpu': rtol *= 2 if model == 'fc': X0_1 = random.normal(key, (3, 7)) X0_2 = None if same_inputs else random.normal(split, (5, 7)) affine = stax.Dense(1024, W_std, b_std) readout = stax.Dense(output_dim) depth = 1 else: X0_1 = random.normal(key, (2, 8, 8, 3)) X0_2 = None if same_inputs else random.normal(split, (3, 8, 8, 3)) affine = stax.Conv(1024, (3, 2), W_std=W_std, b_std=b_std, padding='SAME') readout = stax.serial(stax.GlobalAvgPool() if 'pool' in model else stax.Flatten(), stax.Dense(output_dim)) depth = 2 _, _, kernel_fn = stax.serial(*[affine, phi] * depth, readout) analytic_kernel = kernel_fn(X0_1, X0_2, get) fn = lambda x: phi[1]((), x) _, _, kernel_fn = stax.serial( *[affine, stax.ElementwiseNumerical(fn, deg=deg)] * depth, readout) numerical_activation_kernel = kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, numerical_activation_kernel, rtol)
def _device_to_host_funcs(): """Generates device-to-host transfer functions.""" if jax.default_backend() == "cpu": # device-to-host does not incur transfer on the CPU backend. return [] with jax.transfer_guard_host_to_device("allow"): device_arrays = [jnp.ones(1) for _ in range(6)] return [ # (function name, is an explicit transfer?, function) ("device_to_host_jax_device_get", True, lambda: jax.device_get(device_arrays[0])), ("device_to_host_np_asarray", False, lambda: np.asarray(device_arrays[1])), ("device_to_host_copy_to_host_async", False, lambda: device_arrays[2].copy_to_host_async()), ("device_to_host_np_add", False, lambda: np.add(device_arrays[3], 1)), ("device_to_host_str", False, lambda: str(device_arrays[4])), ("device_to_host_pickle_dumps", False, lambda: pickle.dumps(device_arrays[5])), ]
def _subsample_fn(size, subsample_size, rng_key=None): if rng_key is None: raise ValueError( "Missing random key to generate subsample indices." " Algorithms like HMC/NUTS do not support subsampling." " You might want to use SVI or HMCECS instead." ) if jax.default_backend() == "cpu": # ref: https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm rng_keys = random.split(rng_key, subsample_size) def body_fn(val, idx): i_p1 = size - idx i = i_p1 - 1 j = random.randint(rng_keys[idx], (), 0, i_p1) val = val.at[jnp.array([i, j])].set(val[jnp.array([j, i])]) return val, None val, _ = lax.scan(body_fn, jnp.arange(size), jnp.arange(subsample_size)) return val[-subsample_size:] else: return random.choice(rng_key, size, (subsample_size,), replace=False)
def setUpModule(): if jax.default_backend() not in {'gpu', 'tpu'}: raise unittest.SkipTest("pjit only supports GPU and TPU backends") jtu.set_spmd_lowering_flag(True)
def fori_collect( lower, upper, body_fun, init_val, transform=identity, progbar=True, return_last_val=False, collection_size=None, thinning=1, **progbar_opts, ): """ This looping construct works like :func:`~jax.lax.fori_loop` but with the additional effect of collecting values from the loop body. In addition, this allows for post-processing of these samples via `transform`, and progress bar updates. Note that, `progbar=False` will be faster, especially when collecting a lot of samples. Refer to example usage in :func:`~numpyro.infer.mcmc.hmc`. :param int lower: the index to start the collective work. In other words, we will skip collecting the first `lower` values. :param int upper: number of times to run the loop body. :param body_fun: a callable that takes a collection of `np.ndarray` and returns a collection with the same shape and `dtype`. :param init_val: initial value to pass as argument to `body_fun`. Can be any Python collection type containing `np.ndarray` objects. :param transform: a callable to post-process the values returned by `body_fn`. :param progbar: whether to post progress bar updates. :param bool return_last_val: If `True`, the last value is also returned. This has the same type as `init_val`. :param thinning: Positive integer that controls the thinning ratio for retained values. Defaults to 1, i.e. no thinning. :param int collection_size: Size of the returned collection. If not specified, the size will be ``(upper - lower) // thinning``. If the size is larger than ``(upper - lower) // thinning``, only the top ``(upper - lower) // thinning`` entries will be non-zero. :param `**progbar_opts`: optional additional progress bar arguments. A `diagnostics_fn` can be supplied which when passed the current value from `body_fun` returns a string that is used to update the progress bar postfix. Also a `progbar_desc` keyword argument can be supplied which is used to label the progress bar. :return: collection with the same type as `init_val` with values collected along the leading axis of `np.ndarray` objects. """ assert lower <= upper assert thinning >= 1 collection_size = ((upper - lower) // thinning if collection_size is None else collection_size) assert collection_size >= (upper - lower) // thinning init_val_flat, unravel_fn = ravel_pytree(transform(init_val)) start_idx = lower + (upper - lower) % thinning num_chains = progbar_opts.pop("num_chains", 1) # host_callback does not work yet with multi-GPU platforms # See: https://github.com/google/jax/issues/6447 if num_chains > 1 and jax.default_backend() == "gpu": warnings.warn( "We will disable progress bar because it does not work yet on multi-GPUs platforms.", stacklevel=find_stack_level(), ) progbar = False @cached_by(fori_collect, body_fun, transform) def _body_fn(i, vals): val, collection, start_idx, thinning = vals val = body_fun(val) idx = (i - start_idx) // thinning collection = cond( idx >= 0, collection, lambda x: x.at[idx].set(ravel_pytree(transform(val))[0]), collection, identity, ) return val, collection, start_idx, thinning collection = jnp.zeros((collection_size, ) + init_val_flat.shape, dtype=init_val_flat.dtype) if not progbar: last_val, collection, _, _ = fori_loop( 0, upper, _body_fn, (init_val, collection, start_idx, thinning)) elif num_chains > 1: progress_bar_fori_loop = progress_bar_factory(upper, num_chains) _body_fn_pbar = progress_bar_fori_loop(_body_fn) last_val, collection, _, _ = fori_loop( 0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)) else: diagnostics_fn = progbar_opts.pop("diagnostics_fn", None) progbar_desc = progbar_opts.pop("progbar_desc", lambda x: "") vals = (init_val, collection, device_put(start_idx), device_put(thinning)) if upper == 0: # special case, only compiling jit(_body_fn)(0, vals) else: with tqdm.trange(upper) as t: for i in t: vals = jit(_body_fn)(i, vals) t.set_description(progbar_desc(i), refresh=False) if diagnostics_fn: t.set_postfix_str(diagnostics_fn(vals[0]), refresh=False) last_val, collection, _, _ = vals unravel_collection = vmap(unravel_fn)(collection) return (unravel_collection, last_val) if return_last_val else unravel_collection
def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in, fan_in_mode): if fan_in_mode in ['FanInSum', 'FanInProd']: if axis != 0: raise absltest.SkipTest( '`FanInSum` and `FanInProd` are skipped when ' 'axis != 0.') axis = None if (fan_in_mode == 'FanInSum' or axis == 0) and branch_in == 'dense_after_branch_in': raise absltest.SkipTest('`FanInSum` and `FanInConcat(0)` ' 'require `is_gaussian`.') if ((axis == 1 or fan_in_mode == 'FanInProd') and branch_in == 'dense_before_branch_in'): raise absltest.SkipTest( '`FanInConcat` or `FanInProd` on feature axis requires a dense layer ' 'after concatenation or Hadamard product.') if fan_in_mode == 'FanInSum': fan_in_layer = stax.FanInSum() elif fan_in_mode == 'FanInProd': fan_in_layer = stax.FanInProd() else: fan_in_layer = stax.FanInConcat(axis) if n_branches != 2: test_utils.skip_test(self) key = random.PRNGKey(1) X0_1 = np.cos(random.normal(key, (4, 3))) X0_2 = None if same_inputs else random.normal(key, (8, 3)) width = 1024 n_samples = 256 * 2 if default_backend() == 'tpu': tol = 0.07 else: tol = 0.02 dense = stax.Dense(width, 1.25, 0.1) input_layers = [dense, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): multiplier = 1 if axis not in (1, -1) else (1 + 0.25 * i) branch_layers += [ stax.Dense(int(width * multiplier), 1. + 2 * i, 0.5 + i), FanInTest._get_phi(i) ] if branch_in == 'dense_before_branch_in': branch_layers += [dense] branches += [stax.serial(*branch_layers)] output_layers = [fan_in_layer, stax.Relu()] if branch_in == 'dense_after_branch_in': output_layers.insert(1, dense) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) if get == 'nngp': init_fn, apply_fn, kernel_fn = nn elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1, 1.25, 0.5)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -2) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if axis in (0, -2) else 0, ) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_fan_in_conv(self, same_inputs, axis, n_branches, get, branch_in, readout, fan_in_mode): test_utils.skip_test(self) if fan_in_mode in ['FanInSum', 'FanInProd']: if axis != 0: raise absltest.SkipTest( '`FanInSum` and `FanInProd()` are skipped when ' 'axis != 0.') axis = None if (fan_in_mode == 'FanInSum' or axis in [0, 1, 2]) and branch_in == 'dense_after_branch_in': raise absltest.SkipTest('`FanInSum` and `FanInConcat(0/1/2)` ' 'require `is_gaussian`.') if ((axis == 3 or fan_in_mode == 'FanInProd') and branch_in == 'dense_before_branch_in'): raise absltest.SkipTest( '`FanInConcat` or `FanInProd` on feature axis ' 'requires a dense layer after concatenation ' 'or Hadamard product.') if fan_in_mode == 'FanInSum': fan_in_layer = stax.FanInSum() elif fan_in_mode == 'FanInProd': fan_in_layer = stax.FanInProd() else: fan_in_layer = stax.FanInConcat(axis) key = random.PRNGKey(1) X0_1 = random.normal(key, (2, 5, 6, 3)) X0_2 = None if same_inputs else random.normal(key, (3, 5, 6, 3)) if default_backend() == 'tpu': width = 2048 n_samples = 1024 tol = 0.02 else: width = 1024 n_samples = 512 tol = 0.01 conv = stax.Conv(out_chan=width, filter_shape=(3, 3), padding='SAME', W_std=1.25, b_std=0.1) input_layers = [conv, stax.FanOut(n_branches)] branches = [] for b in range(n_branches): branch_layers = [FanInTest._get_phi(b)] for i in range(b): multiplier = 1 if axis not in (3, -1) else (1 + 0.25 * i) branch_layers += [ stax.Conv(out_chan=int(width * multiplier), filter_shape=(i + 1, 4 - i), padding='SAME', W_std=1.25 + i, b_std=0.1 + i), FanInTest._get_phi(i) ] if branch_in == 'dense_before_branch_in': branch_layers += [conv] branches += [stax.serial(*branch_layers)] output_layers = [ fan_in_layer, stax.Relu(), stax.GlobalAvgPool() if readout == 'pool' else stax.Flatten() ] if branch_in == 'dense_after_branch_in': output_layers.insert(1, conv) nn = stax.serial(*(input_layers + [stax.parallel(*branches)] + output_layers)) init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(1 if get == 'ntk' else width, 1.25, 0.5)) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if axis in (0, -4) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if axis in (0, -4) else 0, ) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) test_utils.assert_close_matrices(self, empirical, exact, tol)
def test_mask_conv(self, same_inputs, get, mask_axis, mask_constant, concat, proj, p, n, transpose): if isinstance(concat, int) and concat > n: raise absltest.SkipTest('Concatenation axis out of bounds.') test_utils.skip_test(self) if default_backend() == 'gpu' and n > 3: raise absltest.SkipTest('>=4D-CNN is not supported on GPUs.') width = 256 n_samples = 256 tol = 0.03 key = random.PRNGKey(1) spatial_shape = ((1, 2, 3, 2, 1) if transpose else (15, 8, 9))[:n] filter_shape = ((2, 3, 1, 2, 1) if transpose else (7, 2, 3))[:n] strides = (2, 1, 3, 2, 3)[:n] spatial_spec = 'HWDZX'[:n] dimension_numbers = ('N' + spatial_spec + 'C', 'OI' + spatial_spec, 'N' + spatial_spec + 'C') x1 = np.cos(random.normal(key, (2, ) + spatial_shape + (2, ))) x1 = test_utils.mask(x1, mask_constant, mask_axis, key, p) if same_inputs: x2 = None else: x2 = np.cos(random.normal(key, (4, ) + spatial_shape + (2, ))) x2 = test_utils.mask(x2, mask_constant, mask_axis, key, p) def get_attn(): return stax.GlobalSelfAttention( n_chan_out=width, n_chan_key=width, n_chan_val=int(np.round(float(width) / int(np.sqrt(width)))), n_heads=int(np.sqrt(width)), ) if proj == 'avg' else stax.Identity() conv = stax.ConvTranspose if transpose else stax.Conv nn = stax.serial( stax.FanOut(3), stax.parallel( stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1.5, b_std=0.2), stax.LayerNorm(axis=(1, -1)), stax.Abs(), stax.DotGeneral(rhs=0.9), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1.2, b_std=0.1), ), stax.serial( conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='SAME', W_std=0.1, b_std=0.3), stax.Relu(), stax.Dropout(0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=0.9, b_std=1.), ), stax.serial( get_attn(), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='CIRCULAR', W_std=1., b_std=0.1), stax.Erf(), stax.Dropout(0.2), stax.DotGeneral(rhs=0.7), conv(dimension_numbers=dimension_numbers, out_chan=width, strides=strides, filter_shape=filter_shape, padding='VALID', W_std=1., b_std=0.1), )), (stax.FanInSum() if concat is None else stax.FanInConcat(concat)), get_attn(), { 'avg': stax.GlobalAvgPool(), 'sum': stax.GlobalSumPool(), 'flatten': stax.Flatten(), }[proj], ) if get == 'nngp': init_fn, apply_fn, kernel_fn = stax.serial( nn, stax.Dense(width, 1., 0.)) elif get == 'ntk': init_fn, apply_fn, kernel_fn = stax.serial(nn, stax.Dense(1, 1., 0.)) else: raise ValueError(get) kernel_fn_mc = nt.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=0 if concat in (0, -n) else -1, implementation=_DEFAULT_TESTING_NTK_IMPLEMENTATION, vmap_axes=None if concat in (0, -n) else 0, ) kernel_fn = jit(kernel_fn, static_argnames='get') exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) empirical = kernel_fn_mc(x1, x2, get=get, mask_constant=mask_constant) test_utils.assert_close_matrices(self, empirical, exact, tol)