Пример #1
0
def _mul_sparse(spenv, *argspecs):
    X, Y = argspecs
    if X.is_sparse() and Y.is_sparse():
        if X.shape != Y.shape:
            raise NotImplementedError(
                "Multiplication between sparse matrices of different shapes.")
        if X.indices_ref == Y.indices_ref:
            out_data = lax.mul(X.data(spenv), Y.data(spenv))
            out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref)
        elif X.indices(spenv).ndim != Y.indices(spenv).ndim or X.data(
                spenv).ndim != Y.data(spenv).ndim:
            raise NotImplementedError(
                "Multiplication between sparse matrices with different batch/dense dimensions."
            )
        else:
            raise NotImplementedError(
                "Multiplication between sparse matrices with different sparsity patterns."
            )
    else:
        if Y.is_sparse():
            X, Y = Y, X
        Ydata = Y.data(spenv)
        if Ydata.ndim == 0:
            out_data = lax.mul(X.data(spenv), Ydata)
        elif Ydata.shape == X.shape:
            out_data = lax.mul(X.data(spenv),
                               sparse.bcoo_extract(X.indices(spenv), Ydata))
        else:
            raise NotImplementedError(
                "Multiplication between sparse and dense matrices of different shape."
            )
        out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref)

    return (out_argspec, )
Пример #2
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale)
    pi = _constant_like(x, np.pi)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.mul(pi, scale))
    return lax.neg(
        lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
Пример #3
0
 def cfun(x):
   return lax.cond(
       lax.lt(x, 2),
       x, lambda x: lax.mul(2, x),
       x, lambda x: lax.cond(lax.lt(x, 5),
                             x, lambda x: lax.mul(3, x),
                             4, lambda y: lax.mul(y, x)))
Пример #4
0
 def inner_cond(x):
   return lax.cond(
       lax.lt(x, 5),
       x,
       lambda x: lax.mul(3, x),
       4,
       lambda y: lax.mul(y, x),
   )
Пример #5
0
def cdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale)
    half = _constant_like(x, 0.5)
    one = _constant_like(x, 1)
    zero = _constant_like(x, 0)
    diff = lax.div(lax.sub(x, loc), scale)
    return lax.select(lax.le(diff, zero), lax.mul(half, lax.exp(diff)),
                      lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
Пример #6
0
 def fun(x):
     if x < 2:
         return lax.mul(2, x)
     else:
         if x < 5:
             return lax.mul(3, x)
         else:
             return lax.mul(4, x)
Пример #7
0
def _logaddexp2_jvp(primals, tangents):
  x1, x2 = primals
  t1, t2 = tangents
  x1, x2, t1, t2 = _promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2)
  primal_out = logaddexp2(x1, x2)
  tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
                        lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
  return primal_out, tangent_out
Пример #8
0
 def fun(x):
   res = 0
   if x < 2:
     res = lax.mul(2, x)
   else:
     if x < 5:
       res = lax.mul(3, x)
     else:
       res = lax.mul(4, x)
   return res
Пример #9
0
def multigammaln(a, d):
  d = core.concrete_or_error(int, d, "d argument of multigammaln")
  a, d = _promote_args_inexact("multigammaln", a, d)

  constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d),
                             lax.sub(d, _constant_like(a, 1))),
                     lax.log(_constant_like(a, np.pi)))
  res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) -
                        lax.div(jnp.arange(d), _constant_like(a, 2))),
               axis=-1)
  return res + constant
Пример #10
0
def logpdf(x, a, b, loc=0, scale=1):
    x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc,
                                                scale)
    one = _constant_like(x, 1)
    shape_term = lax.neg(betaln(a, b))
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.add(lax.mul(lax.sub(a, one), lax.log(y)),
                              lax.mul(lax.sub(b, one), lax.log1p(lax.neg(y))))
    log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)
Пример #11
0
def logpdf(x, df, loc=0, scale=1):
  x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale)
  two = _lax_const(x, 2)
  scaled_x = lax.div(lax.sub(x, loc), scale)
  df_over_two = lax.div(df, two)
  df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5))
  normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi))
  normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two)
  normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp),
                           lax.lgamma(df_plus_one_over_two))
  quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
  return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
