def test_v3_create_device_mesh(self, devices, mesh_shape, expected_device_id_mesh): jax_local_devices_from_process_0 = mock_2x2_devices() global_devices = devices() mesh = mesh_utils._create_device_mesh(jax_local_devices_from_process_0, global_devices, global_devices[0].device_kind, mesh_shape, contiguous_submeshes=False) device_id_mesh = np.vectorize(lambda d: d.id)(mesh) self.assertAllClose(np.array(expected_device_id_mesh), device_id_mesh)
def test_create_contiguous_submeshes_errors(self): process_0_devices = mock_2x2x1_devices(True) v4 = mesh_utils._TPU_V4 topology = (4, 4, 8) mesh_shape = (1, 16, 8) devices = mock_devices(topology[0], topology[1], topology[2], v4, two_cores_per_chip=True) with self.assertRaisesWithLiteralMatch( ValueError, "create_device_mesh cannot create contiguous submeshes for " "physical mesh topology (4, 4, 8)"): mesh_utils._create_device_mesh(process_0_devices, devices, v4, mesh_shape, contiguous_submeshes=True) topology = (4, 8, 8) mesh_shape = (1, 128, 2) devices = mock_devices(topology[0], topology[1], topology[2], v4, two_cores_per_chip=True) with self.assertRaisesWithLiteralMatch( ValueError, "create_device_mesh cannot create contiguous submeshes for mesh_shape " "(1, 128, 2) and physical mesh topology (4, 8, 8). " "Available mesh_shapes: [(1, 64, 4), (1, 4, 64), (64, 4), (4, 64)]" ): mesh_utils._create_device_mesh(process_0_devices, devices, v4, mesh_shape, contiguous_submeshes=True)
def test_create_contiguous_submeshes_for_tpu_v4(self): v4 = mesh_utils._TPU_V4 process_0_devices = mock_2x2x1_devices(True) for topology, mesh_shapes in mesh_utils._TRANSPOSE_TRICKS.items(): logging.vlog(1, "topology: %s", topology) devices = mock_devices(topology[0], topology[1], topology[2], v4, two_cores_per_chip=True) for mesh_shape in mesh_shapes: logging.vlog(1, " mesh_shape: %s", mesh_shape) mesh = mesh_utils._create_device_mesh( process_0_devices, devices, v4, mesh_shape, contiguous_submeshes=True) self._assert_contiguous_submeshes(mesh)