Пример #1
0
async def async_deserialize(mesh,
                            mesh_axes,
                            tensorstore_spec,
                            global_shape=None):
    t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
    shape = t.shape if global_shape is None else global_shape
    requires_padding = prod(shape) > prod(t.shape)

    if requires_padding:
        new_shard_shape = gda.get_shard_shape(shape, mesh, mesh_axes)

    async def cb(index):
        if requires_padding:
            # This is needed because the shape the array was saved with is smaller
            # than the requested shape of the array in which it will be reloaded. So
            # the extra values will be filled with 0s.
            out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
            requested_domain = ts.IndexTransform(
                input_shape=shape)[index].domain
            restricted_domain = t.domain.intersect(requested_domain)
            await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]
                                ][restricted_domain].write(t[restricted_domain]
                                                           )
            return out
        else:
            return await t[index].read()

    return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb)
Пример #2
0
async def async_deserialize(ckpt_path, mesh, mesh_axes, tensorstore_spec):
    t = ts.open(ts.Spec(tensorstore_spec), open=True).result()

    async def cb(index):
        return await t[index].read()

    return await create_async_gsda_from_callback(t.shape, mesh, mesh_axes, cb)
Пример #3
0
async def async_serialize(gda_inp: gda.GlobalDeviceArray,
                          tensorstore_spec,
                          commit_future=None):
    # 'metadata' may not be present at the top level (for example, if we are using
    # a 'cast' driver).
    if not _spec_has_metadata(tensorstore_spec):
        tensorstore_spec['metadata'] = _get_metadata(gda_inp)

    t = await ts.open(ts.Spec(tensorstore_spec),
                      create=True,
                      open=True,
                      context=TS_CONTEXT)

    async def _write_array(shard):
        if shard.replica_id == 0:
            write_future = t[shard.index].write(shard.data)
            if commit_future is not None:
                assert isinstance(commit_future, list)
                commit_future.append(write_future.commit)
                await write_future.copy
            else:
                await write_future.commit

    future_write_state = jax.tree_util.tree_map(_write_array,
                                                gda_inp.local_shards)
    return await asyncio.gather(*future_write_state)
Пример #4
0
def test_spec_indexing_unknown_rank():
  s = ts.Spec({
      "driver": "array",
      "array": [[1, 2], [3, 4]],
      "dtype": "int32",
  })
  with pytest.raises(ValueError, match="IndexTransform is unspecified"):
    s[..., ts.newaxis]
Пример #5
0
def test_spec_indexing():
  transform = ts.IndexTransform(input_rank=2)
  s = ts.Spec({
      "driver": "array",
      "array": [[1, 2], [3, 4]],
      "dtype": "int32",
      "transform": {
          "input_inclusive_min": [["-inf"], ["-inf"]],
          "input_exclusive_max": [["+inf"], ["+inf"]],
      },
  })
  s_transformed = s[..., ts.newaxis]
  s_expected = ts.Spec({
      "driver": "array",
      "array": [[1, 2], [3, 4]],
      "dtype": "int32",
      "transform": transform[..., ts.newaxis].to_json(),
  })
  assert s_transformed == s_expected
Пример #6
0
 async def _write_array(shard):
     if shard.replica_id == 0:
         t = await ts.open(ts.Spec(tensorstore_spec),
                           create=True,
                           open=True,
                           context=ts.Context(
                               {'file_io_concurrency': {
                                   'limit': 128
                               }}))
         await t[shard.index].write(shard.data)
Пример #7
0
def test_spec_pickle():
  driver_json = {"$type": "array", "array": [[1, 2], [3, 4]]}
  s = ts.Spec({
      "driver": "array",
      "array": [[1, 2], [3, 4]],
      "dtype": "int32",
      "transform": {
          "input_inclusive_min": [["-inf"], ["-inf"]],
          "input_exclusive_max": [["+inf"], ["+inf"]],
      },
  })
  assert pickle.loads(pickle.dumps(s)) == s
Пример #8
0
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec,
                          commit_future=None):
  # 'metadata' may not be present at the top level (for example, if we are using
  # a 'cast' driver).
  if not _spec_has_metadata(tensorstore_spec):
    tensorstore_spec['metadata'] = _get_metadata(gda_inp)

  if jax.process_index() == 0:
    open_future = ts.open(
        ts.Spec(tensorstore_spec), create=True, open=True, context=TS_CONTEXT)
    # Asynchronous case.
    if commit_future is not None:
      assert isinstance(commit_future, list)
      commit_future.append(open_future)
    else:
      await open_future

  # `ts.open` runs twice for process 0 because for the first time, we just get
  # the future to be awaited upon in the background thread. The second one runs
  # with `assume_metadata=True` which does no I/O operation and returns the
  # tensorstore object.
  # For every process other than `0`, we open with `assume_metadata=True`.
  t = await ts.open(
      ts.Spec(tensorstore_spec), open=True, assume_metadata=True, context=TS_CONTEXT)

  async def _write_array(shard):
    if shard.replica_id == 0:
      write_future = t[shard.index].write(shard.data)
      if commit_future is not None:
        assert isinstance(commit_future, list)
        commit_future.append(write_future.commit)
        await write_future.copy
      else:
        await write_future.commit

  future_write_state = jax.tree_util.tree_map(_write_array,
                                              gda_inp.local_shards)
  return await asyncio.gather(*future_write_state)
