Ejemplo n.º 1
0
    def test_vmap_after(self):
        batch = 4
        qy_size = 128
        db_size = 1024
        feature_dim = 32
        k = 10
        rng = jtu.rand_default(self.rng())
        qy = rng([qy_size, feature_dim, batch], np.float32)
        db = rng([db_size, feature_dim, batch], np.float32)
        recall = 0.95

        # Create ground truth
        gt_scores = lax.dot_general(qy, db, (([1], [1]), ([2], [2])))
        _, gt_args = lax.top_k(gt_scores, k)
        gt_args = lax.transpose(gt_args, [2, 0, 1])
        gt_args = lax.reshape(gt_args, [qy_size * batch, k])

        # test target
        def approx_max_k(qy, db):
            scores = qy @ db.transpose()
            return lax.approx_max_k(scores, k)

        _, ann_args = jax.vmap(approx_max_k, (2, 2))(qy, db)
        ann_args = lax.transpose(ann_args, [2, 0, 1])
        ann_args = lax.reshape(ann_args, [qy_size * batch, k])
        ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args))
        self.assertGreater(ann_recall, recall)
Ejemplo n.º 2
0
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
    """Sample uniform random bits of given width and shape using PRNG key."""
    if not _is_threefry_prng_key(key):
        raise TypeError("_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    shape = core.as_named_shape(shape)
    for name, size in shape.named_items:
        real_size = lax.psum(1, name)
        if real_size != size:
            raise ValueError(
                f"The shape of axis {name} was specified as {size}, "
                f"but it really is {real_size}")
        axis_index = lax.axis_index(name)
        key = threefry_fold_in(key, axis_index)
    size = prod(shape.positional)
    # Compute ceil(bit_width * size / 32) in a way that is friendly to shape
    # polymorphism
    max_count, r = divmod(bit_width * size, 32)
    if r > 0:
        max_count += 1

    if core.is_constant_dim(max_count):
        nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
    else:
        nblocks, rem = 0, max_count

    if not nblocks:
        bits = threefry_2x32(key, lax.iota(np.uint32, rem))
    else:
        keys = threefry_split(key, nblocks + 1)
        subkeys, last_key = keys[:-1], keys[-1]
        blocks = vmap(threefry_2x32,
                      in_axes=(0, None))(subkeys,
                                         lax.iota(np.uint32,
                                                  jnp.iinfo(np.uint32).max))
        last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
        bits = lax.concatenate([blocks.ravel(), last], 0)

    dtype = UINT_DTYPES[bit_width]
    if bit_width == 64:
        bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
        bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
    elif bit_width in [8, 16]:
        # this is essentially bits.view(dtype)[:size]
        bits = lax.bitwise_and(
            np.uint32(np.iinfo(dtype).max),
            lax.shift_right_logical(
                lax.broadcast(bits, (1, )),
                lax.mul(
                    np.uint32(bit_width),
                    lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0))))
        bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ),
                           (1, 0))
        bits = lax.convert_element_type(bits, dtype)[:size]
    return lax.reshape(bits, shape)
Ejemplo n.º 3
0
def _lu_python(x):
    """Default LU decomposition in Python, where no better version exists."""
    m, n = x.shape[-2:]
    batch_dims = x.shape[:-2]
    if len(batch_dims) > 0:
        batch_size = onp.prod(batch_dims, dtype=onp.int64)
        pivot, lu = api.vmap(_lu_blocked)(lax.reshape(x, (batch_size, m, n)))
        pivot = lax.reshape(pivot, batch_dims + (min(m, n), ))
        lu = lax.reshape(lu, batch_dims + (m, n))
    else:
        pivot, lu = _lu_blocked(x)
    return lu, pivot
