def test_unitary_to_rot_jax(self, U, expected_gate, expected_params):
        """Test that the transform works in the JAX interface."""
        jax = pytest.importorskip("jax")

        # Enable float64 support
        from jax.config import config

        remember = config.read("jax_enable_x64")
        config.update("jax_enable_x64", True)

        U = jax.numpy.array(U, dtype=jax.numpy.complex64)

        transformed_qfunc = unitary_to_rot(qfunc)

        ops = qml.transforms.make_tape(transformed_qfunc)(U).operations

        assert len(ops) == 3

        assert isinstance(ops[0], qml.Hadamard)
        assert ops[0].wires == Wires("a")

        assert isinstance(ops[1], expected_gate)
        assert ops[1].wires == Wires("a")
        assert qml.math.allclose([jax.numpy.asarray(x) for x in ops[1].parameters], expected_params)

        assert isinstance(ops[2], qml.CNOT)
        assert ops[2].wires == Wires(["b", "a"])
    def test_commute_controlled_jax(self):
        """Test QNode and gradient in JAX interface."""
        jax = pytest.importorskip("jax")
        from jax import numpy as jnp

        from jax.config import config

        remember = config.read("jax_enable_x64")
        config.update("jax_enable_x64", True)

        original_qnode = qml.QNode(qfunc, dev, interface="jax")
        transformed_qnode = qml.QNode(transformed_qfunc, dev, interface="jax")

        input = jnp.array([0.3, 0.4], dtype=jnp.float64)

        # Check that the numerical output is the same
        assert qml.math.allclose(original_qnode(input),
                                 transformed_qnode(input))

        # Check that the gradient is the same
        assert qml.math.allclose(
            jax.grad(original_qnode)(input),
            jax.grad(transformed_qnode)(input))

        # Check operation list
        ops = transformed_qnode.qtape.operations
        compare_operation_lists(ops, expected_op_list, expected_wires_list)
Exemple #3
0
    def test_compile_jax(self, diff_method):
        """Test QNode and gradient in JAX interface."""
        jax = pytest.importorskip("jax")
        from jax import numpy as jnp

        from jax.config import config

        remember = config.read("jax_enable_x64")
        config.update("jax_enable_x64", True)

        original_qnode = qml.QNode(qfunc, dev, interface="jax", diff_method=diff_method)
        transformed_qnode = qml.QNode(
            transformed_qfunc, dev, interface="jax", diff_method=diff_method
        )

        x = jnp.array([0.1, 0.2, 0.3], dtype=jnp.float64)
        params = jnp.ones((2, 3), dtype=jnp.float64)

        # Check that the numerical output is the same
        assert qml.math.allclose(original_qnode(x, params), transformed_qnode(x, params))

        # Check that the gradient is the same
        assert qml.math.allclose(
            jax.grad(original_qnode, argnums=(1))(x, params),
            jax.grad(transformed_qnode, argnums=(1))(x, params),
            atol=1e-7,
        )

        # Check operation list
        ops = transformed_qnode.qtape.operations
        compare_operation_lists(ops, expected_op_list, expected_wires_list)
Exemple #4
0
    def test_zyz_decomposition_jax(self, U, expected_gate, expected_params):
        """Test that a one-qubit operation in JAX is correctly decomposed."""
        jax = pytest.importorskip("jax")

        # Enable float64 support
        from jax.config import config

        remember = config.read("jax_enable_x64")
        config.update("jax_enable_x64", True)

        U = jax.numpy.array(U, dtype=jax.numpy.complex128)

        obtained_gates = zyz_decomposition(U, wire="a")

        assert len(obtained_gates) == 1
        assert isinstance(obtained_gates[0], expected_gate)
        assert obtained_gates[0].wires == Wires("a")

        assert qml.math.allclose(qml.math.unwrap(obtained_gates[0].parameters),
                                 expected_params)

        if obtained_gates[0].num_params == 1:
            obtained_mat = qml.RZ(obtained_gates[0].parameters[0],
                                  wires=0).matrix
        else:
            obtained_mat = qml.Rot(*obtained_gates[0].parameters,
                                   wires=0).matrix

        assert check_matrix_equivalence(obtained_mat, U)
