Esempio n. 1
0
def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
  lhs_data = jnp.asarray(lhs_data)
  lhs_indices = jnp.asarray(lhs_indices)
  rhs = jnp.asarray(rhs)
  # Validate all inputs via abstract_eval
  out_aval = _bcoo_dot_general_abstract_eval(lhs_data.aval, lhs_indices.aval, rhs.aval,
                                             dimension_numbers=dimension_numbers,
                                             lhs_shape=lhs_shape)

  (lhs_contracting, rhs_contracting) , (lhs_batch, rhs_batch) = dimension_numbers
  n_sparse = lhs_indices.shape[-2]
  n_batch = lhs_indices.ndim - 2

  # Move lhs batch dimensions to the front
  if lhs_batch:
    perm = list(lhs_batch) + remaining(range(n_batch), lhs_batch)
    lhs_data = lhs_data.transpose(perm + list(range(n_batch, lhs_data.ndim)))
    lhs_indices = lhs_indices.transpose(perm + list(range(n_batch, lhs_indices.ndim)))

  # Move lhs contracting dimensions to the front of sparse dims, in order
  n_contracting = len(lhs_contracting)
  lhs_contracting = [d - n_batch for d in lhs_contracting]
  perm = list(lhs_contracting) + remaining(range(n_sparse), lhs_contracting)
  lhs_indices = lhs_indices[..., jnp.array(perm), :]

  # Move rhs batch dimensions then contracting dimensions to the front, in order
  perm = (list(rhs_batch) + list(rhs_contracting) +
          remaining(range(rhs.ndim), rhs_batch, rhs_contracting))
  rhs = rhs.transpose(perm)

  out_array = jnp.zeros(out_aval.shape, out_aval.dtype)
  def result(out_array, lhs_data, lhs_indices, rhs):
    idx = tuple(lhs_indices)
    idx_right, idx_out = idx[:n_contracting], idx[n_contracting:]
    ctc = [0] if n_contracting else []
    prod = lax.dot_general(lhs_data, rhs[idx_right], (([], []), (ctc, ctc)))
    return out_array.at[idx_out].add(prod) if idx_out else prod.sum(0)
  for i in range(n_batch)[::-1]:
    axes_in = [0, 0, 0, 0]
    if lhs_data.shape[i] == 1:
      lhs_data = lax.squeeze(lhs_data, (i,))
      axes_in[1] = None
    if lhs_indices.shape[i] == 1:
      lhs_indices = lax.squeeze(lhs_indices, (i,))
      axes_in[2] = None
    if i >= len(lhs_batch):
      axes_in[3] = None
    result = vmap(result, tuple(axes_in))
  return result(out_array, lhs_data, lhs_indices, rhs)
