Example #1
0
def _parallel(kernel_fn: KernelFn, device_count: int = -1) -> KernelFn:
  """Returns a function that computes a kernel in batches in parallel.

  When batching in parallel, the data is split over a set number of devices.
  The number of devices must be less than or equal to the number of physical
  devices. Moreover, the dataset size needs to divide the device count.

  Given two datasets `x1` and `x2`, parallel splits the kernel calculation over
  devices such that each device computes a batch of rows of shape
  `[|x1| / device_count, |x2|]`.

  Args:
    kernel_fn:
      A function that computes a kernel between two datasets,
      `kernel_fn(x1, x2)` or the compositional kernel for an input kernel
      `kernel_fn(kernel_in)`. Here `x1` and `x2` are `np.ndarray`s of floats of
      shape `(n1,) + input_shape` and `(n2,) + input_shape`; `kernel_in` is a
      Kernel object. The kernel function should return a `PyTree`.
    device_count:
      Integer specifying the number of devices over which to split the data. If
      `device_count == 0`, the computation is parallelized over all available
      devices.

  Returns:
    A new function with the same signature as kernel_fn that computes the kernel
    by batching over the dataset in parallel over a specified number of cores.
  """
  kernel_fn = _jit_or_pmap_broadcast(kernel_fn, device_count)
  if device_count == -1:
    device_count = xla_bridge.device_count()

  def parallel_fn_x1(x1, x2=None, *args, **kwargs):
    x2_is_none = x2 is None
    if x2_is_none:
      # TODO(schsam): Only compute the upper triangular part of the kernel.
      x2 = x1

    n1 = x1.shape[0]

    assert x1.shape[1:] == x2.shape[1:]
    input_shape = x1.shape[1:]

    _device_count = device_count

    n1_per_device, ragged = divmod(n1, device_count)
    if n1_per_device and ragged:
      raise ValueError(
          ('Dataset size ({}) must divide number of '
           'physical devices ({}).').format(n1, device_count))
    elif not n1_per_device:
      _device_count = ragged
      n1_per_device = 1

    for k, v in kwargs.items():

      if _is_np_ndarray(v):
        assert isinstance(v, tuple) and len(v) == 2
        v0 = np.reshape(v[0], (_device_count, n1_per_device,) + v[0].shape[1:])
        kwargs[k] = (v0, v[1])

    x1 = np.reshape(x1, (_device_count, n1_per_device,) + input_shape)
    kernel = kernel_fn(x1, x2, *args, **kwargs)
    return _flatten_kernel(kernel, x2_is_none, True)

  def parallel_fn_kernel(kernel, *args, **kwargs):
    n1 = kernel.cov1.shape[0]

    _device_count = device_count

    n1_per_device, ragged = divmod(n1, device_count)
    if n1_per_device and ragged:
      raise ValueError(
          ('Dataset size ({}) must divide number of '
           'physical devices ({}).').format(n1, device_count))
    elif not n1_per_device:
      _device_count = ragged
      n1_per_device = 1

    cov2_is_none = kernel.cov2 is None
    kernel = _reshape_kernel_for_pmap(kernel, _device_count, n1_per_device)
    kernel = kernel_fn(kernel, *args, **kwargs)
    if cov2_is_none:
      kernel = kernel.replace(cov2=None)
    return _flatten_kernel(kernel, cov2_is_none, True)

  @utils.wraps(kernel_fn)
  def parallel_fn(x1_or_kernel, x2=None, *args, **kwargs):
    if isinstance(x1_or_kernel, np.ndarray):
      return parallel_fn_x1(x1_or_kernel, x2, *args, **kwargs)
    elif isinstance(x1_or_kernel, Kernel):
      assert not x2
      return parallel_fn_kernel(x1_or_kernel, *args, **kwargs)
    raise NotImplementedError()

  # Set function attributes so that `serial` can detect whether or not it is
  # acting on a parallel function.
  parallel_fn.device_count = device_count
  return parallel_fn
Example #2
0
    def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
        """
        Run the MCMC samplers and collect samples.

        :param random.PRNGKey rng_key: Random number generator key to be used for the sampling.
            For multi-chains, a batch of `num_chains` keys can be supplied. If `rng_key`
            does not have batch_size, it will be split in to a batch of `num_chains` keys.
        :param args: Arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method.
            These are typically the arguments needed by the `model`.
        :param extra_fields: Extra fields (aside from `z`, `diverging`) from :data:`numpyro.infer.mcmc.HMCState`
            to collect during the MCMC run.
        :type extra_fields: tuple or list
        :param init_params: Initial parameters to begin sampling. The type must be consistent
            with the input type to `potential_fn`.
        :param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init`
            method. These are typically the keyword arguments needed by the `model`.

        .. note:: jax allows python code to continue even when the compiled code has not finished yet.
            This can cause troubles when trying to profile the code for speed.
            See https://jax.readthedocs.io/en/latest/async_dispatch.html and
            https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs.
        """
        self._args = args
        self._kwargs = kwargs
        init_state = self._get_cached_init_state(rng_key, args, kwargs)
        if self.num_chains > 1 and rng_key.ndim == 1:
            rng_key = random.split(rng_key, self.num_chains)

        if self._warmup_state is not None:
            self._set_collection_params(0, self.num_samples, self.num_samples)
            init_state = self._warmup_state._replace(rng_key=rng_key)

        chain_method = self.chain_method
        if chain_method == 'parallel' and xla_bridge.device_count(
        ) < self.num_chains:
            chain_method = 'sequential'
            warnings.warn(
                'There are not enough devices to run parallel chains: expected {} but got {}.'
                ' Chains will be drawn sequentially. If you are running MCMC in CPU,'
                ' consider to use `numpyro.set_host_device_count({})` at the beginning'
                ' of your program.'.format(self.num_chains,
                                           xla_bridge.device_count(),
                                           self.num_chains))

        if init_params is not None and self.num_chains > 1:
            prototype_init_val = tree_flatten(init_params)[0][0]
            if jnp.shape(prototype_init_val)[0] != self.num_chains:
                raise ValueError(
                    '`init_params` must have the same leading dimension'
                    ' as `num_chains`.')
        assert isinstance(extra_fields, (tuple, list))
        collect_fields = tuple(
            set((self._sample_field, ) + tuple(self._default_fields) +
                tuple(extra_fields)))
        partial_map_fn = partial(self._single_chain_mcmc,
                                 args=args,
                                 kwargs=kwargs,
                                 collect_fields=collect_fields)
        map_args = (rng_key, init_state, init_params)
        if self.num_chains == 1:
            states_flat, last_state = partial_map_fn(map_args)
            states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
        else:
            if chain_method == 'sequential':
                if self.progress_bar:
                    states, last_state = _laxmap(partial_map_fn, map_args)
                else:
                    states, last_state = lax.map(partial_map_fn, map_args)
            elif chain_method == 'parallel':
                states, last_state = pmap(partial_map_fn)(map_args)
                # TODO: remove when https://github.com/google/jax/issues/3597 is resolved
                states = device_put(states)
            else:
                assert chain_method == 'vectorized'
                states, last_state = partial_map_fn(map_args)
                # swap num_samples x num_chains to num_chains x num_samples
                states = tree_map(lambda x: jnp.swapaxes(x, 0, 1), states)
            states_flat = tree_map(
                lambda x: jnp.reshape(x, (-1, ) + x.shape[2:]), states)
        self._last_state = last_state
        self._states = states
        self._states_flat = states_flat
        self._set_collection_params()
