def testBaseCase(self): # pick leaf objects with leading dimension one as these tests will # be run on a single device. data = {"a": jnp.array([1]), "b": jnp.array([2])} data_summed = jax.pmap(lambda x: utils.tree_psum(x, axis_name="i"), axis_name="i")(data) self.assertEqual(data_summed, data)
def testAxisNameMismatch(self): data = jnp.array([1]) with self.assertRaises(NameError): jax.pmap(lambda x: utils.tree_psum(x, axis_name="i"), axis_name="j")(data)
def testNoPmapWrapper(self): with self.assertRaises(NameError): # axis_name will be undefined utils.tree_psum(jnp.array([1]), axis_name="i")
def testNumDevicesMismatch(self): data = jnp.array([1, 2]) # assumes 2 devices but we only have 1 with self.assertRaises(ValueError): jax.pmap(lambda x: utils.tree_psum(x, axis_name="i"), axis_name="i")(data)
def testNotNumpy(self): data = [1] with self.assertRaises(ValueError): jax.pmap(lambda x: utils.tree_psum(x, axis_name="i"), axis_name="i")(data)
def testSingleLeafTree(self): data = jnp.array([1]) data_summed = jax.pmap(lambda x: utils.tree_psum(x, axis_name="i"), axis_name="i")(data) self.assertEqual(data_summed, data)
def testEmpty(self): data = {"a": jnp.array([]), "b": jnp.array([])} with self.assertRaises(ZeroDivisionError): jax.pmap(lambda x: utils.tree_psum(x, axis_name="i"), axis_name="i")(data)