Пример #1
0
def _add_sparse(spenv, *argspecs):
    X, Y = argspecs
    if X.is_sparse() and Y.is_sparse():
        if X.shape != Y.shape:
            raise NotImplementedError(
                "Addition between sparse matrices of different shapes.")
        if X.indices_ref == Y.indices_ref:
            out_data = lax.add(X.data(spenv), Y.data(spenv))
            out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref)
        elif X.indices(spenv).ndim != Y.indices(spenv).ndim or X.data(
                spenv).ndim != Y.data(spenv).ndim:
            raise NotImplementedError(
                "Addition between sparse matrices with different batch/dense dimensions."
            )
        else:
            out_indices = lax.concatenate(
                [X.indices(spenv), Y.indices(spenv)],
                dimension=X.indices(spenv).ndim - 1)
            out_data = lax.concatenate(
                [X.data(spenv), Y.data(spenv)],
                dimension=X.indices(spenv).ndim - 2)
            out_argspec = ArgSpec(X.shape, spenv.push(out_data),
                                  spenv.push(out_indices))
    else:
        raise NotImplementedError("Addition between sparse and dense matrix.")

    return (out_argspec, )
Пример #2
0
def _add_sparse(spenv, *spvalues):
    X, Y = spvalues
    if X.is_sparse() and Y.is_sparse():
        if X.shape != Y.shape:
            raise NotImplementedError(
                "Addition between sparse matrices of different shapes.")
        if X.indices_ref == Y.indices_ref:
            out_data = lax.add(spenv.data(X), spenv.data(Y))
            if config.jax_enable_checks:
                assert X.indices_sorted == Y.indices_sorted
                assert X.unique_indices == Y.unique_indices
            out_spvalue = spenv.sparse(X.shape,
                                       out_data,
                                       indices_ref=X.indices_ref,
                                       indices_sorted=X.indices_sorted,
                                       unique_indices=X.unique_indices)
        elif spenv.indices(X).ndim != spenv.indices(Y).ndim or spenv.data(
                X).ndim != spenv.data(Y).ndim:
            raise NotImplementedError(
                "Addition between sparse matrices with different batch/dense dimensions."
            )
        else:
            out_indices = lax.concatenate(
                [spenv.indices(X), spenv.indices(Y)],
                dimension=spenv.indices(X).ndim - 2)
            out_data = lax.concatenate(
                [spenv.data(X), spenv.data(Y)],
                dimension=spenv.indices(X).ndim - 2)
            out_spvalue = spenv.sparse(X.shape, out_data, out_indices)
    else:
        raise NotImplementedError("Addition between sparse and dense array.")

    return (out_spvalue, )
Пример #3
0
 def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory):
   rng = rng_factory(self.rng())
   shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
             for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
   operands = tuple(rng(shape, dtype) for shape in shapes)
   concatenate = lambda *args: lax.concatenate(args, dim)
   check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)
Пример #4
0
 def test_concatenate(self):
     self.check(lambda x, y, z: lax.concatenate([x, y, z], 0),
                ['n', 'm', 'n'], 'm + 2 * n', {
                    'n': 2,
                    'm': 3
                }, [(4, ), (3, ), (4, )], ['float_', 'float_', 'float_'],
                jtu.rand_default(self.rng()))
Пример #5
0
def _dct_ortho_norm(out, axis):
    factor = lax.concatenate([
        lax.full((1, ), 4, out.dtype),
        lax.full((out.shape[axis] - 1, ), 2, out.dtype)
    ], 0)
    factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis])
    return out / lax.sqrt(factor * out.shape[axis])