Example #3
0
 def testMakeJaxprOfOpenSpmd(self):
     f = lambda x: x - lax.psum(x, 'i')
     shape = (xla_bridge.device_count(), 4)
     x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape)
     make_jaxpr(f)(x)  # doesn't crash
Example #4
0
    def jit_or_pmap_broadcast(f, device_count=-1):
        """Pmap `f` over the first argument by closing over or broadcasting others.

    Args:
      f: function to pmap. First argument must be a `np.ndarray` with leading
        axis having the size of `device_count`.
      device_count: number of XLA devices. `-1` means all available devices. `0`
        means to just `jit` the function.

    Returns:
      A function of the same signature as `f` pmapped over the first argument
        with other arguments either closed over (non-`np.ndarray`s in `args` and
        all `kwargs`) or broadcasted to `(device_count,) + old_shape` (for
        `np.ndarray`s). If `device_count == 0`, `f` is closed over and jitted
        over all non-array arguments and all `kwargs`.

    Raises:
      An error if `kwargs` have a `np.ndarray`.
      TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it. See
      https://github.com/google/jax/issues/912
    """
        key = (f, device_count)

        if device_count == -1:
            device_count = xla_bridge.device_count()

        # TODO(romann): adapt this when JAX allows `axis_in` for `pmap`.
        def broadcast(arg):
            if device_count == 0:
                return arg
            return np.broadcast_to(arg, (device_count, ) + arg.shape)

        def f_pmapped(x, *args, **kwargs):
            args_np, args_np_idxs = [], []
            args_other = {}

            # TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it.
            # https://github.com/google/jax/issues/912
            # Filter out `np.ndarray`s from other arguments.
            for i, arg in enumerate(args):
                if _is_np_ndarray(arg):
                    args_np.append(arg)
                    args_np_idxs.append(i)
                else:
                    args_other[i] = arg

            # Check cache before jitting.
            _key = key + tuple(args_other.items()) + tuple(kwargs.items())
            if _key in cache:
                _f = cache[_key]
            else:
                # Define a `np.ndarray`-only function as a closure over other arguments.
                def _f(_x, *_args_np):
                    # Merge args.
                    _args_np = {
                        i: _arg_np
                        for i, _arg_np in zip(args_np_idxs, _args_np)
                    }
                    _args = _merge_dicts(_args_np, args_other)
                    _args = tuple(v for k, v in sorted(_args.items()))
                    return f(_x, *_args, **kwargs)

                _f = jit(_f) if device_count == 0 else pmap(_f)
                cache[_key] = _f

            # Broadcast `np.ndarray` arguments and apply the new function to them.
            args_np = tree_map(broadcast, args_np)
            return _f(x, *args_np)

        return f_pmapped
