コード例 #1
0
ファイル: custom_object_test.py プロジェクト: qqsun8819/jax
 def build_sparse_array(data_buf, indices_buf):
     data = xla.DeviceArray(aval.data_aval, device,
                            lazy.array(aval.data_aval.shape), data_buf)
     indices = xla.DeviceArray(aval.indices_aval, device,
                               lazy.array(aval.indices_aval.shape),
                               indices_buf)
     return SparseArray(aval, data, indices)
コード例 #2
0
ファイル: dlpack.py プロジェクト: raj0088/jax
def from_dlpack(dlpack, backend=None):
    """Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.

  The returned `DeviceArray` shares memory with `dlpack`.

  Args:
    dlpack: a DLPack tensor, on either CPU or GPU.
    backend: experimental, optional: the platform on which `dlpack` lives.
  """
    # TODO(phawkins): ideally the user wouldn't need to provide a backend and we
    # would be able to figure it out from the DLPack.
    backend = backend or xla_bridge.get_backend()
    client = getattr(backend, "client", backend)
    buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)
    xla_shape = buf.shape()
    assert not xla_shape.is_tuple()
    aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
    return xla.DeviceArray(aval, buf.device(), lazy.array(aval.shape), buf)  # pytype: disable=attribute-error