Пример #12
0
def multigammaln(a, d):
    a, = _promote_args_inexact("multigammaln", a)
    d = lax.convert_element_type(d, lax.dtype(a))
    constant = lax.mul(
        lax.mul(lax.mul(_constant_like(a, 0.25), d),
                lax.sub(d, _constant_like(a, 1))),
        lax.log(_constant_like(a, np.pi)))
    res = jnp.sum(gammaln(
        jnp.expand_dims(a, axis=-1) -
        lax.div(jnp.arange(d), _constant_like(a, 2))),
                  axis=-1)
    return res + constant
Пример #13
0
def multigammaln(a, d):
    d = core.concrete_or_error(int, d, "d argument of multigammaln")
    a, d_ = _promote_args_inexact("multigammaln", a, d)

    constant = lax.mul(
        lax.mul(lax.mul(_lax_const(a, 0.25), d_),
                lax.sub(d_, _lax_const(a, 1))), lax.log(_lax_const(a, np.pi)))
    b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2))
    res = jnp.sum(gammaln(
        jnp.expand_dims(a, axis=-1) -
        jnp.expand_dims(b, axis=tuple(range(a.ndim)))),
                  axis=-1)
    return res + constant
Пример #14
0
def logpdf(x, df, loc=0, scale=1):
    x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale)
    one = _constant_like(x, 1)
    two = _constant_like(x, 2)
    y = lax.div(lax.sub(x, loc), scale)
    df_on_two = lax.div(df, two)

    kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))

    nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))

    log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
    return where(lax.lt(x, loc), -inf, log_probs)
Пример #15
0
def _mul_sparse(spenv, *spvalues):
    X, Y = spvalues
    if X.is_sparse() and Y.is_sparse():
        if X.indices_ref == Y.indices_ref and X.unique_indices:
            if config.jax_enable_checks:
                assert X.indices_sorted == Y.indices_sorted
                assert X.unique_indices == Y.unique_indices
            out_data = lax.mul(spenv.data(X), spenv.data(Y))
            out_spvalue = spenv.sparse(X.shape,
                                       out_data,
                                       indices_ref=X.indices_ref,
                                       indices_sorted=X.indices_sorted,
                                       unique_indices=True)
        else:
            X_promoted, Y_promoted = spvalues_to_arrays(spenv, spvalues)
            mat = bcoo_multiply_sparse(X_promoted, Y_promoted)
            out_spvalue = spenv.sparse(mat.shape, mat.data, mat.indices)
    else:
        if Y.is_sparse():
            X, Y = Y, X
        X_promoted = spvalues_to_arrays(spenv, X)
        out_data = bcoo_multiply_dense(X_promoted, spenv.data(Y))
        out_spvalue = spenv.sparse(X.shape,
                                   out_data,
                                   indices_ref=X.indices_ref,
                                   indices_sorted=X.indices_sorted,
                                   unique_indices=X.unique_indices)

    return (out_spvalue, )
Пример #16
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale)
    two = _constant_like(x, 2)
    scale_sqrd = lax.pow(scale, two)
    log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * np.pi), scale_sqrd))
    quadratic = lax.div(lax.pow(lax.sub(x, loc), two), scale_sqrd)
    return lax.div(lax.neg(lax.add(log_normalizer, quadratic)), two)
Пример #17
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    if b is None:
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b),
                         axis=dims,
                         keepdims=keepdims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return out
Пример #18
0
def nanvar(a,
           axis: Optional[Union[int, Tuple[int, ...]]] = None,
           dtype=None,
           out=None,
           ddof=0,
           keepdims=False,
           where=None):
    _check_arraylike("nanvar", a)
    lax_internal._check_user_dtype_supported(dtype, "nanvar")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanvar is not supported.")

    a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where)

    centered = _where(lax_internal._isnan(a), 0,
                      a - a_mean)  # double-where trick for gradients.
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
                     axis=axis,
                     keepdims=keepdims,
                     where=where)
    normalizer = normalizer - ddof
    normalizer_mask = lax.le(normalizer, 0)
    result = sum(centered, axis, keepdims=keepdims, where=where)
    result = _where(normalizer_mask, np.nan, result)
    divisor = _where(normalizer_mask, 1, normalizer)
    out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
    return lax.convert_element_type(out, dtype)
