Beispiel #1
0
    def test_spectral_dac_svd(self, linear_size, seed, dtype):
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            if jtu.device_under_test() != "cpu":
                raise unittest.SkipTest("Skip half precision off CPU.")

        rng = self.rng()
        A = rng.randn(linear_size, linear_size).astype(dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError, jax._src.scipy.eigh.svd, A)
            return
        S_expected = np.linalg.svd(A, compute_uv=False)
        U, S, V = jax._src.scipy.eigh.svd(A)
        recon = jnp.dot((U * jnp.expand_dims(S, 0)),
                        V,
                        precision=lax.Precision.HIGHEST)
        eps = jnp.finfo(dtype).eps
        eps = eps * jnp.linalg.norm(A) * 15
        self.assertAllClose(np.sort(S), np.sort(S_expected), atol=eps)
        self.assertAllClose(A, recon, atol=eps)

        # U is unitary.
        u_unitary_delta = jnp.dot(U.conj().T,
                                  U,
                                  precision=lax.Precision.HIGHEST)
        u_eye = jnp.eye(u_unitary_delta.shape[0], dtype=dtype)
        self.assertAllClose(u_unitary_delta, u_eye, atol=eps)

        # V is unitary.
        v_unitary_delta = jnp.dot(V.conj().T,
                                  V,
                                  precision=lax.Precision.HIGHEST)
        v_eye = jnp.eye(v_unitary_delta.shape[0], dtype=dtype)
        self.assertAllClose(v_unitary_delta, v_eye, atol=eps)
Beispiel #2
0
def _contact_points_translation_rule(c, *args):
    shapes = list(map(c.get_shape, args))
    if any(shape.element_type() != np.float64 for shape in shapes):
        raise ValueError("float64 precision is required")

    dims = tuple(shapes[0].dimensions())
    if any(dims != s.dimensions() for s in shapes):
        raise ValueError("Dimension mismatch")
    N = np.prod(dims).astype(np.int32)

    order = tuple(range(len(dims) - 1, -1, -1))
    shape = xla_client.Shape.array_shape(jnp.dtype(np.float64), dims, order)

    return xops.CustomCallWithLayout(
        c,
        b"contact_points",
        operands=(xops.ConstantLiteral(c, N), ) + args,
        shape_with_layout=xla_client.Shape.tuple_shape((
            shape,
            shape,
            xla_client.Shape.array_shape(jnp.dtype(np.int32), dims, order),
        )),
        operand_shapes_with_layout=(xla_client.Shape.array_shape(
            jnp.dtype(jnp.int32), (), ()), ) + tuple(shape for _ in args),
    )
Beispiel #3
0
  def test_staging_nested_including_shape_arg(self):
    # This test covers the _get_tracers_only_in_shapes logic in partial_eval.py.
    n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
    a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
    b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)

    @lu.wrap_init
    def f(x, y):
      @jax.jit
      def g(_, x, y, z, w):
        return (x, w)
      return g(x.shape[0], x, y, x, y)

    jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
                                            keep_inputs=[False, True, True])

    self.assertLen(jaxpr.eqns, 1)
    eqn = jaxpr.eqns[0]
    self.assertIsInstance(eqn.primitive, core.CallPrimitive)
    inner_jaxpr = eqn.params['call_jaxpr']
    self.assertIsInstance(inner_jaxpr, core.Jaxpr)

    self.assertLen(inner_jaxpr.invars, 1 + 4)  # one axis size var
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
Beispiel #4
0
    def test_staging_primitive_applications(self):
        n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(x, y):
            z = lax.mul(x, y)
            w = lax.sin(z)
            u = lax_internal._reduce_sum(w, [0])
            return (u, )

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, a, b], keep_inputs=[False, True, True])

        self.assertLen(jaxpr.invars,
                       1 + 2)  # one axis size var, two other inputs
        self.assertLen(jaxpr.eqns, 3)
        self.assertLen(jaxpr.eqns[0].outvars, 1)
        self.assertEqual(jaxpr.eqns[0].outvars[0].aval.shape,
                         jaxpr.invars[1].aval.shape)

        self.assertLen(jaxpr.outvars, 1)
        self.assertEqual(jaxpr.outvars[0].aval.shape, ())
