def _bcoo_dot_general_impl(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape): lhs_data = jnp.asarray(lhs_data) lhs_indices = jnp.asarray(lhs_indices) rhs = jnp.asarray(rhs) # Validate all inputs via abstract_eval out_aval = _bcoo_dot_general_abstract_eval(lhs_data.aval, lhs_indices.aval, rhs.aval, dimension_numbers=dimension_numbers, lhs_shape=lhs_shape) (lhs_contracting, rhs_contracting) , (lhs_batch, rhs_batch) = dimension_numbers n_sparse = lhs_indices.shape[-2] n_batch = lhs_indices.ndim - 2 # Move lhs batch dimensions to the front if lhs_batch: perm = list(lhs_batch) + remaining(range(n_batch), lhs_batch) lhs_data = lhs_data.transpose(perm + list(range(n_batch, lhs_data.ndim))) lhs_indices = lhs_indices.transpose(perm + list(range(n_batch, lhs_indices.ndim))) # Move lhs contracting dimensions to the front of sparse dims, in order n_contracting = len(lhs_contracting) lhs_contracting = [d - n_batch for d in lhs_contracting] perm = list(lhs_contracting) + remaining(range(n_sparse), lhs_contracting) lhs_indices = lhs_indices[..., jnp.array(perm), :] # Move rhs batch dimensions then contracting dimensions to the front, in order perm = (list(rhs_batch) + list(rhs_contracting) + remaining(range(rhs.ndim), rhs_batch, rhs_contracting)) rhs = rhs.transpose(perm) out_array = jnp.zeros(out_aval.shape, out_aval.dtype) def result(out_array, lhs_data, lhs_indices, rhs): idx = tuple(lhs_indices) idx_right, idx_out = idx[:n_contracting], idx[n_contracting:] ctc = [0] if n_contracting else [] prod = lax.dot_general(lhs_data, rhs[idx_right], (([], []), (ctc, ctc))) return out_array.at[idx_out].add(prod) if idx_out else prod.sum(0) for i in range(n_batch)[::-1]: axes_in = [0, 0, 0, 0] if lhs_data.shape[i] == 1: lhs_data = lax.squeeze(lhs_data, (i,)) axes_in[1] = None if lhs_indices.shape[i] == 1: lhs_indices = lax.squeeze(lhs_indices, (i,)) axes_in[2] = None if i >= len(lhs_batch): axes_in[3] = None result = vmap(result, tuple(axes_in)) return result(out_array, lhs_data, lhs_indices, rhs)
class JAX2DexTest(unittest.TestCase): test_sin = lax_test(lax.sin, lambda: (rn(10, 10), )) test_cos = lax_test(lax.cos, lambda: (rn(10, 10), )) test_neg = lax_test(lax.neg, lambda: (rn(10, 10), )) test_log = lax_test(lax.log, lambda: (rn(10, 10), )) test_exp = lax_test(lax.exp, lambda: (rn(10, 10), )) test_pow = lax_test(lax.pow, lambda: (rn(10), jnp.arange(10, dtype=np.float32))) test_integer_pow = lax_test(lambda x: lax.integer_pow(x, 2), lambda: (rn(10, 10), )) test_scalar_select_lt = lax_test(lambda i, x, y: lax.select(i < 2.0, x, y), lambda: (1.0, rn(10), rn(10))) test_squeeze_none = lax_test(lambda x: lax.squeeze(x, []), lambda: (rn(10, 10), )) test_squeeze_one = lax_test(lambda x: lax.squeeze(x, [1]), lambda: (rn(10, 1, 10), )) test_squeeze_two = lax_test(lambda x: lax.squeeze(x, [0, 2]), lambda: (rn(1, 10, 1), )) test_squeeze_all = lax_test(lambda x: lax.squeeze(x, [0, 1]), lambda: (rn(1, 1), )) test_slice_1d = lax_test(lambda x: lax.slice(x, [2], [5], None), lambda: (rn(10), )) test_slice_3d = lax_test( lambda x: lax.slice(x, [2, 0, 0], [5, 10, 2], None), lambda: (rn(10, 10, 2), )) test_concat_uniform = lax_test(partial(lax.concatenate, dimension=0), lambda: ([rn(4, 2) for _ in range(3)], )) test_concat_ragged = lax_test( partial(lax.concatenate, dimension=0), lambda: ([rn(1, 2, 4), rn(5, 2, 4), rn(2, 2, 4)], )) test_dot_general_matmul = lax_test( partial(lax.dot_general, dimension_numbers=(((1, ), (0, )), ((), ()))), lambda: (rn(4, 8), rn(8, 16))) test_dot_general_matvec = lax_test( partial(lax.dot_general, dimension_numbers=(((1, ), (0, )), ((), ()))), lambda: (rn(4, 8), rn(8))) def test_canonicalize_dtype(self): c = np.arange(5, dtype=np.float64) f = lambda x: x * c x = np.ones(5, dtype=np.float64) dy = dexjit(f)(x) jy = jax.jit(f)(x) np.testing.assert_allclose(dy, jy) self.assertEqual(dy.dtype, jy.dtype)
def _squeeze_sparse(spenv, *spvalues, dimensions): arr, = spvalues dimensions = tuple(canonicalize_axis(dim, arr.ndim) for dim in dimensions) if any(arr.shape[dim] != 1 for dim in dimensions): raise ValueError("cannot select an axis to squeeze out which has size not equal to one, " f"got shape={arr.shape} and dimensions={dimensions}") data = spenv.data(arr) indices = spenv.indices(arr) n_sparse = indices.shape[-1] n_batch = indices.ndim - 2 batch_dims = tuple(d for d in dimensions if d < n_batch) sparse_dims = np.array([i for i in range(n_sparse) if i + n_batch not in dimensions], dtype=int) dense_dims = tuple(d - n_sparse + 1 for d in dimensions if d >= n_batch + n_sparse) data_out = lax.squeeze(data, batch_dims + dense_dims) indices_out = lax.squeeze(indices[..., sparse_dims], batch_dims) out_shape = tuple(s for i, s in enumerate(arr.shape) if i not in dimensions) return (spenv.sparse(out_shape, data_out, indices_out),)
def squeeze_dependency_rule(outstart, outcount, operand, dimensions): if not is_ones(outcount): raise NotImplementedError instart = list(outstart) inshape = list(outcount.shape) for d in np.sort(dimensions): instart.insert(d, 0) inshape.insert(d, 1) return ([(instart, inshape)], [Ones(inshape)], lambda inslice: lax.squeeze(inslice, dimensions))
def testSqueeze(self, arg_shape, dimensions, bdims): dtype = np.float32 rng = jtu.rand_default(self.rng()) op = lambda x: lax.squeeze(x, dimensions) self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)
def testSqueeze(self, arg_shape, dimensions, bdims, rng_factory): dtype = onp.float32 rng = rng_factory(self.rng()) op = lambda x: lax.squeeze(x, dimensions) self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)
def test_squeeze(shape, dtype, dimensions, rng_factory): rng = rng_factory(np.random) args = [rng(shape, dtype)] tu.check_lazy_fun(lambda x: lax.squeeze(x, dimensions), *args)
def squeeze_op(self, params, inputs): return lax.squeeze(inputs, (self.axis, ))