Пример #19
0
def var(a, axis=None, keepdims=False, ddof=0):
    if ddof != 0:
        raise NotImplementedError("Only implemented for ddof=0.")
    centered = subtract(a, mean(a, axis, keepdims=True))
    if iscomplexobj(centered):
        centered = lax.abs(centered)
    return mean(lax.mul(centered, centered), axis, keepdims=keepdims)
Пример #20
0
def xlog1py(x, y):
    x, y = _promote_args_inexact("xlog1py", x, y)
    x_ok = x != 0.
    safe_x = jnp.where(x_ok, x, 1.)
    safe_y = jnp.where(x_ok, y, 1.)
    return jnp.where(x_ok, lax.mul(safe_x, lax.log1p(safe_y)),
                     jnp.zeros_like(x))
Пример #21
0
def _mul_sparse(spenv, *argspecs):
    X, Y = argspecs
    if X.is_sparse() and Y.is_sparse():
        if X.indices_ref == Y.indices_ref:
            # TODO(jakevdp): this is inaccurate if there are duplicate indices
            out_data = lax.mul(X.data(spenv), Y.data(spenv))
            out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref)
        else:
            data, indices, shape = bcoo_multiply_sparse(X.data(spenv),
                                                        X.indices(spenv),
                                                        Y.data(spenv),
                                                        Y.indices(spenv),
                                                        lhs_shape=X.shape,
                                                        rhs_shape=Y.shape)
            out_argspec = ArgSpec(shape, spenv.push(data), spenv.push(indices))
    else:
        if Y.is_sparse():
            X, Y = Y, X
        out_data = bcoo_multiply_dense(X.data(spenv),
                                       X.indices(spenv),
                                       Y.data(spenv),
                                       shape=X.shape)
        out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref)

    return (out_argspec, )
Пример #22
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = jnp.broadcast_arrays(a, b)
    dims = _reduction_dims(a, axis)
    dimadd = lambda x: lax.expand_dims(x, dims)
    amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_singletons = dimadd(amax)
    if b is None:
        out = lax.add(
            lax.log(
                lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
                           _constant_like(a, 0), lax.add, dims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b),
                            _constant_like(a, 0), lax.add, dims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (dimadd(out), dimadd(sign)) if keepdims else (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return dimadd(out) if keepdims else out
Пример #23
0
def exponential(tensor, dtype, exp_hparams):
    """Calculates an exponential approximation based on exp hyper params."""
    # If low_bound defined, it clips x-M.
    if exp_hparams.low_bound != 0:
        tensor = jnp.clip(tensor, exp_hparams.low_bound, 0.)

    # TODO(luispazos) Use standard calls to top level jnp functions.
    # pylint: disable=protected-access
    def make_constant(c):
        return lax_numpy._constant_like(tensor, c).astype(dtype)

    # If clip_and_subtract, replace exp(clip(x-M,low_bound)) term with
    # exp(clip(x-M,low_bound))-exp(low_bound).'
    if exp_hparams.clip_and_subtract:
        tensor = lax.sub(tensor, make_constant(onp.exp(exp_hparams.low_bound)))
    # If linear_gradient: use this gradient as linear approximation of
    # exponential.
    if exp_hparams.linear_gradient is not None and exp_hparams.linear_gradient != 0:
        # Want: max(0, a*x+b) such that a*x+b goes through (0, 1).
        #
        # This comes out to: max(0, a*x+1), for arbitrary a>0.
        one = jnp.full(tensor.shape, 1.).astype(dtype)
        gradient = jnp.full(tensor.shape,
                            exp_hparams.linear_gradient).astype(dtype)
        approx_exp = jnp.clip(lax.add(lax.mul(tensor, gradient), one), 0, 1)

    else:
        approx_exp = lax.exp(tensor)

    return approx_exp
Пример #24
0
def sinc(x):
    _check_arraylike("sinc", x)
    x, = _promote_dtypes_inexact(x)
    eq_zero = lax.eq(x, _lax_const(x, 0))
    pi_x = lax.mul(_lax_const(x, np.pi), x)
    safe_pi_x = _where(eq_zero, _lax_const(x, 1), pi_x)
    return _where(eq_zero, _sinc_maclaurin(0, pi_x),
                  lax.div(lax.sin(safe_pi_x), safe_pi_x))
Пример #25
0
def logpdf(x, a, loc=0, scale=1):
  x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale)
  one = _constant_like(x, 1)
  y = lax.div(lax.sub(x, loc), scale)
  log_linear_term = lax.sub(lax.mul(lax.sub(a, one), lax.log(y)), y)
  shape_terms = lax.add(gammaln(a), lax.log(scale))
  log_probs = lax.sub(log_linear_term, shape_terms)
  return where(lax.lt(x, loc), -inf, log_probs)
Пример #26
0
def logpdf(x, b, loc=0, scale=1):
    x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale)
    one = _constant_like(x, 1)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.div(scale, b))
    log_probs = lax.neg(
        lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x))))
    return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs)