Ejemplo n.º 4
0
def repeat(a, repeats, axis=None):
    if not isscalar(repeats):
        raise NotImplementedError(
            "np.repeat implementation only supports scalar repeats")
    if axis is None or isscalar(a):
        a = ravel(a)
        axis = 0
    a_shape = list(shape(a))
    num_dims = len(a_shape)
    if axis < 0:
        axis = axis + num_dims

    if axis < 0 or axis >= num_dims:
        raise ValueError(
            "axis {} is out of bounds for array of dimension {}".format(
                axis, num_dims))

    # Broadcasts to [..., X, repeats, ...] and reshapes to [..., X * repeats, ...]
    broadcast_shape = list(a_shape)
    broadcast_shape.insert(axis + 1, repeats)
    broadcast_dims = onp.concatenate(
        (onp.arange(0, axis + 1), onp.arange(axis + 2, num_dims + 1)))
    a_shape[axis] *= repeats
    return lax.reshape(
        lax.broadcast_in_dim(a, broadcast_shape, broadcast_dims), a_shape)
Ejemplo n.º 5
0
def array(object, dtype=None, copy=True, order="K", ndmin=0):
    del copy  # Unused.
    if ndmin != 0 or order != "K":
        raise NotImplementedError("Only implemented for order='K', ndmin=0.")

    if hasattr(object, '__asarray__'):
        return object.__asarray__(dtype)
    elif isinstance(object, ndarray):
        if dtype and _dtype(object) != dtype:
            return lax.convert_element_type(object, dtype)
        else:
            return object
    elif isinstance(object, (list, tuple)):
        if object:
            subarrays = [
                expand_dims(array(elt, dtype=dtype), 0) for elt in object
            ]
            return concatenate(subarrays)
        else:
            return onp.array([], dtype)
    elif isscalar(object):
        out = lax.reshape(object, ())
        if dtype and _dtype(out) != dtype:
            return lax.convert_element_type(out, dtype)
        else:
            return out
    else:
        raise TypeError("Unexpected input type for array: {}".format(
            type(object)))
Ejemplo n.º 6
0
def threefry_2x32(keypair, count):
    """Apply the Threefry 2x32 hash.

  Args:
    keypair: a pair of 32bit unsigned integers used for the key.
    count: an array of dtype uint32 used for the counts.

  Returns:
    An array of dtype uint32 with the same shape as `count`.
  """
    key1, key2 = keypair
    if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == np.uint32:
        msg = "threefry_2x32 requires uint32 arguments, got {}"
        raise TypeError(msg.format([lax.dtype(x)
                                    for x in [key1, key2, count]]))

    try:
        odd_size = count.size % 2
    except core.InconclusiveDimensionOperation as e:
        msg = (
            "jax.random functions have limited support for shape polymorphism. "
            "In particular, the product of the known dimensions must be even.")
        raise core.InconclusiveDimensionOperation(msg) from e

    if odd_size:
        x = list(jnp.split(jnp.concatenate([count.ravel(),
                                            np.uint32([0])]), 2))
    else:
        x = list(jnp.split(count.ravel(), 2))

    x = threefry2x32_p.bind(key1, key2, x[0], x[1])
    out = jnp.concatenate(x)
    assert out.dtype == np.uint32
    return lax.reshape(out[:-1] if odd_size else out, count.shape)
Ejemplo n.º 7
0
    def _cumulative_reduction(a,
                              axis: Optional[Union[int, Tuple[int,
                                                              ...]]] = None,
                              dtype=None,
                              out=None):
        _check_arraylike(np_reduction.__name__, a)
        if out is not None:
            raise NotImplementedError(
                f"The 'out' argument to jnp.{np_reduction.__name__} "
                f"is not supported.")
        lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__)

        if axis is None or _isscalar(a):
            a = lax.reshape(a, (np.size(a), ))
            axis = 0

        a_shape = list(np.shape(a))
        num_dims = len(a_shape)
        axis = _canonicalize_axis(axis, num_dims)

        if fill_nan:
            a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)

        if not dtype and dtypes.dtype(a) == np.bool_:
            dtype = dtypes.canonicalize_dtype(dtypes.int_)
        if dtype:
            a = lax.convert_element_type(a, dtype)

        return reduction(a, axis)
