예제 #1
0
 def build_sparse_array(data_buf, indices_buf):
     data = xla.make_device_array(aval.data_aval, device,
                                  lazy.array(aval.data_aval.shape),
                                  data_buf)
     indices = xla.make_device_array(aval.indices_aval, device,
                                     lazy.array(aval.indices_aval.shape),
                                     indices_buf)
     return SparseArray(aval, data, indices)
예제 #2
0
def from_dlpack(dlpack, backend=None):
    """Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.

  The returned `DeviceArray` shares memory with `dlpack`.

  Args:
    dlpack: a DLPack tensor, on either CPU or GPU.
    backend: deprecated, do not use.
  """
    if jax.lib._xla_extension_version >= 25:
        cpu_backend = xla_bridge.get_backend("cpu")
        try:
            gpu_backend = xla_bridge.get_backend("gpu")
        except RuntimeError:
            gpu_backend = None
        buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
            dlpack, cpu_backend, gpu_backend)
    else:
        # TODO(phawkins): drop the backend argument after deleting this case.
        backend = backend or xla_bridge.get_backend()
        client = getattr(backend, "client", backend)
        buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)

    xla_shape = buf.xla_shape()
    assert not xla_shape.is_tuple()
    aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
    return xla.make_device_array(aval, buf.device(), buf)  # pytype: disable=attribute-error
예제 #3
0
def from_dlpack(dlpack, backend=None):
    """Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.

  The returned `DeviceArray` shares memory with `dlpack`.

  Args:
    dlpack: a DLPack tensor, on either CPU or GPU.
    backend: experimental, optional: the platform on which `dlpack` lives.
  """
    # TODO(phawkins): ideally the user wouldn't need to provide a backend and we
    # would be able to figure it out from the DLPack.
    backend = backend or xla_bridge.get_backend()
    client = getattr(backend, "client", backend)
    buf = xla_client._xla.dlpack_managed_tensor_to_buffer(dlpack, client)
    xla_shape = buf.xla_shape()
    assert not xla_shape.is_tuple()
    aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
    return xla.make_device_array(aval, buf.device(), buf)  # pytype: disable=attribute-error
예제 #4
0
파일: dlpack.py 프로젝트: Jakob-Unfried/jax
def from_dlpack(dlpack):
    """Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.

  The returned `DeviceArray` shares memory with `dlpack`.

  Args:
    dlpack: a DLPack tensor, on either CPU or GPU.
  """
    cpu_backend = xla_bridge.get_backend("cpu")
    try:
        gpu_backend = xla_bridge.get_backend("gpu")
    except RuntimeError:
        gpu_backend = None
    buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
        dlpack, cpu_backend, gpu_backend)

    xla_shape = buf.xla_shape()
    assert not xla_shape.is_tuple()
    aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
    return xla.make_device_array(aval, buf.device(), buf)  # pytype: disable=attribute-error