Beispiel #5
0
def _translation_rule(name, spec, c, *args):
    shapes = tuple(c.get_shape(arg) for arg in args)
    dims = OrderedDict(
        (s["name"], shapes[s["coords"][0]].dimensions()[s["coords"][1]])
        for s in spec["dimensions"])
    if any(shape.element_type() != np.float64 for shape in shapes):
        raise ValueError(f"{spec['name']} requires float64 precision")

    return xops.CustomCallWithLayout(
        c,
        name,
        operands=tuple(
            xops.ConstantLiteral(c, np.int32(v))
            for v in dims.values()) + args,
        shape_with_layout=xla_client.Shape.tuple_shape(
            tuple(
                xla_client.Shape.array_shape(
                    jnp.dtype(np.float64),
                    tuple(dims[k] for k in s["shape"]),
                    tuple(range(len(s["shape"]) - 1, -1, -1)),
                ) for s in spec["outputs"] + spec["extra_outputs"])),
        operand_shapes_with_layout=tuple(
            xla_client.Shape.array_shape(jnp.dtype(jnp.int32), (), ())
            for _ in range(len(dims))) + tuple(
                xla_client.Shape.array_shape(
                    jnp.dtype(np.float64),
                    tuple(dims[k] for k in s["shape"]),
                    tuple(range(len(s["shape"]) - 1, -1, -1)),
                ) for s in spec["inputs"]),
    )
Beispiel #6
0
 def testArrayCasts(self):
     for t in [
             jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64
     ]:
         a = np.array([1, 2.5, -3.7])
         self.assertEqual(a.astype(t).dtype, jnp.dtype(t))
         self.assertEqual(jnp.array(a).astype(t).dtype, jnp.dtype(t))
Beispiel #7
0
def _translation_rule(spec, c, *args):
    vals = spec["get_dims"](*(c.get_shape(a).dimensions() for a in args))
    dtype = c.get_shape(args[0]).element_type()
    if dtype != np.float64:
        raise ValueError("Invalid dtype; must be float64")

    return xops.CustomCallWithLayout(
        c,
        spec["xla_name"],
        operands=tuple(
            xops.ConstantLiteral(c, np.int32(v))
            for v in vals.values()) + args,
        shape_with_layout=xla_client.Shape.tuple_shape(
            tuple(
                xla_client.Shape.array_shape(
                    jnp.dtype(dtype),
                    shape,
                    tuple(range(len(shape) - 1, -1, -1)),
                ) for dtype, shape in ((
                    s.get("dtype", np.float64),
                    eval(s["shape"], dict(vals)),
                ) for s in spec["outputs"] + spec["extra_outputs"]))),
        operand_shapes_with_layout=tuple(
            xla_client.Shape.array_shape(jnp.dtype(jnp.int32), (), ())
            for _ in range(len(vals))) + tuple(
                xla_client.Shape.array_shape(
                    jnp.dtype(dtype),
                    shape,
                    tuple(range(len(shape) - 1, -1, -1)),
                ) for dtype, shape in (
                    (s.get("dtype", np.float64), eval(s["shape"], dict(vals)))
                    for s in spec["inputs"])),
    )
Beispiel #8
0
 def check_shapes(c_in, c_out):
     if not isinstance(c_in, jnp.ndarray) or not isinstance(
             c_out, jnp.ndarray):
         return
     if jnp.shape(c_in) != jnp.shape(c_out) or jnp.dtype(
             c_in) != jnp.dtype(c_out):
         raise ValueError()
