def testNested(self): f = lambda x: lax_parallel.psum(lax_parallel.psum(x, 'i'), 'j') x = onp.ones((2, 2)) ans1 = serial_pmap(serial_pmap(f, 'i'), 'j')(x) ans2 = serial_pmap(serial_pmap(f, 'j'), 'i')(x) expected = 4 * onp.ones((2, 2)) self.assertAllClose(ans1, expected, check_dtypes=False) self.assertAllClose(ans2, expected, check_dtypes=False)
def testNestedBasic(self): f = lambda x: psum(psum(x, 'i'), 'j') f = pmap(pmap(f, 'i'), 'j') def sum_and_broadcast(x, axis): return onp.repeat(onp.sum(x, axis, keepdims=True), x.shape[axis], axis) shape = (xla_bridge.device_count(), 1, 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) ans = f(x) expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1) self.assertAllClose(ans, expected, check_dtypes=False)
def testBasic(self): f = pmap(lambda x: x - psum(x, 'i'), axis_name='i') shape = (xla_bridge.device_count(), 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) expected = x - onp.sum(x, 0) ans = f(x) self.assertAllClose(ans, expected, check_dtypes=False)
def mapped_update(i, opt_state, batch, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = num_devices. _, opt_update = optimizer(lr_fun) params = trax_opt.get_params(opt_state) grads = backend.grad(loss_fun)(params, batch, predict_fun, rng) grads = jax.tree_util.tree_map(lambda g: lax_parallel.psum(g, "batch"), grads) return opt_update(i, grads, opt_state)
def testSum(self): pfun, axis_name = papply(lambda x: np.sum(x, axis=0), 5) jaxpr = make_jaxpr(pfun)(onp.ones(3)) expected_jaxpr = make_jaxpr(lambda x: lax_parallel.psum(x, axis_name))( onp.zeros((5, 3))) assert repr(jaxpr) == repr(expected_jaxpr) arg = onp.arange(15.).reshape((5, 3)) ans = serial_pmap(pfun, axis_name)(arg)[0] expected = onp.sum(arg, axis=0) self.assertAllClose(ans, expected, check_dtypes=False)
def testLogSoftmax(self): return SkipTest("test doesn't pass yet") # TODO(frostig) def fun(x): return x - np.log(np.sum(np.exp(x))) pfun, axis_name = _papply(fun, 5) jaxpr = make_jaxpr(pfun)(onp.zeros(5)) expected_jaxpr = make_jaxpr( lambda x: x - np.log(lax_parallel.psum(np.exp(x), axis_name)))(onp.zeros(5)) assert repr(jaxpr) == repr(expected_jaxpr) ans = _serial_pmap(pfun, axis_name)(onp.arange(1., 5.)) expected = fun(onp.arange(1., 5.)) self.assertAllClose(ans, expected, check_dtypes=False)
def testPsumMultiple(self): f = lambda x: psum(x, ('i', 'j')) f = pmap(pmap(f, 'i'), 'j') def sum_and_broadcast(x, axis): return onp.repeat(onp.sum(x, axis, keepdims=True), x.shape[axis], axis) device_count = xla_bridge.device_count() num_pairs, ragged = divmod(device_count, 2) if num_pairs > 1 and not ragged: shape = (num_pairs, 2, 4) else: shape = (device_count, 1, 4) x = onp.arange(prod(shape), dtype=onp.float32).reshape(shape) ans = f(x) expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1) self.assertAllClose(ans, expected, check_dtypes=False)
def testLogSoftmax(self): f = lambda x: x - np.log(lax_parallel.psum(np.exp(x), 'i')) x = onp.log(onp.arange(1., 10., dtype=onp.float32)) ans = serial_pmap(f, axis_name='i')(x) expected = x - onp.log(onp.sum(onp.exp(x))) self.assertAllClose(ans, expected, check_dtypes=False)
def testReduceSum(self): f = lambda x: lax_parallel.psum(x, 'i') ans = serial_pmap(f, axis_name='i')(onp.ones(4)) expected = 4 * onp.ones(4) self.assertAllClose(ans, expected, check_dtypes=False)
def f(x, y): return psum(5. * np.cos(x) * np.sin(y), 'i')