Ejemplo n.º 8
0
def _lu_solve_core(lu, pivots, b, trans):
    m = lu.shape[0]
    permutation = lu_pivots_to_permutation(pivots, m)
    x = np.reshape(b, (m, -1))
    if trans == 0:
        x = x[permutation, :]
        x = triangular_solve(lu,
                             x,
                             left_side=True,
                             lower=True,
                             unit_diagonal=True)
        x = triangular_solve(lu, x, left_side=True, lower=False)
    elif trans == 1 or trans == 2:
        conj = trans == 2
        x = triangular_solve(lu,
                             x,
                             left_side=True,
                             lower=False,
                             transpose_a=True,
                             conjugate_a=conj)
        x = triangular_solve(lu,
                             x,
                             left_side=True,
                             lower=True,
                             unit_diagonal=True,
                             transpose_a=True,
                             conjugate_a=conj)
        x = x[np.argsort(permutation), :]
    else:
        raise ValueError(
            "'trans' value must be 0, 1, or 2, got {}".format(trans))
    return lax.reshape(x, b.shape)
Ejemplo n.º 9
0
def threefry_seed(seed: int) -> jnp.ndarray:
    """Create a single raw threefry PRNG key given an integer seed.

  Args:
    seed: a 64- or 32-bit integer used as the value of the key.

  Returns:
    The PRNG key contents, modeled as an array of shape (2,) and dtype
    uint32. The key is constructed from a 64-bit seed by effectively
    bit-casting to a pair of uint32 values (or from a 32-bit seed by
    first padding out with zeros).
  """
    # Avoid overflowerror in X32 mode by first converting ints to int64.
    # This breaks JIT invariance for large ints, but supports the common
    # use-case of instantiating with Python hashes in X32 mode.
    if isinstance(seed, int):
        seed_arr = jnp.asarray(np.int64(seed))
    else:
        seed_arr = jnp.asarray(seed)
    if seed_arr.shape:
        raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.")
    if not np.issubdtype(seed_arr.dtype, np.integer):
        raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")

    convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32),
                                    [1])
    k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32)))
    k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF)))
    return lax.concatenate([k1, k2], 0)
Ejemplo n.º 10
0
def scan_reference(f, init, xs):
    carry = init
    ys = []
    for x in xs:
        (carry, y) = f(carry, x)
        ys.append(lax.reshape(y, (1, ) + onp.shape(y)))
    ys = lax.concatenate(ys, 0)
    return carry, ys
Ejemplo n.º 11
0
def promote_shapes(*args, shape=()):
    # adapted from lax.lax_numpy
    if len(args) < 2 and not shape:
        return args
    else:
        shapes = [np.shape(arg) for arg in args]
        num_dims = len(lax.broadcast_shapes(shape, *shapes))
        return [lax.reshape(arg, (1,) * (num_dims - len(s)) + s)
                if len(s) < num_dims else arg for arg, s in zip(args, shapes)]
Ejemplo n.º 12
0
def _promote_shapes(*args):
  """Prepend implicit leading singleton dimensions for Numpy broadcasting."""
  if len(args) < 2:
    return args
  else:
    shapes = [shape(arg) for arg in args]
    nd = len(_broadcast_shapes(*shapes))
    return [lax.reshape(arg, (1,) * (nd - len(shp)) + shp)
            if len(shp) != nd else arg for arg, shp in zip(args, shapes)]
Ejemplo n.º 13
0
def squeeze(a, axis=None):
  if 1 not in shape(a):
    return a
  if axis is None:
    newshape = [d for d in shape(a) if d != 1]
  else:
    axis = frozenset(onp.mod(axis, ndim(a)).reshape(-1))
    newshape = [d for i, d in enumerate(shape(a))
                if d != 1 or i not in axis]
  return lax.reshape(a, newshape)
