def mean_or_pmean(n_devices, x, axis=None): """jnp.mean or pmean. `x` is a distributed value. Directly calling jnp.mean on `x` means stacking x's components together to form a large array and then doing jnp.mean on it. In TF, stacking `x` will introduce D2H copy, so we use a collective (pmean) here instead of directly calling jnp.mean for TF. Args: n_devices: number of devices. x: a distributed array. axis: the axis to reduce. Can only be 0 or None. Returns: A local array. """ if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1: if axis not in (None, 0): raise ValueError('axis can only be None or 0') x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices if axis is None: x = jnp.mean(x) return x else: return jnp.mean(x, axis=axis)
def _pjit(self, f, donate_argnums=()): """JIT f if 1 device is available and pmap if more are available.""" if self._n_devices == 1: return fastmath.jit(f, donate_argnums=donate_argnums) else: return fastmath.pmap(f, axis_name='batch', donate_argnums=donate_argnums)
def _accelerate(f, n_devices): """Returns an accelerated version of ``f`` running on ``n_devices``.""" if n_devices == 0: # no accelerators - run on CPU return fastmath.jit(f, device=jax.devices('cpu')[0]) if n_devices == 1: return fastmath.jit(f) return fastmath.pmap(f, axis_name='batch')
def _pjit(self, f, memory_key=None, donate_argnums=()): """JIT f if 1 device is available and pmap if more are available.""" should_memoize = self._jit_memory is not None and memory_key is not None if (should_memoize and memory_key in self._jit_memory): logging.info('Found JITed function in memory for: %s', memory_key) return self._jit_memory[memory_key] if self._n_devices == 1: res = fastmath.jit(f, donate_argnums=donate_argnums) else: res = fastmath.pmap(f, axis_name='batch', donate_argnums=donate_argnums) if should_memoize: self._jit_memory[memory_key] = res return res
def mean_or_pmean(n_devices, x, axis=None): """Computes the mean of a distributed value ``x``. Args: n_devices: Number of devices. x: Distributed array. axis: Axis along which to compute means; can only be ``0`` or ``None``. Returns: A local array. """ if fastmath.backend_name() == 'tensorflow-numpy' and n_devices > 1: if axis not in (None, 0): raise ValueError('axis can only be None or 0') x = fastmath.pmap(fastmath.psum)(x)[0] / n_devices if axis is None: x = jnp.mean(x) return x else: return jnp.mean(x, axis=axis)
def _pjit(self, f): """JIT f if 1 device is available and pmap if more are available.""" if self._n_devices == 1: return fastmath.jit(f) else: return fastmath.pmap(f, axis_name='batch')
def _accelerate(f, n_devices): """JIT-compiled version of `f` running on `n_devices`.""" if n_devices == 1: return fastmath.jit(f) return fastmath.pmap(f, axis_name='batch')