예제 #1
0
 def test_v3_create_device_mesh(self, devices, mesh_shape,
                                expected_device_id_mesh):
     jax_local_devices_from_process_0 = mock_2x2_devices()
     global_devices = devices()
     mesh = mesh_utils._create_device_mesh(jax_local_devices_from_process_0,
                                           global_devices,
                                           global_devices[0].device_kind,
                                           mesh_shape,
                                           contiguous_submeshes=False)
     device_id_mesh = np.vectorize(lambda d: d.id)(mesh)
     self.assertAllClose(np.array(expected_device_id_mesh), device_id_mesh)
예제 #2
0
    def test_create_contiguous_submeshes_errors(self):
        process_0_devices = mock_2x2x1_devices(True)
        v4 = mesh_utils._TPU_V4

        topology = (4, 4, 8)
        mesh_shape = (1, 16, 8)
        devices = mock_devices(topology[0],
                               topology[1],
                               topology[2],
                               v4,
                               two_cores_per_chip=True)
        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "create_device_mesh cannot create contiguous submeshes for "
                "physical mesh topology (4, 4, 8)"):
            mesh_utils._create_device_mesh(process_0_devices,
                                           devices,
                                           v4,
                                           mesh_shape,
                                           contiguous_submeshes=True)

        topology = (4, 8, 8)
        mesh_shape = (1, 128, 2)
        devices = mock_devices(topology[0],
                               topology[1],
                               topology[2],
                               v4,
                               two_cores_per_chip=True)
        with self.assertRaisesWithLiteralMatch(
                ValueError,
                "create_device_mesh cannot create contiguous submeshes for mesh_shape "
                "(1, 128, 2) and physical mesh topology (4, 8, 8). "
                "Available mesh_shapes: [(1, 64, 4), (1, 4, 64), (64, 4), (4, 64)]"
        ):
            mesh_utils._create_device_mesh(process_0_devices,
                                           devices,
                                           v4,
                                           mesh_shape,
                                           contiguous_submeshes=True)
예제 #3
0
 def test_create_contiguous_submeshes_for_tpu_v4(self):
     v4 = mesh_utils._TPU_V4
     process_0_devices = mock_2x2x1_devices(True)
     for topology, mesh_shapes in mesh_utils._TRANSPOSE_TRICKS.items():
         logging.vlog(1, "topology: %s", topology)
         devices = mock_devices(topology[0],
                                topology[1],
                                topology[2],
                                v4,
                                two_cores_per_chip=True)
         for mesh_shape in mesh_shapes:
             logging.vlog(1, "  mesh_shape: %s", mesh_shape)
             mesh = mesh_utils._create_device_mesh(
                 process_0_devices,
                 devices,
                 v4,
                 mesh_shape,
                 contiguous_submeshes=True)
             self._assert_contiguous_submeshes(mesh)