Ejemplo n.º 14
0
def matmul(a, b):  # pylint: disable=missing-docstring
    _check_arraylike("matmul", a, b)
    a_is_vec, b_is_vec = (ndim(a) == 1), (ndim(b) == 1)
    a = lax.reshape(a, (1, ) + shape(a)) if a_is_vec else a
    b = lax.reshape(b, shape(b) + (1, )) if b_is_vec else b

    a, b = _promote_dtypes(a, b)
    batch_shape = _broadcast_shapes(shape(a)[:-2], shape(b)[:-2])
    a = broadcast_to(a, batch_shape + shape(a)[-2:])
    b = broadcast_to(b, batch_shape + shape(b)[-2:])
    batch_dims = tuple(range(len(batch_shape)))
    result = lax.dot_general(a, b, (((ndim(a) - 1, ), (ndim(b) - 2, )),
                                    (batch_dims, batch_dims)))

    if a_is_vec or b_is_vec:
        m, n = shape(result)[-2:]
        new_m = () if a_is_vec else (m, )
        new_n = () if b_is_vec else (n, )
        return lax.reshape(result, batch_shape + new_m + new_n)
    else:
        return result
Ejemplo n.º 15
0
    def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch,
                                        n_dense, dimensions):
        rng = jtu.rand_some_zero(self.rng())
        arr = rng(shape, 'int32')
        arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense)

        f = self.sparsify(
            lambda x: lax.reshape(x, new_shape, dimensions=dimensions))

        arr2 = f(arr)
        arr2_sparse = f(arr_sparse)

        self.assertArraysEqual(arr2, arr2_sparse.todense())
Ejemplo n.º 16
0
def reshape(a, newshape, order="C"):  # pylint: disable=missing-docstring
  if order == "C" or order is None:
    dims = None
  elif order == "F":
    dims = onp.arange(ndim(a))[::-1]
  elif order == "A":
    dims = onp.arange(ndim(a))[::-1] if isfortran(a) else onp.arange(ndim(a))
  else:
    raise ValueError("Unexpected value for 'order' argument: {}.".format(order))

  dummy_val = onp.broadcast_to(0, a.shape)  # zero strides
  computed_newshape = onp.reshape(dummy_val, newshape).shape
  return lax.reshape(a, computed_newshape, dims)
Ejemplo n.º 17
0
 def chooser_taylor_rule(primals_in, series_in, **params):
   operand, = primals_in
   gs, = series_in
   primal_out = chooser_fun(operand, **params)
   axes = params.pop("axes", None)
   primal_dtype = gs[0].dtype
   shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
   location_indicators = lax.convert_element_type(
         lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype)
   counts = lax._reduce_sum(location_indicators, axes)
   def _reduce_chooser_taylor_rule(g):
     return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
   series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
   return primal_out, series_out
Ejemplo n.º 18
0
def dot(a, b):  # pylint: disable=missing-docstring
    _check_arraylike("dot", a, b)
    a, b = _promote_dtypes(a, b)
    a_ndim, b_ndim = ndim(a), ndim(b)
    if a_ndim == 0 or b_ndim == 0:
        return lax.mul(a, b)
    if _max(a_ndim, b_ndim) <= 2:
        return lax.dot(a, b)
    a_reshaped = reshape(a, (-1, shape(a)[-1]))
    if _ndim(b) in {1, 2}:
        out = lax.dot(a_reshaped, b)
    else:
        b_reshaped = reshape(moveaxis(b, -2, 0), (shape(b)[-2], -1))
        out = lax.dot(a_reshaped, b_reshaped)
    return lax.reshape(out, a.shape[:-1] + b.shape[:-2] + b.shape[-2:][1:])
