Пример #1
0
 def test_repr(self):
     with jax._src.config.jax_array(True):
         global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
         input_shape = (8, 2)
         arr, _ = create_array(
             input_shape,
             sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
         repr(arr)  # doesn't crash
Пример #2
0
 def test_array_device_get(self):
     with jax._src.config.jax_array(True):
         global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
         input_shape = (8, 2)
         arr, input_data = create_array(
             input_shape,
             sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
         self.assertArraysEqual(jax.device_get(arr), input_data)
Пример #3
0
 def test_array_sharded_astype(self):
     with jax._src.config.jax_array(True):
         global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
         input_shape = (8, 2)
         arr, input_data = create_array(
             input_shape,
             sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
         arr_float32 = arr.astype(jnp.float32)
         self.assertEqual(arr_float32.dtype, np.float32)
         self.assertArraysEqual(arr_float32, input_data.astype(np.float32))
Пример #4
0
 def test_array_delete(self):
     with jax._src.config.jax_array(True):
         global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
         input_shape = (8, 2)
         arr, _ = create_array(
             input_shape,
             sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
         arr.delete()
         with self.assertRaisesRegex(ValueError, 'Array has been deleted.'):
             arr._check_if_deleted()
         self.assertIsNone(arr._npy_value)
         self.assertIsNone(arr._arrays)
Пример #5
0
 def test_jax_array_value(self, mesh_axes):
     with jax._src.config.jax_array(True):
         global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
         input_shape = (8, 2)
         arr, global_data = create_array(
             input_shape,
             sharding.MeshPspecSharding(global_mesh, mesh_axes))
         for s in arr.addressable_shards:
             self.assertLen(s.data._arrays, 1)
             self.assertArraysEqual(s.data._arrays[0], global_data[s.index])
         self.assertArraysEqual(arr._value, global_data)
         self.assertArraysEqual(arr._npy_value, global_data)
Пример #6
0
    def test_mesh_pspec_sharding_interface(self):
        mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
        pspec = P('y', 'x')
        global_shape = (8, 4)
        mp_sharding = sharding.MeshPspecSharding(mesh, pspec)
        di_map = mp_sharding.devices_indices_map(global_shape)
        op_sharding = mp_sharding._to_xla_op_sharding(len(global_shape))
        device_assignment = mp_sharding._device_assignment()

        self.assertEqual(di_map[mesh.devices.flat[0]],
                         (slice(0, 4), slice(0, 1)))
        self.assertArraysEqual(device_assignment, list(mesh.devices.flat))
        self.assertEqual(op_sharding.type, xc.OpSharding.Type.OTHER)
        self.assertListEqual(op_sharding.tile_assignment_dimensions, [2, 4])
        self.assertListEqual(op_sharding.tile_assignment_devices,
                             [0, 2, 4, 6, 1, 3, 5, 7])