Exemple #5
0
    def __init__(self, wires, *, shots=None, prng_key=None, analytic=None):
        if jax_config.read("jax_enable_x64"):
            self.C_DTYPE = jnp.complex128
            self.R_DTYPE = jnp.float64
        else:
            self.C_DTYPE = jnp.complex64
            self.R_DTYPE = jnp.float32
        super().__init__(wires, shots=shots, cache=0, analytic=analytic)

        # prevent using special apply methods for these gates due to slowdown in jax
        # implementation
        del self._apply_ops["PauliY"]
        del self._apply_ops["Hadamard"]
        del self._apply_ops["CZ"]
        self._prng_key = prng_key
Exemple #6
0
  def testStaticShapeErrors(self):
    if config.read("jax_disable_jit"):
      raise SkipTest("test only relevant when jit enabled")

    @api.jit
    def feature_map(n, d, sigma=1.0, seed=123):
      key = random.PRNGKey(seed)
      W = random.normal(key, (d, n)) / sigma
      w = random.normal(key, (d, )) / sigma
      b = 2 * jnp.pi * random.uniform(key, (d, ))

      phi = lambda x, t: jnp.sqrt(2.0 / d) * jnp.cos(jnp.matmul(W, x) + w*t + b)
      return phi

    self.assertRaisesRegex(TypeError, 'Shapes must be 1D.*',
                           lambda: feature_map(5, 3))
    def test_gradient_unitary_to_rot_jax(self, rot_angles, diff_method):
        """Tests differentiability in jax interface."""
        jax = pytest.importorskip("jax")
        from jax import numpy as jnp

        # Enable float64 support
        from jax.config import config

        remember = config.read("jax_enable_x64")
        config.update("jax_enable_x64", True)

        def qfunc_with_qubit_unitary(angles):
            z = angles[0]
            x = angles[1]

            Z_mat = jnp.array([[jnp.exp(-1j * z / 2), 0.0], [0.0, jnp.exp(1j * z / 2)]])

            c = jnp.cos(x / 2)
            s = jnp.sin(x / 2) * 1j
            X_mat = jnp.array([[c, -s], [-s, c]])

            qml.Hadamard(wires="a")
            qml.QubitUnitary(Z_mat, wires="a")
            qml.QubitUnitary(X_mat, wires="b")
            qml.CNOT(wires=["b", "a"])
            return qml.expval(qml.PauliX(wires="a"))

        # Setting the dtype to complex64 causes the gradients to be complex...
        input = jnp.array(rot_angles, dtype=jnp.float64)

        dev = qml.device("default.qubit", wires=["a", "b"])

        original_qnode = qml.QNode(
            original_qfunc_for_grad, dev, interface="jax", diff_method=diff_method
        )
        original_result = original_qnode(input)

        transformed_qfunc = unitary_to_rot(qfunc_with_qubit_unitary)
        transformed_qnode = qml.QNode(
            transformed_qfunc, dev, interface="jax", diff_method=diff_method
        )
        transformed_result = transformed_qnode(input)
        assert qml.math.allclose(original_result, transformed_result)

        original_grad = jax.grad(original_qnode)(input)
        transformed_grad = jax.grad(transformed_qnode)(input)
        assert qml.math.allclose(original_grad, transformed_grad, atol=1e-7)
Exemple #8
0
  def test_correctly_capture_default(self, jit, enable_or_disable):
    if jit == "cpp" and not config.omnistaging_enabled:
      self.skipTest("cpp_jit requires omnistaging")

    # The fact we defined a jitted function with a block with a different value
    # of `config.enable_x64` has no impact on the output.
    with enable_or_disable():
      func = _maybe_jit(jit, lambda: jnp.arange(10.0))
      func()

    expected_dtype = "float64" if config.read("jax_enable_x64") else "float32"
    self.assertEqual(func().dtype, expected_dtype)

    with enable_x64():
      self.assertEqual(func().dtype, "float64")
    with disable_x64():
      self.assertEqual(func().dtype, "float32")
