Example #1
0
    def test_vmap_after(self):
        batch = 4
        qy_size = 128
        db_size = 1024
        feature_dim = 32
        k = 10
        rng = jtu.rand_default(self.rng())
        qy = rng([qy_size, feature_dim, batch], np.float32)
        db = rng([db_size, feature_dim, batch], np.float32)
        recall = 0.95

        # Create ground truth
        gt_scores = lax.dot_general(qy, db, (([1], [1]), ([2], [2])))
        _, gt_args = lax.top_k(gt_scores, k)
        gt_args = lax.transpose(gt_args, [2, 0, 1])
        gt_args = lax.reshape(gt_args, [qy_size * batch, k])

        # test target
        def approx_max_k(qy, db):
            scores = qy @ db.transpose()
            return lax.approx_max_k(scores, k)

        _, ann_args = jax.vmap(approx_max_k, (2, 2))(qy, db)
        ann_args = lax.transpose(ann_args, [2, 0, 1])
        ann_args = lax.reshape(ann_args, [qy_size * batch, k])
        ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
        self.assertGreater(ann_recall, recall)
Example #2
0
    def test_transpose(self):
        devices = self.get_devices()

        x = jnp.ones((2, 3))
        self.assert_uncommitted_to_device(x, devices[0])

        y = lax.transpose(x, (1, 0))
        self.assert_uncommitted_to_device(y, devices[0])
        z = lax.transpose(jax.device_put(x, devices[2]), (1, 0))
        self.assert_committed_to_device(z, devices[2])
Example #3
0
    def test_transpose(self):
        if jax.device_count() < 3:
            self.skipTest("test requires 3 devices")
        devices = self.get_devices()

        x = jnp.ones((2, 3))
        self.assert_uncommitted_to_device(x, devices[0])

        y = lax.transpose(x, (1, 0))
        self.assert_uncommitted_to_device(y, devices[0])
        z = lax.transpose(jax.device_put(x, devices[2]), (1, 0))
        self.assert_committed_to_device(z, devices[2])
Example #4
0
def transpose_dependency_rule(outstart, outcount, operand, permutation):
    inverse_perm = np.argsort(permutation)
    inshape = np.take(outcount.shape, inverse_perm)
    return ([(np.take(outstart, inverse_perm), inshape)], [
        Ones(inshape) if is_ones(outcount) else np.transpose(
            outcount, inverse_perm)
    ], lambda inslice: lax.transpose(inslice, permutation))
Example #5
0
def _transpose_papply_rule(name, vals, dims, permutation):
    x, = vals
    xdim, = dims
    perm = list(permutation)
    if perm[xdim] == xdim:
        x = lax.transpose(x, perm)
        out_dim = xdim
    else:
        in_dim, = [i for i in range(len(perm)) if perm[i] == xdim]
        out_dim = perm[xdim]
        perm[in_dim] = out_dim
        perm[out_dim] = in_dim
        perm = perm[:xdim] + perm[xdim + 1:]
        perm = [i - 1 if i > xdim else i for i in perm]
        x = lax.transpose(x, perm)
        x = pswapaxes(x, name, in_dim)
    return x, xdim
Example #6
0
def _moveaxis(a, source: int, destination: int):
    # simplified version of jnp.moveaxis() for local use.
    _check_arraylike("moveaxis", a)
    a = _asarray(a)
    source = _canonicalize_axis(source, np.ndim(a))
    destination = _canonicalize_axis(destination, np.ndim(a))
    perm = [i for i in range(np.ndim(a)) if i != source]
    perm.insert(destination, source)
    return lax.transpose(a, perm)
Example #7
0
def _pswapaxes_serial_pmap_rule(vals, axes, axis):
    x, = vals
    axis_in, = axes
    if x.shape[axis_in] != x.shape[axis]:
        raise ValueError("pswapaxes between non-square dimensions")
    perm = list(range(x.ndim))
    perm[axis_in] = axis
    perm[axis] = axis_in
    return lax.transpose(x, perm), axis_in
Example #8
0
def _transpose_j(eqn: JaxprEqn, idx: int, invals: List[ShapedArray],
                 cts_in: ShapedArray) -> np.ndarray:
    j = _eye_like(cts_in, invals[idx])
    inval = invals[idx]
    j = j.reshape(inval.shape * 2)

    inval_dims = tuple(i + cts_in.ndim for i in range(cts_in.ndim))
    j = lax.transpose(j, eqn.params['permutation'] + inval_dims)
    j = j.reshape(cts_in.shape + invals[idx].shape)
    return j
Example #9
0
def moveaxis(a, source, destination):
    source = onp.mod(source, ndim(a)).reshape(-1)
    destination = onp.mod(destination, ndim(a)).reshape(-1)
    if len(source) != len(destination):
        raise ValueError("Inconsistent number of elements: {} vs {}".format(
            len(source), len(destination)))
    perm = [i for i in range(ndim(a)) if i not in source]
    for dest, src in sorted(zip(destination, source)):
        perm.insert(dest, src)
    return lax.transpose(a, perm)
Example #10
0
def _reshape_j(eqn: JaxprEqn, idx: int, invals: List[ShapedArray],
               cts_in: ShapedArray) -> np.ndarray:
    inval = invals[idx]
    j = _eye_like(inval, inval)
    j = j.reshape(inval.shape * 2)

    inval_dims = tuple(i + inval.ndim for i in range(inval.ndim))
    if eqn.params['dimensions'] is not None:
        j = lax.transpose(j, eqn.params['dimensions'] + inval_dims)
    j = j.reshape(inval.shape + inval.shape)
    return j
