def testAdd(self): x = onp.array([[1, 2, 3], [4, 5, 6]]) expected = x + x pfun, axis_name = _papply(np.add) ans = soft_pmap(pfun, axis_name)(x, x) self.assertAllClose(ans, expected, check_dtypes=True)
def testSelect(self): p = onp.arange(15).reshape((5, 3)) % 4 == 1 f = onp.zeros((5, 3)) def fun(t): return lax.select(p, t, f) t = onp.ones((5, 3)) ans = soft_pmap(*_papply(fun))(t) expected = fun(t) self.assertAllClose(ans, expected, check_dtypes=True)
def testTransposeAndAddRank3(self): def fun(x): return x + x.T x = onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2)) expected = x + x.T pfun, axis_name = _papply(fun, 2) ans = _serial_pmap(pfun, axis_name)(x) self.assertAllClose(ans, expected, check_dtypes=False)
def testMax(self): pfun, axis_name = _papply(lambda x: np.max(x, axis=0)) jaxpr = make_jaxpr(pfun)(onp.ones(3)) expected_jaxpr = make_jaxpr(lambda x: lax.pmax(x, axis_name))( onp.zeros((5, 3))) assert repr(jaxpr) == repr(expected_jaxpr) arg = onp.arange(15.).reshape((5, 3)) ans = soft_pmap(pfun, axis_name)(arg)[0] expected = onp.max(arg, axis=0) self.assertAllClose(ans, expected, check_dtypes=False)
def testDot(self): raise SkipTest("known failure") # TODO(frostig) x = onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)) def fun(x, y): return lax.dot(x, y) expected = fun(x, x) pfun, axis_name = _papply(fun) ans = soft_pmap(pfun, axis_name)(x, x) ans = self.dedup(ans, expected.ndim) self.assertAllClose(ans, expected, check_dtypes=False)
def testAddBroadcasting(self): raise SkipTest("test doesn't pass yet") # TODO(frostig) def fun(x): return x + 3 x = onp.array([[1, 2], [3, 4]]) expected = x + 3 pfun, axis_name = _papply(fun) ans = soft_pmap(pfun, axis_name)(x) self.assertAllClose(ans, expected, check_dtypes=True)
def testSum(self): pfun, axis_name = _papply(lambda x: np.sum(x, axis=0), 5) jaxpr = make_jaxpr(pfun)(onp.ones(3)) expected_jaxpr = make_jaxpr( lambda x: lax_parallel.psum(x, axis_name))(onp.zeros((5, 3))) assert repr(jaxpr) == repr(expected_jaxpr) arg = onp.arange(15.).reshape((5, 3)) ans = _serial_pmap(pfun, axis_name)(arg)[0] expected = onp.sum(arg, axis=0) self.assertAllClose(ans, expected, check_dtypes=False)
def testDotGeneral(self, matching, coloring, split): BATCH, CONTRACT, _ = range(3) SPLIT_LHS, SPLIT_RHS, SPLIT_BOTH = range(3) x = onp.reshape(onp.arange(8.), (2, 2, 2)) y = onp.reshape(onp.arange(8.), (2, 2, 2)) + 4. cdims = [(i, matching[i]) for i in range(3) if coloring[i] == CONTRACT] bdims = [(i, matching[i]) for i in range(3) if coloring[i] == BATCH] dimension_numbers = [ list(zip(*cdims)) or [(), ()], list(zip(*bdims)) or [(), ()] ] def f(x, y): return lax.dot_general(x, y, dimension_numbers) if split == SPLIT_LHS: fun = lambda x: f(x, y) elif split == SPLIT_RHS: fun = lambda y: f(x, y) else: fun = f try: if split != SPLIT_BOTH: expected = fun(x) pfun, axis_name = _papply(fun) ans = soft_pmap(pfun, axis_name)(x) else: expected = fun(x, y) pfun, axis_name = _papply(fun) ans = soft_pmap(pfun, axis_name)(x, y) except (NotImplementedError, TypeError) as e: raise SkipTest(str(e)) from e ans = self.dedup(ans, expected.ndim) self.assertAllClose(ans, expected, check_dtypes=False)
def testTransposeWithOddPermutation(self): def fun(x): return np.transpose(x, (2, 0, 1)) xs = [ onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2)), onp.reshape(onp.arange(27., dtype=onp.float32), (3, 3, 3)), ] for x in xs: expected = np.transpose(x, (2, 0, 1)) pfun, axis_name = _papply(fun, x.shape[0]) ans = _serial_pmap(pfun, axis_name)(x) self.assertAllClose(ans, expected, check_dtypes=False)
def testTranspose(self): def fun(x): return x.T xs = [ onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)), onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)), ] for x in xs: expected = x.T pfun, axis_name = _papply(fun, x.shape[0]) ans = _serial_pmap(pfun, axis_name)(x) self.assertAllClose(ans, expected, check_dtypes=False)
def testLogSoftmax(self): raise SkipTest("test doesn't pass yet") # TODO(frostig) def fun(x): return x - np.log(np.sum(np.exp(x))) pfun, axis_name = _papply(fun) jaxpr = make_jaxpr(pfun)(onp.zeros(5)) expected_jaxpr = make_jaxpr( lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5)) assert repr(jaxpr) == repr(expected_jaxpr) ans = soft_pmap(pfun, axis_name)(onp.arange(1., 5.)) expected = fun(onp.arange(1., 5.)) self.assertAllClose(ans, expected, check_dtypes=False)
def testDot(self): return SkipTest("test doesn't pass yet") # TODO(frostig) def fun(x, y): return lax.dot(x, y) xs = [ onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)), onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)), ] in_axes_combos = [(0, 0), (0, 1)] # [(1, 0)] for in_axes in in_axes_combos: for x in xs: expected = fun(x, x) pfun, axis_name = _papply(fun, x.shape[0], in_axes=in_axes) ans = _serial_pmap(pfun, axis_name)(x, x) self.assertAllClose(ans, expected, check_dtypes=False)
def testSelect(self): pfun, axis_name = _papply(lax.select, 5, in_axes=(None, 0, None)) p = onp.arange(15).reshape((5, 3)) % 4 == 1 t = onp.ones((5, 3)) f = onp.zeros((5, 3)) jaxpr = make_jaxpr(pfun)(p, t[0], f) def expected_spmd(p, t, f): return lax.select( lax_parallel.psplit_like(p, t, axis_name), t, lax_parallel.psplit_like(f, t, axis_name)) expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f) assert repr(jaxpr) == repr(expected_jaxpr) ans = _serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f) expected = lax.select(p, t, f) self.assertAllClose(ans, expected, check_dtypes=True)
def testMap(self): pfun, axis_name = _papply(np.sin) ans = pfun(onp.arange(3.)) expected = onp.sin(onp.arange(3.)) self.assertAllClose(ans, expected, check_dtypes=False)
def testIdentity(self): pfun, axis_name = _papply(lambda x: x) ans = pfun(onp.arange(3)) expected = onp.arange(3) self.assertAllClose(ans, expected, check_dtypes=False)
def testMakeJaxprPapplyComposition(self): raise SkipTest( # TODO(mattjj) "fails because select's papply rule calls an SPMD primitive") x = b = onp.ones(3) pfun, axis_name = _papply(lambda a: np.where(x, a, b)) make_jaxpr(pfun)(onp.ones(3)) # doesn't crash