Ejemplo n.º 19
0
  def reduction(a, axis=None, dtype=None, out=None, keepdims=False):
    if out is not None:
      raise ValueError("reduction does not support `out` argument.")

    a = a if isinstance(a, ndarray) else asarray(a)
    dims = _reduction_dims(a, axis)
    result_dtype = _dtype(np_fun(onp.ones((), dtype=_dtype(a))))
    if _dtype(a) != result_dtype:
      a = lax.convert_element_type(a, result_dtype)
    result = lax.reduce(a, _reduction_init_val(a, init_val), op, dims)
    if keepdims:
      shape_with_singletons = lax.subvals(shape(a), zip(dims, (1,) * len(dims)))
      result = lax.reshape(result, shape_with_singletons)
    if dtype and onp.dtype(dtype) != onp.dtype(result_dtype):
      result = lax.convert_element_type(result, dtype)
    return result
Ejemplo n.º 20
0
def _reshape_papply_rule(name, vals, axes, new_sizes, dimensions, old_sizes):
    operand, = vals
    axis, = axes

    def filter_ones(xs):
        return filter(lambda x: x != 1, xs)

    def find_new_axis(old_axis, old_sizes, new_sizes):
        if len(filter_ones(new_sizes)) != len(filter_ones(old_sizes)):
            return None
        num_before = len(filter_ones(old_sizes[:old_axis]))
        sz = old_sizes[old_axis]
        for i, new_sz in enumerate(new_sizes):
            if num_before == 0:
                if new_sz == sz:
                    return i
                elif new_sz != 1:
                    return None
            elif new_sz != 1:
                num_before -= 1
        return None

    err = NotImplementedError(
        'papply of reshape that would change hidden dimension size')

    if dimensions is None:
        new_axis = find_new_axis(axis, old_sizes, new_sizes)
        if new_axis is not None:
            if (lax.prod(old_sizes[:axis]) != lax.prod(new_sizes[:new_axis])
                    or lax.prod(old_sizes[axis + 1:]) != lax.prod(
                        new_sizes[new_axis + 1:])):
                raise err
            new_sizes_ = new_sizes[:new_axis] + new_sizes[new_axis + 1:]
            return lax.reshape(operand, new_sizes_,
                               dimensions=dimensions), new_axis
        else:
            raise err
    else:
        raise NotImplementedError('papply of reshape with `dimensions`')
Ejemplo n.º 21
0
 def update_entry(arr, val, i, j):
     val = lax.reshape(val, [1, 1])
     return lax.dynamic_update_slice(arr, val, (i, j))
Ejemplo n.º 22
0
 def reshape_op(self, params, inputs):
     return lax.reshape(inputs, (inputs.shape[0], *self.target_shape))
Ejemplo n.º 23
0
 def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory):
   rng = rng_factory(self.rng())
   operand = rng(arg_shape, dtype)
   reshape = lambda x: lax.reshape(x, out_shape, permutation)
   check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.)