Exemple #9
0
    def test_two_qubit_decomposition_jax(self, U, wires):
        """Test that a two-qubit operation in JAX is correctly decomposed."""
        jax = pytest.importorskip("jax")

        from jax.config import config

        remember = config.read("jax_enable_x64")
        config.update("jax_enable_x64", True)

        U = jax.numpy.array(U, dtype=jax.numpy.complex128)

        obtained_decomposition = two_qubit_decomposition(U, wires=wires)

        with qml.tape.QuantumTape() as tape:
            for op in obtained_decomposition:
                qml.apply(op)

        obtained_matrix = get_unitary_matrix(tape, wire_order=wires)()

        assert check_matrix_equivalence(U, obtained_matrix, atol=1e-7)
    def test_gradient_unitary_to_rot_two_qubit_jax(self):
        """Tests differentiability in jax interface."""
        jax = pytest.importorskip("jax")
        from jax import numpy as jnp

        from jax.config import config

        remember = config.read("jax_enable_x64")
        config.update("jax_enable_x64", True)

        U0 = jnp.array(test_two_qubit_unitaries[0], dtype=jnp.complex128)
        U1 = jnp.array(test_two_qubit_unitaries[1], dtype=jnp.complex128)

        def two_qubit_decomp_qnode(x):
            qml.RX(x, wires=0)
            qml.QubitUnitary(U0, wires=[0, 1])
            qml.QubitUnitary(U1, wires=[1, 2])
            return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1) @ qml.PauliZ(2))

        x = jnp.array(0.1, dtype=jnp.float64)

        dev = qml.device("default.qubit", wires=3)

        original_qnode = qml.QNode(
            two_qubit_decomp_qnode, device=dev, interface="jax", diff_method="backprop"
        )

        transformed_qfunc = unitary_to_rot(two_qubit_decomp_qnode)
        transformed_qnode = qml.QNode(
            transformed_qfunc, dev, interface="jax", diff_method="backprop"
        )

        assert qml.math.allclose(original_qnode(x), transformed_qnode(x))

        # 3 normal operations + 10 for the first decomp and 2 for the second
        assert len(transformed_qnode.qtape.operations) == 13

        original_grad = jax.grad(original_qnode, argnums=(0))(x)
        transformed_grad = jax.grad(transformed_qnode, argnums=(0))(x)

        assert qml.math.allclose(original_grad, transformed_grad, atol=1e-6)
Exemple #11
0
    def test_coefficients_jax_interface(self):
        """Test that coefficients are correctly computed when using the JAX interface."""
        jax = pytest.importorskip("jax")

        # Need to enable float64 support
        from jax.config import config

        remember = config.read("jax_enable_x64")
        config.update("jax_enable_x64", True)

        qnode = qml.QNode(self.circuit,
                          self.dev,
                          interface="jax",
                          diff_method="parameter-shift")

        weights = jax.numpy.array([0.5, 0.2])

        obtained_result = coefficients(partial(qnode, weights), 2, 1)

        assert np.allclose(obtained_result, self.expected_result)

        config.update("jax_enable_x64", remember)
    def test_merge_amplitude_embedding_jax(self):
        """Test QNode in JAX interface."""
        jax = pytest.importorskip("jax")
        from jax import numpy as jnp

        from jax.config import config

        remember = config.read("jax_enable_x64")
        config.update("jax_enable_x64", True)

        def qfunc(amplitude):
            qml.AmplitudeEmbedding(amplitude, wires=0)
            qml.AmplitudeEmbedding(amplitude, wires=1)
            return qml.state()

        dev = qml.device("default.qubit", wires=2)
        optimized_qfunc = qml.transforms.merge_amplitude_embedding(qfunc)
        optimized_qnode = qml.QNode(optimized_qfunc, dev, interface="jax")

        amplitude = jnp.array([0.0, 1.0], dtype=jnp.float64)
        # Check the state |11> is being generated.
        assert optimized_qnode(amplitude)[-1] == 1
Exemple #13
0
def test_get_unitary_matrix_interface_jax():
    """Test with JAX interface"""

    jax = pytest.importorskip("jax")
    from jax import numpy as jnp
    from jax.config import config

    remember = config.read("jax_enable_x64")
    config.update("jax_enable_x64", True)

    dev = qml.device("default.qubit", wires=3)

    def circuit(theta):
        qml.RZ(theta[0], wires=0)
        qml.RZ(theta[1], wires=1)
        qml.CRY(theta[2], wires=[1, 2])
        return qml.expval(qml.PauliZ(1))

    # set qnode interface
    qnode = qml.QNode(circuit, dev, interface="jax")

    get_matrix = get_unitary_matrix(qnode)

    # input jax parameters
    theta = jnp.array([0.1, 0.2, 0.3], dtype=jnp.float64)

    matrix = get_matrix(theta)

    # expected matrix
    matrix1 = np.kron(
        qml.RZ(theta[0], wires=0).matrix, np.kron(qml.RZ(theta[1], wires=1).matrix, I)
    )
    matrix2 = np.kron(I, qml.CRY(theta[2], wires=[1, 2]).matrix)
    expected_matrix = matrix2 @ matrix1

    assert np.allclose(matrix, expected_matrix)
