コード例 #1
0
ファイル: serialization.py プロジェクト: John1Tang/jax
async def async_deserialize(mesh,
                            mesh_axes,
                            tensorstore_spec,
                            global_shape=None):
    t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
    shape = t.shape if global_shape is None else global_shape
    requires_padding = prod(shape) > prod(t.shape)

    if requires_padding:
        new_shard_shape = gda.get_shard_shape(shape, mesh, mesh_axes)

    async def cb(index):
        if requires_padding:
            # This is needed because the shape the array was saved with is smaller
            # than the requested shape of the array in which it will be reloaded. So
            # the extra values will be filled with 0s.
            out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
            requested_domain = ts.IndexTransform(
                input_shape=shape)[index].domain
            restricted_domain = t.domain.intersect(requested_domain)
            await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]
                                ][restricted_domain].write(t[restricted_domain]
                                                           )
            return out
        else:
            return await t[index].read()

    return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb)
コード例 #2
0
async def async_deserialize(mesh,
                            mesh_axes,
                            tensorstore_spec,
                            global_shape=None):
    t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
    shape = t.shape if global_shape is None else global_shape
    new_shard_shape = gda.get_shard_shape(shape, mesh, mesh_axes)

    async def cb(index):
        out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
        requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
        restricted_domain = t.domain.intersect(requested_domain)
        await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]
                            ][restricted_domain].write(t[restricted_domain])
        return out

    return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb)
コード例 #3
0
 def test_gda_shape_0_1d_mesh(self):
   global_mesh = jtu.create_global_mesh((8,), ('x'))
   global_input_shape = (0,)
   mesh_axes = P(None)
   def cb(index):
     return np.array([])
   gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh,
                                         mesh_axes, cb)
   self.assertEqual(gda.ndim, 1)
   self.assertEqual(gda.size, 0)
   for i, s in enumerate(gda.local_shards):
     self.assertEqual(s.index, (slice(None),))
     self.assertEqual(s.replica_id, i)
     self.assertArraysEqual(s.data.to_py(), np.array([]))
   self.assertEqual(gda.dtype, np.float32)
   self.assertEqual(
       gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes),
       (0,))