Пример #6
0
def concatenate_dependency_rule(outstart, outcount, *operands, dimension):
    if not is_ones(outcount):
        raise NotImplementedError
    dim = dimension
    outstart, outshape = list(outstart), list(outcount.shape)
    dimstart, dimshape = outstart[dim], outshape[dim]
    position = 0
    inboxes = []
    incounts = []
    for operand in operands:
        shape = operand.shape
        if dimstart < position + shape[dim] and position < dimstart + dimshape:
            instart = (outstart[:dim] + [max(0, dimstart - position)] +
                       outstart[dim + 1:])
            inshape = (outshape[:dim] + [
                min(dimstart + dimshape - position, shape[dim], dimshape,
                    position + shape[dim] - instart[dim])
            ] + outshape[dim + 1:])
            inboxes.append((instart, inshape))
            incounts.append(Ones(inshape))
        else:
            inboxes.append(None)
            incounts.append(None)
        position += shape[dim]

    return inboxes, incounts, lambda *inslices: lax.concatenate(
        [x for x in inslices if x is not None], dimension)
Пример #7
0
  def testConcatenate(self):
    R = lambda *shape: np.random.RandomState(0).randn(*shape).astype(np.float32)

    fun = lambda *args: lax.concatenate(args, dimension=0)
    x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3)
    ans = vmap(fun, in_axes=(0, 1, None))(x, y, z)
    expected_ans = np.concatenate([x, np.swapaxes(y, 0, 1),
                                    np.broadcast_to(z, (10, 4, 3))], 1)
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    fun = lambda *args: lax.concatenate(args, dimension=1)
    x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10)
    ans = vmap(fun, in_axes=(0, None, 2))(x, y, z)
    expected_ans = np.concatenate([x, np.broadcast_to(y, (10, 2, 3)),
                                    np.moveaxis(z, 2, 0)], 2)
    self.assertAllClose(ans, expected_ans, check_dtypes=False)
Пример #8
0
def threefry_seed(seed: int) -> jnp.ndarray:
    """Create a single raw threefry PRNG key given an integer seed.

  Args:
    seed: a 64- or 32-bit integer used as the value of the key.

  Returns:
    The PRNG key contents, modeled as an array of shape (2,) and dtype
    uint32. The key is constructed from a 64-bit seed by effectively
    bit-casting to a pair of uint32 values (or from a 32-bit seed by
    first padding out with zeros).
  """
    # Avoid overflowerror in X32 mode by first converting ints to int64.
    # This breaks JIT invariance for large ints, but supports the common
    # use-case of instantiating with Python hashes in X32 mode.
    if isinstance(seed, int):
        seed_arr = jnp.asarray(np.int64(seed))
    else:
        seed_arr = jnp.asarray(seed)
    if seed_arr.shape:
        raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.")
    if not np.issubdtype(seed_arr.dtype, np.integer):
        raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")

    convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32),
                                    [1])
    k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32)))
    k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF)))
    return lax.concatenate([k1, k2], 0)
Пример #9
0
def scan_reference(f, init, xs):
    carry = init
    ys = []
    for x in xs:
        (carry, y) = f(carry, x)
        ys.append(lax.reshape(y, (1, ) + onp.shape(y)))
    ys = lax.concatenate(ys, 0)
    return carry, ys
Пример #10
0
def test_concatenate(dim, base_shape, dtype, num_arrs, rng_factory):
    rng = rng_factory(np.random)
    shapes = [
        base_shape[:dim] + (size, ) + base_shape[dim + 1:]
        for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))
    ]
    args = [rng(shape, dtype) for shape in shapes]
    op = lambda *args: lax.concatenate(args, dim)
    tu.check_lazy_fun(op, *args)