Beispiel #9
0
  def test_staging_nested(self):
    n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
    a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
    b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)

    @lu.wrap_init
    def f(x, y):
      @jax.jit
      def g(x, y, z, w):
        return (x, w)
      return g(x, y, x, y)

    jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
                                            keep_inputs=[False, True, True])

    self.assertLen(jaxpr.invars, 1 + 2)  # one axis size var, two other inputs
    self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
    self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape)

    self.assertLen(jaxpr.outvars, 2)
    self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape)
    self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape)

    self.assertLen(jaxpr.eqns, 1)
    eqn = jaxpr.eqns[0]
    self.assertIsInstance(eqn.primitive, core.CallPrimitive)
    inner_jaxpr = eqn.params['call_jaxpr']
    self.assertIsInstance(inner_jaxpr, core.Jaxpr)

    self.assertLen(inner_jaxpr.invars, 1 + 4)  # one axis size var
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
Beispiel #10
0
def _von_mises_centered(key, concentration, shape, dtype):
    # Cutoff from TensorFlow probability
    # (https://github.com/tensorflow/probability/blob/f051e03dd3cc847d31061803c2b31c564562a993/tensorflow_probability/python/distributions/von_mises.py#L567-L570)
    s_cutoff_map = {
        jnp.dtype(jnp.float16): 1.8e-1,
        jnp.dtype(jnp.float32): 2e-2,
        jnp.dtype(jnp.float64): 1.2e-4,
    }
    s_cutoff = s_cutoff_map.get(dtype)

    r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration**2)
    rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
    s_exact = (1.0 + rho**2) / (2.0 * rho)

    s_approximate = 1.0 / concentration

    s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)

    def cond_fn(*args):
        """ check if all are done or reached max number of iterations """
        i, _, done, _, _ = args[0]
        return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))

    def body_fn(*args):
        i, key, done, _, w = args[0]
        uni_ukey, uni_vkey, key = random.split(key, 3)

        u = random.uniform(
            key=uni_ukey,
            shape=shape,
            dtype=concentration.dtype,
            minval=-1.0,
            maxval=1.0,
        )
        z = jnp.cos(jnp.pi * u)
        w = jnp.where(done, w,
                      (1.0 + s * z) / (s + z))  # Update where not done

        y = concentration * (s - w)
        v = random.uniform(key=uni_vkey,
                           shape=shape,
                           dtype=concentration.dtype)

        accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)

        return i + 1, key, accept | done, u, w

    init_done = jnp.zeros(shape, dtype=bool)
    init_u = jnp.zeros(shape)
    init_w = jnp.zeros(shape)

    _, _, done, u, w = lax.while_loop(
        cond_fun=cond_fn,
        body_fun=body_fn,
        init_val=(jnp.array(0), key, init_done, init_u, init_w),
    )

    return jnp.sign(u) * jnp.arccos(w)
Beispiel #11
0
    def testPolar(self, n_zero_sv, degeneracy, geometric_spectrum, max_sv,
                  shape, method, side, nonzero_condition_number, dtype, seed):
        """ Tests jax.scipy.linalg.polar."""
        if jtu.device_under_test() != "cpu":
            if jnp.dtype(dtype).name in ("bfloat16", "float16"):
                raise unittest.SkipTest("Skip half precision off CPU.")

        m, n = shape
        if (method == "qdwh" and ((side == "left" and m >= n) or
                                  (side == "right" and m < n))):
            raise unittest.SkipTest("method=qdwh does not support these sizes")

        matrix, _ = _initialize_polar_test(self.rng(), shape, n_zero_sv,
                                           degeneracy, geometric_spectrum,
                                           max_sv, nonzero_condition_number,
                                           dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError,
                              jsp.linalg.polar,
                              matrix,
                              method=method,
                              side=side)
            return

        unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side)
        if shape[0] >= shape[1]:
            should_be_eye = np.matmul(unitary.conj().T, unitary)
        else:
            should_be_eye = np.matmul(unitary, unitary.conj().T)
        tol = 500 * float(jnp.finfo(matrix.dtype).eps)
        eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
        with self.subTest('Test unitarity.'):
            self.assertAllClose(eye_mat, should_be_eye, atol=tol * min(shape))

        with self.subTest('Test Hermiticity.'):
            self.assertAllClose(posdef,
                                posdef.conj().T,
                                atol=tol * jnp.linalg.norm(posdef))

        ev, _ = np.linalg.eigh(posdef)
        ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)]
        negative_ev = jnp.sum(ev < 0.)
        with self.subTest('Test positive definiteness.'):
            self.assertEqual(negative_ev, 0)

        if side == "right":
            recon = jnp.matmul(unitary,
                               posdef,
                               precision=lax.Precision.HIGHEST)
        elif side == "left":
            recon = jnp.matmul(posdef,
                               unitary,
                               precision=lax.Precision.HIGHEST)
        with self.subTest('Test reconstruction.'):
            self.assertAllClose(matrix,
                                recon,
                                atol=tol * jnp.linalg.norm(matrix))