Example #5
0
  def testRule30(self):
    # This is a test of collective_permute implementing a simple halo exchange
    # to run a rule 30 simulation: https://en.wikipedia.org/wiki/Rule_30
    # Halo exchange should be useful in spatially-sharded convolutions and in
    # other simulations.
    device_count = xla_bridge.device_count()

    def send_right(x, axis_name):
      left_perm = [(i, (i + 1) % device_count) for i in range(device_count)]
      return lax.ppermute(x, perm=left_perm, axis_name=axis_name)

    def send_left(x, axis_name):
      left_perm = [((i + 1) % device_count, i) for i in range(device_count)]
      return lax.ppermute(x, perm=left_perm, axis_name=axis_name)

    def update_board(board):
      left = board[:-2]
      right = board[2:]
      center = board[1:-1]
      return lax.bitwise_xor(left, lax.bitwise_or(center, right))

    @partial(pmap, axis_name='i')
    def step(board_slice):
      left, right = board_slice[:1], board_slice[-1:]
      right, left = send_left(left, 'i'), send_right(right, 'i')
      enlarged_board_slice = np.concatenate([left, board_slice, right])
      return update_board(enlarged_board_slice)

    board = onp.zeros(40, dtype=bool)
    board[board.shape[0] // 2] = True
    reshaped_board = board.reshape((device_count, -1))

    boards = []
    def print_board(board):
      boards.append(''.join('*' if x else ' ' for x in board.ravel()))

    print_board(reshaped_board)
    for _ in range(20):
      reshaped_board = step(reshaped_board)
      print_board(reshaped_board)

    ans = '\n'.join(boards)
    expected = '\n'.join((
        '                    *                   ',
        '                   ***                  ',
        '                  **  *                 ',
        '                 ** ****                ',
        '                **  *   *               ',
        '               ** **** ***              ',
        '              **  *    *  *             ',
        '             ** ****  ******            ',
        '            **  *   ***     *           ',
        '           ** **** **  *   ***          ',
        '          **  *    * **** **  *         ',
        '         ** ****  ** *    * ****        ',
        '        **  *   ***  **  ** *   *       ',
        '       ** **** **  *** ***  ** ***      ',
        '      **  *    * ***   *  ***  *  *     ',
        '     ** ****  ** *  * *****  *******    ',
        '    **  *   ***  **** *    ***      *   ',
        '   ** **** **  ***    **  **  *    ***  ',
        '  **  *    * ***  *  ** *** ****  **  * ',
        ' ** ****  ** *  ******  *   *   *** ****',
        ' *  *   ***  ****     **** *** **   *   ',
    ))

    print(ans)
    self.assertEqual(ans, expected)
Example #6
0
File: mcmc.py Project: juvu/numpyro
def mcmc(num_warmup,
         num_samples,
         init_params,
         num_chains=1,
         sampler='hmc',
         constrain_fn=None,
         print_summary=True,
         **sampler_kwargs):
    """
    Convenience wrapper for MCMC samplers -- runs warmup, prints
    diagnostic summary and returns a collections of samples
    from the posterior.

    :param num_warmup: Number of warmup steps.
    :param num_samples: Number of samples to generate from the Markov chain.
    :param init_params: Initial parameters to begin sampling. The type can
        must be consistent with the input type to `potential_fn`.
    :param sampler: currently, only `hmc` is implemented (default).
    :param constrain_fn: Callable that converts a collection of unconstrained
        sample values returned from the sampler to constrained values that
        lie within the support of the sample sites.
    :param print_summary: Whether to print diagnostics summary for
        each sample site. Default is ``True``.
    :param `**sampler_kwargs`: Sampler specific keyword arguments.

         - *HMC*: Refer to :func:`~numpyro.mcmc.hmc` and
           :func:`~numpyro.mcmc.hmc.init_kernel` for accepted arguments. Note
           that all arguments must be provided as keywords.

    :return: collection of samples from the posterior.

    .. testsetup::

       import jax
       from jax import random
       import jax.numpy as np
       import numpyro.distributions as dist
       from numpyro.handlers import sample
       from numpyro.hmc_util import initialize_model
       from numpyro.mcmc import hmc
       from numpyro.util import fori_collect

    .. doctest::

        >>> true_coefs = np.array([1., 2., 3.])
        >>> data = random.normal(random.PRNGKey(2), (2000, 3))
        >>> dim = 3
        >>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))
        >>>
        >>> def model(data, labels):
        ...     coefs_mean = np.zeros(dim)
        ...     coefs = sample('beta', dist.Normal(coefs_mean, np.ones(3)))
        ...     intercept = sample('intercept', dist.Normal(0., 10.))
        ...     return sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)
        >>>
        >>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), model,
        ...                                                            data, labels)
        >>> num_warmup, num_samples = 1000, 1000
        >>> samples = mcmc(num_warmup, num_samples, init_params,
        ...                potential_fn=potential_fn,
        ...                constrain_fn=constrain_fn)  # doctest: +SKIP
        warmup: 100%|██████████| 1000/1000 [00:09<00:00, 109.40it/s, 1 steps of size 5.83e-01. acc. prob=0.79]
        sample: 100%|██████████| 1000/1000 [00:00<00:00, 1252.39it/s, 1 steps of size 5.83e-01. acc. prob=0.85]


                           mean         sd       5.5%      94.5%      n_eff       Rhat
            coefs[0]       0.96       0.07       0.85       1.07     455.35       1.01
            coefs[1]       2.05       0.09       1.91       2.20     332.00       1.01
            coefs[2]       3.18       0.13       2.96       3.37     320.27       1.00
           intercept      -0.03       0.02      -0.06       0.00     402.53       1.00
    """
    sequential_chain = False
    if xla_bridge.device_count() < num_chains:
        sequential_chain = True
        warnings.warn(
            'There are not enough devices to run parallel chains: expected {} but got {}.'
            ' Chains will be drawn sequentially. If you are running `mcmc` in CPU,'
            ' consider to disable XLA intra-op parallelism by setting the environment'
            ' flag "XLA_FLAGS=--xla_force_host_platform_device_count={}".'.
            format(num_chains, xla_bridge.device_count(), num_chains))
    progbar = sampler_kwargs.pop('progbar', True)
    if num_chains > 1:
        progbar = False

    if sampler == 'hmc':
        if constrain_fn is None:
            constrain_fn = identity
        potential_fn = sampler_kwargs.pop('potential_fn')
        kinetic_fn = sampler_kwargs.pop('kinetic_fn', None)
        algo = sampler_kwargs.pop('algo', 'NUTS')
        if num_chains > 1:
            rngs = sampler_kwargs.pop('rng',
                                      vmap(PRNGKey)(np.arange(num_chains)))
        else:
            rng = sampler_kwargs.pop('rng', PRNGKey(0))

        init_kernel, sample_kernel = hmc(potential_fn, kinetic_fn, algo)
        if progbar:
            hmc_state = init_kernel(init_params,
                                    num_warmup,
                                    progbar=progbar,
                                    rng=rng,
                                    **sampler_kwargs)
            samples_flat = fori_collect(0,
                                        num_samples,
                                        sample_kernel,
                                        hmc_state,
                                        transform=lambda x: constrain_fn(x.z),
                                        progbar=progbar,
                                        diagnostics_fn=get_diagnostics_str,
                                        progbar_desc='sample')
            samples = tree_map(lambda x: x[np.newaxis, ...], samples_flat)
        else:

            def single_chain_mcmc(rng, init_params):
                hmc_state = init_kernel(init_params,
                                        num_warmup,
                                        run_warmup=False,
                                        rng=rng,
                                        **sampler_kwargs)
                samples = fori_collect(num_warmup,
                                       num_warmup + num_samples,
                                       sample_kernel,
                                       hmc_state,
                                       transform=lambda x: constrain_fn(x.z),
                                       progbar=progbar)
                return samples

            if num_chains == 1:
                samples_flat = single_chain_mcmc(rng, init_params)
                samples = tree_map(lambda x: x[np.newaxis, ...], samples_flat)
            else:
                if sequential_chain:
                    samples = []
                    for i in range(num_chains):
                        init_params_i = tree_map(lambda x: x[i], init_params)
                        samples.append(
                            jit(single_chain_mcmc)(rngs[i], init_params_i))
                    samples = tree_multimap(lambda *args: np.stack(args),
                                            *samples)
                else:
                    samples = pmap(single_chain_mcmc)(rngs, init_params)
                samples_flat = tree_map(
                    lambda x: np.reshape(x, (-1, ) + x.shape[2:]), samples)

        if print_summary:
            summary(samples)
        return samples_flat
    else:
        raise ValueError('sampler: {} not recognized'.format(sampler))
Example #7
0
 def test_jit_device_assignment(self):
     device_num = xb.device_count() - 1
     x = api.jit(lambda x: x, device_assignment=device_num)(3.)
     self.assertIsInstance(x, DeviceArray)
     self.assertEqual(x.device_buffer.device(), device_num)
Example #8
0
if __name__ == "__main__":
  layer_sizes = [784, 1024, 1024, 10]
  param_scale = 0.1
  step_size = 0.001
  num_epochs = 10
  batch_size = 128

  train_images, train_labels, test_images, test_labels = datasets.mnist()
  num_train = train_images.shape[0]
  num_complete_batches, leftover = divmod(num_train, batch_size)
  num_batches = num_complete_batches + bool(leftover)

  # For this manual SPMD example, we get the number of devices (e.g. GPUs or
  # TPU cores) that we're using, and use it to reshape data minibatches.
  num_devices = xla_bridge.device_count()
  def data_stream():
    rng = npr.RandomState(0)
    while True:
      perm = rng.permutation(num_train)
      for i in range(num_batches):
        batch_idx = perm[i * batch_size:(i + 1) * batch_size]
        images, labels = train_images[batch_idx], train_labels[batch_idx]
        # For this SPMD example, we reshape the data batch dimension into two
        # batch dimensions, one of which is mapped over parallel devices.
        batch_size_per_device, ragged = divmod(images.shape[0], num_devices)
        if ragged:
          msg = "batch size must be divisible by device count, got {} and {}."
          raise ValueError(msg.format(batch_size, num_devices))
        shape_prefix = (num_devices, batch_size_per_device)
        images = images.reshape(shape_prefix + images.shape[1:])
Example #9
0
def _parallel(
    kernel_fn: KernelFn,
    use_serial: bool = True,
    dropout_in_analytic_kernel: bool = False,
    device_count: int = -1,
) -> KernelFn:
    """Returns a function that computes a kernel in batches in parallel.

  When batching in parallel, the data is split over a set number of devices.
  The number of devices must be less than or equal to the number of physical
  devices. Moreover, the dataset size needs to divide the device count.

  Given two datasets `x1` and `x2`, parallel splits the kernel calculation over
  devices such that each device computes a batch of rows of shape
  `[|x1| / device_count, |x2|]`.

  Args:
    kernel_fn:
      A function that computes a kernel between two datasets,
      `kernel_fn(x1, x2)` or the compositional kernel for an input kernel
      `kernel_fn(kernel_in)`. Here `x1` and `x2` are `np.ndarray`s of floats of
      shape `(n1,) + input_shape` and `(n2,) + input_shape`; `kernel_in` is a
      Kernel object. The kernel function should return a `PyTree`.
    use_serial:
      Whether `serial` will be called after `_parallel`. The only use case is to
      make sure when `dropout` is used in the analytic/empirical kernel, the
      batch size in each device is square.
    dropout_in_analytic_kernel:
      whether `dropout` is used in the analytic kernel. See `use_serial` above
      for the only use case.
    device_count:
      Integer specifying the number of devices over which to split the data. If
      `device_count == 0`, the computation is parallelized over all available
      devices.

  Returns:
    A new function with the same signature as kernel_fn that computes the kernel
    by batching over the dataset in parallel over a specified number of cores.
  """

    if device_count == -1:
        device_count = xla_bridge.device_count()

    def _check_dropout(n1, n2, kwargs):
        dropout_in_empirical_kernel = getattr(kwargs, 'rng', None) is not None
        if n1 == n2 and (dropout_in_empirical_kernel
                         or dropout_in_analytic_kernel) and not use_serial:
            raise NotImplementedError(
                'Batching for empirical / analytic kernels with dropout'
                ' is not implemented for non-square batch size. '
                'Using `serial` (i.e. use a non-zero batch_size in the '
                '`batch` function.) could enforce square batch size in each device.'
            )

    def _get_n_per_device(n1, n2):
        _device_count = device_count

        n1_per_device, ragged = divmod(n1, device_count)
        if n1_per_device and ragged:
            raise ValueError(
                ('Dataset size ({}) must divide number of '
                 'physical devices ({}).').format(n1, device_count))
        elif not n1_per_device:
            _device_count = ragged
            n1_per_device = 1

        return n1_per_device, _device_count

    def parallel_fn_x1(x1, x2=None, *args, **kwargs):
        x2_is_none = utils.all_none(x2)
        if x2_is_none:
            # TODO(schsam): Only compute the upper triangular part of the kernel.
            x2 = x1

        def get_batch_size(x):
            if utils.is_list_or_tuple(x):
                return get_batch_size(x[0])
            return x.shape[0]

        n1 = get_batch_size(x1)
        n2 = n1 if x2_is_none else get_batch_size(x2)

        _check_dropout(n1, n2, kwargs)
        n1_per_device, _device_count = _get_n_per_device(n1, n2)

        _kernel_fn = _jit_or_pmap_broadcast(kernel_fn, _device_count)

        @utils.nt_tree_fn()
        def batch_data(x):
            input_shape = x.shape[1:]
            return np.reshape(x, (
                _device_count,
                n1_per_device,
            ) + input_shape)

        for k, v in kwargs.items():
            if _is_np_ndarray(v):
                assert isinstance(v, tuple) and len(v) == 2
                v0 = np.reshape(v[0], (
                    _device_count,
                    n1_per_device,
                ) + v[0].shape[1:])
                kwargs[k] = (v0, v[1])

        x1 = batch_data(x1)

        kernel = _kernel_fn(x1, x2, *args, **kwargs)
        return _flatten_kernel(kernel, x2_is_none, True)

    def parallel_fn_kernel(kernel, *args, **kwargs):
        @utils.nt_tree_fn(reduce=lambda shapes: shapes[0])
        def get_batch_sizes(k):
            n1 = n2 = k.cov1.shape[0]
            if k.cov2 is not None:
                n2 = k.cov2.shape[0]
            return n1, n2

        n1, n2 = get_batch_sizes(kernel)
        _check_dropout(n1, n2, kwargs)
        n1_per_device, _device_count = _get_n_per_device(n1, n2)

        _kernel_fn = _jit_or_pmap_broadcast(kernel_fn, _device_count)

        cov2_is_none = utils.nt_tree_fn(
            reduce=lambda k: all(k))(lambda k: k.cov2 is None)(kernel)
        kernel = _reshape_kernel_for_pmap(kernel, _device_count, n1_per_device)
        kernel = _kernel_fn(kernel, *args, **kwargs)
        if cov2_is_none:
            kernel = _set_cov2_is_none(kernel)
        return _flatten_kernel(kernel, cov2_is_none, True)

    @utils.wraps(kernel_fn)
    def parallel_fn(x1_or_kernel, x2=None, *args, **kwargs):
        if utils.is_nt_tree_of(x1_or_kernel, np.ndarray):
            return parallel_fn_x1(x1_or_kernel, x2, *args, **kwargs)
        elif utils.is_nt_tree_of(x1_or_kernel, Kernel):
            assert not x2
            return parallel_fn_kernel(x1_or_kernel, *args, **kwargs)
        raise NotImplementedError()

    # Set function attributes so that `serial` can detect whether or not it is
    # acting on a parallel function.
    parallel_fn.device_count = device_count
    return parallel_fn
Example #10
0
 def wrapper(*args, schedule, **kwargs):
     for loop, n in schedule:
         approx_n = axis_size if n is None else n
         if loop == 'parallel' and approx_n > xla_bridge.device_count():
             raise SkipTest("this test requires more XLA devices")
     return fun(*args, schedule=schedule, **kwargs)
Example #11
0
from jax.lib import xla_bridge

if __name__ == "__main__":
    xla_backend = xla_bridge.get_backend()
    xla_backend_type = xla_bridge.get_backend().platform  # cpu, gpu, tpu
    print(f"XLA backend type: {xla_backend_type}")

    gpu_count = xla_bridge.device_count() if xla_backend_type == "gpu" else 0
    print(f"\nNumber of GPUs found on system: {gpu_count}")
    if xla_backend_type == "gpu":
        for idx, device in enumerate(xla_backend.devices()):
            gpu_type = "Active GPU" if idx == 0 else "GPU"
            print(f"\n{gpu_type} index: {device.id}")
            print(f"{gpu_type} name: {device.device_kind}")
Example #12
0
def load_data_from_jax_tfds_or_ml_prepare(dataset_name,
                                          tfds_dir=None,
                                          K=None,
                                          as_numpy=False,
                                          batch_size=128,
                                          **data_loader_kwargs):
    """
    Acquire from the official TFDS model zoo through JAX wrapper, or the ophthalmology focussed ml-prepare library

    :param dataset_name: name of dataset
    :type dataset_name: ```str```

    :param tfds_dir: directory to look for models in. Default is ~/tensorflow_datasets.
    :type tfds_dir: ```None or str```

    :param K: backend engine, e.g., `np` or `tf`
    :type K: ```None or np or tf or Any```

    :param as_numpy: Convert to numpy ndarrays
    :type as_numpy: ```bool```

    :param data_loader_kwargs: pass this as arguments to data_loader function
    :type data_loader_kwargs: ```**data_loader_kwargs```

    :return: Train and tests dataset splits
    :rtype: ```Tuple[tf.data.Dataset, tf.data.Dataset] or Tuple[np.ndarray, np.ndarray]```
    """

    data_loader_kwargs.update({
        'dataset_name': dataset_name,
        'tfds_dir': tfds_dir,
    })
    if 'scale' not in data_loader_kwargs:
        data_loader_kwargs['scale'] = 255

    if dataset_name in datasets2classes:
        return load_data_from_ml_prepare(dataset_name=dataset_name,
                                         tfds_dir=tfds_dir,
                                         **data_loader_kwargs)
    else:
        ml_params_jax.stolen.datasets._DATA = tfds_dir
        train_images, train_labels, test_images, test_labels = getattr(
            ml_params_jax.stolen.datasets, dataset_name)()
        num_train = train_images.shape[0]
        num_complete_batches, leftover = divmod(num_train, batch_size)
        num_batches = num_complete_batches + bool(leftover)

        # For this manual SPMD example, we get the number of devices (e.g. GPUs or
        # TPU cores) that we're using, and use it to reshape data minibatches.
        num_devices = xla_bridge.device_count()

        def data_stream():
            rng = npr.RandomState(0)
            while True:
                perm = rng.permutation(num_train)
                for i in range(num_batches):
                    batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                    images, labels = train_images[batch_idx], train_labels[
                        batch_idx]
                    # For this SPMD example, we reshape the data batch dimension into two
                    # batch dimensions, one of which is mapped over parallel devices.
                    batch_size_per_device, ragged = divmod(
                        images.shape[0], num_devices)
                    if ragged:
                        msg = "batch size must be divisible by device count, got {} and {}."
                        raise ValueError(msg.format(batch_size, num_devices))
                    shape_prefix = (num_devices, batch_size_per_device)
                    images = images.reshape(shape_prefix + images.shape[1:])
                    labels = labels.reshape(shape_prefix + labels.shape[1:])
                    yield images, labels

        return data_stream(
        ), num_batches, num_devices, train_images, train_labels, test_images, test_labels
Example #13
0
def test_mcmc_one_chain(deterministic, find_heuristic_step_size):
    GLOBAL["count"] = 0
    mcmc = MCMC(NUTS(model, find_heuristic_step_size=find_heuristic_step_size),
                100, 100)
    mcmc.run(random.PRNGKey(0), deterministic=deterministic)
    mcmc.get_samples()

    num_traces_for_heuristic = 2 if find_heuristic_step_size else 0
    if deterministic:
        assert GLOBAL["count"] == 4 + num_traces_for_heuristic
    else:
        assert GLOBAL["count"] == 3 + num_traces_for_heuristic


@pytest.mark.parametrize('deterministic', [True, False])
@pytest.mark.skipif(xla_bridge.device_count() < 2,
                    reason="only one device is available")
def test_mcmc_parallel_chain(deterministic):
    GLOBAL["count"] = 0
    mcmc = MCMC(NUTS(model), 100, 100, num_chains=2)
    mcmc.run(random.PRNGKey(0), deterministic=deterministic)
    mcmc.get_samples()

    if deterministic:
        assert GLOBAL["count"] == 4
    else:
        assert GLOBAL["count"] == 3


@pytest.mark.parametrize('deterministic', [True, False])
def test_autoguide(deterministic):
Example #14
0
def _parallel(kernel_fn, device_count=-1):
  """Returns a function that computes a kernel in batches in parallel.

  When batching in parallel, the data is split over a set number of devices.
  The number of devices must be less than or equal to the number of physical
  devices. Moreover, the dataset size needs to divide the device count.

  Given two datasets x1 and x2, parallel splits the kernel calculation over
  devices such that each device computes a batch of rows of shape
  [|x1| / device_count, |x2|].

  Args:
    kernel_fn: A function that computes a kernel between two datasets,
        `kernel_fn(x1, x2)` or the compositional kernel for an input kernel
        `kernel_fn(kernel_in)`. Here x1 and x2 are `np.ndarray`s of floats of
        shape [n1] + input_shape and [n2] + input_shape; `kernel_in` is a Kernel
        object. The kernel function should return a PyTree.
    device_count: Integer specifying the number of devices over which to split
        the data. If device_count = 0, the computation is parallelized over all
        available devices.

  Returns:
    A new function with the same signature as kernel_fn that computes the kernel
    by batching over the dataset in parallel over a specified number of cores.
  """
  kernel_fn = _jit_or_pmap_broadcast(kernel_fn, device_count)
  if device_count == -1:
    device_count = xla_bridge.device_count()

  def parallel_fn_x1(x1, x2=None, *args, **kwargs):
    if 'key' in kwargs:
      raise NotImplementedError('Batching for the empirical kernel with dropout'
                                ' is not implemented. ')
    x2_is_none = x2 is None
    if x2_is_none:
      # TODO(schsam): Only compute the upper triangular part of the kernel.
      x2 = x1

    n1 = x1.shape[0]

    assert x1.shape[1:] == x2.shape[1:]
    input_shape = x1.shape[1:]

    _device_count = device_count

    n1_per_device, ragged = divmod(n1, device_count)
    if n1_per_device and ragged:
      raise ValueError(
          ('Dataset size ({}) must divide number of '
           'physical devices ({}).').format(n1, device_count))
    elif not n1_per_device:
      _device_count = ragged
      n1_per_device = 1

    x1 = np.reshape(x1, (_device_count, n1_per_device,) + input_shape)
    kernel = kernel_fn(x1, x2, *args, **kwargs)
    return _flatten_kernel(kernel, x2_is_none, True)

  def parallel_fn_kernel(kernel, *args, **kwargs):
    n1 = kernel.var1.shape[0]

    _device_count = device_count

    n1_per_device, ragged = divmod(n1, device_count)
    if n1_per_device and ragged:
      raise ValueError(
          ('Dataset size ({}) must divide number of '
           'physical devices ({}).').format(n1, device_count))
    elif not n1_per_device:
      _device_count = ragged
      n1_per_device = 1

    kernel_dict = kernel._asdict()

    var2 = kernel_dict['var2']
    var2_is_none = var2 is None
    if var2 is None:
      var2 = kernel_dict['var1']
    kernel_dict['var2'] = np.broadcast_to(var2, (_device_count,) + var2.shape)
    kernel_dict['x1_is_x2'] = np.broadcast_to(
        kernel_dict['x1_is_x2'],
        (_device_count,) + kernel_dict['x1_is_x2'].shape)

    for k, v in kernel_dict.items():
      if k in ('nngp', 'ntk', 'var1'):
        kernel_dict[k] = \
            np.reshape(v, (_device_count, n1_per_device,) + v.shape[1:])
      if k in ('shape1',):
        kernel_dict[k] = (n1_per_device,) + v[1:]
    kernel = kernel_fn(Kernel(**kernel_dict), *args, **kwargs)
    if var2_is_none:
      kernel = kernel._replace(var2=None)
    return _flatten_kernel(kernel, var2_is_none, True)

  @utils.wraps(kernel_fn)
  def parallel_fn(x1_or_kernel, x2=None, *args, **kwargs):
    if isinstance(x1_or_kernel, np.ndarray):
      return parallel_fn_x1(x1_or_kernel, x2, *args, **kwargs)
    elif isinstance(x1_or_kernel, Kernel):
      assert not x2
      return parallel_fn_kernel(x1_or_kernel, *args, **kwargs)
    raise NotImplementedError()

  # Set function attributes so that `serial` can detect whether or not it is
  # acting on a parallel function.
  parallel_fn.device_count = device_count
  return parallel_fn
Example #15
0
    def jit_or_pmap_broadcast(f: Callable, device_count: int = -1) -> Callable:
        """Pmap `f` over the first argument by closing over or broadcasting others.

    Args:
      f:
        function to pmap. First argument must be an `np.ndarray` or a Kernel.
        In either case, ndarrays should have a leading axis having the size of
        `device_count`.
      device_count:
        number of XLA devices. `-1` means all available devices. `0` means to
        just `jit` the function.

    Returns:
      A function of the same signature as `f` pmapped over the `np.ndarray`s in
      the first argument. Other arguments are either closed over
      (non-`np.ndarray`s in `args` and all `kwargs`) or broadcasted to
      `(device_count,) + old_shape` (for `np.ndarray`s). If `device_count == 0`,
      `f` is closed over and jitted over all non-array arguments and all
      `kwargs`.

    Raises:
      An error if `kwargs` have a `np.ndarray`.
      TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it. See
      https://github.com/google/jax/issues/912
    """
        key = (f, device_count)

        if device_count == -1:
            device_count = xla_bridge.device_count()

        # TODO(romann): adapt this when JAX allows `axis_in` for `pmap`.
        def broadcast(arg: np.ndarray) -> np.ndarray:
            if device_count == 0:
                return arg
            # If the argument has already been sharded, no need to broadcast it.
            if isinstance(arg,
                          ShardedDeviceArray) and arg.shape[0] == device_count:
                return arg
            return np.broadcast_to(arg, (device_count, ) + arg.shape)

        @utils.wraps(f)
        def f_pmapped(x_or_kernel: Union[np.ndarray, Kernel], *args, **kwargs):
            args_np, args_np_idxs = [], []
            args_other = {}

            # TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it.
            # https://github.com/google/jax/issues/912
            # Filter out `np.ndarray`s from other arguments.
            for i, arg in enumerate(args):
                if _is_np_ndarray(arg):
                    args_np.append(arg)
                    args_np_idxs.append(i)
                else:
                    args_other[i] = arg
            kwargs_np = {}
            kwargs_other = {}
            for k, v in kwargs.items():
                if _is_np_ndarray(v):
                    assert isinstance(v, tuple), len(v) == 2
                    kwargs_np[k] = (v[0], broadcast(v[1]))
                else:
                    kwargs_other[k] = v

            # Check cache before jitting.
            _key = key + \
                tuple(args_other.items()) + \
                tuple(kwargs_other.items())

            # If any of the instance inside `_key` is a tf.Tensor object, use `ref()`
            # method to avoid directly hashing the TF Tensor.
            _key, tree = tree_flatten(_key)
            for i in range(len(_key)):
                if isinstance(_key[i], tf.Tensor):
                    _key[i] = tuple(map(tuple, _key[i].numpy()))
                elif isinstance(_key[i], onp.ndarray):
                    _key[i] = tuple(map(tuple, _key[i]))
            _key = tuple(_key)

            if _key in cache:
                _f = cache[_key]
            else:
                # Define a `np.ndarray`-only function as a closure over other arguments.
                def _f(_x_or_kernel, *_args_np, **_kwargs_np):
                    # Merge args.
                    _args_np = {
                        i: _arg_np
                        for i, _arg_np in zip(args_np_idxs, _args_np)
                    }
                    _args = {**_args_np, **args_other}
                    _args = tuple(v for k, v in sorted(_args.items()))
                    _kwargs = {**_kwargs_np, **kwargs_other}
                    return f(_x_or_kernel, *_args, **_kwargs)

                _f = jit(_f) if device_count == 0 else pmap(_f)
                cache[_key] = _f

            # Broadcast `np.ndarray` arguments and apply the new function to them.
            args_np = tree_map(broadcast, args_np)
            return _f(x_or_kernel, *args_np, **kwargs_np)

        return f_pmapped
Example #16
0
  def jit_or_pmap_broadcast(f, device_count=-1):
    """Pmap `f` over the first argument by closing over or broadcasting others.

    Args:
      f: function to pmap. First argument must be an `np.ndarray` or a Kernel.
        In either case, ndarrays should have a leading axis having the size of
        `device_count`.
      device_count: number of XLA devices. `-1` means all available devices. `0`
        means to just `jit` the function.

    Returns:
      A function of the same signature as `f` pmapped over the ndarrays in the
      first argument. Other arguments are either closed over (non-`np.ndarray`s
      in `args` and all `kwargs`) or broadcasted to
      `(device_count,) + old_shape` (for `np.ndarray`s). If `device_count == 0`,
      `f` is closed over and jitted over all non-array arguments and all
      `kwargs`.

    Raises:
      An error if `kwargs` have a `np.ndarray`.
      TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it. See
      https://github.com/google/jax/issues/912
    """
    key = (f, device_count)

    if device_count == -1:
      device_count = xla_bridge.device_count()

    # TODO(romann): adapt this when JAX allows `axis_in` for `pmap`.
    def broadcast(arg):
      if device_count == 0:
        return arg
      # If the argument has already been sharded, no need to broadcast it.
      if isinstance(arg, ShardedDeviceArray) and arg.shape[0] == device_count:
        return arg
      return np.broadcast_to(arg, (device_count,) + arg.shape)

    @utils.wraps(f)
    def f_pmapped(x_or_kernel, *args, **kwargs):
      args_np, args_np_idxs = [], []
      args_other = {}

      is_input_kernel = isinstance(x_or_kernel, Kernel)
      x_or_kernel_np = {}
      x_or_kernel_other = {}

      if is_input_kernel:
        kernel_dict = x_or_kernel._asdict()
        for k, v in kernel_dict.items():
          if isinstance(v, np.ndarray):
            x_or_kernel_np[k] = v
          else:
            x_or_kernel_other[k] = v
      else:
        x_or_kernel_np = x_or_kernel

      # TODO(romann): treat `np.ndarray`s in `kwargs` when JAX allows it.
      # https://github.com/google/jax/issues/912
      # Filter out `np.ndarray`s from other arguments.
      for i, arg in enumerate(args):
        if _is_np_ndarray(arg):
          args_np.append(arg)
          args_np_idxs.append(i)
        else:
          args_other[i] = arg

      # Check cache before jitting.
      _key = key + \
          tuple(args_other.items()) + \
          tuple(kwargs.items()) + \
          tuple(x_or_kernel_other.items())
      if _key in cache:
        _f = cache[_key]
      else:
        # Define a `np.ndarray`-only function as a closure over other arguments.
        def _f(_x_or_kernel_np, *_args_np):
          # Merge Kernel.
          if is_input_kernel:
            _x_or_kernel_np = {**_x_or_kernel_np, **x_or_kernel_other}
            _x_or_kernel_np = Kernel(**_x_or_kernel_np)
          # Merge args.
          _args_np = {i: _arg_np for i, _arg_np in zip(args_np_idxs, _args_np)}
          _args = {**_args_np, **args_other}
          _args = tuple(v for k, v in sorted(_args.items()))
          return f(_x_or_kernel_np, *_args, **kwargs)

        _f = jit(_f) if device_count == 0 else pmap(_f)
        cache[_key] = _f

      # Broadcast `np.ndarray` arguments and apply the new function to them.
      args_np = tree_map(broadcast, args_np)
      return _f(x_or_kernel_np, *args_np)

    return f_pmapped
Example #17
0
 def testIssue804(self):
   num_devices = xla_bridge.device_count()
   f = partial(lax.scan, lambda c, x: (c + lax.psum(x, "i") , c), 0.)
   api.pmap(f, axis_name="i")(np.ones((num_devices, 4)))  # doesn't crash
Example #18
0
    def run(self,
            rng_key,
            *args,
            extra_fields=(),
            collect_warmup=False,
            init_params=None,
            **kwargs):
        """
        Run the MCMC samplers and collect samples.

        :param random.PRNGKey rng_key: Random number generator key to be used for the sampling.
        :param args: Arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method.
            These are typically the arguments needed by the `model`.
        :param extra_fields: Extra fields (aside from `z`, `diverging`) from :data:`numpyro.infer.mcmc.HMCState`
            to collect during the MCMC run.
        :type extra_fields: tuple or list
        :param bool collect_warmup: Whether to collect samples from the warmup phase. Defaults
            to `False`.
        :param init_params: Initial parameters to begin sampling. The type must be consistent
            with the input type to `potential_fn`.
        :param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init`
            method. These are typically the keyword arguments needed by the `model`.
        """
        self._args = args
        self._kwargs = kwargs
        chain_method = self.chain_method
        if chain_method == 'parallel' and xla_bridge.device_count(
        ) < self.num_chains:
            chain_method = 'sequential'
            warnings.warn(
                'There are not enough devices to run parallel chains: expected {} but got {}.'
                ' Chains will be drawn sequentially. If you are running MCMC in CPU,'
                ' consider to use `numpyro.set_host_device_count({})` at the beginning'
                ' of your program.'.format(self.num_chains,
                                           xla_bridge.device_count(),
                                           self.num_chains))

        if init_params is not None and self.num_chains > 1:
            prototype_init_val = tree_flatten(init_params)[0][0]
            if np.shape(prototype_init_val)[0] != self.num_chains:
                raise ValueError(
                    '`init_params` must have the same leading dimension'
                    ' as `num_chains`.')
        assert isinstance(extra_fields, (tuple, list))
        collect_fields = tuple(set(('z', 'diverging') + tuple(extra_fields)))
        if self.num_chains == 1:
            states_flat = self._single_chain_mcmc(
                (rng_key, init_params), collect_fields, collect_warmup, args,
                kwargs)
            states = tree_map(lambda x: x[np.newaxis, ...], states_flat)
        else:
            rng_keys = random.split(rng_key, self.num_chains)
            partial_map_fn = partial(self._single_chain_mcmc,
                                     collect_fields=collect_fields,
                                     collect_warmup=collect_warmup,
                                     args=args,
                                     kwargs=kwargs)
            if chain_method == 'sequential':
                if self.progress_bar:
                    map_fn = partial(_laxmap, partial_map_fn)
                else:
                    map_fn = partial(lax.map, partial_map_fn)
            elif chain_method == 'parallel':
                map_fn = pmap(partial_map_fn)
            elif chain_method == 'vectorized':
                map_fn = partial_map_fn
            else:
                raise ValueError(
                    'Only supporting the following methods to draw chains:'
                    ' "sequential", "parallel", or "vectorized"')
            states = map_fn((rng_keys, init_params))
            if chain_method == 'vectorized':
                # swap num_samples x num_chains to num_chains x num_samples
                states = tree_map(lambda x: np.swapaxes(x, 0, 1), states)
            states_flat = tree_map(
                lambda x: np.reshape(x, (-1, ) + x.shape[2:]), states)
        self._states = states
        self._states_flat = states_flat
Example #19
0
 def testShardedDeviceArrayBlockUntilReady(self):
     x = onp.arange(xla_bridge.device_count())
     x = pmap(lambda x: x)(x)
     x.block_until_ready()  # doesn't crash
Example #20
0
def _parallel(ker_fun, device_count=-1):
    """Returns a function that computes a kernel in batches in parallel.

  When batching in parallel, the data is split over a set number of devices.
  The number of devices must be less than or equal to the number of physical
  devices. Moreover, the dataset size needs to divide the device count.

  Given two datasets x1 and x2, parallel splits the kernel calculation over
  devices such that each device computes a batch of rows of shape
  [|x1| / device_count, |x2|].

  Args:
    ker_fun: A function that computes a kernel between two datasets,
        ker_fun(x1, x2). Here x1 and x2 are `np.ndarray`s of floats of shape
        [n1,] + input_shape and [n2,] + input_shape. The kernel function
        should return a PyTree.
    device_count: Integer specifying the number of devices over which to split
        the data. If device_count = 0, the computation is parallelized over all
        available devices.

  Returns:
    A new function with the same signature as ker_fun that computes the kernel
    by batching over the dataset in parallel over a specified number of cores.
  """
    ker_fun = _jit_or_pmap_broadcast(ker_fun, device_count)
    if device_count == -1:
        device_count = xla_bridge.device_count()

    def parallel_fn(x1, x2=None, *args, **kwargs):
        if x2 is None:
            # TODO(schsam): Only compute the upper triangular part of the kernel.
            x2 = x1

        n1 = x1.shape[0]

        assert x1.shape[1:] == x2.shape[1:]
        input_shape = x1.shape[1:]

        _device_count = device_count

        n1_per_device, ragged = divmod(n1, device_count)
        if n1_per_device and ragged:
            raise ValueError(
                ('Dataset size ({}) must divide number of '
                 'physical devices ({}).').format(n1, device_count))
        elif not n1_per_device:
            _device_count = ragged
            n1_per_device = 1

        x1 = np.reshape(x1, (
            _device_count,
            n1_per_device,
        ) + input_shape)
        kernel = ker_fun(x1, x2, *args, **kwargs)
        return _flatten_kernel(kernel)

    # Set function attributes so that `serial` can detect whether or not it is
    # acting on a parallel function.
    parallel_fn.is_parallel = True
    parallel_fn.device_count = device_count

    return parallel_fn
Example #21
0
 def testMismatchedAxisSizes(self):
   n = xla_bridge.device_count()
   f = pmap(lambda x, y: x + y)
   jtu.check_raises_regexp(
       lambda: f(onp.random.randn(n), onp.random.randn(n - 1)), ValueError,
       "Axis size .* does not match leading dimension of shape .*")