コード例 #1
0
ファイル: scatter.py プロジェクト: sharadmv/jax
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, normalize_indices):
    dtype = lax.dtype(x)
    x, y = jnp._promote_dtypes(x, y)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indices_are_sorted,
                     unique_indices=unique_indices)
    return lax.convert_element_type(out, dtype)
コード例 #2
0
ファイル: csr.py プロジェクト: rsepassi/jax
 def matmat(self, B):
     data, B = _promote_dtypes(self.data, B)
     return csr_matmat(data,
                       self.indices,
                       self.indptr,
                       B,
                       shape=self.shape[::-1],
                       transpose=True)
コード例 #3
0
ファイル: csr.py プロジェクト: rsepassi/jax
 def matvec(self, v):
     data, v = _promote_dtypes(self.data, v)
     return csr_matvec(data,
                       self.indices,
                       self.indptr,
                       v,
                       shape=self.shape[::-1],
                       transpose=True)
コード例 #4
0
 def __matmul__(self, other):
   if isinstance(other, JAXSparse):
     raise NotImplementedError("matmul between two sparse objects.")
   other = jnp.asarray(other)
   data, other = _promote_dtypes(self.data, other)
   if other.ndim == 1:
     return csr_matvec(data, self.indices, self.indptr, other, shape=self.shape)
   elif other.ndim == 2:
     return csr_matmat(data, self.indices, self.indptr, other, shape=self.shape)
   else:
     raise NotImplementedError(f"matmul with object of shape {other.shape}")
コード例 #5
0
ファイル: coo.py プロジェクト: cloudhan/jax
 def __matmul__(self, other):
   if isinstance(other, JAXSparse):
     raise NotImplementedError("matmul between two sparse objects.")
   other = jnp.asarray(other)
   data, other = _promote_dtypes(self.data, other)
   self_promoted = COO((data, self.row, self.col), **self._info._asdict())
   if other.ndim == 1:
     return coo_matvec(self_promoted, other)
   elif other.ndim == 2:
     return coo_matmat(self_promoted, other)
   else:
     raise NotImplementedError(f"matmul with object of shape {other.shape}")
コード例 #6
0
ファイル: scatter.py プロジェクト: xueeinstein/jax
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, mode, normalize_indices):
    dtype = lax.dtype(x)
    weak_type = dtypes.is_weakly_typed(x)

    if dtype != dtypes.result_type(x, y):
        # TODO(jakevdp): change this to an error after the deprecation period.
        warnings.warn(
            "scatter inputs have incompatible types: cannot safely cast "
            f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
            "In future JAX releases this will result in an error.",
            FutureWarning)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Avoid calling scatter if the slice shape is empty, both as a fast path and
    # to handle cases like zeros(0)[array([], int32)].
    if core.is_empty_shape(indexer.slice_shape):
        return x

    x, y = jnp._promote_dtypes(x, y)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indexer.indices_are_sorted
                     or indices_are_sorted,
                     unique_indices=indexer.unique_indices or unique_indices,
                     mode=mode)
    return lax_internal._convert_element_type(out, dtype, weak_type)
コード例 #7
0
def block_diag(*arrs):
  if len(arrs) == 0:
    arrs = [jnp.zeros((1, 0))]
  arrs = jnp._promote_dtypes(*arrs)
  bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2]
  if bad_shapes:
    raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at "
                     "most 2 dimensions, got {} at argument {}."
                     .format(arrs[bad_shapes[0]], bad_shapes[0]))
  arrs = [jnp.atleast_2d(a) for a in arrs]
  acc = arrs[0]
  dtype = lax.dtype(acc)
  for a in arrs[1:]:
    _, c = a.shape
    a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0)))
    acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0)))
    acc = lax.concatenate([acc, a], dimension=0)
  return acc
コード例 #8
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, mode, normalize_indices):
    dtype = lax.dtype(x)
    weak_type = dtypes.is_weakly_typed(x)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Avoid calling scatter if the slice shape is empty, both as a fast path and
    # to handle cases like zeros(0)[array([], int32)].
    if core.is_empty_shape(indexer.slice_shape):
        return x

    x, y = jnp._promote_dtypes(x, y)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indexer.indices_are_sorted
                     or indices_are_sorted,
                     unique_indices=indexer.unique_indices or unique_indices,
                     mode=mode)
    return lax._convert_element_type(out, dtype, weak_type)
コード例 #9
0
ファイル: csr.py プロジェクト: rsepassi/jax
 def matmat(self, B):
     data, B = _promote_dtypes(self.data, B)
     return csr_matmat(data, self.indices, self.indptr, B, shape=self.shape)
コード例 #10
0
ファイル: csr.py プロジェクト: rsepassi/jax
 def matvec(self, v):
     data, v = _promote_dtypes(self.data, v)
     return csr_matvec(data, self.indices, self.indptr, v, shape=self.shape)
コード例 #11
0
ファイル: coo.py プロジェクト: rsepassi/jax
 def matmat(self, B):
   data, B = _promote_dtypes(self.data, B)
   return coo_matmat(data, self.row, self.col, B, shape=self.shape)
コード例 #12
0
ファイル: coo.py プロジェクト: rsepassi/jax
 def matvec(self, v):
   data, v = _promote_dtypes(self.data, v)
   return coo_matvec(data, self.row, self.col, v, shape=self.shape)