def testSoftPmapAxisIndex(self): n = 4 * xla_bridge.device_count() def f(x): return x * lax.axis_index('i') ans = soft_pmap(f, 'i')(2 * np.ones(n)) expected = 2 * onp.arange(n) self.assertAllClose(ans, expected, check_dtypes=False)
def testSoftPmapPsum(self): n = 4 * xla_bridge.device_count() def f(x): return x / lax.psum(x, 'i') ans = soft_pmap(f, 'i')(np.ones(n)) expected = onp.ones(n) / n self.assertAllClose(ans, expected, check_dtypes=False)
def testSelect(self): p = onp.arange(15).reshape((5, 3)) % 4 == 1 f = onp.zeros((5, 3)) def fun(t): return lax.select(p, t, f) t = onp.ones((5, 3)) ans = soft_pmap(*_papply(fun))(t) expected = fun(t) self.assertAllClose(ans, expected, check_dtypes=True)
def testMax(self): pfun, axis_name = _papply(lambda x: np.max(x, axis=0)) jaxpr = make_jaxpr(pfun)(onp.ones(3)) expected_jaxpr = make_jaxpr(lambda x: lax.pmax(x, axis_name))( onp.zeros((5, 3))) assert repr(jaxpr) == repr(expected_jaxpr) arg = onp.arange(15.).reshape((5, 3)) ans = soft_pmap(pfun, axis_name)(arg)[0] expected = onp.max(arg, axis=0) self.assertAllClose(ans, expected, check_dtypes=False)
def testDot(self): raise SkipTest("known failure") # TODO(frostig) x = onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)) def fun(x, y): return lax.dot(x, y) expected = fun(x, x) pfun, axis_name = _papply(fun) ans = soft_pmap(pfun, axis_name)(x, x) ans = self.dedup(ans, expected.ndim) self.assertAllClose(ans, expected, check_dtypes=False)
def testAddBroadcasting(self): raise SkipTest("test doesn't pass yet") # TODO(frostig) def fun(x): return x + 3 x = onp.array([[1, 2], [3, 4]]) expected = x + 3 pfun, axis_name = _papply(fun) ans = soft_pmap(pfun, axis_name)(x) self.assertAllClose(ans, expected, check_dtypes=True)
def testSoftPmapDevicePersistence(self): device_count = xla_bridge.device_count() shape = (2 * 2 * device_count, 2, 3) # check that we can maintain device persistence across calls x = onp.arange(prod(shape)).reshape(shape) x = soft_pmap(lambda x: x)(x) self.assertIsInstance(x, pxla.ShardedDeviceArray) x._npy_value = onp.float32(onp.nan) # can't be coerced to ndarray for xfer x = soft_pmap(lambda x: x)(x) # doesn't crash self.assertIsInstance(x, pxla.ShardedDeviceArray) # check that we don't crash when we can't maintain device persistence x = onp.arange(prod(shape)).reshape(shape) x = soft_pmap(lambda x: x)(x) self.assertIsInstance(x, pxla.ShardedDeviceArray) y = x.reshape(device_count, -1) self.assertIsInstance(y, xla.DeviceArray) # should have forced collection soft_pmap(lambda x: x)(y) # doesn't crash z = x + 2 self.assertIsInstance(z, xla.DeviceArray) # should have forced collection x._npy_value = onp.float32(onp.nan) # can't be coerced to ndarray for xfer self.assertRaisesRegex( RuntimeError, '.*does not match host shape or layout of computation parameter 0.*', lambda: x + 2) # check that different axis merges aren't a problem x = onp.arange(prod(shape)).reshape(shape) x = soft_pmap(lambda x: x)(x) self.assertIsInstance(x, pxla.ShardedDeviceArray) x = x.reshape(2 * device_count, 2, 2, 3) # axis merge of the wrong size self.assertIsInstance(x, xla.DeviceArray) # should have forced collection
def testDotGeneral(self, matching, coloring, split): BATCH, CONTRACT, _ = range(3) SPLIT_LHS, SPLIT_RHS, SPLIT_BOTH = range(3) x = onp.reshape(onp.arange(8.), (2, 2, 2)) y = onp.reshape(onp.arange(8.), (2, 2, 2)) + 4. cdims = [(i, matching[i]) for i in range(3) if coloring[i] == CONTRACT] bdims = [(i, matching[i]) for i in range(3) if coloring[i] == BATCH] dimension_numbers = [ list(zip(*cdims)) or [(), ()], list(zip(*bdims)) or [(), ()] ] def f(x, y): return lax.dot_general(x, y, dimension_numbers) if split == SPLIT_LHS: fun = lambda x: f(x, y) elif split == SPLIT_RHS: fun = lambda y: f(x, y) else: fun = f try: if split != SPLIT_BOTH: expected = fun(x) pfun, axis_name = _papply(fun) ans = soft_pmap(pfun, axis_name)(x) else: expected = fun(x, y) pfun, axis_name = _papply(fun) ans = soft_pmap(pfun, axis_name)(x, y) except (NotImplementedError, TypeError) as e: raise SkipTest(str(e)) from e ans = self.dedup(ans, expected.ndim) self.assertAllClose(ans, expected, check_dtypes=False)
def testLogSoftmax(self): raise SkipTest("test doesn't pass yet") # TODO(frostig) def fun(x): return x - np.log(np.sum(np.exp(x))) pfun, axis_name = _papply(fun) jaxpr = make_jaxpr(pfun)(onp.zeros(5)) expected_jaxpr = make_jaxpr( lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5)) assert repr(jaxpr) == repr(expected_jaxpr) ans = soft_pmap(pfun, axis_name)(onp.arange(1., 5.)) expected = fun(onp.arange(1., 5.)) self.assertAllClose(ans, expected, check_dtypes=False)
def _soft_parallel(ker_fun, batch_size): """Returns a function that computes a kernel in batches in parallel. The current implementation uses jax.soft_pmap to simulate a given batch size. However, it is possible that larger batches will be chosen if (n1 * n2 // batch_size ** 2) > physical_device_count In this case soft_pmap will attempt implicitly use a larger batch size. To use a fixed batch size independent of the physical device count, one should compose this function with serial. In a future CL, it might be a good idea to introduce an `auto` batching function that composes serial and parallel automatically. Args: ker_fun: A function that computes a kernel between two datasets, ker_fun(x1, x2). Here x1 and x2 are `np.ndarray`s of floats of shape [n1,] + input_shape and [n2,] + input_shape. The kernel function should return a PyTree. batch_size: Integer specifying the size of batches in which to split the data. Returns: A new function with the same signature as ker_fun that computes the kernel by batching over the dataset in parallel with the specified batch_size. """ ker_fun = soft_pmap(soft_pmap(ker_fun)) def parallel_fn(x1, x2=None, *args, **kwargs): if x2 is None: # TODO(schsam): Only compute the upper triangular part of the kernel. x2 = x1 n1 = x1.shape[0] n2 = x2.shape[0] input_shape = x1.shape[1:] n1_batches, ragged = divmod(n1, batch_size) if ragged: # TODO(schsam): Relax this constraint. raise ValueError(( 'Number of examples in x1 must divide batch size. Found |x1| = {} ' 'and batch size = {}').format(n1, batch_size)) n2_batches, ragged = divmod(n2, batch_size) if ragged: # TODO(schsam): Relax this constraint. raise ValueError(( 'Number of examples in x2 must divide batch size. Found |x2| = {} ' 'and batch size = {}').format(n2, batch_size)) x1s = np.reshape(x1, ( n1_batches, 1, batch_size, ) + input_shape) x1s = np.broadcast_to(x1s, ( n1_batches, n2_batches, batch_size, ) + input_shape) x2s = np.reshape(x2, ( 1, n2_batches, batch_size, ) + input_shape) x2s = np.broadcast_to(x2s, ( n1_batches, n2_batches, batch_size, ) + input_shape) kernel = ker_fun(x1s, x2s, *args, **kwargs) return _flatten_kernel(kernel) return parallel_fn