Example #1
0
 def test_create_device_mesh_for_tpu_v4(self, devices, mesh_shape,
                                        expected_assignment):
   jax_local_devices_from_process_0 = mock_2x2x1_devices(True)
   jax_devices = devices(True)
   physical_mesh = mesh_utils._jax_devices_order_normalized(
       jax_local_devices_from_process_0, jax_devices)
   _, assignment = mesh_utils._create_device_mesh_for_tpu_v4(
       physical_mesh, mesh_shape)
   self.assertEqual(assignment, expected_assignment)
Example #2
0
 def test_jax_devices_order_normalized(self, devices, expected_shape):
   jax_local_devices_from_process_0 = mock_2x2x1_devices(True)
   jax_devices = devices(True)
   normalized = mesh_utils._jax_devices_order_normalized(
       jax_local_devices_from_process_0, jax_devices)
   self.assertEqual(normalized.shape, expected_shape)
   x, y, z = expected_shape
   # major_to_minor: x, y, z
   for i in range(x):
     for j in range(y):
       for k in range(z):
         self.assertEqual(normalized[i, j, k].coords, (i, j, k))