Beispiel #1
0
def threefry_seed(seed: int) -> jnp.ndarray:
    """Create a single raw threefry PRNG key given an integer seed.

  Args:
    seed: a 64- or 32-bit integer used as the value of the key.

  Returns:
    The PRNG key contents, modeled as an array of shape (2,) and dtype
    uint32. The key is constructed from a 64-bit seed by effectively
    bit-casting to a pair of uint32 values (or from a 32-bit seed by
    first padding out with zeros).
  """
    # Avoid overflowerror in X32 mode by first converting ints to int64.
    # This breaks JIT invariance for large ints, but supports the common
    # use-case of instantiating with Python hashes in X32 mode.
    if isinstance(seed, int):
        seed_arr = jnp.asarray(np.int64(seed))
    else:
        seed_arr = jnp.asarray(seed)
    if seed_arr.shape:
        raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.")
    if not np.issubdtype(seed_arr.dtype, np.integer):
        raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")

    convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32),
                                    [1])
    k1 = convert(
        lax.shift_right_logical(seed_arr, lax_internal._const(seed_arr, 32)))
    k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF)))
    return lax.concatenate([k1, k2], 0)
Beispiel #2
0
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches,
                        linear):
    index, *ops = args
    index_dim, *op_dims = dims

    if index_dim is not batching.not_mapped:
        # Convert to a lax.select. While we could get away with not broadcasting
        # some operands yet, because all outputs must be broadcast together anyway
        # for the select we broadcast the input operands for simplicity and leave
        # optimizations to XLA.
        # TODO(mattjj,frostig): assumes branches are side-effect-free, revise!
        index, *ops = (batching.bdim_at_front(x, d, axis_size)
                       for x, d in zip(args, dims))

        in_batched = [True] * len(branches[0].in_avals)
        out_batched = [True] * len(branches[0].out_avals)

        branches_batched = [
            batching.batch_jaxpr(jaxpr, axis_size, in_batched, out_batched,
                                 axis_name, main_type)[0] for jaxpr in branches
        ]

        branch_outs = []
        for i, jaxpr in enumerate(branches_batched):
            # Perform a select on the inputs for safety of reverse-mode autodiff; see
            # https://github.com/google/jax/issues/1052
            predicate = lax.eq(index, lax._const(index, i))
            ops_ = [
                _bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops
            ]
            branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_))
        out = [_bcast_select_n(index, *outs) for outs in zip(*branch_outs)]
        return out, [0 if b else None for b in out_batched]
    else:
        ops_bat = [d is not batching.not_mapped for d in op_dims]
        ops = [
            batching.moveaxis(x, d, 0) if b else x
            for b, x, d in zip(ops_bat, ops, op_dims)
        ]

        branches_out_bat = [
            batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name,
                                 main_type)[1] for jaxpr in branches
        ]
        out_bat = [any(bat) for bat in zip(*branches_out_bat)]
        branches_batched = tuple(
            batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name,
                                 main_type)[0] for jaxpr in branches)

        out_dims = [0 if b else batching.not_mapped for b in out_bat]
        out = cond_p.bind(index,
                          *ops,
                          branches=branches_batched,
                          linear=linear)
        return out, out_dims
Beispiel #3
0
    def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'):
        if k > 0:
            diag_size = min(N, M - k)
        else:
            diag_size = min(N + k, M)

        if diag_size <= 0:
            # if k is out of range, return an empty matrix.
            return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype)

        k = jnp.asarray(k)
        data = jnp.ones(diag_size, dtype=dtype)
        idx = jnp.arange(diag_size, dtype=index_dtype)
        zero = _const(idx, 0)
        k = _const(idx, k)
        row = lax.sub(idx, lax.cond(k >= 0, lambda: zero, lambda: k))
        col = lax.add(idx, lax.cond(k <= 0, lambda: zero, lambda: k))
        return cls((data, row, col),
                   shape=(N, M),
                   rows_sorted=True,
                   cols_sorted=True)
Beispiel #4
0
    def _eye(cls, N, M, k, *, dtype=None, index_dtype='int32'):
        if k > 0:
            diag_size = min(N, M - k)
        else:
            diag_size = min(N + k, M)

        if diag_size <= 0:
            # if k is out of range, return an empty matrix.
            return cls._empty((N, M), dtype=dtype, index_dtype=index_dtype)

        k = jnp.asarray(k)
        data = jnp.ones(diag_size, dtype=dtype)
        idx = jnp.arange(diag_size, dtype=index_dtype)
        zero = _const(idx, 0)
        k = _const(idx, k)
        col = lax.add(idx, lax.cond(k <= 0, lambda: zero, lambda: k))
        indices = col.astype(index_dtype)
        # TODO(jakevdp): this can be done more efficiently.
        row = lax.sub(idx, lax.cond(k >= 0, lambda: zero, lambda: k))
        indptr = jnp.zeros(N + 1, dtype=index_dtype).at[1:].set(
            jnp.cumsum(jnp.bincount(row, length=N).astype(index_dtype)))
        return cls((data, indices, indptr), shape=(N, M))