Esempio n. 2
0
class JAX2DexTest(unittest.TestCase):
    test_sin = lax_test(lax.sin, lambda: (rn(10, 10), ))
    test_cos = lax_test(lax.cos, lambda: (rn(10, 10), ))
    test_neg = lax_test(lax.neg, lambda: (rn(10, 10), ))
    test_log = lax_test(lax.log, lambda: (rn(10, 10), ))
    test_exp = lax_test(lax.exp, lambda: (rn(10, 10), ))
    test_pow = lax_test(lax.pow, lambda:
                        (rn(10), jnp.arange(10, dtype=np.float32)))
    test_integer_pow = lax_test(lambda x: lax.integer_pow(x, 2), lambda:
                                (rn(10, 10), ))
    test_scalar_select_lt = lax_test(lambda i, x, y: lax.select(i < 2.0, x, y),
                                     lambda: (1.0, rn(10), rn(10)))

    test_squeeze_none = lax_test(lambda x: lax.squeeze(x, []), lambda:
                                 (rn(10, 10), ))
    test_squeeze_one = lax_test(lambda x: lax.squeeze(x, [1]), lambda:
                                (rn(10, 1, 10), ))
    test_squeeze_two = lax_test(lambda x: lax.squeeze(x, [0, 2]), lambda:
                                (rn(1, 10, 1), ))
    test_squeeze_all = lax_test(lambda x: lax.squeeze(x, [0, 1]), lambda:
                                (rn(1, 1), ))

    test_slice_1d = lax_test(lambda x: lax.slice(x, [2], [5], None), lambda:
                             (rn(10), ))
    test_slice_3d = lax_test(
        lambda x: lax.slice(x, [2, 0, 0], [5, 10, 2], None), lambda:
        (rn(10, 10, 2), ))

    test_concat_uniform = lax_test(partial(lax.concatenate, dimension=0),
                                   lambda: ([rn(4, 2) for _ in range(3)], ))
    test_concat_ragged = lax_test(
        partial(lax.concatenate, dimension=0), lambda:
        ([rn(1, 2, 4), rn(5, 2, 4), rn(2, 2, 4)], ))

    test_dot_general_matmul = lax_test(
        partial(lax.dot_general, dimension_numbers=(((1, ), (0, )), ((), ()))),
        lambda: (rn(4, 8), rn(8, 16)))
    test_dot_general_matvec = lax_test(
        partial(lax.dot_general, dimension_numbers=(((1, ), (0, )), ((), ()))),
        lambda: (rn(4, 8), rn(8)))

    def test_canonicalize_dtype(self):
        c = np.arange(5, dtype=np.float64)
        f = lambda x: x * c
        x = np.ones(5, dtype=np.float64)
        dy = dexjit(f)(x)
        jy = jax.jit(f)(x)
        np.testing.assert_allclose(dy, jy)
        self.assertEqual(dy.dtype, jy.dtype)
Esempio n. 3
0
def _squeeze_sparse(spenv, *spvalues, dimensions):
  arr, = spvalues
  dimensions = tuple(canonicalize_axis(dim, arr.ndim) for dim in dimensions)
  if any(arr.shape[dim] != 1 for dim in dimensions):
    raise ValueError("cannot select an axis to squeeze out which has size not equal to one, "
                     f"got shape={arr.shape} and dimensions={dimensions}")
  data = spenv.data(arr)
  indices = spenv.indices(arr)
  n_sparse = indices.shape[-1]
  n_batch = indices.ndim - 2
  batch_dims = tuple(d for d in dimensions if d < n_batch)
  sparse_dims = np.array([i for i in range(n_sparse) if i + n_batch not in dimensions], dtype=int)
  dense_dims = tuple(d - n_sparse + 1 for d in dimensions if d >= n_batch + n_sparse)
  data_out = lax.squeeze(data, batch_dims + dense_dims)
  indices_out = lax.squeeze(indices[..., sparse_dims], batch_dims)
  out_shape = tuple(s for i, s in enumerate(arr.shape) if i not in dimensions)
  return (spenv.sparse(out_shape, data_out, indices_out),)
Esempio n. 4
0
def squeeze_dependency_rule(outstart, outcount, operand, dimensions):
    if not is_ones(outcount):
        raise NotImplementedError
    instart = list(outstart)
    inshape = list(outcount.shape)
    for d in np.sort(dimensions):
        instart.insert(d, 0)
        inshape.insert(d, 1)
    return ([(instart, inshape)], [Ones(inshape)],
            lambda inslice: lax.squeeze(inslice, dimensions))
Esempio n. 5
0
 def testSqueeze(self, arg_shape, dimensions, bdims):
     dtype = np.float32
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.squeeze(x, dimensions)
     self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)
Esempio n. 6
0
 def testSqueeze(self, arg_shape, dimensions, bdims, rng_factory):
     dtype = onp.float32
     rng = rng_factory(self.rng())
     op = lambda x: lax.squeeze(x, dimensions)
     self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)
Esempio n. 7
0
def test_squeeze(shape, dtype, dimensions, rng_factory):
    rng = rng_factory(np.random)
    args = [rng(shape, dtype)]
    tu.check_lazy_fun(lambda x: lax.squeeze(x, dimensions), *args)
Esempio n. 8
0
 def squeeze_op(self, params, inputs):
     return lax.squeeze(inputs, (self.axis, ))