Beispiel #12
0
    def test_typecheck_staging_nested(self):
        n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
        m = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((DBIdx(1), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(a, b):
            @jax.jit
            def g(x):
                return x

            return g(a),

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, m, a, b], keep_inputs=[False, False, True, True])
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a c
        #   in (e,) }
        core.check_jaxpr(jaxpr)  # no problems here...

        # Let's introduce a type error by applying the called jaxpr to arguments
        # with types which aren't consistent with its input binders:
        _, _, c, d = jaxpr.invars
        jaxpr.eqns[0].invars[1] = d
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a d   !!! type error here !!!
        #   in (e,) }
        with self.assertRaisesRegex(TypeError, "passes operand"):
            core.check_jaxpr(jaxpr)

        # Restore the original jaxpr:
        jaxpr.eqns[0].invars[1] = c
        core.check_jaxpr(jaxpr)  # no problems here...

        # Let's introduce another type error by setting the call result let binders
        # to have the wrong type:
        jaxpr.eqns[0].outvars[0] = core.Var(0, '', d.aval)
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[b] = xla_call[   !!! type error here !!!
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a c
        #   in (h,) }
        with self.assertRaisesRegex(TypeError, "inconsistently typed as"):
            core.check_jaxpr(jaxpr)
Beispiel #13
0
def _threefry2x32_abstract_eval(*args):
  if any(a.dtype != jnp.uint32 for a in args):
    raise TypeError("Arguments to threefry2x32 must have uint32 type, got {}"
                    .format(args))
  if all(isinstance(arg, core.ShapedArray) for arg in args):
    shape = lax._broadcasting_shape_rule(*args)
    named_shape = core.join_named_shapes(*(a.named_shape for a in args))
    aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape)
  else:
    aval = core.UnshapedArray(jnp.dtype(jnp.uint32))
  return (aval,) * 2
Beispiel #14
0
def get_epsilon(dtype: jnp.dtype) -> float:
    """Helper for grabbing type-specific precision constants.

    Args:
        dtype: Datatype.

    Returns:
        Output float.
    """
    return {
        jnp.dtype("float32"): 1e-5,
        jnp.dtype("float64"): 1e-10,
    }[dtype]
    def __init__(self, dtype, shape):
        """Initialize this LinearOperator.
        To be called by subclasses. ``dtype`` may be None; ``shape`` should
        be convertible to a length-2 tuple.
        """
        if dtype is not None:
            dtype = np.dtype(dtype)

        shape = tuple(shape)
        if not isshape(shape):
            raise ValueError("invalid shape %r (must be 2-d)" % (shape, ))

        self.dtype = np.dtype('float32')  #force float 32
        self.shape = shape
Beispiel #16
0
    def testMultivariateNormal(self, dim, dtype):
        if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
            raise SkipTest(
                "random.multivariate_normal() not supported on TPU for 16-bit types."
            )
        r = np.random.RandomState(dim)
        mean = r.randn(dim)
        cov_factor = r.randn(dim, dim)
        cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim)

        key = random.PRNGKey(0)
        rand = partial(random.multivariate_normal,
                       mean=mean,
                       cov=cov,
                       shape=(10000, ))
        crand = api.jit(rand)

        uncompiled_samples = np.asarray(rand(key), np.float64)
        compiled_samples = np.asarray(crand(key), np.float64)

        inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov),
                                               lower=True)[0]
        for samples in [uncompiled_samples, compiled_samples]:
            centered = samples - mean
            whitened = np.einsum('nj,ij->ni', centered, inv_scale)

            # This is a quick-and-dirty multivariate normality check that tests that a
            # uniform mixture of the marginals along the covariance matrix's
            # eigenvectors follow a standard normal distribution.
            self._CheckKolmogorovSmirnovCDF(whitened.ravel(),
                                            scipy.stats.norm().cdf)
 def _init_dtype(self):
     """Called from subclasses at the end of the __init__ routine.
     """
     if self.dtype is None:
         #v = np.zeros(self.shape[-1])
         self.dtype = np.dtype(
             'float32')  #self.matvec(v).dtype #force float 32
