Example #1
0
 def test_lax_slice(self):
   self.check(lambda x: lax.slice(x, (1,), (x.shape[0],)), ['n'], 'n+-1',
              {'n': 2}, [(3,)], ['float_'], jtu.rand_default(self.rng()))
   # TODO self.check(lambda x: lax.slice(x, (x.shape[0] // 2,), (x.shape[0],)),
   #  ['2*n'], 'n', {'n': 2}, [(6,)], ['float_'], jtu.rand_default(self.rng()))
   self.check(lambda x: lax.slice(x, (0,), (x.shape[0],), (x.shape[0],)),
              ['n'], '1', {'n': 2}, [(5,)], ['float_'],
              jtu.rand_default(self.rng()))
Example #2
0
 def test_lax_slice(self):
   raise SkipTest("Failing after fixing Poly unsoundness #4878")
   self.check(lambda x: lax.slice(x, (1,), (x.shape[0],)), ['n'], 'n+-1',
              {'n': 2}, [(3,)], ['float_'], jtu.rand_default(self.rng()))
   # TODO self.check(lambda x: lax.slice(x, (x.shape[0] // 2,), (x.shape[0],)),
   #  ['2*n'], 'n', {'n': 2}, [(6,)], ['float_'], jtu.rand_default(self.rng()))
   self.check(lambda x: lax.slice(x, (0,), (x.shape[0],), (x.shape[0],)),
              ['n'], '1', {'n': 2}, [(5,)], ['float_'],
              jtu.rand_default(self.rng()))
Example #3
0
class JAX2DexTest(unittest.TestCase):
    test_sin = lax_test(lax.sin, lambda: (rn(10, 10), ))
    test_cos = lax_test(lax.cos, lambda: (rn(10, 10), ))
    test_neg = lax_test(lax.neg, lambda: (rn(10, 10), ))
    test_log = lax_test(lax.log, lambda: (rn(10, 10), ))
    test_exp = lax_test(lax.exp, lambda: (rn(10, 10), ))
    test_pow = lax_test(lax.pow, lambda:
                        (rn(10), jnp.arange(10, dtype=np.float32)))
    test_integer_pow = lax_test(lambda x: lax.integer_pow(x, 2), lambda:
                                (rn(10, 10), ))
    test_scalar_select_lt = lax_test(lambda i, x, y: lax.select(i < 2.0, x, y),
                                     lambda: (1.0, rn(10), rn(10)))

    test_squeeze_none = lax_test(lambda x: lax.squeeze(x, []), lambda:
                                 (rn(10, 10), ))
    test_squeeze_one = lax_test(lambda x: lax.squeeze(x, [1]), lambda:
                                (rn(10, 1, 10), ))
    test_squeeze_two = lax_test(lambda x: lax.squeeze(x, [0, 2]), lambda:
                                (rn(1, 10, 1), ))
    test_squeeze_all = lax_test(lambda x: lax.squeeze(x, [0, 1]), lambda:
                                (rn(1, 1), ))

    test_slice_1d = lax_test(lambda x: lax.slice(x, [2], [5], None), lambda:
                             (rn(10), ))
    test_slice_3d = lax_test(
        lambda x: lax.slice(x, [2, 0, 0], [5, 10, 2], None), lambda:
        (rn(10, 10, 2), ))

    test_concat_uniform = lax_test(partial(lax.concatenate, dimension=0),
                                   lambda: ([rn(4, 2) for _ in range(3)], ))
    test_concat_ragged = lax_test(
        partial(lax.concatenate, dimension=0), lambda:
        ([rn(1, 2, 4), rn(5, 2, 4), rn(2, 2, 4)], ))

    test_dot_general_matmul = lax_test(
        partial(lax.dot_general, dimension_numbers=(((1, ), (0, )), ((), ()))),
        lambda: (rn(4, 8), rn(8, 16)))
    test_dot_general_matvec = lax_test(
        partial(lax.dot_general, dimension_numbers=(((1, ), (0, )), ((), ()))),
        lambda: (rn(4, 8), rn(8)))

    def test_canonicalize_dtype(self):
        c = np.arange(5, dtype=np.float64)
        f = lambda x: x * c
        x = np.ones(5, dtype=np.float64)
        dy = dexjit(f)(x)
        jy = jax.jit(f)(x)
        np.testing.assert_allclose(dy, jy)
        self.assertEqual(dy.dtype, jy.dtype)