from neural_tangents import predict
from neural_tangents import stax
from neural_tangents.utils import batch
from neural_tangents.utils import empirical

config.parse_flags_with_absl()

MATRIX_SHAPES = [(3, 3), (4, 4)]
OUTPUT_LOGITS = [1, 2, 3]

GETS = ('ntk', 'nngp', ('ntk', 'nngp'))

RTOL = 0.1
ATOL = 0.1

if not config.read('jax_enable_x64'):
    RTOL = 0.2
    ATOL = 0.2

FLAT = 'FLAT'
POOLING = 'POOLING'

# TODO(schsam): Add a pooling test when multiple inputs are supported in
# Conv + Pooling.
TRAIN_SHAPES = [(4, 4), (4, 8), (8, 8), (6, 4, 4, 3)]
TEST_SHAPES = [(2, 4), (6, 8), (16, 8), (2, 4, 4, 3)]
NETWORK = [FLAT, FLAT, FLAT, FLAT]
OUTPUT_LOGITS = [1, 2, 3]

CONVOLUTION_CHANNELS = 256
Exemple #15
0
 def setUp(self):
     self.cfg = config.read("jax_debug_nans")
     config.update("jax_debug_nans", True)
    def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
                                   out_logits):
        if xla_bridge.get_backend().platform == 'gpu' and config.read(
                'jax_enable_x64'):
            raise jtu.SkipTest('Not running GPU x64 to save time.')
        training_steps = 5000
        learning_rate = 1.0
        ensemble_size = 50

        init_fn, apply_fn, ker_fn = stax.serial(
            stax.Dense(1024, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        opt_update = jit(opt_update)

        key = random.PRNGKey(0)
        key, = random.split(key, 1)

        key, split = random.split(key)
        x_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)
        train = (x_train, y_train)
        key, split = random.split(key)
        x_test = np.cos(random.normal(split, test_shape))

        ensemble_key = random.split(key, ensemble_size)

        loss = jit(lambda params, x, y: 0.5 * np.mean(
            (apply_fn(params, x) - y)**2))
        grad_loss = jit(lambda state, x, y: grad(loss)
                        (get_params(state), x, y))

        def train_network(key):
            _, params = init_fn(key, (-1, ) + train_shape[1:])
            opt_state = opt_init(params)
            for i in range(training_steps):
                opt_state = opt_update(i, grad_loss(opt_state, *train),
                                       opt_state)

            return get_params(opt_state)

        params = vmap(train_network)(ensemble_key)

        ensemble_fx = vmap(apply_fn, (0, None))(params, x_test)
        ensemble_loss = vmap(loss, (0, None, None))(params, x_train, y_train)
        ensemble_loss = np.mean(ensemble_loss)
        self.assertLess(ensemble_loss, 1e-5, True)

        mean_emp = np.mean(ensemble_fx, axis=0)
        mean_subtracted = ensemble_fx - mean_emp
        cov_emp = np.einsum(
            'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (
                mean_subtracted.shape[0] * mean_subtracted.shape[-1])

        reg = 1e-7
        ntk_predictions = predict.gp_inference(ker_fn,
                                               x_train,
                                               y_train,
                                               x_test,
                                               'ntk',
                                               reg,
                                               compute_cov=True)

        self.assertAllClose(mean_emp, ntk_predictions.mean, True, RTOL, ATOL)
        self.assertAllClose(cov_emp, ntk_predictions.covariance, True, RTOL,
                            ATOL)
Exemple #17
0
 def test_float64(self, module_fn: ModuleFn, shape, dtype):
     self.assertTrue(config.read("jax_enable_x64"))
     self.assert_dtype(jnp.float64, module_fn, shape, dtype)
# -*- coding: utf-8 -*-

import logging

from jax.config import config

logger = logging.getLogger(__name__)

if not config.read("jax_enable_x64"):
    logger.warning(
        "exoplanet_core.jax only works with dtype float64. "
        "We're enabling x64 now, but you might run into issues if you've "
        "already run some jax code.\n"
        "You can squash this warning by setting the environment variable "
        "'JAX_ENABLE_X64=True' or by running:\n"
        ">>> from jax.config import config\n"
        ">>> config.update('jax_enable_x64', True)")
    config.update("jax_enable_x64", True)

__all__ = ["ops"]

from . import ops  # noqa isort:skip