Beispiel #18
0
def _use_cholesky(u, params):
    """Uses Cholesky decomposition."""
    a, b, c = params
    _, n = u.shape
    x = c * (u.T.conj() @ u) + jnp.eye(n, dtype=jnp.dtype(u))

    # `y` is lower triangular.
    y = lax_linalg.cholesky(x, symmetrize_input=False)

    z = lax_linalg.triangular_solve(y,
                                    u.T,
                                    left_side=True,
                                    lower=True,
                                    conjugate_a=True).conj()

    z = lax_linalg.triangular_solve(y,
                                    z,
                                    left_side=True,
                                    lower=True,
                                    transpose_a=True,
                                    conjugate_a=True).T.conj()

    e = b / c
    u = e * u + (a - e) * z
    return u
Beispiel #19
0
def _use_cholesky(u, m, n, params):
    """QDWH iteration using Cholesky decomposition.

  Args:
  u: a matrix, with static (padded) shape M x N
  m, n: the dynamic shape of the matrix, where m <= M and n <= N.
  params: the QDWH parameters.
  """
    a, b, c = params
    _, N = u.shape
    x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=jnp.dtype(u))
    # Pads the lower-right corner with the identity matrix to prevent the Cholesky
    # decomposition from failing due to the matrix not being PSD if padded with
    # zeros.
    x = _mask(x, (n, n), jnp.eye(N, dtype=x.dtype))

    # `y` is lower triangular.
    y = lax_linalg.cholesky(x, symmetrize_input=False)

    z = lax_linalg.triangular_solve(y,
                                    u.T,
                                    left_side=True,
                                    lower=True,
                                    conjugate_a=True).conj()

    z = lax_linalg.triangular_solve(y,
                                    z,
                                    left_side=True,
                                    lower=True,
                                    transpose_a=True,
                                    conjugate_a=True).T.conj()

    e = b / c
    u = e * u + (a - e) * z
    return u
Beispiel #20
0
    def testInitializerProvider(self, initializer_provider, shape, dtype):
        rng = random.PRNGKey(0)
        initializer = initializer_provider(dtype=dtype)
        val = initializer(rng, shape)

        self.assertEqual(shape, jnp.shape(val))
        self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val))
Beispiel #21
0
def _rbg_random_bits(key: jnp.ndarray, bit_width: int,
                     shape: Sequence[int]) -> jnp.ndarray:
    if not key.shape == (4, ) and key.dtype == jnp.dtype('uint32'):
        raise TypeError("_rbg_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    _, bits = lax.rng_bit_generator(key, shape, dtype=UINT_DTYPES[bit_width])
    return bits
Beispiel #22
0
 def random_variable(rng, size, dtype, *args):
     prng = jax.random.PRNGKey(rng["state"]["key"][0])
     dtype = jnp.dtype(dtype)
     data = getattr(jax.random, name)(key=prng, shape=size)
     smpl_value = jnp.array(data, dtype=dtype)
     prng = jax.random.split(prng, num=1)[0]
     jax.ops.index_update(rng["state"]["key"], 0, prng[0])
     return (rng, smpl_value)
Beispiel #23
0
 def testScalarInstantiation(self):
     for t in [
             jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64
     ]:
         a = t(1)
         self.assertEqual(a.dtype, jnp.dtype(t))
         self.assertIsInstance(a, xla.DeviceArray)
         self.assertEqual(0, jnp.ndim(a))
Beispiel #24
0
    def test_staging_nested_including_shape_arg(self):
        n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(x, y):
            @jax.jit
            def g(_, x, y, z, w):
                return (x, w)

            return g(x.shape[0], x, y, x, y)

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, a, b], keep_inputs=[False, True, True])
        print(jaxpr)

        # { lambda ; a:i32[] b:f32[a] c:f32[a]. let
        #     d:f32[a] e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f] i:f32[f] j:f32[f] k:f32[f]. let
        #
        #         in (h, k) }
        #       name=g
        #     ] a a b c b c
        #   in (d, e) }

        self.assertLen(jaxpr.eqns, 1)
        eqn = jaxpr.eqns[0]
        self.assertIsInstance(eqn.primitive, core.CallPrimitive)
        inner_jaxpr = eqn.params['call_jaxpr']
        self.assertIsInstance(inner_jaxpr, core.Jaxpr)

        self.assertLen(inner_jaxpr.invars, 1 + 4)  # one axis size var
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[1].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[2].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[3].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[4].aval.shape)
def multiply_add_xla_translation(c, xc, yc, zc):
    """The compilation to XLA of the primitive.

    Given an XlaBuilder and XlaOps for each argument, return the XlaOp for the
    result of the function.

    Does not need to be a JAX-traceable function.
    """
    return c.CustomCall(
        b'multiply_add_f32',
        operands=(xc, yc, zc),
        shape_with_layout=xla_client.Shape.array_shape(jnp.dtype(jnp.float32),
                                                       (), ()),
        operand_shapes_with_layout=(xla_client.Shape.array_shape(
            jnp.dtype(jnp.float32), (),
            ()), xla_client.Shape.array_shape(jnp.dtype(jnp.float32), (), ()),
                                    xla_client.Shape.array_shape(
                                        jnp.dtype(jnp.float32), (), ())))
