def testAdd(self): x = onp.array([[1, 2, 3], [4, 5, 6]]) expected = x + x pfun, axis_name = papply(np.add, 2) ans = serial_pmap(pfun, axis_name)(x, x) self.assertAllClose(ans, expected, check_dtypes=True)
def testSum(self): pfun, axis_name = papply(np.sum) jaxpr = make_jaxpr(pfun)(onp.zeros(5)) expected_jaxpr = make_jaxpr(lambda x: psum(x, axis_name))(onp.zeros(5)) assert repr(jaxpr) == repr(expected_jaxpr) ans = pmap(pfun, axis_name)(onp.arange(3.)) expected = onp.sum(onp.arange(3.)) self.assertAllClose(ans, expected, check_dtypes=False)
def testAddBroadcasting(self): def fun(x): return x + 3 x = onp.array([[1, 2], [3, 4]]) expected = x + 3 pfun, axis_name = papply(fun) ans = pmap(pfun, axis_name)(x) self.assertAllClose(ans, expected, check_dtypes=True)
def testTransposeAndAddRank3(self): def fun(x): return x + x.T x = onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2)) expected = x + x.T pfun, axis_name = papply(fun, 2) ans = serial_pmap(pfun, axis_name)(x) self.assertAllClose(ans, expected, check_dtypes=False)
def testSum(self): pfun, axis_name = papply(lambda x: np.sum(x, axis=0), 5) jaxpr = make_jaxpr(pfun)(onp.ones(3)) expected_jaxpr = make_jaxpr(lambda x: lax_parallel.psum(x, axis_name))( onp.zeros((5, 3))) assert repr(jaxpr) == repr(expected_jaxpr) arg = onp.arange(15.).reshape((5, 3)) ans = serial_pmap(pfun, axis_name)(arg)[0] expected = onp.sum(arg, axis=0) self.assertAllClose(ans, expected, check_dtypes=False)
def testAddBroadcasting(self): return SkipTest("test doesn't pass yet") # TODO(frostig) def fun(x): return x + 3 x = onp.array([[1, 2], [3, 4]]) expected = x + 3 pfun, axis_name = papply(fun, 2) ans = serial_pmap(pfun, axis_name)(x) self.assertAllClose(ans, expected, check_dtypes=True)
def testTransposeWithOddPermutation(self): def fun(x): return np.transpose(x, (2, 0, 1)) xs = [ onp.reshape(onp.arange(8., dtype=onp.float32), (2, 2, 2)), onp.reshape(onp.arange(27., dtype=onp.float32), (3, 3, 3)), ] for x in xs: expected = np.transpose(x, (2, 0, 1)) pfun, axis_name = papply(fun, x.shape[0]) ans = serial_pmap(pfun, axis_name)(x) self.assertAllClose(ans, expected, check_dtypes=False)
def testTranspose(self): def fun(x): return x.T xs = [ onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)), onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)), ] for x in xs: expected = x.T pfun, axis_name = papply(fun, x.shape[0]) ans = serial_pmap(pfun, axis_name)(x) self.assertAllClose(ans, expected, check_dtypes=False)
def testLogSoftmax(self): def fun(x): return x - np.log(np.sum(np.exp(x))) pfun, axis_name = papply(fun) jaxpr = make_jaxpr(pfun)(onp.zeros(5)) expected_jaxpr = make_jaxpr( lambda x: x - np.log(lax.psum(np.exp(x), axis_name)))(onp.zeros(5)) assert repr(jaxpr) == repr(expected_jaxpr) ans = pmap(pfun, axis_name)(onp.arange(1., 5.)) expected = fun(onp.arange(1., 5.)) self.assertAllClose(ans, expected, check_dtypes=False)
def testDot(self): def fun(x, y): return lax.dot(x, y) xs = [ onp.reshape(onp.arange(4., dtype=onp.float32), (2, 2)), onp.reshape(onp.arange(9., dtype=onp.float32), (3, 3)), ] in_axes_combos = [(0, 0), (0, 1)] # [(1, 0)] for in_axes in in_axes_combos: for x in xs: expected = fun(x, x) pfun, axis_name = papply(fun, x.shape[0], in_axes=in_axes) ans = serial_pmap(pfun, axis_name)(x, x) self.assertAllClose(ans, expected, check_dtypes=False)
def testLogSoftmax(self): return SkipTest("test doesn't pass yet") # TODO(frostig) def fun(x): return x - np.log(np.sum(np.exp(x))) pfun, axis_name = papply(fun, 5) jaxpr = make_jaxpr(pfun)(onp.zeros(5)) expected_jaxpr = make_jaxpr( lambda x: x - np.log(lax_parallel.psum(np.exp(x), axis_name)))( onp.zeros(5)) assert repr(jaxpr) == repr(expected_jaxpr) ans = serial_pmap(pfun, axis_name)(onp.arange(1., 5.)) expected = fun(onp.arange(1., 5.)) self.assertAllClose(ans, expected, check_dtypes=False)
def testSelect(self): pfun, axis_name = papply(lax.select, 5, in_axes=(None, 0, None)) p = onp.arange(15).reshape((5, 3)) % 4 == 1 t = onp.ones((5, 3)) f = onp.zeros((5, 3)) jaxpr = make_jaxpr(pfun)(p, t[0], f) def expected_spmd(p, t, f): return lax.select(lax_parallel.psplit_like(p, t, axis_name), t, lax_parallel.psplit_like(f, t, axis_name)) expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f) assert repr(jaxpr) == repr(expected_jaxpr) ans = serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f) expected = lax.select(p, t, f) self.assertAllClose(ans, expected, check_dtypes=True)
def testMap(self): pfun, axis_name = papply(np.sin, 3) ans = pfun(onp.arange(3.)) expected = onp.sin(onp.arange(3.)) self.assertAllClose(ans, expected, check_dtypes=False)
def testIdentity(self): pfun, axis_name = papply(lambda x: x, 3) ans = pfun(onp.arange(3)) expected = onp.arange(3) self.assertAllClose(ans, expected, check_dtypes=False)