async def async_deserialize(mesh, mesh_axes, tensorstore_spec, global_shape=None): t = ts.open(ts.Spec(tensorstore_spec), open=True).result() shape = t.shape if global_shape is None else global_shape requires_padding = prod(shape) > prod(t.shape) if requires_padding: new_shard_shape = gda.get_shard_shape(shape, mesh, mesh_axes) async def cb(index): if requires_padding: # This is needed because the shape the array was saved with is smaller # than the requested shape of the array in which it will be reloaded. So # the extra values will be filled with 0s. out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) requested_domain = ts.IndexTransform( input_shape=shape)[index].domain restricted_domain = t.domain.intersect(requested_domain) await ts.array(out)[ts.d[:].translate_to[requested_domain.origin] ][restricted_domain].write(t[restricted_domain] ) return out else: return await t[index].read() return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb)
async def async_deserialize(ckpt_path, mesh, mesh_axes, tensorstore_spec): t = ts.open(ts.Spec(tensorstore_spec), open=True).result() async def cb(index): return await t[index].read() return await create_async_gsda_from_callback(t.shape, mesh, mesh_axes, cb)
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec, commit_future=None): # 'metadata' may not be present at the top level (for example, if we are using # a 'cast' driver). if not _spec_has_metadata(tensorstore_spec): tensorstore_spec['metadata'] = _get_metadata(gda_inp) t = await ts.open(ts.Spec(tensorstore_spec), create=True, open=True, context=TS_CONTEXT) async def _write_array(shard): if shard.replica_id == 0: write_future = t[shard.index].write(shard.data) if commit_future is not None: assert isinstance(commit_future, list) commit_future.append(write_future.commit) await write_future.copy else: await write_future.commit future_write_state = jax.tree_util.tree_map(_write_array, gda_inp.local_shards) return await asyncio.gather(*future_write_state)
def test_spec_indexing_unknown_rank(): s = ts.Spec({ "driver": "array", "array": [[1, 2], [3, 4]], "dtype": "int32", }) with pytest.raises(ValueError, match="IndexTransform is unspecified"): s[..., ts.newaxis]
def test_spec_indexing(): transform = ts.IndexTransform(input_rank=2) s = ts.Spec({ "driver": "array", "array": [[1, 2], [3, 4]], "dtype": "int32", "transform": { "input_inclusive_min": [["-inf"], ["-inf"]], "input_exclusive_max": [["+inf"], ["+inf"]], }, }) s_transformed = s[..., ts.newaxis] s_expected = ts.Spec({ "driver": "array", "array": [[1, 2], [3, 4]], "dtype": "int32", "transform": transform[..., ts.newaxis].to_json(), }) assert s_transformed == s_expected
async def _write_array(shard): if shard.replica_id == 0: t = await ts.open(ts.Spec(tensorstore_spec), create=True, open=True, context=ts.Context( {'file_io_concurrency': { 'limit': 128 }})) await t[shard.index].write(shard.data)
def test_spec_pickle(): driver_json = {"$type": "array", "array": [[1, 2], [3, 4]]} s = ts.Spec({ "driver": "array", "array": [[1, 2], [3, 4]], "dtype": "int32", "transform": { "input_inclusive_min": [["-inf"], ["-inf"]], "input_exclusive_max": [["+inf"], ["+inf"]], }, }) assert pickle.loads(pickle.dumps(s)) == s
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec, commit_future=None): # 'metadata' may not be present at the top level (for example, if we are using # a 'cast' driver). if not _spec_has_metadata(tensorstore_spec): tensorstore_spec['metadata'] = _get_metadata(gda_inp) if jax.process_index() == 0: open_future = ts.open( ts.Spec(tensorstore_spec), create=True, open=True, context=TS_CONTEXT) # Asynchronous case. if commit_future is not None: assert isinstance(commit_future, list) commit_future.append(open_future) else: await open_future # `ts.open` runs twice for process 0 because for the first time, we just get # the future to be awaited upon in the background thread. The second one runs # with `assume_metadata=True` which does no I/O operation and returns the # tensorstore object. # For every process other than `0`, we open with `assume_metadata=True`. t = await ts.open( ts.Spec(tensorstore_spec), open=True, assume_metadata=True, context=TS_CONTEXT) async def _write_array(shard): if shard.replica_id == 0: write_future = t[shard.index].write(shard.data) if commit_future is not None: assert isinstance(commit_future, list) commit_future.append(write_future.commit) await write_future.copy else: await write_future.commit future_write_state = jax.tree_util.tree_map(_write_array, gda_inp.local_shards) return await asyncio.gather(*future_write_state)
def test_spec_init_json(): s = ts.Spec({ "driver": "array", "array": [[1, 2], [3, 4]], "transform": { "input_rank": 2 }, "dtype": "int32", }) assert s.transform == ts.IndexTransform(input_rank=2) assert s.rank == 2 assert s.ndim == 2 assert s.dtype == ts.int32 assert s.to_json(include_defaults=False) == { "driver": "array", "array": [[1, 2], [3, 4]], "dtype": "int32", "transform": { "input_inclusive_min": [["-inf"], ["-inf"]], "input_exclusive_max": [["+inf"], ["+inf"]], }, } assert s.T == ts.Spec({ "driver": "array", "array": [[1, 2], [3, 4]], "transform": { "input_rank": 2, "output": [ { "input_dimension": 1 }, { "input_dimension": 0 }, ], }, "dtype": "int32", })
def test_spec_indexing_unknown_rank(): s = ts.Spec({ "driver": "zarr", "kvstore": { "driver": "memory" }, "dtype": "int32", }) assert s.rank is None assert s.ndim is None with pytest.raises(ValueError, match="IndexTransform is unspecified"): s[..., ts.newaxis] with pytest.raises(ValueError, match="IndexTransform is unspecified"): s.T
async def async_deserialize(mesh, mesh_axes, tensorstore_spec, global_shape=None): t = ts.open(ts.Spec(tensorstore_spec), open=True).result() shape = t.shape if global_shape is None else global_shape new_shard_shape = gda.get_shard_shape(shape, mesh, mesh_axes) async def cb(index): out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype) requested_domain = ts.IndexTransform(input_shape=shape)[index].domain restricted_domain = t.domain.intersect(requested_domain) await ts.array(out)[ts.d[:].translate_to[requested_domain.origin] ][restricted_domain].write(t[restricted_domain]) return out return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb)
def test_spec_indexing_unknown_rank(): s = ts.Spec({ "driver": "zarr", "kvstore": { "driver": "memory" }, "dtype": "int32", }) assert s.rank is None assert s.ndim is None with pytest.raises( ValueError, match="Cannot perform indexing operations on Spec with unspecified rank"): s[..., ts.newaxis] with pytest.raises( ValueError, match="Cannot perform indexing operations on Spec with unspecified rank"): s.T
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec): if not tensorstore_spec.get('metadata'): tensorstore_spec['metadata'] = _get_metadata(gda_inp) t = await ts.open(ts.Spec(tensorstore_spec), create=True, open=True, context=ts.Context( {'file_io_concurrency': { 'limit': 128 }})) async def _write_array(shard): if shard.replica_id == 0: await t[shard.index].write(shard.data) future_write_state = jax.tree_util.tree_map(_write_array, tuple(gda_inp.local_shards)) return await asyncio.gather(*future_write_state)
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec): # 'metadata' may not be present at the top level (for example, if we are using # a 'cast' driver). if not _spec_has_metadata(tensorstore_spec): tensorstore_spec['metadata'] = _get_metadata(gda_inp) t = await ts.open(ts.Spec(tensorstore_spec), create=True, open=True, context=ts.Context( {'file_io_concurrency': { 'limit': 128 }})) async def _write_array(shard): if shard.replica_id == 0: await t[shard.index].write(shard.data) future_write_state = jax.tree_util.tree_map(_write_array, tuple(gda_inp.local_shards)) return await asyncio.gather(*future_write_state)