Ejemplo n.º 24
0
def softmax(attn_weights, norm_dims, dtype, softmax_hparams, quant_context):
    """Normalizes attention."""
    a = attn_weights

    def unquantized_softmax(a):
        a = lax.exp(
            a - jax.scipy.special.logsumexp(a, axis=norm_dims, keepdims=True))
        return a.astype(dtype)

    # Quantize intermediate activations with QuantOps.
    # Currently only supports unscaled floating-point formats.
    def quantized_softmax(a):
        # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing
        # intermediate values. Note this differs from the log-domain
        # implementation of softmax used above.
        quant_hparams = softmax_hparams.quant_hparams
        fp_quant_config = QuantOps.FloatQuant(is_scaled=False,
                                              fp_spec=quant_hparams.prec)
        quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config,
                                                 bounds=None)

        a = quant_ops.to_quantized(a, dtype=dtype)
        # Note that the max of a quantized vector is necessarily also quantized to
        # the same precision since the max of a vector must be an existing element
        # of the vector, so we don't need to explicitly insert a quantization
        # operator to the output of the max reduction.
        a_max = jnp.max(a, axis=norm_dims, keepdims=True)
        a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype)
        a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype)

        sum_exp_quantized_reduction = quantization.quantized_sum(
            a_exp,
            axis=norm_dims,
            keepdims=True,
            prec=quant_hparams.reduction_prec)
        sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction,
                                         dtype=dtype)

        inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp),
                                             dtype=dtype)
        a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype)

        return a_softmax.astype(dtype)

    # If no params, return accurate Softmax.
    if softmax_hparams == SoftmaxHParams(None, None,
                                         None) or softmax_hparams is None:
        return unquantized_softmax(a)

    # TODO(shivaniagrawal): Partial sum quantization (if enabled) will happen for
    # the entire training run, even before the global activation start step.
    if softmax_hparams.quant_hparams is not None:
        return lax.cond(quant_context.quantize_acts, quantized_softmax,
                        unquantized_softmax, a)

    # Approximated Softmax
    exp_hparams = softmax_hparams.exp_hparams
    recip_hparams = softmax_hparams.reciprocal_hparams

    # Substract max value from dimensions to be normalized.
    shape = jax.util.subvals(onp.shape(a),
                             zip(norm_dims, (1, ) * len(norm_dims)))
    dimadd = lambda x: lax.reshape(x, shape)
    # pylint: disable=protected-access
    amax = lax.reduce(a, lax_numpy._constant_like(a, -onp.inf), lax.max,
                      norm_dims)
    amax = lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))
    amax_singletons = dimadd(amax)
    asubmax = lax.sub(a, amax_singletons)

    # Calculate approximated exponential
    approx_exp = exponential(asubmax, dtype, exp_hparams)

    # If sum_high_bound: Upper clip bound for sum(exp(x-M)).
    asumexp = dimadd(
        lax.reduce(approx_exp, lax_numpy._constant_like(a, 0), lax.add,
                   norm_dims))

    if exp_hparams.sum_high_bound is not None and exp_hparams.sum_high_bound != 0:
        sum_low_bound = 1.
        if (exp_hparams.low_bound != 0) and exp_hparams.clip_and_subtract:
            sum_low_bound = 1 - onp.exp(exp_hparams.low_bound)
        asumexp = jnp.clip(asumexp, sum_low_bound, exp_hparams.sum_high_bound)

    # Approximation of reciprocal.
    arecip = reciprocal(asumexp, dtype, recip_hparams)
    return lax.mul(approx_exp, arecip).astype(dtype)
Ejemplo n.º 25
0
def _threefry_split(key, num) -> jnp.ndarray:
    counts = lax.iota(np.uint32, num * 2)
    return lax.reshape(threefry_2x32(key, counts), (num, 2))
Ejemplo n.º 26
0
def expand_dims(a, axis):
    shape = _shape(a)
    axis = axis % (ndim(a) + 1)  # pylint: disable=g-no-augmented-assignment
    return lax.reshape(a, shape[:axis] + (1, ) + shape[axis:])
Ejemplo n.º 27
0
 def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims,
                 rng_factory):
     rng = rng_factory(self.rng())
     op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions)
     self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)
Ejemplo n.º 28
0
 def flatten(x):
     return lax.reshape(x, (x.shape[0] * x.shape[1], ))
Ejemplo n.º 29
0
 def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions)
     self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)