Beispiel #26
0
  def test_staging_basic(self):
    n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
    a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
    b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)

    @lu.wrap_init
    def f(x, y):
      return x, y

    jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
                                            keep_inputs=[False, True, True])

    self.assertLen(jaxpr.invars, 3)
    self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
    self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape)

    self.assertLen(jaxpr.outvars, 2)
    self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape)
    self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape)
Beispiel #27
0
    def test_dump_dict_json(self):
        """Tests JSON dumping function."""
        data_dict = {
            'np_float': np.dtype('float32').type(1.0),
            'jnp_float': jnp.dtype('float32').type(1.0),
            'np_int': np.dtype('int32').type(1),
            'jnp_int': jnp.dtype('int32').type(1),
            'np_array': np.array(1.0, dtype=np.float32),
            'jnp_array': jnp.array(1.0, dtype=jnp.float32),
        }
        converted_dict = {
            key: utils._np_converter(value)
            for key, value in data_dict.items()
        }
        json_path = tempfile.NamedTemporaryFile()
        utils.dump_dict_json(data_dict, json_path.name)

        with open(json_path.name, 'r') as input_file:
            loaded_dict = json.load(input_file)
        self.assertDictEqual(loaded_dict, converted_dict)
Beispiel #28
0
 def __init__(self, val, dtype=None):
   if isinstance(val, tuple):
     head, tail = val
   elif isinstance(val, str):
     dtype = jnp.dtype(dtype or 'float64').type
     val = decimal.Decimal(val)
     head = dtype(val)
     tail = 0 if np.isinf(head) else dtype(val - decimal.Decimal(float(head)))
   elif isinstance(val, int):
     dtype = jnp.dtype(dtype or 'float64').type
     head = dtype(val)
     tail = 0 if np.isinf(head) else dtype(val - int(head))
   elif isinstance(val, _DoubleDouble):
     head, tail = val.head, val.tail
   else:
     head, tail = val, jnp.zeros_like(val)
   dtype = dtype or jnp.result_type(head, tail)
   head = jnp.asarray(head, dtype=dtype)
   tail = jnp.asarray(tail, dtype=dtype)
   self.head, self.tail = _normalize(head, tail)
Beispiel #29
0
def _use_qr(u, params):
    """Uses QR decomposition."""
    a, b, c = params
    m, n = u.shape
    y = jnp.concatenate([jnp.sqrt(c) * u, jnp.eye(n, dtype=jnp.dtype(u))])
    q, _ = lax_linalg.qr(y, full_matrices=False)
    q1 = q[:m, :]
    q2 = (q[m:, :]).T.conj()
    e = b / c
    u = (e * u + (a - e) / jnp.sqrt(c) * jnp.einsum('ij,jk->ik', q1, q2))
    return u
Beispiel #30
0
    def test_spectral_dac_eigh(self, linear_size, seed, dtype):
        if jtu.device_under_test != "cpu":
            raise unittest.SkipTest("Skip eigh off CPU for now.")
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            if jtu.device_under_test() != "cpu":
                raise unittest.SkipTest("Skip half precision off CPU.")

        rng = self.rng()
        H = rng.randn(linear_size, linear_size)
        H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError, jax._src.scipy.eigh.eigh, H)
            return
        evs, V = jax._src.scipy.eigh.eigh(H)
        ev_exp, eV_exp = jnp.linalg.eigh(H)
        HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST)
        vV = evs * V
        eps = jnp.finfo(H.dtype).eps
        atol = jnp.linalg.norm(H) * eps
        self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol)
        self.assertAllClose(HV, vV, atol=30 * atol)