Beispiel #1
0
def mean_or_pmean(n_devices, x, axis=None):
  """jnp.mean or pmean.

  `x` is a distributed value. Directly calling jnp.mean on `x` means stacking
  x's components together to form a large array and then doing jnp.mean on
  it. In TF, stacking `x` will introduce D2H copy, so we use a collective
  (pmean) here instead of directly calling jnp.mean for TF.

  Args:
    n_devices: number of devices.
    x: a distributed array.
    axis: the axis to reduce. Can only be 0 or None.

  Returns:
    A local array.
  """
  if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1:
    if axis not in (None, 0):
      raise ValueError('axis can only be None or 0')
    x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices
    if axis is None:
      x = jnp.mean(x)
    return x
  else:
    return jnp.mean(x, axis=axis)
Beispiel #2
0
 def _pjit(self, f, donate_argnums=()):
     """JIT f if 1 device is available and pmap if more are available."""
     if self._n_devices == 1:
         return fastmath.jit(f, donate_argnums=donate_argnums)
     else:
         return fastmath.pmap(f,
                              axis_name='batch',
                              donate_argnums=donate_argnums)
Beispiel #3
0
def _accelerate(f, n_devices):
  """Returns an accelerated version of ``f`` running on ``n_devices``."""
  if n_devices == 0:  # no accelerators - run on CPU
    return fastmath.jit(f, device=jax.devices('cpu')[0])

  if n_devices == 1:
    return fastmath.jit(f)

  return fastmath.pmap(f, axis_name='batch')
Beispiel #4
0
 def _pjit(self, f, memory_key=None, donate_argnums=()):
   """JIT f if 1 device is available and pmap if more are available."""
   should_memoize = self._jit_memory is not None and memory_key is not None
   if (should_memoize and memory_key in self._jit_memory):
     logging.info('Found JITed function in memory for: %s', memory_key)
     return self._jit_memory[memory_key]
   if self._n_devices == 1:
     res = fastmath.jit(f, donate_argnums=donate_argnums)
   else:
     res = fastmath.pmap(f, axis_name='batch', donate_argnums=donate_argnums)
   if should_memoize:
     self._jit_memory[memory_key] = res
   return res
Beispiel #5
0
def mean_or_pmean(n_devices, x, axis=None):
  """Computes the mean of a distributed value ``x``.

  Args:
    n_devices: Number of devices.
    x: Distributed array.
    axis: Axis along which to compute means; can only be ``0`` or ``None``.

  Returns:
    A local array.
  """
  if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1:
    if axis not in (None, 0):
      raise ValueError('axis can only be None or 0')
    x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices
    if axis is None:
      x = jnp.mean(x)
    return x
  else:
    return jnp.mean(x, axis=axis)
Beispiel #6
0
 def _pjit(self, f):
   """JIT f if 1 device is available and pmap if more are available."""
   if self._n_devices == 1:
     return fastmath.jit(f)
   else:
     return fastmath.pmap(f, axis_name='batch')
Beispiel #7
0
def _accelerate(f, n_devices):
    """JIT-compiled version of `f` running on `n_devices`."""
    if n_devices == 1:
        return fastmath.jit(f)

    return fastmath.pmap(f, axis_name='batch')