def test_dropout(self, model, width, same_inputs, is_ntk, padding, strides, filter_shape, phi, use_pooling, proj_into_2d): if xla_bridge.get_backend().platform == 'tpu' and same_inputs: raise jtu.SkipTest( 'Skip TPU test for `same_inputs`. Need to handle ' 'random keys carefully for dropout + empirical kernel.') pool_type = 'AVG' use_dropout = True is_conv = 'conv' in model is_res = False # Check for duplicate / incorrectly-shaped NN configs / wrong backend. W_std, b_std = 2.**0.5, 0.5**0.5 layer_norm = None parameterization = 'ntk' if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNN models on CPU to save time.') if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or (padding == 'VALID' and filter_shape != (1, 1)))): raise jtu.SkipTest('Different paths in a residual models need to return' ' outputs of the same shape.') elif (filter_shape != FILTER_SHAPES[0] or padding != PADDINGS[0] or strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or use_pooling): raise jtu.SkipTest('FC models do not have these parameters.') net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout) self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout, is_ntk, proj_into_2d)
def test_exact(self, model, width, strides, padding, phi, same_inputs, filter_size, use_pooling, is_ntk, is_res, proj_into_2d): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNN models on CPU to save time.') if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or (padding == 'VALID' and filter_size != (1, 1)))): raise jtu.SkipTest('Different paths in a residual models need to return' ' outputs of the same shape.') elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0] or strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or use_pooling): raise jtu.SkipTest('FC models do not have these parameters.') if (proj_into_2d.startswith('ATTN') and strides == (2, 1) and padding == 'VALID' and xla_bridge.get_backend().platform == 'tpu'): #TODO: speed up the vmap alternative impl or fix the current one raise jtu.SkipTest('ATTN forward pass on TPU is broken if one of' ' the spatial dimensions is singleton.') W_std, b_std = 2.**0.5, 0.5**0.5 layer_norm = None self._check_agreement_with_empirical(W_std, b_std, filter_size, is_conv, is_ntk, is_res, layer_norm, padding, phi, proj_into_2d, same_inputs, strides, use_pooling, width)
def from_dlpack(dlpack, backend=None): """Returns a `DeviceArray` representation of a DLPack tensor `dlpack`. The returned `DeviceArray` shares memory with `dlpack`. Args: dlpack: a DLPack tensor, on either CPU or GPU. backend: deprecated, do not use. """ if jax.lib._xla_extension_version >= 25: cpu_backend = xla_bridge.get_backend("cpu") try: gpu_backend = xla_bridge.get_backend("gpu") except RuntimeError: gpu_backend = None buf = xla_client._xla.dlpack_managed_tensor_to_buffer( dlpack, cpu_backend, gpu_backend) else: # TODO(phawkins): drop the backend argument after deleting this case. backend = backend or xla_bridge.get_backend() client = getattr(backend, "client", backend) buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client) xla_shape = buf.xla_shape() assert not xla_shape.is_tuple() aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype()) return xla.make_device_array(aval, buf.device(), buf) # pytype: disable=attribute-error
def test_exact(self, model, width, strides, padding, phi, same_inputs, filter_size, use_pooling, is_ntk, is_res, proj_into_2d): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNN models on CPU to save time.') if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or (padding == 'VALID' and filter_size != (1, 1)))): raise jtu.SkipTest('Different paths in a residual models need to return' ' outputs of the same shape.') elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0] or strides != STRIDES[0] or proj_into_2d != PROJECTIONS[0] or use_pooling): raise jtu.SkipTest('FC models do not have these parameters.') if (proj_into_2d.startswith('ATTN') and strides == (2, 1) and padding == 'VALID' and xla_bridge.get_backend().platform == 'tpu'): #TODO(jirihron): speed up the vmap alternative impl or fix the current one raise jtu.SkipTest('ATTN forward pass on TPU is broken if one of' ' the spatial dimensions is singleton.') W_std, b_std = 2.**0.5, 0.5**0.5 key = random.PRNGKey(1) x1, x2 = _get_inputs(key, is_conv, same_inputs, INPUT_SHAPE) init_fn, apply_fn, kernel_fn = _get_net(W_std, b_std, filter_size, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d) def _get_empirical(n_samples, get): kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples) return kernel_fn_empirical(x1, x2, get) if proj_into_2d == 'ATTN_PARAM': # no analytic kernel available, just test forward/backward pass _get_empirical(1, 'ntk' if is_ntk else 'nngp') else: if is_ntk: exact = kernel_fn(x1, x2, 'ntk') empirical = np.reshape(_get_empirical(N_SAMPLES, 'ntk'), exact.shape) else: exact = kernel_fn(x1, x2, 'nngp') empirical = _get_empirical(N_SAMPLES, 'nngp') utils.assert_close_matrices(self, empirical, exact, RTOL)
def test_convert_scalars(self): # TODO(jblespiau): Remove when the version is out. if jaxlib.version < (0, 1, 53): return jax_jit = jaxlib.jax_jit jax_enable_x64 = FLAGS.jax_enable_x64 if jax_enable_x64: int_type = np.int64 float_type = np.float64 complex_type = np.complex128 else: int_type = np.int32 float_type = np.float32 complex_type = np.complex64 # int res = jax_jit._ScalarToBuffer(1, jax_enable_x64, xla_bridge.get_backend()).to_py() self.assertEqual(res, 1) self.assertEqual(res.dtype, int_type) # We also compare to the Python Jax API, to make sure we have the exact # same behavior. When Jax removes the flag and removes this feature, this # test will fail. self.assertEqual(jnp.asarray(1).dtype, res.dtype) # float res = jax_jit._ScalarToBuffer(1.0, jax_enable_x64, xla_bridge.get_backend()).to_py() self.assertEqual(res, 1.0) self.assertEqual(res.dtype, float_type) self.assertEqual(jnp.asarray(1.0).dtype, res.dtype) # bool for bool_value in [True, False]: res = jax_jit._ScalarToBuffer(bool_value, jax_enable_x64, xla_bridge.get_backend()).to_py() self.assertEqual(res, np.asarray(bool_value)) self.assertEqual(res.dtype, np.bool) self.assertEqual(jnp.asarray(bool_value).dtype, res.dtype) # Complex res = jax_jit._ScalarToBuffer(1 + 1j, jax_enable_x64, xla_bridge.get_backend()).to_py() self.assertEqual(res, 1 + 1j) self.assertEqual(res.dtype, complex_type) self.assertEqual(jnp.asarray(1 + 1j).dtype, res.dtype)
def static_cast(*xs): """Function to cast a value to the lowest dtype that can express it.""" # NOTE(schsam): static_cast is so named because it cannot be jit. if xla_bridge.get_backend().platform == 'tpu': return (np.array(x, np.float32) for x in xs) else: return (np.array(x, dtype=onp.min_scalar_type(x)) for x in xs)
def testIsOnCPU(self): for dtype in [np.float32, np.float64]: with self.subTest(dtype=dtype): def x(): return random.normal(random.PRNGKey(1), (2, 3), dtype) def x_cpu(): return device_get( random.normal(random.PRNGKey(1), (2, 3), dtype)) x_jit = jit(x) # x_cpu_jit = jit(x_cpu) x_cpu_jit_cpu = jit(x_cpu, backend='cpu') self.assertTrue(utils.is_on_cpu(x_cpu())) # TODO(mattjj): re-enable this when device_put under jit works # self.assertTrue(utils.is_on_cpu(x_cpu_jit())) self.assertTrue(utils.is_on_cpu(x_cpu_jit_cpu())) if xla_bridge.get_backend().platform == 'cpu': self.assertTrue(utils.is_on_cpu(x())) self.assertTrue(utils.is_on_cpu(x_jit())) else: self.assertFalse(utils.is_on_cpu(x())) self.assertFalse(utils.is_on_cpu(x_jit()))
def testIsOnCPU(self): for dtype in [np.float32, np.float64]: with self.subTest(dtype=dtype): def x(): return random.normal(random.PRNGKey(1), (2, 3), dtype) def x_cpu(): return device_get( random.normal(random.PRNGKey(1), (2, 3), dtype)) x_jit = jit(x) x_cpu_jit = jit(x_cpu) x_cpu_jit_cpu = jit(x_cpu, backend='cpu') self.assertTrue(predict._is_on_cpu(x_cpu())) self.assertTrue(predict._is_on_cpu(x_cpu_jit())) self.assertTrue(predict._is_on_cpu(x_cpu_jit_cpu())) if xla_bridge.get_backend().platform == 'cpu': self.assertTrue(predict._is_on_cpu(x())) self.assertTrue(predict._is_on_cpu(x_jit())) else: self.assertFalse(predict._is_on_cpu(x())) self.assertFalse(predict._is_on_cpu(x_jit()))
def test_layernorm(self, model, width, same_inputs, is_ntk, proj_into_2d, layer_norm): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest( 'Not running CNN models on CPU to save time.') elif proj_into_2d != PROJECTIONS[0] or layer_norm != LAYER_NORM[0]: raise jtu.SkipTest('FC models do not have these parameters.') W_std, b_std = 2.**0.5, 0.5**0.5 filter_size = FILTER_SIZES[0] padding = PADDINGS[0] strides = STRIDES[0] phi = stax.Relu() use_pooling, is_res = False, False parameterization = 'ntk' use_dropout = False self._check_agreement_with_empirical(W_std, b_std, filter_size, is_conv, is_ntk, is_res, layer_norm, padding, phi, proj_into_2d, same_inputs, strides, use_pooling, width, parameterization, use_dropout)
def device_memory_profile(backend: Optional[str] = None) -> bytes: """Captures a JAX device memory profile as ``pprof``-format protocol buffer. A device memory profile is a snapshot of the state of memory, that describes the JAX :class:`jax.DeviceArray` and executable objects present in memory and their allocation sites. For more information how to use the device memory profiler, see :doc:`/device_memory_profiling`. The profiling system works by instrumenting JAX on-device allocations, capturing a Python stack trace for each allocation. The instrumentation is always enabled; :func:`device_memory_profile` provides an API to capture it. The output of :func:`device_memory_profile` is a binary protocol buffer that can be interpreted and visualized by the `pprof tool <https://github.com/google/pprof>`_. Args: backend: optional; the name of the JAX backend for which the device memory profile should be collected. Returns: A byte string containing a binary `pprof`-format protocol buffer. """ return xla_client.heap_profile(xla_bridge.get_backend(backend))
def test_gpu_translation_rule(self): version = xla_bridge.get_backend().platform_version cuda_version = None if version == "<unknown>" else int(version.split()[-1]) if cuda_version is None or cuda_version < 11000: self.assertNotIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"]) else: self.assertIn(sparse_ops.csr_todense_p, xla.backend_specific_translations["gpu"])
def test_parameterizations(self, model, width, same_inputs, is_ntk, filter_shape, proj_into_2d, parameterization): is_conv = 'conv' in model W_std, b_std = 2.**0.5, 0.5**0.5 padding = PADDINGS[0] strides = STRIDES[0] phi = stax.Relu() use_pooling, is_res = False, False layer_norm = None pool_type = 'AVG' use_dropout = False # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest('Not running CNN models on CPU to save time.') elif proj_into_2d != PROJECTIONS[0]: raise jtu.SkipTest('FC models do not have these parameters.') net = _get_net(W_std, b_std, filter_shape, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk, proj_into_2d, pool_type, layer_norm, parameterization, use_dropout) self._check_agreement_with_empirical(net, same_inputs, is_conv, use_dropout, is_ntk, proj_into_2d)
def testPredictOnCPU(self): x_train = random.normal(random.PRNGKey(1), (4, 4, 4, 2)) x_test = random.normal(random.PRNGKey(1), (8, 4, 4, 2)) y_train = random.uniform(random.PRNGKey(1), (4, 2)) _, _, kernel_fn = stax.serial( stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: for x in [None, 'x_test']: with self.subTest( store_on_device=store_on_device, device_count=device_count, get=get, x=x): kernel_fn_batched = batch.batch(kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_ensemble( kernel_fn_batched, x_train, y_train) x = x if x is None else x_test predict_none = predictor(None, x, get, compute_cov=True) predict_inf = predictor(np.inf, x, get, compute_cov=True) self.assertAllClose(predict_none, predict_inf) if x is not None: on_cpu = (not store_on_device or xla_bridge.get_backend().platform == 'cpu') self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf)) self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
def train_loop(key, init_params, loss_fn, parallel=True, summarize_fn=default_summarize, lr=1e-4, num_steps=int(1e5), summarize_every=100, checkpoint_every=5000, clobber_checkpoint=False, logdir="/tmp/lda_inference"): if not parallel: train_fn = local_train_loop elif parallel and can_train_parallel(): train_fn = parallel_train_loop else: print( "Platform is %s and num devices is %d, defaulting to local training." % (xla_bridge.get_backend().platform, len(xla_bridge.devices()))) train_fn = local_train_loop train_fn(key, init_params, loss_fn, summarize_fn=summarize_fn, lr=lr, num_steps=num_steps, summarize_every=summarize_every, checkpoint_every=checkpoint_every, clobber_checkpoint=clobber_checkpoint, logdir=logdir)
def _gamma_grad(sample, a): samples = np.reshape(sample, -1) alphas = np.reshape(a, -1) if xla_bridge.get_backend().platform == 'cpu': grads = lax.map(lambda args: _gamma_grad_one(*args), (samples, alphas)) else: grads = vmap(_gamma_grad_one)(samples, alphas) return grads.reshape(onp.shape(a))
def testNTKMeanPrediction(self, train_shape, test_shape, network, out_logits): key = random.PRNGKey(0) key, split = random.split(key) data_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) data_labels = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) data_test = np.cos(random.normal(split, test_shape)) _, _, ker_fun = _build_network(train_shape[1:], network, out_logits) mean_pred, var = predict.gp_inference(ker_fun, data_train, data_labels, data_test, diag_reg=0., mode='NTK', compute_var=True) if xla_bridge.get_backend().platform == 'tpu': eigh = np.onp.linalg.eigh else: eigh = np.linalg.eigh self.assertEqual(var.shape[0], data_test.shape[0]) min_eigh = np.min(eigh(var)[0]) self.assertGreater(min_eigh + 1e-10, 0.) def mc_sampling(count=10): empirical_mean = 0. key = random.PRNGKey(100) for _ in range(count): key, split = random.split(key) params, f, theta = _empirical_kernel(split, train_shape[1:], network, out_logits) g_dd = theta(data_train, None) g_td = theta(data_test, data_train) predictor = predict.gradient_descent_mse( g_dd, data_labels, g_td) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) _, fx_pred_test = predictor(1.0e8, fx_initial_train, fx_initial_test) empirical_mean += fx_pred_test return empirical_mean / count atol = ATOL rtol = RTOL mean_emp = mc_sampling(100) self.assertAllClose(mean_pred, mean_emp, True, rtol, atol)
def testPredictOnCPU(self): key1 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) key2 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) key3 = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) x_train = np.asarray(normal((4, 4, 4, 2), seed=key1)) x_test = np.asarray(normal((8, 4, 4, 2), seed=key2)) y_train = np.asarray(stateless_uniform(shape=(4, 2), seed=key3)) _, _, kernel_fn = stax.serial(stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1)) for store_on_device in [False, True]: for device_count in [0, 1]: for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]: for x in [None, 'x_test']: with self.subTest(store_on_device=store_on_device, device_count=device_count, get=get, x=x): kernel_fn_batched = batch.batch( kernel_fn, 2, device_count, store_on_device) predictor = predict.gradient_descent_mse_ensemble( kernel_fn_batched, x_train, y_train) x = x if x is None else x_test predict_none = predictor(None, x, get, compute_cov=True) predict_inf = predictor(np.inf, x, get, compute_cov=True) self.assertAllClose(predict_none, predict_inf) if x is not None: on_cpu = (not store_on_device or xla_bridge.get_backend().platform == 'cpu') self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf)) self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
def _res_tf_to_jax(res_tf): if isinstance(res_tf, tf.Tensor) and res_tf.dtype in dlpack.SUPPORTED_DTYPES: res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type res_jax_platform = res_tf_platform.lower() if res_jax_platform in _DLPACK_PLATFORMS: res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf) return jax.dlpack.from_dlpack( res_dlpack, backend=xla_bridge.get_backend(res_jax_platform)) return _device_put_raw(np.asarray(res_tf))
def _binomial(key, p, n, shape): shape = shape or lax.broadcast_shapes(np.shape(p), np.shape(n)) # reshape to map over axis 0 p = np.reshape(np.broadcast_to(p, shape), -1) n = np.reshape(np.broadcast_to(n, shape), -1) key = random.split(key, np.size(p)) if xla_bridge.get_backend().platform == 'cpu': ret = lax.map(lambda x: _binomial_dispatch(*x), (key, p, n)) else: ret = vmap(lambda *x: _binomial_dispatch(*x))(key, p, n) return np.reshape(ret, shape)
def _eigh(mat): """Platform specific eigh.""" # TODO(schsam): Eventually, we may want to handle non-symmetric kernels for # e.g. masking. Additionally, once JAX supports eigh on TPU, we probably want # to switch to JAX's eigh. if xla_bridge.get_backend().platform == 'tpu': eigh = np.onp.linalg.eigh else: eigh = np.linalg.eigh eigh = jit(eigh, backend='cpu') if _is_on_cpu(mat) else jit(eigh) return eigh(mat)
def testTorchToJaxFailure(self): x = torch.arange(6).reshape((2, 3)) y = torch.utils.dlpack.to_dlpack(x[:, :2]) backend = xla_bridge.get_backend() client = getattr(backend, "client", backend) regex_str = ( r'Unimplemented: Only DLPack tensors with trivial \(compact\) ' r'striding are supported') with self.assertRaisesRegex(RuntimeError, regex_str): xla_client._xla.dlpack_managed_tensor_to_buffer(y, client)
def stub_out_pmap(batch, count): # If we are using GPU or CPU stub out pmap with vmap to simulate multi-core. if count > 0: class xla_bridge_stub(object): def device_count(self): return count platform = xla_bridge.get_backend().platform if platform == 'gpu' or platform == 'cpu': batch.pmap = _jit_vmap batch.xla_bridge = xla_bridge_stub()
def _poisson(key, rate, shape, dtype): # Ref: https://en.wikipedia.org/wiki/Poisson_distribution#Generating_Poisson-distributed_random_variables shape = shape or np.shape(rate) rate = lax.convert_element_type(rate, canonicalize_dtype(np.float64)) rate = np.broadcast_to(rate, shape) rng_keys = random.split(key, np.size(rate)) if xla_bridge.get_backend().platform == 'cpu': k = lax.map(_poisson_one, (rng_keys, np.reshape(rate, -1))) else: k = vmap(_poisson_one)((rng_keys, np.reshape(rate, -1))) k = lax.convert_element_type(k, dtype) return np.reshape(k, shape)
def _gamma_impl(key, a): a_shape = np.shape(a) # split key to match the shape of a key_ndim = np.ndim(key) - 1 key = np.reshape(key, (-1, 2)) key = vmap(split, in_axes=(0, None))(key, prod(a_shape[key_ndim:])) keys = np.reshape(key, (-1, 2)) alphas = np.reshape(a, -1) if xla_bridge.get_backend().platform == 'cpu': samples = lax.map(lambda args: _gamma_one(*args), (keys, alphas)) else: samples = vmap(_gamma_one)(keys, alphas) return np.reshape(samples, a_shape),
def test_conv_local_general_dilated(self, n, padding, lhs_spec, rhs_spec, out_spec): """Make sure LCN with tiled CNN kernel matches CNN.""" if xla_bridge.get_backend().platform == 'cpu' and n > 1: raise absltest.SkipTest('Skipping large tests on CPU.') lhs_spec_default = 'NCHWDX'[:n + 2] rhs_spec_default = 'OIHWDX'[:n + 2] lhs_default = random.normal(random.PRNGKey(1), (2, 4, 7, 6, 5, 8)[:n + 2]) rhs_default = random.normal(random.PRNGKey(2), (3, 4, 2, 3, 1, 2)[:n + 2]) window_strides = (1, 2, 3, 4)[:n] rhs_dilation = (2, 1, 3, 2)[:n] lhs_perm = [lhs_spec_default.index(c) for c in lhs_spec] lhs = np.transpose(lhs_default, lhs_perm) rhs_perm = [rhs_spec_default.index(c) for c in rhs_spec] rhs = np.transpose(rhs_default, rhs_perm) kwargs = dict(lhs=lhs, window_strides=window_strides, padding=padding, rhs_dilation=rhs_dilation, dimension_numbers=(lhs_spec, rhs_spec, out_spec)) out_conv = lax.conv_general_dilated(rhs=rhs, **kwargs) rhs_local = np.moveaxis(rhs, (rhs_spec.index('O'), rhs_spec.index('I')), (0, 1)) rhs_local = rhs_local.reshape((rhs_local.shape[0], -1) + (1, ) * n) rhs_shape = (rhs_local.shape[:2] + tuple(out_conv.shape[out_spec.index(c)] for c in rhs_spec_default[2:])) rhs_local = np.broadcast_to(rhs_local, rhs_shape) rhs_local = np.transpose(rhs_local, rhs_perm) filter_shape = [ rhs.shape[i] for i in range(n + 2) if rhs_spec[i] not in ('O', 'I') ] out_local = utils.conv_local_general_dilated(rhs=rhs_local, filter_shape=filter_shape, **kwargs) self.assertAllClose(out_conv, out_local, atol=1e-5, rtol=1e-5)
def _replicate(x, devices=None): x = jax.numpy.array(x) if devices is None: # match the default device assignments used in pmap: # for single-host, that's the XLA default device assignment # for multi-host, it's the order of jax.local_devices() if jax.host_count() == 1: devices = [d for d in xb.get_backend().get_default_device_assignment( jax.device_count()) if d.host_id == jax.host_id()] else: devices = jax.local_devices() aval = jax.ShapedArray((len(devices),) + x.shape, x.dtype) buffers = [jax.interpreters.xla.device_put(x, device=d) for d in devices] return jax.pxla.ShardedDeviceArray(aval, buffers)
def test_exact(self, model, width, strides, padding, phi, same_inputs, filter_size, use_pooling, is_ntk, is_res): is_conv = 'conv' in model # Check for duplicate / incorrectly-shaped NN configs / wrong backend. if is_conv: if xla_bridge.get_backend().platform == 'cpu': raise jtu.SkipTest( 'Not running CNN models on CPU to save time.') if use_pooling and not same_inputs: raise jtu.SkipTest( 'Pooling layers for different inputs or for same ' 'padding not implemented.') if (is_res and is_conv and ((strides is not None and strides != (1, 1)) or (padding == 'VALID' and filter_size != (1, 1)))): raise jtu.SkipTest( 'Different paths in a residual models need to return' ' outputs of the same shape.') elif (filter_size != FILTER_SIZES[0] or padding != PADDINGS[0] or strides != STRIDES[0] or use_pooling): raise jtu.SkipTest('FC models do not have these parameters.') W_std, b_std = 2.**0.5, 0.5**0.5 key = random.PRNGKey(1) x1, x2 = _get_inputs(key, is_conv, same_inputs, INPUT_SHAPE) init_fun, apply_fun, ker_fun = _get_net(W_std, b_std, filter_size, is_conv, use_pooling, is_res, padding, phi, strides, width, is_ntk) if is_ntk: exact = ker_fun(x1, x2).ntk ker_fun_empirical = monte_carlo.get_ker_fun_monte_carlo( init_fun, apply_fun, False, True) empirical = ker_fun_empirical(x1, x2, key, N_SAMPLES).ntk empirical = np.reshape(empirical, exact.shape) else: exact = ker_fun(x1, x2, compute_ntk=False).nngp ker_fun_empirical = monte_carlo.get_ker_fun_monte_carlo( init_fun, apply_fun, True, False) empirical = ker_fun_empirical(x1, x2, key, N_SAMPLES).nngp utils.assert_close_matrices(self, empirical, exact, RTOL)
def _dynamic_xla_call_impl(*args, jaxpr, num_consts): in_dim_vals, consts, args = split_list(args, [len(jaxpr.in_dim_binders), num_consts]) dim_in_avals = [v.aval for v in jaxpr.in_dim_binders] c = xb.make_computation_builder("dxla_call") dim_params, params = _make_params(c, dim_in_avals, map(xla.abstractify, args)) const_params = _xla_consts(c, consts) dim_outs, outs = djaxpr_subcomp(c, jaxpr, dim_params, const_params + params) out = xops.Tuple(c, [o for ops in dim_outs + outs for o in ops]) compiled = xb.get_backend(None).compile(c.build(out)) result_handlers = map(result_handler, [v.aval for v in jaxpr.outs]) out_bufcounts = [v.aval._num_buffers for v in jaxpr.outs] partitioner = result_partitioner(jaxpr.in_dim_binders, in_dim_vals, jaxpr.out_dims, out_bufcounts) return execute_compiled(compiled, partitioner, result_handlers, in_dim_vals, args)
def stub_out_pmap(batch, count): # If we are using GPU or CPU stub out pmap with vmap to simulate multi-core. if count > 1: class xla_bridge_stub(object): def device_count(self): return count platform = xla_bridge.get_backend().platform if platform == 'gpu' or platform == 'cpu': # TODO(romann): investigate why vmap is extremely slow in # `utils/monte_carlo_test.py`, `test_monte_carlo_vs_analytic`. # Example: http://sponge/e081c176-e77f-428c-846d-bafbfd86a46c batch.pmap = vmap batch.xla_bridge = xla_bridge_stub()
def _res_tf_to_jax(res_tf: TfVal, out_aval: core.AbstractValue): res_tf, _ = jax2tf_internal._tfval_to_tensor_jax_dtype( res_tf, jax_dtype=out_aval.dtype) if isinstance(res_tf, tf.Tensor) and res_tf.dtype in dlpack.SUPPORTED_DTYPES: res_tf_platform = tf.DeviceSpec.from_string( res_tf.backing_device).device_type res_jax_platform = res_tf_platform.lower() if res_jax_platform in _DLPACK_PLATFORMS: res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf) return jax.dlpack.from_dlpack( res_dlpack, backend=xla_bridge.get_backend(res_jax_platform)) return jnp.asarray(np.asarray(res_tf))