Пример #27
0
def round(a, decimals=0):
    if onp.issubdtype(_dtype(a), onp.integer):
        return a  # no-op on integer types

    if decimals == 0:
        return lax.round(a)

    factor = _constant_like(a, 10**decimals)
    return lax.div(lax.round(lax.mul(a, factor)), factor)
Пример #28
0
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
    """Sample uniform random bits of given width and shape using PRNG key."""
    if not _is_threefry_prng_key(key):
        raise TypeError("_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    shape = core.as_named_shape(shape)
    for name, size in shape.named_items:
        real_size = lax.psum(1, name)
        if real_size != size:
            raise ValueError(
                f"The shape of axis {name} was specified as {size}, "
                f"but it really is {real_size}")
        axis_index = lax.axis_index(name)
        key = threefry_fold_in(key, axis_index)
    size = prod(shape.positional)
    # Compute ceil(bit_width * size / 32) in a way that is friendly to shape
    # polymorphism
    max_count, r = divmod(bit_width * size, 32)
    if r > 0:
        max_count += 1

    if core.is_constant_dim(max_count):
        nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
    else:
        nblocks, rem = 0, max_count

    if not nblocks:
        bits = threefry_2x32(key, lax.iota(np.uint32, rem))
    else:
        keys = threefry_split(key, nblocks + 1)
        subkeys, last_key = keys[:-1], keys[-1]
        blocks = vmap(threefry_2x32,
                      in_axes=(0, None))(subkeys,
                                         lax.iota(np.uint32,
                                                  jnp.iinfo(np.uint32).max))
        last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
        bits = lax.concatenate([blocks.ravel(), last], 0)

    dtype = UINT_DTYPES[bit_width]
    if bit_width == 64:
        bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
        bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
    elif bit_width in [8, 16]:
        # this is essentially bits.view(dtype)[:size]
        bits = lax.bitwise_and(
            np.uint32(np.iinfo(dtype).max),
            lax.shift_right_logical(
                lax.broadcast(bits, (1, )),
                lax.mul(
                    np.uint32(bit_width),
                    lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0))))
        bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ),
                           (1, 0))
        bits = lax.convert_element_type(bits, dtype)[:size]
    return lax.reshape(bits, shape)
Пример #29
0
def _power(x1, x2):
  x1, x2 = _promote_args("power", x1, x2)
  dtype = dtypes.dtype(x1)
  if not dtypes.issubdtype(dtype, np.integer):
    return lax.pow(x1, x2)

  # Integer power => use binary exponentiation.

  # TODO(phawkins): add integer pow support to XLA.
  bits = 6  # Anything more would overflow for any x1 > 1
  zero = _constant_like(x2, 0)
  one = _constant_like(x2, 1)
  # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
  acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
  for _ in range(bits):
    acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
    x1 = lax.mul(x1, x1)
    x2 = lax.shift_right_logical(x2, one)
  return acc
Пример #30
0
    def cfun(x):
      def inner_cond(x):
        return lax.cond(
            lax.lt(x, 5),
            x,
            lambda x: lax.mul(3, x),
            4,
            lambda y: lax.mul(y, x),
        )

      return lax.cond(lax.lt(x, 2), x, lambda x: lax.mul(2, x), x, inner_cond)