Example #4
0
def split(ary, indices_or_sections, axis=0):
  dummy_val = onp.broadcast_to(0, ary.shape)  # zero strides
  subarrays = onp.split(dummy_val, indices_or_sections, axis)  # shapes
  split_indices = onp.cumsum([0] + [onp.shape(sub)[axis] for sub in subarrays])
  starts, ends = [0] * ndim(ary), shape(ary)
  _subval = lambda x, i, v: lax.subvals(x, [(i, v)])
  return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
          for start, end in zip(split_indices[:-1], split_indices[1:])]
Example #5
0
  def testSliceLax(self):
    fun = lambda x: lax.slice(x, (2,), (4,))
    R = np.random.RandomState(0).randn
    x = R(5, 10)

    ans = vmap(fun)(x)
    expected_ans = x[:, 2:4]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)
Example #6
0
def _matrix_take(ndarray, idx, block_size=1):
    """Similar to numpy.take using LAX operations."""
    idx_i, idx_j = idx
    sli, row_rev = _canonical_idx(ndarray.shape, idx_i, -2, block_size)
    slj, col_rev = _canonical_idx(ndarray.shape, idx_j, -1, block_size)

    start_indices = [0] * (ndarray.ndim - 2) + [sli.start, slj.start]
    limit_indices = list(ndarray.shape[:-2]) + [sli.stop, slj.stop]
    strides = [1] * (ndarray.ndim - 2) + [sli.step, slj.step]
    out = lax.slice(ndarray, start_indices, limit_indices, strides)

    if row_rev or col_rev:
        out = lax.rev(out, *onp.where([row_rev, col_rev]))
    return out
Example #7
0
def slice_dependency_rule(outstart, outcount, operand, start_indices,
                          limit_indices, strides):
    out_shape = np.asarray(outcount.shape)
    if strides is None:
        inbox = start_indices + outstart, out_shape
        return [inbox], [np.ones(inbox[1], int)], lambda inslice: inslice
    else:
        strides = np.asarray(strides)
        inbox = (start_indices + outstart * strides,
                 (out_shape - 1) * strides + 1)
        count = np.zeros(inbox[1], int)
        count_slice = tuple(slice(None, None, s) for s in strides)
        count[count_slice] = 1
        return ([inbox], [count], lambda inslice: lax.slice(
            inslice, np.zeros_like(strides), inslice.shape, strides))
Example #8
0
def onnx_split(x, split=None, axis=0, n_out=None):
    if split is None:
        split = [x.shape[axis] // n_out] * n_out

    starts = []
    ends = []
    starts.append([0] * x.ndim)
    for idx in range(1, n_out):
        st = [0] * x.ndim
        st[axis] = sum(split[:idx])
        starts.append(st)
        en = list(x.shape)
        en[axis] = sum(split[:idx])
        ends.append(en)
    ends.append(list(x.shape))

    return [lax.slice(x, start, end) for start, end in zip(starts, ends)]
Example #9
0
def _update_slice(operand, update, start_indices, update_dims):
    """
  Similar to lax.dynamic_update_slice, but handles padded updates where padding
  values should not overwrite existing values in the array.

  Args:
  operand: the array to update
  update: the padded array to write
  start_indices: the offset at which to write `update`.
  update_dims: the true dimensions of the padded update `update`. Only values
    inside the rectangle given by `update_dims` will be overwritten."""
    operand_shape = operand.shape
    operand = lax.pad(operand, jnp.array(0, operand.dtype),
                      [(0, d, 0) for d in update.shape])
    start_indices = tuple(jnp.int32(i) for i in start_indices)
    t = lax.dynamic_slice(operand, start_indices, update.shape)
    t = _mask(update, update_dims, t)
    operand = lax.dynamic_update_slice(operand, t, start_indices)
    return lax.slice(operand, [0] * operand.ndim, operand_shape)
Example #10
0
 def thunk():
     if primitive.multiple_results:
         raise NotImplementedError
     # TODO (j-towns): add option to disable this assert statement
     assert np.all(getbox(arr.state, box) == REQUESTED), \
       'Repeated computation detected'
     invals_ = [
         val.cache if isinstance(val, LazyArray) else jnp.asarray(val)
         for val in invals
     ]
     inslices = [
         None if ibox is None else lax.slice(inval, ibox[0],
                                             np.add(ibox[0], ibox[1]))
         for inval, ibox, count in zip(invals_, inboxes, counts)
     ]
     outslice = outslice_from_inslices(*inslices)
     outstart, _ = box
     arr.cache = inplace_dynamic_update_slice(arr.cache, outslice, outstart)
     setbox(arr.state, box, KNOWN)
Example #11
0
def _use_qr(u, m, n, params):
    """QDWH iteration using QR decomposition.

  Args:
  u: a matrix, with static (padded) shape M x N.
  m, n: the dynamic shape of the matrix, where m <= M and n <= N.
  params: the QDWH parameters.
  """
    a, b, c = params
    M, N = u.shape

    y = _dynamic_concat(jnp.sqrt(c) * u, jnp.eye(N, dtype=jnp.dtype(u)), m)
    q, _ = lax_linalg.qr(y, full_matrices=False)
    # q1 = q[:m, :]
    q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n))
    # q2 = (q[m:, :]).T.conj()
    q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0)
    q2 = _mask(q2, (n, n)).T.conj()
    e = b / c
    u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2))
    return u