Пример #9
0
def test_spec_init_json():
    s = ts.Spec({
        "driver": "array",
        "array": [[1, 2], [3, 4]],
        "transform": {
            "input_rank": 2
        },
        "dtype": "int32",
    })
    assert s.transform == ts.IndexTransform(input_rank=2)
    assert s.rank == 2
    assert s.ndim == 2
    assert s.dtype == ts.int32
    assert s.to_json(include_defaults=False) == {
        "driver": "array",
        "array": [[1, 2], [3, 4]],
        "dtype": "int32",
        "transform": {
            "input_inclusive_min": [["-inf"], ["-inf"]],
            "input_exclusive_max": [["+inf"], ["+inf"]],
        },
    }
    assert s.T == ts.Spec({
        "driver": "array",
        "array": [[1, 2], [3, 4]],
        "transform": {
            "input_rank": 2,
            "output": [
                {
                    "input_dimension": 1
                },
                {
                    "input_dimension": 0
                },
            ],
        },
        "dtype": "int32",
    })
Пример #10
0
def test_spec_indexing_unknown_rank():
    s = ts.Spec({
        "driver": "zarr",
        "kvstore": {
            "driver": "memory"
        },
        "dtype": "int32",
    })
    assert s.rank is None
    assert s.ndim is None
    with pytest.raises(ValueError, match="IndexTransform is unspecified"):
        s[..., ts.newaxis]
    with pytest.raises(ValueError, match="IndexTransform is unspecified"):
        s.T
Пример #11
0
async def async_deserialize(mesh,
                            mesh_axes,
                            tensorstore_spec,
                            global_shape=None):
    t = ts.open(ts.Spec(tensorstore_spec), open=True).result()
    shape = t.shape if global_shape is None else global_shape
    new_shard_shape = gda.get_shard_shape(shape, mesh, mesh_axes)

    async def cb(index):
        out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
        requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
        restricted_domain = t.domain.intersect(requested_domain)
        await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]
                            ][restricted_domain].write(t[restricted_domain])
        return out

    return await create_async_gda_from_callback(shape, mesh, mesh_axes, cb)
Пример #12
0
def test_spec_indexing_unknown_rank():
  s = ts.Spec({
      "driver": "zarr",
      "kvstore": {
          "driver": "memory"
      },
      "dtype": "int32",
  })
  assert s.rank is None
  assert s.ndim is None
  with pytest.raises(
      ValueError,
      match="Cannot perform indexing operations on Spec with unspecified rank"):
    s[..., ts.newaxis]
  with pytest.raises(
      ValueError,
      match="Cannot perform indexing operations on Spec with unspecified rank"):
    s.T
Пример #13
0
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec):
    if not tensorstore_spec.get('metadata'):
        tensorstore_spec['metadata'] = _get_metadata(gda_inp)

    t = await ts.open(ts.Spec(tensorstore_spec),
                      create=True,
                      open=True,
                      context=ts.Context(
                          {'file_io_concurrency': {
                              'limit': 128
                          }}))

    async def _write_array(shard):
        if shard.replica_id == 0:
            await t[shard.index].write(shard.data)

    future_write_state = jax.tree_util.tree_map(_write_array,
                                                tuple(gda_inp.local_shards))
    return await asyncio.gather(*future_write_state)
Пример #14
0
async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec):
    # 'metadata' may not be present at the top level (for example, if we are using
    # a 'cast' driver).
    if not _spec_has_metadata(tensorstore_spec):
        tensorstore_spec['metadata'] = _get_metadata(gda_inp)

    t = await ts.open(ts.Spec(tensorstore_spec),
                      create=True,
                      open=True,
                      context=ts.Context(
                          {'file_io_concurrency': {
                              'limit': 128
                          }}))

    async def _write_array(shard):
        if shard.replica_id == 0:
            await t[shard.index].write(shard.data)

    future_write_state = jax.tree_util.tree_map(_write_array,
                                                tuple(gda_inp.local_shards))
    return await asyncio.gather(*future_write_state)