def _parallel(kernel_fn: KernelFn, device_count: int = -1) -> KernelFn: """Returns a function that computes a kernel in batches in parallel. When batching in parallel, the data is split over a set number of devices. The number of devices must be less than or equal to the number of physical devices. Moreover, the dataset size needs to divide the device count. Given two datasets `x1` and `x2`, parallel splits the kernel calculation over devices such that each device computes a batch of rows of shape `[|x1| / device_count, |x2|]`. Args: kernel_fn: A function that computes a kernel between two datasets, `kernel_fn(x1, x2)` or the compositional kernel for an input kernel `kernel_fn(kernel_in)`. Here `x1` and `x2` are `np.ndarray`s of floats of shape `(n1,) + input_shape` and `(n2,) + input_shape`; `kernel_in` is a Kernel object. The kernel function should return a `PyTree`. device_count: Integer specifying the number of devices over which to split the data. If `device_count == 0`, the computation is parallelized over all available devices. Returns: A new function with the same signature as kernel_fn that computes the kernel by batching over the dataset in parallel over a specified number of cores. """ kernel_fn = _jit_or_pmap_broadcast(kernel_fn, device_count) if device_count == -1: device_count = xla_bridge.device_count() def parallel_fn_x1(x1, x2=None, *args, **kwargs): x2_is_none = x2 is None if x2_is_none: # TODO(schsam): Only compute the upper triangular part of the kernel. x2 = x1 n1 = x1.shape[0] assert x1.shape[1:] == x2.shape[1:] input_shape = x1.shape[1:] _device_count = device_count n1_per_device, ragged = divmod(n1, device_count) if n1_per_device and ragged: raise ValueError( ('Dataset size ({}) must divide number of ' 'physical devices ({}).').format(n1, device_count)) elif not n1_per_device: _device_count = ragged n1_per_device = 1 for k, v in kwargs.items(): if _is_np_ndarray(v): assert isinstance(v, tuple) and len(v) == 2 v0 = np.reshape(v[0], (_device_count, n1_per_device,) + v[0].shape[1:]) kwargs[k] = (v0, v[1]) x1 = np.reshape(x1, (_device_count, n1_per_device,) + input_shape) kernel = kernel_fn(x1, x2, *args, **kwargs) return _flatten_kernel(kernel, x2_is_none, True) def parallel_fn_kernel(kernel, *args, **kwargs): n1 = kernel.cov1.shape[0] _device_count = device_count n1_per_device, ragged = divmod(n1, device_count) if n1_per_device and ragged: raise ValueError( ('Dataset size ({}) must divide number of ' 'physical devices ({}).').format(n1, device_count)) elif not n1_per_device: _device_count = ragged n1_per_device = 1 cov2_is_none = kernel.cov2 is None kernel = _reshape_kernel_for_pmap(kernel, _device_count, n1_per_device) kernel = kernel_fn(kernel, *args, **kwargs) if cov2_is_none: kernel = kernel.replace(cov2=None) return _flatten_kernel(kernel, cov2_is_none, True) @utils.wraps(kernel_fn) def parallel_fn(x1_or_kernel, x2=None, *args, **kwargs): if isinstance(x1_or_kernel, np.ndarray): return parallel_fn_x1(x1_or_kernel, x2, *args, **kwargs) elif isinstance(x1_or_kernel, Kernel): assert not x2 return parallel_fn_kernel(x1_or_kernel, *args, **kwargs) raise NotImplementedError() # Set function attributes so that `serial` can detect whether or not it is # acting on a parallel function. parallel_fn.device_count = device_count return parallel_fn
def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): """ Run the MCMC samplers and collect samples. :param random.PRNGKey rng_key: Random number generator key to be used for the sampling. For multi-chains, a batch of `num_chains` keys can be supplied. If `rng_key` does not have batch_size, it will be split in to a batch of `num_chains` keys. :param args: Arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method. These are typically the arguments needed by the `model`. :param extra_fields: Extra fields (aside from `z`, `diverging`) from :data:`numpyro.infer.mcmc.HMCState` to collect during the MCMC run. :type extra_fields: tuple or list :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method. These are typically the keyword arguments needed by the `model`. .. note:: jax allows python code to continue even when the compiled code has not finished yet. This can cause troubles when trying to profile the code for speed. See https://jax.readthedocs.io/en/latest/async_dispatch.html and https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs. """ self._args = args self._kwargs = kwargs init_state = self._get_cached_init_state(rng_key, args, kwargs) if self.num_chains > 1 and rng_key.ndim == 1: rng_key = random.split(rng_key, self.num_chains) if self._warmup_state is not None: self._set_collection_params(0, self.num_samples, self.num_samples) init_state = self._warmup_state._replace(rng_key=rng_key) chain_method = self.chain_method if chain_method == 'parallel' and xla_bridge.device_count( ) < self.num_chains: chain_method = 'sequential' warnings.warn( 'There are not enough devices to run parallel chains: expected {} but got {}.' ' Chains will be drawn sequentially. If you are running MCMC in CPU,' ' consider to use `numpyro.set_host_device_count({})` at the beginning' ' of your program.'.format(self.num_chains, xla_bridge.device_count(), self.num_chains)) if init_params is not None and self.num_chains > 1: prototype_init_val = tree_flatten(init_params)[0][0] if jnp.shape(prototype_init_val)[0] != self.num_chains: raise ValueError( '`init_params` must have the same leading dimension' ' as `num_chains`.') assert isinstance(extra_fields, (tuple, list)) collect_fields = tuple( set((self._sample_field, ) + tuple(self._default_fields) + tuple(extra_fields))) partial_map_fn = partial(self._single_chain_mcmc, args=args, kwargs=kwargs, collect_fields=collect_fields) map_args = (rng_key, init_state, init_params) if self.num_chains == 1: states_flat, last_state = partial_map_fn(map_args) states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat) else: if chain_method == 'sequential': if self.progress_bar: states, last_state = _laxmap(partial_map_fn, map_args) else: states, last_state = lax.map(partial_map_fn, map_args) elif chain_method == 'parallel': states, last_state = pmap(partial_map_fn)(map_args) # TODO: remove when https://github.com/google/jax/issues/3597 is resolved states = device_put(states) else: assert chain_method == 'vectorized' states, last_state = partial_map_fn(map_args) # swap num_samples x num_chains to num_chains x num_samples states = tree_map(lambda x: jnp.swapaxes(x, 0, 1), states) states_flat = tree_map( lambda x: jnp.reshape(x, (-1, ) + x.shape[2:]), states) self._last_state = last_state self._states = states self._states_flat = states_flat self._set_collection_params()
def testMakeJaxprOfOpenSpmd(self): f = lambda x: x - lax.psum(x, 'i') shape = (xla_bridge.device_count(), 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) make_jaxpr(f)(x) # doesn't crash
def jit_or_pmap_broadcast(f, device_count=-1): """Pmap `f` over the first argument by closing over or broadcasting others. Args: f: function to pmap. First argument must be a `np.ndarray` with leading axis having the size of `device_count`. device_count: number of XLA devices. `-1` means all available devices. `0` means to just `jit` the function. Returns: A function of the same signature as `f` pmapped over the first argument with other arguments either closed over (non-`np.ndarray`s in `args` and all `kwargs`) or broadcasted to `(device_count,) + old_shape` (for `np.ndarray`s). If `device_count == 0`, `f` is closed over and jitted over all non-array arguments and all `kwargs`. Raises: An error if `kwargs` have a `np.ndarray`. TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it. See https://github.com/google/jax/issues/912 """ key = (f, device_count) if device_count == -1: device_count = xla_bridge.device_count() # TODO(romann): adapt this when JAX allows `axis_in` for `pmap`. def broadcast(arg): if device_count == 0: return arg return np.broadcast_to(arg, (device_count, ) + arg.shape) def f_pmapped(x, *args, **kwargs): args_np, args_np_idxs = [], [] args_other = {} # TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it. # https://github.com/google/jax/issues/912 # Filter out `np.ndarray`s from other arguments. for i, arg in enumerate(args): if _is_np_ndarray(arg): args_np.append(arg) args_np_idxs.append(i) else: args_other[i] = arg # Check cache before jitting. _key = key + tuple(args_other.items()) + tuple(kwargs.items()) if _key in cache: _f = cache[_key] else: # Define a `np.ndarray`-only function as a closure over other arguments. def _f(_x, *_args_np): # Merge args. _args_np = { i: _arg_np for i, _arg_np in zip(args_np_idxs, _args_np) } _args = _merge_dicts(_args_np, args_other) _args = tuple(v for k, v in sorted(_args.items())) return f(_x, *_args, **kwargs) _f = jit(_f) if device_count == 0 else pmap(_f) cache[_key] = _f # Broadcast `np.ndarray` arguments and apply the new function to them. args_np = tree_map(broadcast, args_np) return _f(x, *args_np) return f_pmapped
def testRule30(self): # This is a test of collective_permute implementing a simple halo exchange # to run a rule 30 simulation: https://en.wikipedia.org/wiki/Rule_30 # Halo exchange should be useful in spatially-sharded convolutions and in # other simulations. device_count = xla_bridge.device_count() def send_right(x, axis_name): left_perm = [(i, (i + 1) % device_count) for i in range(device_count)] return lax.ppermute(x, perm=left_perm, axis_name=axis_name) def send_left(x, axis_name): left_perm = [((i + 1) % device_count, i) for i in range(device_count)] return lax.ppermute(x, perm=left_perm, axis_name=axis_name) def update_board(board): left = board[:-2] right = board[2:] center = board[1:-1] return lax.bitwise_xor(left, lax.bitwise_or(center, right)) @partial(pmap, axis_name='i') def step(board_slice): left, right = board_slice[:1], board_slice[-1:] right, left = send_left(left, 'i'), send_right(right, 'i') enlarged_board_slice = np.concatenate([left, board_slice, right]) return update_board(enlarged_board_slice) board = onp.zeros(40, dtype=bool) board[board.shape[0] // 2] = True reshaped_board = board.reshape((device_count, -1)) boards = [] def print_board(board): boards.append(''.join('*' if x else ' ' for x in board.ravel())) print_board(reshaped_board) for _ in range(20): reshaped_board = step(reshaped_board) print_board(reshaped_board) ans = '\n'.join(boards) expected = '\n'.join(( ' * ', ' *** ', ' ** * ', ' ** **** ', ' ** * * ', ' ** **** *** ', ' ** * * * ', ' ** **** ****** ', ' ** * *** * ', ' ** **** ** * *** ', ' ** * * **** ** * ', ' ** **** ** * * **** ', ' ** * *** ** ** * * ', ' ** **** ** *** *** ** *** ', ' ** * * *** * *** * * ', ' ** **** ** * * ***** ******* ', ' ** * *** **** * *** * ', ' ** **** ** *** ** ** * *** ', ' ** * * *** * ** *** **** ** * ', ' ** **** ** * ****** * * *** ****', ' * * *** **** **** *** ** * ', )) print(ans) self.assertEqual(ans, expected)
def mcmc(num_warmup, num_samples, init_params, num_chains=1, sampler='hmc', constrain_fn=None, print_summary=True, **sampler_kwargs): """ Convenience wrapper for MCMC samplers -- runs warmup, prints diagnostic summary and returns a collections of samples from the posterior. :param num_warmup: Number of warmup steps. :param num_samples: Number of samples to generate from the Markov chain. :param init_params: Initial parameters to begin sampling. The type can must be consistent with the input type to `potential_fn`. :param sampler: currently, only `hmc` is implemented (default). :param constrain_fn: Callable that converts a collection of unconstrained sample values returned from the sampler to constrained values that lie within the support of the sample sites. :param print_summary: Whether to print diagnostics summary for each sample site. Default is ``True``. :param `**sampler_kwargs`: Sampler specific keyword arguments. - *HMC*: Refer to :func:`~numpyro.mcmc.hmc` and :func:`~numpyro.mcmc.hmc.init_kernel` for accepted arguments. Note that all arguments must be provided as keywords. :return: collection of samples from the posterior. .. testsetup:: import jax from jax import random import jax.numpy as np import numpyro.distributions as dist from numpyro.handlers import sample from numpyro.hmc_util import initialize_model from numpyro.mcmc import hmc from numpyro.util import fori_collect .. doctest:: >>> true_coefs = np.array([1., 2., 3.]) >>> data = random.normal(random.PRNGKey(2), (2000, 3)) >>> dim = 3 >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3)) >>> >>> def model(data, labels): ... coefs_mean = np.zeros(dim) ... coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3))) ... intercept = sample('intercept', dist.Normal(0., 10.)) ... return sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels) >>> >>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), model, ... data, labels) >>> num_warmup, num_samples = 1000, 1000 >>> samples = mcmc(num_warmup, num_samples, init_params, ... potential_fn=potential_fn, ... constrain_fn=constrain_fn) # doctest: +SKIP warmup: 100%|██████████| 1000/1000 [00:09<00:00, 109.40it/s, 1 steps of size 5.83e-01. acc. prob=0.79] sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85] mean sd 5.5% 94.5% n_eff Rhat coefs[0] 0.96 0.07 0.85 1.07 455.35 1.01 coefs[1] 2.05 0.09 1.91 2.20 332.00 1.01 coefs[2] 3.18 0.13 2.96 3.37 320.27 1.00 intercept -0.03 0.02 -0.06 0.00 402.53 1.00 """ sequential_chain = False if xla_bridge.device_count() < num_chains: sequential_chain = True warnings.warn( 'There are not enough devices to run parallel chains: expected {} but got {}.' ' Chains will be drawn sequentially. If you are running `mcmc` in CPU,' ' consider to disable XLA intra-op parallelism by setting the environment' ' flag "XLA_FLAGS=--xla_force_host_platform_device_count={}".'. format(num_chains, xla_bridge.device_count(), num_chains)) progbar = sampler_kwargs.pop('progbar', True) if num_chains > 1: progbar = False if sampler == 'hmc': if constrain_fn is None: constrain_fn = identity potential_fn = sampler_kwargs.pop('potential_fn') kinetic_fn = sampler_kwargs.pop('kinetic_fn', None) algo = sampler_kwargs.pop('algo', 'NUTS') if num_chains > 1: rngs = sampler_kwargs.pop('rng', vmap(PRNGKey)(np.arange(num_chains))) else: rng = sampler_kwargs.pop('rng', PRNGKey(0)) init_kernel, sample_kernel = hmc(potential_fn, kinetic_fn, algo) if progbar: hmc_state = init_kernel(init_params, num_warmup, progbar=progbar, rng=rng, **sampler_kwargs) samples_flat = fori_collect(0, num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z), progbar=progbar, diagnostics_fn=get_diagnostics_str, progbar_desc='sample') samples = tree_map(lambda x: x[np.newaxis, ...], samples_flat) else: def single_chain_mcmc(rng, init_params): hmc_state = init_kernel(init_params, num_warmup, run_warmup=False, rng=rng, **sampler_kwargs) samples = fori_collect(num_warmup, num_warmup + num_samples, sample_kernel, hmc_state, transform=lambda x: constrain_fn(x.z), progbar=progbar) return samples if num_chains == 1: samples_flat = single_chain_mcmc(rng, init_params) samples = tree_map(lambda x: x[np.newaxis, ...], samples_flat) else: if sequential_chain: samples = [] for i in range(num_chains): init_params_i = tree_map(lambda x: x[i], init_params) samples.append( jit(single_chain_mcmc)(rngs[i], init_params_i)) samples = tree_multimap(lambda *args: np.stack(args), *samples) else: samples = pmap(single_chain_mcmc)(rngs, init_params) samples_flat = tree_map( lambda x: np.reshape(x, (-1, ) + x.shape[2:]), samples) if print_summary: summary(samples) return samples_flat else: raise ValueError('sampler: {} not recognized'.format(sampler))
def test_jit_device_assignment(self): device_num = xb.device_count() - 1 x = api.jit(lambda x: x, device_assignment=device_num)(3.) self.assertIsInstance(x, DeviceArray) self.assertEqual(x.device_buffer.device(), device_num)
if __name__ == "__main__": layer_sizes = [784, 1024, 1024, 10] param_scale = 0.1 step_size = 0.001 num_epochs = 10 batch_size = 128 train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) # For this manual SPMD example, we get the number of devices (e.g. GPUs or # TPU cores) that we're using, and use it to reshape data minibatches. num_devices = xla_bridge.device_count() def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] images, labels = train_images[batch_idx], train_labels[batch_idx] # For this SPMD example, we reshape the data batch dimension into two # batch dimensions, one of which is mapped over parallel devices. batch_size_per_device, ragged = divmod(images.shape[0], num_devices) if ragged: msg = "batch size must be divisible by device count, got {} and {}." raise ValueError(msg.format(batch_size, num_devices)) shape_prefix = (num_devices, batch_size_per_device) images = images.reshape(shape_prefix + images.shape[1:])
def _parallel( kernel_fn: KernelFn, use_serial: bool = True, dropout_in_analytic_kernel: bool = False, device_count: int = -1, ) -> KernelFn: """Returns a function that computes a kernel in batches in parallel. When batching in parallel, the data is split over a set number of devices. The number of devices must be less than or equal to the number of physical devices. Moreover, the dataset size needs to divide the device count. Given two datasets `x1` and `x2`, parallel splits the kernel calculation over devices such that each device computes a batch of rows of shape `[|x1| / device_count, |x2|]`. Args: kernel_fn: A function that computes a kernel between two datasets, `kernel_fn(x1, x2)` or the compositional kernel for an input kernel `kernel_fn(kernel_in)`. Here `x1` and `x2` are `np.ndarray`s of floats of shape `(n1,) + input_shape` and `(n2,) + input_shape`; `kernel_in` is a Kernel object. The kernel function should return a `PyTree`. use_serial: Whether `serial` will be called after `_parallel`. The only use case is to make sure when `dropout` is used in the analytic/empirical kernel, the batch size in each device is square. dropout_in_analytic_kernel: whether `dropout` is used in the analytic kernel. See `use_serial` above for the only use case. device_count: Integer specifying the number of devices over which to split the data. If `device_count == 0`, the computation is parallelized over all available devices. Returns: A new function with the same signature as kernel_fn that computes the kernel by batching over the dataset in parallel over a specified number of cores. """ if device_count == -1: device_count = xla_bridge.device_count() def _check_dropout(n1, n2, kwargs): dropout_in_empirical_kernel = getattr(kwargs, 'rng', None) is not None if n1 == n2 and (dropout_in_empirical_kernel or dropout_in_analytic_kernel) and not use_serial: raise NotImplementedError( 'Batching for empirical / analytic kernels with dropout' ' is not implemented for non-square batch size. ' 'Using `serial` (i.e. use a non-zero batch_size in the ' '`batch` function.) could enforce square batch size in each device.' ) def _get_n_per_device(n1, n2): _device_count = device_count n1_per_device, ragged = divmod(n1, device_count) if n1_per_device and ragged: raise ValueError( ('Dataset size ({}) must divide number of ' 'physical devices ({}).').format(n1, device_count)) elif not n1_per_device: _device_count = ragged n1_per_device = 1 return n1_per_device, _device_count def parallel_fn_x1(x1, x2=None, *args, **kwargs): x2_is_none = utils.all_none(x2) if x2_is_none: # TODO(schsam): Only compute the upper triangular part of the kernel. x2 = x1 def get_batch_size(x): if utils.is_list_or_tuple(x): return get_batch_size(x[0]) return x.shape[0] n1 = get_batch_size(x1) n2 = n1 if x2_is_none else get_batch_size(x2) _check_dropout(n1, n2, kwargs) n1_per_device, _device_count = _get_n_per_device(n1, n2) _kernel_fn = _jit_or_pmap_broadcast(kernel_fn, _device_count) @utils.nt_tree_fn() def batch_data(x): input_shape = x.shape[1:] return np.reshape(x, ( _device_count, n1_per_device, ) + input_shape) for k, v in kwargs.items(): if _is_np_ndarray(v): assert isinstance(v, tuple) and len(v) == 2 v0 = np.reshape(v[0], ( _device_count, n1_per_device, ) + v[0].shape[1:]) kwargs[k] = (v0, v[1]) x1 = batch_data(x1) kernel = _kernel_fn(x1, x2, *args, **kwargs) return _flatten_kernel(kernel, x2_is_none, True) def parallel_fn_kernel(kernel, *args, **kwargs): @utils.nt_tree_fn(reduce=lambda shapes: shapes[0]) def get_batch_sizes(k): n1 = n2 = k.cov1.shape[0] if k.cov2 is not None: n2 = k.cov2.shape[0] return n1, n2 n1, n2 = get_batch_sizes(kernel) _check_dropout(n1, n2, kwargs) n1_per_device, _device_count = _get_n_per_device(n1, n2) _kernel_fn = _jit_or_pmap_broadcast(kernel_fn, _device_count) cov2_is_none = utils.nt_tree_fn( reduce=lambda k: all(k))(lambda k: k.cov2 is None)(kernel) kernel = _reshape_kernel_for_pmap(kernel, _device_count, n1_per_device) kernel = _kernel_fn(kernel, *args, **kwargs) if cov2_is_none: kernel = _set_cov2_is_none(kernel) return _flatten_kernel(kernel, cov2_is_none, True) @utils.wraps(kernel_fn) def parallel_fn(x1_or_kernel, x2=None, *args, **kwargs): if utils.is_nt_tree_of(x1_or_kernel, np.ndarray): return parallel_fn_x1(x1_or_kernel, x2, *args, **kwargs) elif utils.is_nt_tree_of(x1_or_kernel, Kernel): assert not x2 return parallel_fn_kernel(x1_or_kernel, *args, **kwargs) raise NotImplementedError() # Set function attributes so that `serial` can detect whether or not it is # acting on a parallel function. parallel_fn.device_count = device_count return parallel_fn
def wrapper(*args, schedule, **kwargs): for loop, n in schedule: approx_n = axis_size if n is None else n if loop == 'parallel' and approx_n > xla_bridge.device_count(): raise SkipTest("this test requires more XLA devices") return fun(*args, schedule=schedule, **kwargs)
from jax.lib import xla_bridge if __name__ == "__main__": xla_backend = xla_bridge.get_backend() xla_backend_type = xla_bridge.get_backend().platform # cpu, gpu, tpu print(f"XLA backend type: {xla_backend_type}") gpu_count = xla_bridge.device_count() if xla_backend_type == "gpu" else 0 print(f"\nNumber of GPUs found on system: {gpu_count}") if xla_backend_type == "gpu": for idx, device in enumerate(xla_backend.devices()): gpu_type = "Active GPU" if idx == 0 else "GPU" print(f"\n{gpu_type} index: {device.id}") print(f"{gpu_type} name: {device.device_kind}")
def load_data_from_jax_tfds_or_ml_prepare(dataset_name, tfds_dir=None, K=None, as_numpy=False, batch_size=128, **data_loader_kwargs): """ Acquire from the official TFDS model zoo through JAX wrapper, or the ophthalmology focussed ml-prepare library :param dataset_name: name of dataset :type dataset_name: ```str``` :param tfds_dir: directory to look for models in. Default is ~/tensorflow_datasets. :type tfds_dir: ```None or str``` :param K: backend engine, e.g., `np` or `tf` :type K: ```None or np or tf or Any``` :param as_numpy: Convert to numpy ndarrays :type as_numpy: ```bool``` :param data_loader_kwargs: pass this as arguments to data_loader function :type data_loader_kwargs: ```**data_loader_kwargs``` :return: Train and tests dataset splits :rtype: ```Tuple[tf.data.Dataset, tf.data.Dataset] or Tuple[np.ndarray, np.ndarray]``` """ data_loader_kwargs.update({ 'dataset_name': dataset_name, 'tfds_dir': tfds_dir, }) if 'scale' not in data_loader_kwargs: data_loader_kwargs['scale'] = 255 if dataset_name in datasets2classes: return load_data_from_ml_prepare(dataset_name=dataset_name, tfds_dir=tfds_dir, **data_loader_kwargs) else: ml_params_jax.stolen.datasets._DATA = tfds_dir train_images, train_labels, test_images, test_labels = getattr( ml_params_jax.stolen.datasets, dataset_name)() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) # For this manual SPMD example, we get the number of devices (e.g. GPUs or # TPU cores) that we're using, and use it to reshape data minibatches. num_devices = xla_bridge.device_count() def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] images, labels = train_images[batch_idx], train_labels[ batch_idx] # For this SPMD example, we reshape the data batch dimension into two # batch dimensions, one of which is mapped over parallel devices. batch_size_per_device, ragged = divmod( images.shape[0], num_devices) if ragged: msg = "batch size must be divisible by device count, got {} and {}." raise ValueError(msg.format(batch_size, num_devices)) shape_prefix = (num_devices, batch_size_per_device) images = images.reshape(shape_prefix + images.shape[1:]) labels = labels.reshape(shape_prefix + labels.shape[1:]) yield images, labels return data_stream( ), num_batches, num_devices, train_images, train_labels, test_images, test_labels
def test_mcmc_one_chain(deterministic, find_heuristic_step_size): GLOBAL["count"] = 0 mcmc = MCMC(NUTS(model, find_heuristic_step_size=find_heuristic_step_size), 100, 100) mcmc.run(random.PRNGKey(0), deterministic=deterministic) mcmc.get_samples() num_traces_for_heuristic = 2 if find_heuristic_step_size else 0 if deterministic: assert GLOBAL["count"] == 4 + num_traces_for_heuristic else: assert GLOBAL["count"] == 3 + num_traces_for_heuristic @pytest.mark.parametrize('deterministic', [True, False]) @pytest.mark.skipif(xla_bridge.device_count() < 2, reason="only one device is available") def test_mcmc_parallel_chain(deterministic): GLOBAL["count"] = 0 mcmc = MCMC(NUTS(model), 100, 100, num_chains=2) mcmc.run(random.PRNGKey(0), deterministic=deterministic) mcmc.get_samples() if deterministic: assert GLOBAL["count"] == 4 else: assert GLOBAL["count"] == 3 @pytest.mark.parametrize('deterministic', [True, False]) def test_autoguide(deterministic):
def _parallel(kernel_fn, device_count=-1): """Returns a function that computes a kernel in batches in parallel. When batching in parallel, the data is split over a set number of devices. The number of devices must be less than or equal to the number of physical devices. Moreover, the dataset size needs to divide the device count. Given two datasets x1 and x2, parallel splits the kernel calculation over devices such that each device computes a batch of rows of shape [|x1| / device_count, |x2|]. Args: kernel_fn: A function that computes a kernel between two datasets, `kernel_fn(x1, x2)` or the compositional kernel for an input kernel `kernel_fn(kernel_in)`. Here x1 and x2 are `np.ndarray`s of floats of shape [n1] + input_shape and [n2] + input_shape; `kernel_in` is a Kernel object. The kernel function should return a PyTree. device_count: Integer specifying the number of devices over which to split the data. If device_count = 0, the computation is parallelized over all available devices. Returns: A new function with the same signature as kernel_fn that computes the kernel by batching over the dataset in parallel over a specified number of cores. """ kernel_fn = _jit_or_pmap_broadcast(kernel_fn, device_count) if device_count == -1: device_count = xla_bridge.device_count() def parallel_fn_x1(x1, x2=None, *args, **kwargs): if 'key' in kwargs: raise NotImplementedError('Batching for the empirical kernel with dropout' ' is not implemented. ') x2_is_none = x2 is None if x2_is_none: # TODO(schsam): Only compute the upper triangular part of the kernel. x2 = x1 n1 = x1.shape[0] assert x1.shape[1:] == x2.shape[1:] input_shape = x1.shape[1:] _device_count = device_count n1_per_device, ragged = divmod(n1, device_count) if n1_per_device and ragged: raise ValueError( ('Dataset size ({}) must divide number of ' 'physical devices ({}).').format(n1, device_count)) elif not n1_per_device: _device_count = ragged n1_per_device = 1 x1 = np.reshape(x1, (_device_count, n1_per_device,) + input_shape) kernel = kernel_fn(x1, x2, *args, **kwargs) return _flatten_kernel(kernel, x2_is_none, True) def parallel_fn_kernel(kernel, *args, **kwargs): n1 = kernel.var1.shape[0] _device_count = device_count n1_per_device, ragged = divmod(n1, device_count) if n1_per_device and ragged: raise ValueError( ('Dataset size ({}) must divide number of ' 'physical devices ({}).').format(n1, device_count)) elif not n1_per_device: _device_count = ragged n1_per_device = 1 kernel_dict = kernel._asdict() var2 = kernel_dict['var2'] var2_is_none = var2 is None if var2 is None: var2 = kernel_dict['var1'] kernel_dict['var2'] = np.broadcast_to(var2, (_device_count,) + var2.shape) kernel_dict['x1_is_x2'] = np.broadcast_to( kernel_dict['x1_is_x2'], (_device_count,) + kernel_dict['x1_is_x2'].shape) for k, v in kernel_dict.items(): if k in ('nngp', 'ntk', 'var1'): kernel_dict[k] = \ np.reshape(v, (_device_count, n1_per_device,) + v.shape[1:]) if k in ('shape1',): kernel_dict[k] = (n1_per_device,) + v[1:] kernel = kernel_fn(Kernel(**kernel_dict), *args, **kwargs) if var2_is_none: kernel = kernel._replace(var2=None) return _flatten_kernel(kernel, var2_is_none, True) @utils.wraps(kernel_fn) def parallel_fn(x1_or_kernel, x2=None, *args, **kwargs): if isinstance(x1_or_kernel, np.ndarray): return parallel_fn_x1(x1_or_kernel, x2, *args, **kwargs) elif isinstance(x1_or_kernel, Kernel): assert not x2 return parallel_fn_kernel(x1_or_kernel, *args, **kwargs) raise NotImplementedError() # Set function attributes so that `serial` can detect whether or not it is # acting on a parallel function. parallel_fn.device_count = device_count return parallel_fn
def jit_or_pmap_broadcast(f: Callable, device_count: int = -1) -> Callable: """Pmap `f` over the first argument by closing over or broadcasting others. Args: f: function to pmap. First argument must be an `np.ndarray` or a Kernel. In either case, ndarrays should have a leading axis having the size of `device_count`. device_count: number of XLA devices. `-1` means all available devices. `0` means to just `jit` the function. Returns: A function of the same signature as `f` pmapped over the `np.ndarray`s in the first argument. Other arguments are either closed over (non-`np.ndarray`s in `args` and all `kwargs`) or broadcasted to `(device_count,) + old_shape` (for `np.ndarray`s). If `device_count == 0`, `f` is closed over and jitted over all non-array arguments and all `kwargs`. Raises: An error if `kwargs` have a `np.ndarray`. TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it. See https://github.com/google/jax/issues/912 """ key = (f, device_count) if device_count == -1: device_count = xla_bridge.device_count() # TODO(romann): adapt this when JAX allows `axis_in` for `pmap`. def broadcast(arg: np.ndarray) -> np.ndarray: if device_count == 0: return arg # If the argument has already been sharded, no need to broadcast it. if isinstance(arg, ShardedDeviceArray) and arg.shape[0] == device_count: return arg return np.broadcast_to(arg, (device_count, ) + arg.shape) @utils.wraps(f) def f_pmapped(x_or_kernel: Union[np.ndarray, Kernel], *args, **kwargs): args_np, args_np_idxs = [], [] args_other = {} # TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it. # https://github.com/google/jax/issues/912 # Filter out `np.ndarray`s from other arguments. for i, arg in enumerate(args): if _is_np_ndarray(arg): args_np.append(arg) args_np_idxs.append(i) else: args_other[i] = arg kwargs_np = {} kwargs_other = {} for k, v in kwargs.items(): if _is_np_ndarray(v): assert isinstance(v, tuple), len(v) == 2 kwargs_np[k] = (v[0], broadcast(v[1])) else: kwargs_other[k] = v # Check cache before jitting. _key = key + \ tuple(args_other.items()) + \ tuple(kwargs_other.items()) # If any of the instance inside `_key` is a tf.Tensor object, use `ref()` # method to avoid directly hashing the TF Tensor. _key, tree = tree_flatten(_key) for i in range(len(_key)): if isinstance(_key[i], tf.Tensor): _key[i] = tuple(map(tuple, _key[i].numpy())) elif isinstance(_key[i], onp.ndarray): _key[i] = tuple(map(tuple, _key[i])) _key = tuple(_key) if _key in cache: _f = cache[_key] else: # Define a `np.ndarray`-only function as a closure over other arguments. def _f(_x_or_kernel, *_args_np, **_kwargs_np): # Merge args. _args_np = { i: _arg_np for i, _arg_np in zip(args_np_idxs, _args_np) } _args = {**_args_np, **args_other} _args = tuple(v for k, v in sorted(_args.items())) _kwargs = {**_kwargs_np, **kwargs_other} return f(_x_or_kernel, *_args, **_kwargs) _f = jit(_f) if device_count == 0 else pmap(_f) cache[_key] = _f # Broadcast `np.ndarray` arguments and apply the new function to them. args_np = tree_map(broadcast, args_np) return _f(x_or_kernel, *args_np, **kwargs_np) return f_pmapped
def jit_or_pmap_broadcast(f, device_count=-1): """Pmap `f` over the first argument by closing over or broadcasting others. Args: f: function to pmap. First argument must be an `np.ndarray` or a Kernel. In either case, ndarrays should have a leading axis having the size of `device_count`. device_count: number of XLA devices. `-1` means all available devices. `0` means to just `jit` the function. Returns: A function of the same signature as `f` pmapped over the ndarrays in the first argument. Other arguments are either closed over (non-`np.ndarray`s in `args` and all `kwargs`) or broadcasted to `(device_count,) + old_shape` (for `np.ndarray`s). If `device_count == 0`, `f` is closed over and jitted over all non-array arguments and all `kwargs`. Raises: An error if `kwargs` have a `np.ndarray`. TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it. See https://github.com/google/jax/issues/912 """ key = (f, device_count) if device_count == -1: device_count = xla_bridge.device_count() # TODO(romann): adapt this when JAX allows `axis_in` for `pmap`. def broadcast(arg): if device_count == 0: return arg # If the argument has already been sharded, no need to broadcast it. if isinstance(arg, ShardedDeviceArray) and arg.shape[0] == device_count: return arg return np.broadcast_to(arg, (device_count,) + arg.shape) @utils.wraps(f) def f_pmapped(x_or_kernel, *args, **kwargs): args_np, args_np_idxs = [], [] args_other = {} is_input_kernel = isinstance(x_or_kernel, Kernel) x_or_kernel_np = {} x_or_kernel_other = {} if is_input_kernel: kernel_dict = x_or_kernel._asdict() for k, v in kernel_dict.items(): if isinstance(v, np.ndarray): x_or_kernel_np[k] = v else: x_or_kernel_other[k] = v else: x_or_kernel_np = x_or_kernel # TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it. # https://github.com/google/jax/issues/912 # Filter out `np.ndarray`s from other arguments. for i, arg in enumerate(args): if _is_np_ndarray(arg): args_np.append(arg) args_np_idxs.append(i) else: args_other[i] = arg # Check cache before jitting. _key = key + \ tuple(args_other.items()) + \ tuple(kwargs.items()) + \ tuple(x_or_kernel_other.items()) if _key in cache: _f = cache[_key] else: # Define a `np.ndarray`-only function as a closure over other arguments. def _f(_x_or_kernel_np, *_args_np): # Merge Kernel. if is_input_kernel: _x_or_kernel_np = {**_x_or_kernel_np, **x_or_kernel_other} _x_or_kernel_np = Kernel(**_x_or_kernel_np) # Merge args. _args_np = {i: _arg_np for i, _arg_np in zip(args_np_idxs, _args_np)} _args = {**_args_np, **args_other} _args = tuple(v for k, v in sorted(_args.items())) return f(_x_or_kernel_np, *_args, **kwargs) _f = jit(_f) if device_count == 0 else pmap(_f) cache[_key] = _f # Broadcast `np.ndarray` arguments and apply the new function to them. args_np = tree_map(broadcast, args_np) return _f(x_or_kernel_np, *args_np) return f_pmapped
def testIssue804(self): num_devices = xla_bridge.device_count() f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.) api.pmap(f, axis_name="i")(np.ones((num_devices, 4))) # doesn't crash
def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs): """ Run the MCMC samplers and collect samples. :param random.PRNGKey rng_key: Random number generator key to be used for the sampling. :param args: Arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method. These are typically the arguments needed by the `model`. :param extra_fields: Extra fields (aside from `z`, `diverging`) from :data:`numpyro.infer.mcmc.HMCState` to collect during the MCMC run. :type extra_fields: tuple or list :param bool collect_warmup: Whether to collect samples from the warmup phase. Defaults to `False`. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. :param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method. These are typically the keyword arguments needed by the `model`. """ self._args = args self._kwargs = kwargs chain_method = self.chain_method if chain_method == 'parallel' and xla_bridge.device_count( ) < self.num_chains: chain_method = 'sequential' warnings.warn( 'There are not enough devices to run parallel chains: expected {} but got {}.' ' Chains will be drawn sequentially. If you are running MCMC in CPU,' ' consider to use `numpyro.set_host_device_count({})` at the beginning' ' of your program.'.format(self.num_chains, xla_bridge.device_count(), self.num_chains)) if init_params is not None and self.num_chains > 1: prototype_init_val = tree_flatten(init_params)[0][0] if np.shape(prototype_init_val)[0] != self.num_chains: raise ValueError( '`init_params` must have the same leading dimension' ' as `num_chains`.') assert isinstance(extra_fields, (tuple, list)) collect_fields = tuple(set(('z', 'diverging') + tuple(extra_fields))) if self.num_chains == 1: states_flat = self._single_chain_mcmc( (rng_key, init_params), collect_fields, collect_warmup, args, kwargs) states = tree_map(lambda x: x[np.newaxis, ...], states_flat) else: rng_keys = random.split(rng_key, self.num_chains) partial_map_fn = partial(self._single_chain_mcmc, collect_fields=collect_fields, collect_warmup=collect_warmup, args=args, kwargs=kwargs) if chain_method == 'sequential': if self.progress_bar: map_fn = partial(_laxmap, partial_map_fn) else: map_fn = partial(lax.map, partial_map_fn) elif chain_method == 'parallel': map_fn = pmap(partial_map_fn) elif chain_method == 'vectorized': map_fn = partial_map_fn else: raise ValueError( 'Only supporting the following methods to draw chains:' ' "sequential", "parallel", or "vectorized"') states = map_fn((rng_keys, init_params)) if chain_method == 'vectorized': # swap num_samples x num_chains to num_chains x num_samples states = tree_map(lambda x: np.swapaxes(x, 0, 1), states) states_flat = tree_map( lambda x: np.reshape(x, (-1, ) + x.shape[2:]), states) self._states = states self._states_flat = states_flat
def testShardedDeviceArrayBlockUntilReady(self): x = onp.arange(xla_bridge.device_count()) x = pmap(lambda x: x)(x) x.block_until_ready() # doesn't crash
def _parallel(ker_fun, device_count=-1): """Returns a function that computes a kernel in batches in parallel. When batching in parallel, the data is split over a set number of devices. The number of devices must be less than or equal to the number of physical devices. Moreover, the dataset size needs to divide the device count. Given two datasets x1 and x2, parallel splits the kernel calculation over devices such that each device computes a batch of rows of shape [|x1| / device_count, |x2|]. Args: ker_fun: A function that computes a kernel between two datasets, ker_fun(x1, x2). Here x1 and x2 are `np.ndarray`s of floats of shape [n1,] + input_shape and [n2,] + input_shape. The kernel function should return a PyTree. device_count: Integer specifying the number of devices over which to split the data. If device_count = 0, the computation is parallelized over all available devices. Returns: A new function with the same signature as ker_fun that computes the kernel by batching over the dataset in parallel over a specified number of cores. """ ker_fun = _jit_or_pmap_broadcast(ker_fun, device_count) if device_count == -1: device_count = xla_bridge.device_count() def parallel_fn(x1, x2=None, *args, **kwargs): if x2 is None: # TODO(schsam): Only compute the upper triangular part of the kernel. x2 = x1 n1 = x1.shape[0] assert x1.shape[1:] == x2.shape[1:] input_shape = x1.shape[1:] _device_count = device_count n1_per_device, ragged = divmod(n1, device_count) if n1_per_device and ragged: raise ValueError( ('Dataset size ({}) must divide number of ' 'physical devices ({}).').format(n1, device_count)) elif not n1_per_device: _device_count = ragged n1_per_device = 1 x1 = np.reshape(x1, ( _device_count, n1_per_device, ) + input_shape) kernel = ker_fun(x1, x2, *args, **kwargs) return _flatten_kernel(kernel) # Set function attributes so that `serial` can detect whether or not it is # acting on a parallel function. parallel_fn.is_parallel = True parallel_fn.device_count = device_count return parallel_fn
def testMismatchedAxisSizes(self): n = xla_bridge.device_count() f = pmap(lambda x, y: x + y) jtu.check_raises_regexp( lambda: f(onp.random.randn(n), onp.random.randn(n - 1)), ValueError, "Axis size .* does not match leading dimension of shape .*")