Example #11
0
def _bcoo_dot_general_transpose(ct, lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_shape):
  assert not ad.is_undefined_primal(lhs_indices)
  if type(ct) is ad.Zero:
    return ad.Zero
  (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
  lhs_ndim = len(lhs_shape)
  rhs_ndim = rhs.aval.ndim if ad.is_undefined_primal(rhs) else rhs.ndim
  lhs_kept = remaining(range(lhs_ndim), lhs_contract, lhs_batch)
  rhs_kept = remaining(range(rhs_ndim), rhs_contract, rhs_batch)
  ans_batch, ans_lhs, ans_rhs = ranges_like(lhs_batch, lhs_kept, rhs_kept)
  if ad.is_undefined_primal(lhs_data):
    dims = ((ans_rhs, rhs_kept), (ans_batch, rhs_batch))
    lhs_contract_sorted_by_rhs = list(np.take(lhs_contract, np.argsort(rhs_contract)))
    # TODO: extract these sparse indices without constructing the dense matrix.
    out_axes = np.argsort(list(lhs_batch) + lhs_kept + lhs_contract_sorted_by_rhs)
    out_dense = lax.transpose(lax.dot_general(ct, rhs, dimension_numbers=dims), out_axes)
    return bcoo_extract(lhs_indices, out_dense), lhs_indices, rhs
  else:
    dims = ((lhs_kept, ans_lhs), (lhs_batch, ans_batch))
    rhs_contract_sorted_by_lhs = list(np.take(rhs_contract, np.argsort(lhs_contract)))
    out_axes = np.argsort(list(rhs_batch) + rhs_contract_sorted_by_lhs + rhs_kept)
    result = bcoo_dot_general(lhs_data, lhs_indices, ct, lhs_shape=lhs_shape, dimension_numbers=dims)
    return lhs_data, lhs_indices, lax.transpose(result, out_axes)
Example #12
0
def test_conv_general_dilated(lhs_shape, rhs_shape, dtype, strides, padding,
                              lhs_dilation, rhs_dilation, feature_group_count,
                              batch_group_count, dimension_numbers, perms,
                              rng_factory):
    rng = rng_factory(np.random)
    lhs_perm, rhs_perm = perms  # permute to compatible shapes
    args = [
        lax.transpose(rng(lhs_shape, dtype), lhs_perm),
        lax.transpose(rng(rhs_shape, dtype), rhs_perm)
    ]

    def fun(lhs, rhs):
        return lax.conv_general_dilated(
            lhs,
            rhs,
            strides,
            padding,
            lhs_dilation,
            rhs_dilation,
            dimension_numbers,
            feature_group_count=feature_group_count,
            batch_group_count=batch_group_count)

    tu.check_lazy_fun(fun, *args, rtol=.005, atol=.2)
Example #13
0
def broadcast_to(arr, shape):
  """Like Numpy's broadcast_to but doesn't necessarily return views."""
  arr = arr if isinstance(arr, ndarray) or isscalar(arr) else array(arr)
  if _shape(arr) != shape:
    # TODO(mattjj): revise this to call lax.broadcast_in_dim rather than
    # lax.broadcast and lax.transpose
    _broadcast_shapes(shape, _shape(arr))  # error checking
    nlead = len(shape) - len(_shape(arr))
    diff, = onp.where(onp.not_equal(shape[nlead:], _shape(arr)))

    new_dims = tuple(range(nlead)) + tuple(nlead + diff)
    kept_dims = tuple(onp.delete(onp.arange(len(shape)), new_dims))
    perm = onp.argsort(new_dims + kept_dims)

    broadcast_dims = onp.take(shape, new_dims)
    squeezed_array = squeeze(arr, diff)
    return lax.transpose(lax.broadcast(squeezed_array, broadcast_dims), perm)
  else:
    return arr
Example #14
0
 def test_transpose(self):
   self.check(lambda x: lax.transpose(x, (1, 0, 2)),
              ['(a, b, c)'], 'b, a, c', dict(a=2, b=3, c=4), [(3, 4, 5)],
              ['float_'], jtu.rand_default(self.rng()))
Example #15
0
 def testTranspose(self, shape, dtype, perm, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.transpose(x, perm)
     self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)
Example #16
0
 def fun(x):
     return lax.transpose(x, perm)
Example #17
0
 def testTransposeGrad(self, shape, dtype, perm, rng_factory):
   rng = rng_factory(self.rng())
   operand = rng(shape, dtype)
   transpose = lambda x: lax.transpose(x, perm)
   check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)
Example #18
0
 def testTransposeGrad(self, shape, dtype, perm):
   rng = jtu.rand_default(self.rng())
   operand = rng(shape, dtype)
   transpose = lambda x: lax.transpose(x, perm)
   check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)
Example #19
0
 def testTranspose(self, shape, dtype, perm, bdims, rng_factory):
     rng = rng_factory(self.rng())
     op = lambda x: lax.transpose(x, perm)
     self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)
Example #20
0
def test_transpose(shape, dtype, permutation, rng_factory):
    rng = rng_factory(np.random)
    arg = rng(shape, dtype)
    tu.check_lazy_fun(lambda x: lax.transpose(x, permutation=permutation), arg)
Example #21
0
def _matrix_transpose(ndarray):
    dims = tuple(range(ndarray.ndim))
    dims = dims[:-2] + (dims[-1], dims[-2])
    return lax.transpose(ndarray, dims)
Example #22
0
def swapaxes(a, axis1, axis2):
    perm = onp.arange(ndim(a))
    perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
    return lax.transpose(a, perm)
Example #23
0
def transpose(x, axis=None):
    axis = onp.arange(ndim(x))[::-1] if axis is None else axis
    return lax.transpose(x, axis)