Пример #1
0
 def testBaseCase(self):
     # pick leaf objects with leading dimension one as these tests will
     # be run on a single device.
     data = {"a": jnp.array([1]), "b": jnp.array([2])}
     data_summed = jax.pmap(lambda x: utils.tree_psum(x, axis_name="i"),
                            axis_name="i")(data)
     self.assertEqual(data_summed, data)
Пример #2
0
 def testAxisNameMismatch(self):
     data = jnp.array([1])
     with self.assertRaises(NameError):
         jax.pmap(lambda x: utils.tree_psum(x, axis_name="i"),
                  axis_name="j")(data)
Пример #3
0
 def testNoPmapWrapper(self):
     with self.assertRaises(NameError):  # axis_name will be undefined
         utils.tree_psum(jnp.array([1]), axis_name="i")
Пример #4
0
 def testNumDevicesMismatch(self):
     data = jnp.array([1, 2])  # assumes 2 devices but we only have 1
     with self.assertRaises(ValueError):
         jax.pmap(lambda x: utils.tree_psum(x, axis_name="i"),
                  axis_name="i")(data)
Пример #5
0
 def testNotNumpy(self):
     data = [1]
     with self.assertRaises(ValueError):
         jax.pmap(lambda x: utils.tree_psum(x, axis_name="i"),
                  axis_name="i")(data)
Пример #6
0
 def testSingleLeafTree(self):
     data = jnp.array([1])
     data_summed = jax.pmap(lambda x: utils.tree_psum(x, axis_name="i"),
                            axis_name="i")(data)
     self.assertEqual(data_summed, data)
Пример #7
0
 def testEmpty(self):
     data = {"a": jnp.array([]), "b": jnp.array([])}
     with self.assertRaises(ZeroDivisionError):
         jax.pmap(lambda x: utils.tree_psum(x, axis_name="i"),
                  axis_name="i")(data)