Example #12
0
def test_slice(shape, dtype, starts, limits, strides, rng_factory):
    rng = rng_factory(np.random)
    args = [rng(shape, dtype)]
    op = lambda x: lax.slice(x, starts, limits, strides)
    tu.check_lazy_fun(op, *args)
Example #13
0
 def testSlice(self, shape, dtype, starts, limits, strides, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.slice(x, starts, limits, strides)
     self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)
Example #14
0
 def fixed_point(x):
   first_row = np.concatenate(
       [np.array([[1.]]), 2 * lax.slice(x, [0, 0], [1, 3])], -1)
   second_row = np.concatenate([lax.slice(first_row, [0, 2], [1, 4]),
                                np.array([[3., 4]])], -1)
   return np.concatenate([first_row, second_row], 0)
Example #15
0
    _make_harness("jnp_squeeze", "axis=1",
                  jnp.squeeze,
                  [RandArg((4, 1), _f32), StaticArg((1,))],
                  poly_axes=[0]),

    _make_harness("scatter_add", "",
                  partial(lax.scatter_add, indices_are_sorted=False, unique_indices=True),
                  [RandArg((7, 4), _f32),
                   np.array([[1], [2]], np.int32),  # indices
                   RandArg((7, 2), _f32),  # upd
                   StaticArg(lax.ScatterDimensionNumbers((0,), (1,), (1,)))],
                  poly_axes=[0, None, 0]),

    _make_harness("slice", "entire_axis",
                  lambda x: lax.slice(x, start_indices=(0, 1), limit_indices=(x.shape[0], 3)),
                  [RandArg((7, 3), _f32)],
                  poly_axes=[0]),

    _make_harness("select", "0",
                  # x.shape = (b, 3)
                  lambda x: lax.select(x > 5., x, x),
                  [RandArg((7, 3), _f32)],
                  poly_axes=[0]),

    _make_harness("select", "1",
                  # x.shape = (b, 3); y.shape = (3,)
                  jax.vmap(lambda x, y: lax.select(x > 5., x, y), in_axes=[0, None]),
                  [RandArg((7, 3), _f32), RandArg((3,), _f32)],
                  poly_axes=[0, None]),
Example #16
0
 def f_jax(x):  # x.shape = (b, 3)
     return lax.slice(x,
                      start_indices=(0, 1),
                      limit_indices=(x.shape[0], 3))
Example #17
0
 def test_lax_slice(self):
     self.check(lambda x: lax.slice(x, (1, ), (x.shape[0], )), ['n'],
                'n+-1', {'n': 2}, [(3, )], ['float_'],
                jtu.rand_default(self.rng()))
Example #18
0
 def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory):
   rng = rng_factory(self.rng())
   operand = rng(shape, dtype)
   slice = lambda x: lax.slice(x, starts, limits, strides)
   check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.)
Example #19
0
 def testSlice(self, shape, dtype, starts, limits, strides, bdims,
               rng_factory):
     rng = rng_factory(self.rng())
     op = lambda x: lax.slice(x, starts, limits, strides)
     self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)
Example #20
0
 def fixed_point(x):
   return np.concatenate([np.array([1.]), 2 * lax.slice(x, [0], [3])])