async def async_deserialize(mesh, mesh_axes, tensorstore_spec, global_shape=None): t = ts.open(ts.Spec(tensorstore_spec), open=True).result() shape = t.shape if global_shape is None else global_shape requires_padding = prod(shape) > prod(t.shape) if requires_padding: new_shard_shape = gda.get_shard_shape(shape, mesh, mesh_axes) async def cb(index): if requires_padding: # This is needed because the shape the array was saved with is smaller # than the requested shape of the array in which it will be reloaded. So # the extra values will be filled with 0s. out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) requested_domain = ts.IndexTransform( input_shape=shape)[index].domain restricted_domain = t.domain.intersect(requested_domain) await ts.array(out)[ts.d[:].translate_to[requested_domain.origin] ][restricted_domain].write(t[restricted_domain] ) return out else: return await t[index].read() return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb)
async def async_deserialize(mesh, mesh_axes, tensorstore_spec, global_shape=None): t = ts.open(ts.Spec(tensorstore_spec), open=True).result() shape = t.shape if global_shape is None else global_shape new_shard_shape = gda.get_shard_shape(shape, mesh, mesh_axes) async def cb(index): out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) requested_domain = ts.IndexTransform(input_shape=shape)[index].domain restricted_domain = t.domain.intersect(requested_domain) await ts.array(out)[ts.d[:].translate_to[requested_domain.origin] ][restricted_domain].write(t[restricted_domain]) return out return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb)
def test_gda_shape_0_1d_mesh(self): global_mesh = jtu.create_global_mesh((8,), ('x')) global_input_shape = (0,) mesh_axes = P(None) def cb(index): return np.array([]) gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) self.assertEqual(gda.ndim, 1) self.assertEqual(gda.size, 0) for i, s in enumerate(gda.local_shards): self.assertEqual(s.index, (slice(None),)) self.assertEqual(s.replica_id, i) self.assertArraysEqual(s.data.to_py(), np.array([])) self.assertEqual(gda.dtype, np.float32) self.assertEqual( gda_lib.get_shard_shape(global_input_shape, global_mesh, mesh_axes), (0,))