Пример #11
0
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
    """Sample uniform random bits of given width and shape using PRNG key."""
    if not _is_threefry_prng_key(key):
        raise TypeError("_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    shape = core.as_named_shape(shape)
    for name, size in shape.named_items:
        real_size = lax.psum(1, name)
        if real_size != size:
            raise ValueError(
                f"The shape of axis {name} was specified as {size}, "
                f"but it really is {real_size}")
        axis_index = lax.axis_index(name)
        key = threefry_fold_in(key, axis_index)
    size = prod(shape.positional)
    # Compute ceil(bit_width * size / 32) in a way that is friendly to shape
    # polymorphism
    max_count, r = divmod(bit_width * size, 32)
    if r > 0:
        max_count += 1

    if core.is_constant_dim(max_count):
        nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
    else:
        nblocks, rem = 0, max_count

    if not nblocks:
        bits = threefry_2x32(key, lax.iota(np.uint32, rem))
    else:
        keys = threefry_split(key, nblocks + 1)
        subkeys, last_key = keys[:-1], keys[-1]
        blocks = vmap(threefry_2x32,
                      in_axes=(0, None))(subkeys,
                                         lax.iota(np.uint32,
                                                  jnp.iinfo(np.uint32).max))
        last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
        bits = lax.concatenate([blocks.ravel(), last], 0)

    dtype = UINT_DTYPES[bit_width]
    if bit_width == 64:
        bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
        bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
    elif bit_width in [8, 16]:
        # this is essentially bits.view(dtype)[:size]
        bits = lax.bitwise_and(
            np.uint32(np.iinfo(dtype).max),
            lax.shift_right_logical(
                lax.broadcast(bits, (1, )),
                lax.mul(
                    np.uint32(bit_width),
                    lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0))))
        bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ),
                           (1, 0))
        bits = lax.convert_element_type(bits, dtype)[:size]
    return lax.reshape(bits, shape)
Пример #12
0
def block_diag(*arrs):
  if len(arrs) == 0:
    arrs = [jnp.zeros((1, 0))]
  arrs = jnp._promote_dtypes(*arrs)
  bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2]
  if bad_shapes:
    raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at "
                     "most 2 dimensions, got {} at argument {}."
                     .format(arrs[bad_shapes[0]], bad_shapes[0]))
  arrs = [jnp.atleast_2d(a) for a in arrs]
  acc = arrs[0]
  dtype = lax.dtype(acc)
  for a in arrs[1:]:
    _, c = a.shape
    a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0)))
    acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0)))
    acc = lax.concatenate([acc, a], dimension=0)
  return acc
Пример #13
0
Файл: fft.py Проект: raj0088/jax
def _irfft_transpose(t, fft_lengths):
  # The transpose of IRFFT is the RFFT of the cotangent times a scaling
  # factor and a mask. The mask scales the cotangent for the Hermitian
  # symmetric components of the RFFT by a factor of two, since these components
  # are de-duplicated in the RFFT.
  x = fft(t, xla_client.FftType.RFFT, fft_lengths)
  n = x.shape[-1]
  is_odd = fft_lengths[-1] % 2
  full = partial(lax.full_like, t, dtype=t.dtype)
  mask = lax.concatenate(
      [full(1.0, shape=(1,)),
       full(2.0, shape=(n - 2 + is_odd,)),
       full(1.0, shape=(1 - is_odd,))],
      dimension=0)
  scale = 1 / prod(fft_lengths)
  out = scale * mask * x
  assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
  return out
Пример #14
0
def _scale_and_translate(x, output_shape, scale, translate, kernel, antialias):
    input_shape = x.shape
    assert len(input_shape) == len(output_shape)
    assert len(input_shape) == len(scale)
    assert len(input_shape) == len(translate)
    spatial_dims = np.nonzero(
        np.not_equal(input_shape, output_shape) | np.not_equal(scale, 1)
        | np.not_equal(translate, 0))[0]
    if len(spatial_dims) == 0:
        return x
    output_spatial_shape = tuple(np.array(output_shape)[spatial_dims])
    indices = []
    contractions = []
    slice_shape = list(input_shape)
    in_indices = list(range(len(output_shape) + len(spatial_dims)))
    out_indices = list(range(len(output_shape)))
    for i, d in enumerate(spatial_dims):
        m = input_shape[d]
        n = output_shape[d]
        starts, weights = _compute_spans(m,
                                         n,
                                         scale[d],
                                         translate[d],
                                         kernel,
                                         antialias=antialias)
        starts = lax.broadcast_in_dim(starts, output_spatial_shape + (1, ),
                                      (i, ))
        slice_shape[d] = weights.shape[1]
        indices.append(starts.astype(np.int32))
        contractions.append(weights.astype(x.dtype))
        contractions.append([len(output_shape) + i, d])
        out_indices[d] = len(output_shape) + i
    index = lax.concatenate(indices, len(output_spatial_shape))
    dnums = lax.GatherDimensionNumbers(offset_dims=tuple(
        range(len(output_shape))),
                                       collapsed_slice_dims=(),
                                       start_index_map=tuple(spatial_dims))
    out = lax.gather(x, index, dnums, slice_shape)
    contractions.append(out_indices)
    return jnp.einsum(out,
                      in_indices,
                      *contractions,
                      precision=lax.Precision.HIGHEST)