Ejemplo n.º 30
0
def _rewriting_take(arr, idx, axis=0):
    """A function like numpy.take that handles boxes and rewrites to LAX."""

    # Handle special indexers: (), Ellipsis, slice(None), and None.
    # TODO(mattjj): don't compare empty tuple identity (though works for CPython)
    if idx is () or idx is Ellipsis or _is_slice_none(idx):  # pylint: disable=literal-comparison
        return arr
    elif idx is None:
        return expand_dims(arr, 0)

    # Handle int index
    _int = lambda aval: not aval.shape and onp.issubdtype(
        aval.dtype, onp.integer)
    try:
        abstract_idx = core.get_aval(idx)
    except TypeError:
        abstract_idx = None

    if isinstance(abstract_idx, ConcreteArray) and _int(abstract_idx):
        return lax.index_in_dim(arr, idx, axis, False)
    elif isinstance(abstract_idx, ShapedArray) and _int(abstract_idx):
        idx = mod(idx, arr.shape[axis])
        return lax.dynamic_index_in_dim(arr, idx, axis, False)

    # Handle slice index (only static, otherwise an error is raised)
    elif isinstance(idx, slice):
        if not _all(
                elt is None or isinstance(core.get_aval(elt), ConcreteArray)
                for elt in (idx.start, idx.stop, idx.step)):
            msg = (
                "Array slice indices must have static start/stop/step to be used "
                "with Numpy indexing syntax. Try lax.dynamic_slice instead.")
            raise IndexError(msg)
        else:
            start, limit, stride, needs_rev = _static_idx(idx, arr.shape[axis])
            result = lax.slice_in_dim(arr, start, limit, stride, axis=axis)
            return lax.rev(result, [axis]) if needs_rev else result

    # Handle non-advanced tuple indices by recursing once
    elif isinstance(idx, tuple) and _all(onp.ndim(elt) == 0 for elt in idx):
        canonical_idx = _canonicalize_tuple_index(arr, idx)
        result, axis = arr, 0
        for elt in (elt for elt in canonical_idx if elt is not None):
            result = _rewriting_take(result, elt, axis=axis)
            axis += isinstance(elt,
                               slice)  # advance axis index if not eliminated
        unexpanded_shape_itr = iter(result.shape)
        result_shape = tuple(1 if elt is None else next(unexpanded_shape_itr)
                             for elt in canonical_idx
                             if not isinstance(elt, int))
        return lax.reshape(result, result_shape)

    # Handle advanced indexing (non-tuple sequence, ndarray of dtype int or bool,
    # or a tuple with at least one sequence object).
    # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
    # https://gist.github.com/seberg/976373b6a2b7c4188591

    # Handle integer array indexing *without* ellipsis/slices/nones
    # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing
    if _is_advanced_int_indexer_without_slices(idx):
        if isinstance(idx, list):
            if _any(_shape(e) for e in idx):
                # At least one sequence element in the index list means broadcasting.
                idx = broadcast_arrays(*idx)
            else:
                # The index list is a flat list of integers.
                idx = [
                    lax.concatenate([lax.reshape(e, (1, )) for e in idx], 0)
                ]
        else:
            # The indexer is just a single integer array.
            idx = [idx]

        flat_idx = tuple(
            mod(ravel(x), arr.shape[i]) for i, x in enumerate(idx))
        out = lax.index_take(arr, flat_idx, tuple(range(len(idx))))
        return lax.reshape(out, idx[0].shape + _shape(arr)[len(idx):])

    # Handle integer array indexing *with* ellipsis/slices/nones by recursing once
    # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
    elif _is_advanced_int_indexer(idx):
        canonical_idx = _canonicalize_tuple_index(arr, tuple(idx))
        idx_noadvanced = [
            slice(None) if _is_int(e) else e for e in canonical_idx
        ]
        arr_sliced = _rewriting_take(arr, tuple(idx_noadvanced))

        advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx)
                          if _is_int(e))
        idx_advanced, axes = zip(*advanced_pairs)
        idx_advanced = broadcast_arrays(*idx_advanced)

        flat_idx = tuple(
            mod(ravel(x), arr_sliced.shape[i])
            for i, x in zip(axes, idx_advanced))
        out = lax.index_take(arr_sliced, flat_idx, axes)
        shape_suffix = tuple(onp.delete(_shape(arr_sliced), axes))
        out = lax.reshape(out, idx_advanced[0].shape + shape_suffix)

        axes_are_contiguous = onp.all(onp.diff(axes) == 1)
        if axes_are_contiguous:
            start = axes[0]
            naxes = idx_advanced[0].ndim
            out = moveaxis(out, list(range(naxes)),
                           list(range(start, start + naxes)))
        return out

    msg = "Indexing mode not yet supported. Open a feature request!\n{}"
    raise IndexError(msg.format(idx))