Beispiel #5
0
def _atan2_taylor(primals_in, series_in):
    x, y = primals_in
    primal_out = lax.atan2(x, y)

    x, series = jet(lax.div, primals_in, series_in)
    one = lax_internal._const(x, 1)
    c0, cs = jet(lambda x: lax.div(one, 1 + lax.square(x)), (x, ), (series, ))
    c = [c0] + cs
    u = [x] + series
    v = [primal_out] + [None] * len(series)
    for k in range(1, len(v)):
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))
    primal_out, *series_out = v
    return primal_out, series_out
Beispiel #6
0
def _reduce_window_prod(operand: Array, window_dimensions: core.Shape,
                        window_strides: Sequence[int],
                        padding: Sequence[Tuple[int, int]],
                        base_dilation: Optional[Sequence[int]] = None,
                        window_dilation: Optional[Sequence[int]] = None) -> Array:
  init_value = lax._const(operand, 1)
  jaxpr, consts = lax._reduction_jaxpr(lax.mul, lax._abstractify(init_value))
  if base_dilation is None:
    base_dilation = (1,) * len(window_dimensions)
  if window_dilation is None:
    window_dilation = (1,) * len(window_dimensions)
  out, = reduce_window_p.bind(
      operand, init_value, jaxpr=jaxpr, consts=consts,
      window_dimensions=tuple(window_dimensions),
      window_strides=tuple(window_strides), padding=tuple(padding),
      base_dilation=tuple(base_dilation),
      window_dilation=tuple(window_dilation))
  return out
Beispiel #7
0
    x, = primals_in
    series, = series_in
    primal_out = prim.bind(x)
    c0, cs = jet(deriv, primals_in, series_in)
    c = [c0] + cs
    u = [x] + series
    v = [primal_out] + [None] * len(series)
    for k in range(1, len(v)):
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))
    primal_out, *series_out = v
    return primal_out, series_out


def_deriv(
    lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)),
                                 lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
Beispiel #8
0
def deriv_prop(prim, deriv, primals_in, series_in):
  x, = primals_in
  series, = series_in
  primal_out = prim.bind(x)
  c0, cs = jet(deriv, primals_in, series_in)
  c = [c0] + cs
  u = [x] + series
  v = [primal_out] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
  primal_out, *series_out = v
  return primal_out, series_out


def_deriv(lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
  """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
  jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x ** 0.5)
def_comp(lax.rsqrt_p, lambda x: x ** -0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
Beispiel #9
0
def norm(x,
         ord=None,
         axis: Union[None, Tuple[int, ...], int] = None,
         keepdims=False):
    x = _promote_arg_dtypes(jnp.asarray(x))
    x_shape = jnp.shape(x)
    ndim = len(x_shape)

    if axis is None:
        # NumPy has an undocumented behavior that admits arbitrary rank inputs if
        # `ord` is None: https://github.com/numpy/numpy/issues/14215
        if ord is None:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
        axis = tuple(range(ndim))
    elif isinstance(axis, tuple):
        axis = tuple(canonicalize_axis(x, ndim) for x in axis)
    else:
        axis = (canonicalize_axis(axis, ndim), )

    num_axes = len(axis)
    if num_axes == 1:
        if ord is None or ord == 2:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == jnp.inf:
            return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == -jnp.inf:
            return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == 0:
            return jnp.sum(x != 0,
                           dtype=jnp.finfo(lax.dtype(x)).dtype,
                           axis=axis,
                           keepdims=keepdims)
        elif ord == 1:
            # Numpy has a special case for ord == 1 as an optimization. We don't
            # really need the optimization (XLA could do it for us), but the Numpy
            # code has slightly different type promotion semantics, so we need a
            # special case too.
            return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims)
        else:
            abs_x = jnp.abs(x)
            ord = lax_internal._const(abs_x, ord)
            out = jnp.sum(abs_x**ord, axis=axis, keepdims=keepdims)
            return jnp.power(out, 1. / ord)

    elif num_axes == 2:
        row_axis, col_axis = cast(Tuple[int, ...], axis)
        if ord is None or ord in ('f', 'fro'):
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == 1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == -1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord == -jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord in ('nuc', 2, -2):
            x = jnp.moveaxis(x, axis, (-2, -1))
            if ord == 2:
                reducer = jnp.amax
            elif ord == -2:
                reducer = jnp.amin
            else:
                reducer = jnp.sum
            y = reducer(svd(x, compute_uv=False), axis=-1)
            if keepdims:
                y = jnp.expand_dims(y, axis)
            return y
        else:
            raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
    else:
        raise ValueError(
            "Invalid axis values ({}) for jnp.linalg.norm.".format(axis))