Пример #15
0
def _irfft_transpose(t, fft_lengths):
    # The transpose of IRFFT is the RFFT of the cotangent times a scaling
    # factor and a mask. The mask scales the cotangent for the Hermitian
    # symmetric components of the RFFT by a factor of two, since these components
    # are de-duplicated in the RFFT.
    x = fft(t, xla_client.FftType.RFFT, fft_lengths)
    n = x.shape[-1]
    is_odd = fft_lengths[-1] % 2
    full = partial(lax.full_like, t, dtype=t.dtype)
    mask = lax.concatenate([
        full(1.0, shape=(1, )),
        full(2.0, shape=(n - 2 + is_odd, )),
        full(1.0, shape=(1 - is_odd, ))
    ],
                           dimension=0)
    scale = 1 / prod(fft_lengths)
    out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x
    assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
    # Use JAX's convention for complex gradients
    # https://github.com/google/jax/issues/6223#issuecomment-807740707
    return lax.conj(out)
Пример #16
0
def _concatenate_j(eqn: JaxprEqn, idx: int, invals: List[ShapedArray],
                   cts_in: ShapedArray) -> np.ndarray:
    dimension = eqn.params['dimension']

    js = []
    inval = invals[idx]
    for i in range(len(invals)):
        inval_i = invals[i]
        inval_i_shape = tuple(
            inval_i.shape[k] if k == dimension else inval.shape[k]
            for k in range(inval.ndim))

        if i == idx:
            j = np.eye(inval.size, dtype=inval.dtype)
        else:
            inval_i_size = onp.prod(inval_i_shape)
            j = np.zeros((inval_i_size, inval.size), inval.dtype)

        j = j.reshape(inval_i_shape + inval.shape)
        js.append(j)

    j = lax.concatenate(js, dimension)
    j = j.reshape(cts_in.shape + inval.shape)
    return j
Пример #17
0
 def concat(y):
   return lax.concatenate([x, y], 0)
Пример #18
0
 def duplicate(x):
   assert python_should_be_executing
   return lax.concatenate([x, x], 0)
Пример #19
0
 def cat(x, y, z):
     return lax.concatenate([x, y, x], 0)
Пример #20
0
def _dct_interleave(x, axis):
    v0 = lax.slice_in_dim(x, None, None, 2, axis)
    v1 = lax.rev(lax.slice_in_dim(x, 1, None, 2, axis), (axis, ))
    return lax.concatenate([v0, v1], axis)
Пример #21
0
def concatenate(arrays, axis=0):
    if not arrays:
        raise ValueError("Need at least one array to concatenate.")
    return lax.concatenate(_promote_dtypes(*arrays), axis % ndim(arrays[0]))
