Exemple #1
0
def lu_pivots_to_permutation(swaps, m):
    """Converts the pivots (row swaps) returned by LU to a permutation.

  We build a permutation rather than applying `swaps` directly to the rows
  of a matrix because lax loops aren't differentiable.

  Args:
    swaps: an array of shape (..., k) of row swaps to perform
    m: the size of the output permutation. m should be >= k.
  Returns:
    An int32 array of shape (..., m).
  """
    assert len(swaps.shape) >= 1
    batch_dims = swaps.shape[:-1]
    k = swaps.shape[-1]

    def body_fn(i, permutation):
        j = swaps[..., i]
        iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims))
        x = permutation[..., i]
        y = permutation[iotas + (j, )]
        permutation = ops.index_update(permutation, ops.index[..., i], y)
        return ops.index_update(permutation, ops.index[iotas + (j, )], x)

    permutation = lax.broadcasted_iota(np.int32, batch_dims + (m, ),
                                       len(batch_dims))
    return lax.fori_loop(onp.array(0, onp.int32), onp.array(k, onp.int32),
                         body_fn, permutation)
Exemple #2
0
Fichier : ann.py Projet : gtr8/jax
def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension,
                      recall_target, is_max_k, reduction_input_size_override,
                      aggregate_to_topk):
    operand, = primals
    tangent, = tangents
    if is_max_k:
        val_out, arg_out = approx_max_k(operand, k, reduction_dimension,
                                        recall_target,
                                        reduction_input_size_override,
                                        aggregate_to_topk)
    else:
        val_out, arg_out = approx_min_k(operand, k, reduction_dimension,
                                        recall_target,
                                        reduction_input_size_override,
                                        aggregate_to_topk)
    if type(tangent) is ad_util.Zero:
        tangent_out = ad_util.Zero.from_value(val_out)
    else:
        arg_shape = arg_out.shape
        rank = len(arg_shape)
        if reduction_dimension < 0:
            reduction_dimension += rank
        iotas = [
            lax.broadcasted_iota(arg_out.dtype, arg_shape, i)
            for i in range(rank)
        ]
        idx = tuple(arg_out if i == reduction_dimension else iotas[i]
                    for i in range(rank))
        tangent_out = tangent[idx]
    return (val_out, arg_out), (tangent_out, ad_util.Zero.from_value(arg_out))
Exemple #3
0
def svd(a,
        full_matrices: bool = True,
        compute_uv: bool = True,
        hermitian: bool = False):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    if hermitian:
        w, v = lax_linalg.eigh(a)
        s = lax.abs(v)
        if compute_uv:
            sign = lax.sign(v)
            idxs = lax.broadcasted_iota(np.int64,
                                        s.shape,
                                        dimension=s.ndim - 1)
            s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1)
            s = lax.rev(s, dimensions=[s.ndim - 1])
            idxs = lax.rev(idxs, dimensions=[s.ndim - 1])
            sign = lax.rev(sign, dimensions=[s.ndim - 1])
            u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
            vh = _H(u * sign[..., None, :])
            return u, s, vh
        else:
            return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim - 1])

    return lax_linalg.svd(a,
                          full_matrices=full_matrices,
                          compute_uv=compute_uv)
Exemple #4
0
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
    """Sample uniform random bits of given width and shape using PRNG key."""
    if not _is_threefry_prng_key(key):
        raise TypeError("_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    shape = core.as_named_shape(shape)
    for name, size in shape.named_items:
        real_size = lax.psum(1, name)
        if real_size != size:
            raise ValueError(
                f"The shape of axis {name} was specified as {size}, "
                f"but it really is {real_size}")
        axis_index = lax.axis_index(name)
        key = threefry_fold_in(key, axis_index)
    size = prod(shape.positional)
    # Compute ceil(bit_width * size / 32) in a way that is friendly to shape
    # polymorphism
    max_count, r = divmod(bit_width * size, 32)
    if r > 0:
        max_count += 1

    if core.is_constant_dim(max_count):
        nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
    else:
        nblocks, rem = 0, max_count

    if not nblocks:
        bits = threefry_2x32(key, lax.iota(np.uint32, rem))
    else:
        keys = threefry_split(key, nblocks + 1)
        subkeys, last_key = keys[:-1], keys[-1]
        blocks = vmap(threefry_2x32,
                      in_axes=(0, None))(subkeys,
                                         lax.iota(np.uint32,
                                                  jnp.iinfo(np.uint32).max))
        last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
        bits = lax.concatenate([blocks.ravel(), last], 0)

    dtype = UINT_DTYPES[bit_width]
    if bit_width == 64:
        bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
        bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
    elif bit_width in [8, 16]:
        # this is essentially bits.view(dtype)[:size]
        bits = lax.bitwise_and(
            np.uint32(np.iinfo(dtype).max),
            lax.shift_right_logical(
                lax.broadcast(bits, (1, )),
                lax.mul(
                    np.uint32(bit_width),
                    lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0))))
        bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ),
                           (1, 0))
        bits = lax.convert_element_type(bits, dtype)[:size]
    return lax.reshape(bits, shape)
Exemple #5
0
def _mask(x, dims, alternative=0):
  """Masks `x` up to the dynamic shape `dims`.

  Replaces values outside those dimensions with `alternative`. `alternative` is
  broadcast with `x`.
  """
  assert jnp.ndim(x) == len(dims)
  mask = None
  for i, d in enumerate(dims):
    if d is not None:
      mask_dim_i = lax.broadcasted_iota(jnp.int32, x.shape, i) < d
      mask = mask_dim_i if mask is None else (mask & mask_dim_i)
  return x if mask is None else jnp.where(mask, x, alternative)
Exemple #6
0
def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
    """
    Helper function for intersect1d which is jit-able
    """
    ar = concatenate((ar1, ar2))
    if return_indices:
        iota = lax.broadcasted_iota(np.int64, np.shape(ar), dimension=0)
        aux, indices = lax.sort_key_val(ar, iota)
    else:
        aux = sort(ar)

    mask = aux[1:] == aux[:-1]
    if return_indices:
        return aux, mask, indices
    else:
        return aux, mask
Exemple #7
0
def lu_pivots_to_permutation(swaps, m):
    """Converts the pivots (row swaps) returned by LU to a permutation.

  We build a permutation rather than applying `swaps` directly to the rows
  of a matrix because lax loops aren't differentiable.

  Args:
    swaps: an array of shape (..., k) of row swaps to perform
    m: the size of the output permutation. m should be >= k.
  Returns:
    An int32 array of shape (..., m).
  """
    assert len(swaps.shape) >= 1
    batch_dims = swaps.shape[:-1]
    k = swaps.shape[-1]

    permutation = lax.broadcasted_iota(np.int32, batch_dims + (m, ),
                                       len(batch_dims))
    result, _ = lax.fori_loop(onp.array(0, onp.int32), onp.array(k, onp.int32),
                              _lu_pivots_body_fn, (permutation, swaps))
    return result
Exemple #8
0
def _one_hot(x: Array, num_classes: int, *, dtype: Any,
             axis: Union[int, AxisName]) -> Array:
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)  # type: ignore[arg-type]
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis)
    return jnp.asarray(lhs == rhs, dtype=dtype)