Exemplo n.º 1
0
def _ibp_integer_pow(x: PrimitiveInput, y: int) -> IntervalBound:
    """Propagation of IBP bounds through integer_pow.

  Args:
    x: Argument be raised to a power, element-wise
    y: fixed integer exponent

  Returns:
    out_bounds: integer_pow output or its bounds.
  """
    if y < 0:
        raise NotImplementedError
    l_pow = lax.integer_pow(x.lower, y)
    u_pow = lax.integer_pow(x.upper, y)

    if y % 2 == 0:
        # Even powers
        contains_zero = jnp.logical_and(jnp.less_equal(x.lower, 0),
                                        jnp.greater_equal(x.upper, 0))
        lower = jnp.where(contains_zero, jnp.zeros_like(x.lower),
                          jnp.minimum(l_pow, u_pow))
        upper = jnp.maximum(l_pow, u_pow)
        return IntervalBound(lower, upper)
    else:
        # Odd powers
        return IntervalBound(l_pow, u_pow)
Exemplo n.º 2
0
def power(x1, x2):
    # Special case for concrete integer scalars: use binary exponentiation.
    # Using lax.pow may be imprecise for floating-point values; the goal of this
    # code path is to make sure we end up with a precise output for the common
    # pattern ``x ** 2`` or similar.
    if isinstance(core.get_aval(x2), core.ConcreteArray):
        try:
            x2 = operator.index(x2)
        except TypeError:
            pass
        else:
            return lax.integer_pow(x1, x2)
    return _power(x1, x2)
Exemplo n.º 3
0
class JAX2DexTest(unittest.TestCase):
    test_sin = lax_test(lax.sin, lambda: (rn(10, 10), ))
    test_cos = lax_test(lax.cos, lambda: (rn(10, 10), ))
    test_neg = lax_test(lax.neg, lambda: (rn(10, 10), ))
    test_log = lax_test(lax.log, lambda: (rn(10, 10), ))
    test_exp = lax_test(lax.exp, lambda: (rn(10, 10), ))
    test_pow = lax_test(lax.pow, lambda:
                        (rn(10), jnp.arange(10, dtype=np.float32)))
    test_integer_pow = lax_test(lambda x: lax.integer_pow(x, 2), lambda:
                                (rn(10, 10), ))
    test_scalar_select_lt = lax_test(lambda i, x, y: lax.select(i < 2.0, x, y),
                                     lambda: (1.0, rn(10), rn(10)))

    test_squeeze_none = lax_test(lambda x: lax.squeeze(x, []), lambda:
                                 (rn(10, 10), ))
    test_squeeze_one = lax_test(lambda x: lax.squeeze(x, [1]), lambda:
                                (rn(10, 1, 10), ))
    test_squeeze_two = lax_test(lambda x: lax.squeeze(x, [0, 2]), lambda:
                                (rn(1, 10, 1), ))
    test_squeeze_all = lax_test(lambda x: lax.squeeze(x, [0, 1]), lambda:
                                (rn(1, 1), ))

    test_slice_1d = lax_test(lambda x: lax.slice(x, [2], [5], None), lambda:
                             (rn(10), ))
    test_slice_3d = lax_test(
        lambda x: lax.slice(x, [2, 0, 0], [5, 10, 2], None), lambda:
        (rn(10, 10, 2), ))

    test_concat_uniform = lax_test(partial(lax.concatenate, dimension=0),
                                   lambda: ([rn(4, 2) for _ in range(3)], ))
    test_concat_ragged = lax_test(
        partial(lax.concatenate, dimension=0), lambda:
        ([rn(1, 2, 4), rn(5, 2, 4), rn(2, 2, 4)], ))

    test_dot_general_matmul = lax_test(
        partial(lax.dot_general, dimension_numbers=(((1, ), (0, )), ((), ()))),
        lambda: (rn(4, 8), rn(8, 16)))
    test_dot_general_matvec = lax_test(
        partial(lax.dot_general, dimension_numbers=(((1, ), (0, )), ((), ()))),
        lambda: (rn(4, 8), rn(8)))

    def test_canonicalize_dtype(self):
        c = np.arange(5, dtype=np.float64)
        f = lambda x: x * c
        x = np.ones(5, dtype=np.float64)
        dy = dexjit(f)(x)
        jy = jax.jit(f)(x)
        np.testing.assert_allclose(dy, jy)
        self.assertEqual(dy.dtype, jy.dtype)
Exemplo n.º 4
0
Arquivo: jet.py Projeto: 0x0is1/jax
def _integer_pow_taylor(primals_in, series_in, *, y):
    if y == 0:
        return jet(jnp.ones_like, primals_in, series_in)
    elif y == 1:
        return jet(lambda x: x, primals_in, series_in)
    elif y == 2:
        return jet(lambda x: x * x, primals_in, series_in)
    x, = primals_in
    series, = series_in
    u = [x] + series
    v = [lax.integer_pow(x, y)] + [None] * len(series)
    for k in range(1, len(v)):
        vu = sum(_scale(k, j) * v[k - j] * u[j] for j in range(1, k + 1))
        uv = sum(_scale(k, j) * u[k - j] * v[j] for j in range(1, k))
        v[k] = jnp.where(x == 0, 0, fact(k - 1) * (y * vu - uv) / x)
    primal_out, *series_out = v

    return primal_out, series_out
Exemplo n.º 5
0
  def integer_pow(x): return lax.integer_pow(x, 3)

  # CHECK-LABEL: TEST: is_finite float64[]
  # CHECK: mhlo.is_finite
  # CHECK-SAME: tensor<f64>
  print_ir(np.float64(0))(lax.is_finite)
Exemplo n.º 6
0
def reciprocal(x):
    _check_arraylike("reciprocal", x)
    x, = _promote_dtypes_inexact(x)
    return lax.integer_pow(x, -1)
Exemplo n.º 7
0
def square(x):
    _check_arraylike("square", x)
    return lax.integer_pow(x, 2)
Exemplo n.º 8
0
 def integer_pow(x):
     return lax.integer_pow(x, 3)
Exemplo n.º 9
0
 def f(x):
     return lax.integer_pow(x, 2)