Пример #22
0
def _rewriting_take(arr, idx, axis=0):
    """A function like numpy.take that handles boxes and rewrites to LAX."""

    # Handle special indexers: (), Ellipsis, slice(None), and None.
    # TODO(mattjj): don't compare empty tuple identity (though works for CPython)
    if idx is () or idx is Ellipsis or _is_slice_none(idx):  # pylint: disable=literal-comparison
        return arr
    elif idx is None:
        return expand_dims(arr, 0)

    # Handle int index
    _int = lambda aval: not aval.shape and onp.issubdtype(
        aval.dtype, onp.integer)
    try:
        abstract_idx = core.get_aval(idx)
    except TypeError:
        abstract_idx = None

    if isinstance(abstract_idx, ConcreteArray) and _int(abstract_idx):
        return lax.index_in_dim(arr, idx, axis, False)
    elif isinstance(abstract_idx, ShapedArray) and _int(abstract_idx):
        idx = mod(idx, arr.shape[axis])
        return lax.dynamic_index_in_dim(arr, idx, axis, False)

    # Handle slice index (only static, otherwise an error is raised)
    elif isinstance(idx, slice):
        if not _all(
                elt is None or isinstance(core.get_aval(elt), ConcreteArray)
                for elt in (idx.start, idx.stop, idx.step)):
            msg = (
                "Array slice indices must have static start/stop/step to be used "
                "with Numpy indexing syntax. Try lax.dynamic_slice instead.")
            raise IndexError(msg)
        else:
            start, limit, stride, needs_rev = _static_idx(idx, arr.shape[axis])
            result = lax.slice_in_dim(arr, start, limit, stride, axis=axis)
            return lax.rev(result, [axis]) if needs_rev else result

    # Handle non-advanced tuple indices by recursing once
    elif isinstance(idx, tuple) and _all(onp.ndim(elt) == 0 for elt in idx):
        canonical_idx = _canonicalize_tuple_index(arr, idx)
        result, axis = arr, 0
        for elt in (elt for elt in canonical_idx if elt is not None):
            result = _rewriting_take(result, elt, axis=axis)
            axis += isinstance(elt,
                               slice)  # advance axis index if not eliminated
        unexpanded_shape_itr = iter(result.shape)
        result_shape = tuple(1 if elt is None else next(unexpanded_shape_itr)
                             for elt in canonical_idx
                             if not isinstance(elt, int))
        return lax.reshape(result, result_shape)

    # Handle advanced indexing (non-tuple sequence, ndarray of dtype int or bool,
    # or a tuple with at least one sequence object).
    # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
    # https://gist.github.com/seberg/976373b6a2b7c4188591

    # Handle integer array indexing *without* ellipsis/slices/nones
    # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing
    if _is_advanced_int_indexer_without_slices(idx):
        if isinstance(idx, list):
            if _any(_shape(e) for e in idx):
                # At least one sequence element in the index list means broadcasting.
                idx = broadcast_arrays(*idx)
            else:
                # The index list is a flat list of integers.
                idx = [
                    lax.concatenate([lax.reshape(e, (1, )) for e in idx], 0)
                ]
        else:
            # The indexer is just a single integer array.
            idx = [idx]

        flat_idx = tuple(
            mod(ravel(x), arr.shape[i]) for i, x in enumerate(idx))
        out = lax.index_take(arr, flat_idx, tuple(range(len(idx))))
        return lax.reshape(out, idx[0].shape + _shape(arr)[len(idx):])

    # Handle integer array indexing *with* ellipsis/slices/nones by recursing once
    # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
    elif _is_advanced_int_indexer(idx):
        canonical_idx = _canonicalize_tuple_index(arr, tuple(idx))
        idx_noadvanced = [
            slice(None) if _is_int(e) else e for e in canonical_idx
        ]
        arr_sliced = _rewriting_take(arr, tuple(idx_noadvanced))

        advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx)
                          if _is_int(e))
        idx_advanced, axes = zip(*advanced_pairs)
        idx_advanced = broadcast_arrays(*idx_advanced)

        flat_idx = tuple(
            mod(ravel(x), arr_sliced.shape[i])
            for i, x in zip(axes, idx_advanced))
        out = lax.index_take(arr_sliced, flat_idx, axes)
        shape_suffix = tuple(onp.delete(_shape(arr_sliced), axes))
        out = lax.reshape(out, idx_advanced[0].shape + shape_suffix)

        axes_are_contiguous = onp.all(onp.diff(axes) == 1)
        if axes_are_contiguous:
            start = axes[0]
            naxes = idx_advanced[0].ndim
            out = moveaxis(out, list(range(naxes)),
                           list(range(start, start + naxes)))
        return out

    msg = "Indexing mode not yet supported. Open a feature request!\n{}"
    raise IndexError(msg.format(idx))
Пример #23
0
 def concat(x):
     return lax.concatenate([x, x], 0)