Example #1
0
    def test_jax(self):
        """Test that a jax array is automatically converted into
        a diagonal tensor"""
        t = jnp.array([0.1, 0.2, 0.3])
        res = fn.diag(t)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3]))

        res = fn.diag(t, k=1)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3], k=1))
Example #2
0
    def test_tensorflow(self):
        """Test that a tensorflow tensor is automatically converted into
        a diagonal tensor"""
        t = tf.Variable([0.1, 0.2, 0.3])
        res = fn.diag(t)
        assert isinstance(res, tf.Tensor)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3]))

        res = fn.diag(t, k=1)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3], k=1))
Example #3
0
    def test_array(self):
        """Test that a NumPy array is automatically converted into
        a diagonal tensor"""
        t = np.array([0.1, 0.2, 0.3])
        res = fn.diag(t)
        assert isinstance(res, np.ndarray)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3]))

        res = fn.diag(t, k=1)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3], k=1))
Example #4
0
    def test_torch(self):
        """Test that a torch tensor is automatically converted into
        a diagonal tensor"""
        t = torch.tensor([0.1, 0.2, 0.3])
        res = fn.diag(t)
        assert isinstance(res, torch.Tensor)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3]))

        res = fn.diag(t, k=1)
        assert fn.allclose(res, onp.diag([0.1, 0.2, 0.3], k=1))
Example #5
0
    def test_torch(self):
        """Test that a torch tensor is differentiable when using scatter addition"""
        x = torch.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], requires_grad=True)
        y = torch.tensor(0.56, requires_grad=True)

        res = fn.scatter_element_add(x, [1, 2], y ** 2)
        loss = res[1, 2]

        assert isinstance(res, torch.Tensor)
        assert fn.allclose(res.detach(), onp.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.3136]]))

        loss.backward()
        assert fn.allclose(x.grad, onp.array([[0, 0, 0], [0, 0, 1.0]]))
        assert fn.allclose(y.grad, 2 * y)
Example #6
0
    def test_jax(self):
        """Test that a JAX array is differentiable when using scatter addition"""
        x = jnp.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
        y = jnp.array(0.56)

        def cost(weights):
            return fn.scatter_element_add(weights[0], [1, 2], weights[1] ** 2)

        res = cost([x, y])
        assert isinstance(res, jax.interpreters.xla.DeviceArray)
        assert fn.allclose(res, onp.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.3136]]))

        grad = jax.grad(lambda weights: cost(weights)[1, 2])([x, y])
        assert fn.allclose(grad[0], onp.array([[0, 0, 0], [0, 0, 1.0]]))
        assert fn.allclose(grad[1], 2 * y)
Example #7
0
    def test_tensorflow(self):
        """Test that a TF tensor is differentiable when using scatter addition"""
        x = tf.Variable([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
        y = tf.Variable(0.56)

        with tf.GradientTape() as tape:
            res = fn.scatter_element_add(x, [1, 2], y ** 2)
            loss = res[1, 2]

        assert isinstance(res, tf.Tensor)
        assert fn.allclose(res, onp.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.3136]]))

        grad = tape.gradient(loss, [x, y])
        assert fn.allclose(grad[0], onp.array([[0, 0, 0], [0, 0, 1.0]]))
        assert fn.allclose(grad[1], 2 * y)
Example #8
0
    def test_array(self):
        """Test that a NumPy array is differentiable when using scatter addition"""
        x = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], requires_grad=True)
        y = np.array(0.56, requires_grad=True)

        def cost(weights):
            return fn.scatter_element_add(weights[0], [1, 2], weights[1] ** 2)

        res = cost([x, y])
        assert isinstance(res, np.ndarray)
        assert fn.allclose(res, onp.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.3136]]))

        grad = qml.grad(lambda weights: cost(weights)[1, 2])([x, y])
        assert fn.allclose(grad[0], onp.array([[0, 0, 0], [0, 0, 1.0]]))
        assert fn.allclose(grad[1], 2 * y)
Example #9
0
def compute_vjp(dy, jac):
    """Convenience function to compute the vector-Jacobian product for a given
    vector of gradient outputs and a Jacobian.

    Args:
        dy (tensor_like): vector of gradient outputs
        jac (tensor_like): Jacobian matrix. For an n-dimensional ``dy``
            vector, the first n-dimensions of ``jac`` should match
            the shape of ``dy``.

    Returns:
        tensor_like: the vector-Jacobian product
    """
    if jac is None:
        return None

    dy_row = math.reshape(dy, [-1])

    if not isinstance(dy_row, np.ndarray):
        jac = math.convert_like(jac, dy_row)

    jac = math.reshape(jac, [dy_row.shape[0], -1])

    try:
        if math.allclose(dy, 0):
            # If the dy vector is zero, then the
            # corresponding element of the VJP will be zero.
            num_params = jac.shape[1]
            return math.convert_like(np.zeros([num_params]), dy)
    except (AttributeError, TypeError):
        pass

    return math.tensordot(jac, dy_row, [[0], [0]])
Example #10
0
def test_block_diag(tensors):
    """Tests for the block diagonal function"""
    res = fn.block_diag(tensors)
    expected = np.array(
        [[1, 2, 0, 0, 0], [3, 4, 0, 0, 0], [0, 0, 1, 2, 0], [0, 0, -1, -6, 0], [0, 0, 0, 0, 5]]
    )
    assert fn.allclose(res, expected)
Example #11
0
 def test_sequence(self, a, interface):
     """Test that a sequence is automatically converted into
     a diagonal tensor"""
     t = [0.1, 0.2, a]
     res = fn.diag(t)
     assert fn.get_interface(res) == interface
     assert fn.allclose(res, onp.diag([0.1, 0.2, 0.5]))
Example #12
0
    def test_extra_width(self):
        """tests a box with added width."""

        drawer = MPLDrawer(1, 1)
        drawer.box_gate(0, 0, text="Wide Gate", extra_width=0.4)

        rect = drawer.ax.patches[0]

        assert allclose(rect.get_xy(), (-0.6, -0.4))
        assert rect.get_height() == 0.8
        assert allclose(rect.get_width(), 1.2)

        text = drawer.ax.texts[0]

        assert text.get_text() == "Wide Gate"
        assert text.get_position() == (0, 0)
        plt.close()
Example #13
0
    def test_multidimensional_indexing_along_axis_autograd(self):
        """Test that indexing with a sequence properly extracts
        the elements from the specified tensor axis"""
        t = np.array([[[1, 2], [3, 4], [-1, 1]], [[5, 6], [0, -1], [2, 1]]])
        indices = np.array([[0, 0], [1, 0]])

        def cost_fn(t):
            return fn.sum(fn.take(t, indices, axis=1))

        res = cost_fn(t)
        expected = np.sum(
            np.array([[[[1, 2], [1, 2]], [[3, 4], [1, 2]]], [[[5, 6], [5, 6]], [[0, -1], [5, 6]]]])
        )
        assert fn.allclose(res, expected)

        grad = qml.grad(cost_fn)(t)
        expected = np.array([[[3, 3], [1, 1], [0, 0]], [[3, 3], [1, 1], [0, 0]]])
        assert fn.allclose(grad, expected)
Example #14
0
 def test_array_indexing_along_axis(self, t):
     """Test that indexing with a sequence properly extracts
     the elements from the specified tensor axis"""
     indices = [0, 1, -2]
     res = fn.take(t, indices, axis=2)
     expected = np.array(
         [[[1, 2, 1], [3, 4, 3], [-1, 1, -1]], [[5, 6, 5], [0, -1, 0], [2, 1, 2]]]
     )
     assert fn.allclose(res, expected)
Example #15
0
 def test_multidimensional_indexing_along_axis(self, t):
     """Test that indexing with a sequence properly extracts
     the elements from the specified tensor axis"""
     indices = np.array([[0, 0], [1, 0]])
     res = fn.take(t, indices, axis=1)
     expected = np.array(
         [[[[1, 2], [1, 2]], [[3, 4], [1, 2]]], [[[5, 6], [5, 6]], [[0, -1], [5, 6]]]]
     )
     assert fn.allclose(res, expected)
Example #16
0
def _su2su2_to_tensor_products(U):
    r"""Given a matrix :math:`U = A \otimes B` in SU(2) x SU(2), extract the two SU(2)
    operations A and B.

    This process has been described in detail in the Appendix of Coffey & Deiotte
    https://link.springer.com/article/10.1007/s11128-009-0156-3
    """

    # First, write A = [[a1, a2], [-a2*, a1*]], which we can do for any SU(2) element.
    # Then, A \otimes B = [[a1 B, a2 B], [-a2*B, a1*B]] = [[C1, C2], [C3, C4]]
    # where the Ci are 2x2 matrices.
    C1 = U[0:2, 0:2]
    C2 = U[0:2, 2:4]
    C3 = U[2:4, 0:2]
    C4 = U[2:4, 2:4]

    # From the definition of A \otimes B, C1 C4^\dag = a1^2 I, so we can extract a1
    C14 = math.dot(C1, math.conj(math.T(C4)))
    a1 = math.sqrt(math.cast_like(C14[0, 0], 1j))

    # Similarly, -C2 C3^\dag = a2^2 I, so we can extract a2
    C23 = math.dot(C2, math.conj(math.T(C3)))
    a2 = math.sqrt(-math.cast_like(C23[0, 0], 1j))

    # This gets us a1, a2 up to a sign. To resolve the sign, ensure that
    # C1 C2^dag = a1 a2* I
    C12 = math.dot(C1, math.conj(math.T(C2)))

    if not math.allclose(a1 * math.conj(a2), C12[0, 0]):
        a2 *= -1

    # Construct A
    A = math.stack(
        [math.stack([a1, a2]),
         math.stack([-math.conj(a2), math.conj(a1)])])

    # Next, extract B. Can do from any of the C, just need to be careful in
    # case one of the elements of A is 0.
    if not math.allclose(A[0, 0], 0.0, atol=1e-6):
        B = C1 / math.cast_like(A[0, 0], 1j)
    else:
        B = C2 / math.cast_like(A[0, 1], 1j)

    return math.convert_like(A, U), math.convert_like(B, U)
Example #17
0
def _yzy_to_zyz(middle_yzy):
    """Converts a set of angles representing a sequence of rotations RY, RZ, RY into
    an equivalent sequence of the form RZ, RY, RZ.

    Any rotation in 3-dimensional space (or, equivalently, any single-qubit unitary)
    can be expressed as a sequence of rotations about 3 axes in 12 different ways.
    Typically, the arbitrary single-qubit rotation is expressed as RZ(a) RY(b) RZ(c),
    but there are some situations, e.g., composing two such rotations, where we need
    to convert between representations. This function converts the angles of a sequence

    .. math::

       RY(y_1) RZ(z) RY(y_2)

    into the form

    .. math::

       RZ(z_1) RY(y) RZ(z_2)

    This is accomplished by first converting the rotation to quaternion form, and then
    extracting the desired set of angles.

    Args:
        y1 (float): The angle of the first ``RY`` rotation.
        z (float): The angle of the inner ``RZ`` rotation.
        y2 (float): The angle of the second ``RY`` rotation.

    Returns:
        tuple[float, float, float]: A list of rotation angles in the ZYZ representation.
    """
    if allclose(stack(middle_yzy), cast_like(zeros(3), stack(middle_yzy))):
        return stack([0.0, 0.0, 0.0])

    y1, z, y2 = middle_yzy[0], middle_yzy[1], middle_yzy[2]

    # First, compute the quaternion representation
    # https://ntrs.nasa.gov/api/citations/19770024290/downloads/19770024290.pdf
    qw = cos(z / 2) * cos(0.5 * (y1 + y2))
    qx = sin(z / 2) * sin(0.5 * (y1 - y2))
    qy = cos(z / 2) * sin(0.5 * (y1 + y2))
    qz = sin(z / 2) * cos(0.5 * (y1 - y2))

    # Now convert from YZY Euler angles to ZYZ angles
    # Source: http://bediyap.com/programming/convert-quaternion-to-euler-rotations/
    z1_arg1 = 2 * (qy * qz - qw * qx)
    z1_arg2 = 2 * (qx * qz + qw * qy)
    z1 = arctan2(z1_arg1, z1_arg2)

    y = arccos(qw ** 2 - qx ** 2 - qy ** 2 + qz ** 2)

    z2_arg1 = 2 * (qy * qz + qw * qx)
    z2_arg2 = -2 * (qx * qz - qw * qy)
    z2 = arctan2(z2_arg1, z2_arg2)

    return stack([z1, y, z2])
Example #18
0
def fuse_rot_angles(angles_1, angles_2):
    """Computed the set of rotation angles that is obtained when composing
    two ``qml.Rot`` operations.

    The ``qml.Rot`` operation represents the most general single-qubit operation.
    Two such operations can be fused into a new operation, however the angular dependence
    is non-trivial.

    Args:
        angles_1 (float): A set of three angles for the first ``qml.Rot`` operation.
        angles_2 (float): A set of three angles for the second ``qml.Rot`` operation.

    Returns:
        array[float]: A tuple of rotation angles for a single ``qml.Rot`` operation
        that implements the same operation as the two sets of input angles.
    """
    # Check for all-zero instances; if there are some, we can just return the sum.
    are_angles_1_zero = allclose(angles_1, zeros(3))
    are_angles_2_zero = allclose(angles_2, zeros(3))

    if are_angles_1_zero or are_angles_2_zero:
        return stack([angles_1[i] + angles_2[i] for i in range(3)])

    # RZ(a) RY(b) RZ(c) fused with RZ(d) RY(e) RZ(f)
    # first produces RZ(a) RY(b) RZ(c+d) RY(e) RZ(f)
    left_z = angles_1[0]
    middle_yzy = angles_1[1], angles_1[2] + angles_2[0], angles_2[1]
    right_z = angles_2[2]

    # There are a few other cases to consider where things can be 0 and
    # avoid having to use the quaternion conversion routine
    # If b = 0, then we have RZ(a + c + d) RY(e) RZ(f)
    if allclose(middle_yzy[0], 0.0):
        # Then if e is close to zero, return a single rotation RZ(a + c + d + f)
        if allclose(middle_yzy[2], 0.0):
            return [left_z + middle_yzy[1] + right_z, 0.0, 0.0]
        return stack([left_z + middle_yzy[1], middle_yzy[2], right_z])

    # If c + d is close to 0, then we have the case RZ(a) RY(b + e) RZ(f)
    if allclose(middle_yzy[1], 0.0):
        # If b + e is 0, we have RZ(a + f)
        if allclose(middle_yzy[0] + middle_yzy[2], 0.0):
            return [left_z + right_z, 0.0, 0.0]
        return stack([left_z, middle_yzy[0] + middle_yzy[2], right_z])

    # If e is close to 0, then we have the case RZ(a) RY(b) RZ(c + d + f)
    # The case where b is 0 actually already covered in the first loop,
    # so only one case here.
    if allclose(middle_yzy[2], 0.0):
        return stack([left_z, middle_yzy[0], middle_yzy[1] + right_z])

    # Otherwise, we need to turn the RY(b) RZ(c+d) RY(e) into something
    # of the form RZ(u) RY(v) RZ(w)
    u, v, w = _yzy_to_zyz(middle_yzy)

    # Then we can combine to create
    # RZ(a + u) RY(v) RZ(w + f)
    return stack([left_z + u, v, w + right_z])
Example #19
0
    def test_sum_axis_keepdims(self, t1):
        """Test that passing the axis argument allows for summing along
        a specific axis, while keepdims avoids the summed dimensions from being removed"""
        res = fn.sum(t1, axis=(0, 2), keepdims=True)

        # if tensorflow or pytorch, extract view of underlying data
        if hasattr(res, "numpy"):
            res = res.numpy()

        assert fn.allclose(res, np.array([[[14], [6], [3]]]))
        assert res.shape == (1, 3, 1)
Example #20
0
    def test_sum_axis(self, t1):
        """Test that passing the axis argument allows for summing along
        a specific axis"""
        res = fn.sum(t1, axis=(0, 2))

        # if tensorflow or pytorch, extract view of underlying data
        if hasattr(res, "numpy"):
            res = res.numpy()

        assert fn.allclose(res, np.array([14, 6, 3]))
        assert res.shape == (3,)
Example #21
0
    def test_concatenate_flattened_arrays(self, t1):
        """Concatenating arrays with axis=None will result in all arrays being pre-flattened"""
        t2 = onp.array([5])
        res = fn.concatenate([t1, t2], axis=None)

        # if tensorflow or pytorch, extract view of underlying data
        if hasattr(res, "numpy"):
            res = res.numpy()

        assert fn.allclose(res, np.array([1, 2, 5]))
        assert list(res.shape) == [3]
Example #22
0
def test_allclose(t1, t2):
    """Test that the allclose function works for a variety of inputs."""
    res = fn.allclose(t1, t2)

    if isinstance(t1, tf.Variable):
        t1 = tf.convert_to_tensor(t1)

    if isinstance(t2, tf.Variable):
        t2 = tf.convert_to_tensor(t2)

    expected = all(float(x) == float(y) for x, y in zip(t1, t2))
    assert res == expected
Example #23
0
    def test_stack_axis(self, t1):
        """Test that passing the axis argument allows for stacking along
        a different axis"""
        t2 = onp.array([3, 4])
        res = fn.stack([t1, t2], axis=1)

        # if tensorflow or pytorch, extract view of underlying data
        if hasattr(res, "numpy"):
            res = res.numpy()

        assert fn.allclose(res, np.array([[1, 3], [2, 4]]))
        assert list(res.shape) == [2, 2]
Example #24
0
def _compute_num_cnots(U):
    r"""Compute the number of CNOTs required to implement a U in SU(4). This is based on
    the trace of

    .. math::

        \gamma(U) = (E^\dag U E) (E^\dag U E)^T,

    and follows the arguments of this paper: https://arxiv.org/abs/quant-ph/0308045.
    """
    u = math.dot(Edag, math.dot(U, E))
    gammaU = math.dot(u, math.T(u))
    trace = math.trace(gammaU)

    # Case: 0 CNOTs (tensor product), the trace is +/- 4
    # We need a tolerance of around 1e-7 here in order to work with the case where U
    # is specified with 8 decimal places.
    if math.allclose(trace, 4, atol=1e-7) or math.allclose(
            trace, -4, atol=1e-7):
        return 0

    # To distinguish between 1/2 CNOT cases, we need to look at the eigenvalues
    evs = math.linalg.eigvals(gammaU)

    sorted_evs = math.sort(math.imag(evs))

    # Case: 1 CNOT, the trace is 0, and the eigenvalues of gammaU are [-1j, -1j, 1j, 1j]
    # Checking the eigenvalues is needed because of some special 2-CNOT cases that yield
    # a trace 0.
    if math.allclose(trace, 0j, atol=1e-7) and math.allclose(
            sorted_evs, [-1, -1, 1, 1]):
        return 1

    # Case: 2 CNOTs, the trace has only a real part (or is 0)
    if math.allclose(math.imag(trace), 0.0, atol=1e-7):
        return 2

    # For the case with 3 CNOTs, the trace is a non-zero complex number
    # with both real and imaginary parts.
    return 3
Example #25
0
def _convert_to_su2(U):
    r"""Check unitarity of a matrix and convert it to :math:`SU(2)` if possible.

    Args:
        U (array[complex]): A matrix, presumed to be :math:`2 \times 2` and unitary.

    Returns:
        array[complex]: A :math:`2 \times 2` matrix in :math:`SU(2)` that is
        equivalent to U up to a global phase.
    """
    # Check unitarity
    if not math.allclose(
            math.dot(U, math.T(math.conj(U))), math.eye(2), atol=1e-7):
        raise ValueError("Operator must be unitary.")

    # Compute the determinant
    det = U[0, 0] * U[1, 1] - U[0, 1] * U[1, 0]

    # Convert to SU(2) if it's not close to 1
    if not math.allclose(det, [1.0]):
        exp_angle = -1j * math.cast_like(math.angle(det), 1j) / 2
        U = math.cast_like(U, exp_angle) * math.exp(exp_angle)

    return U
Example #26
0
def _convert_to_su4(U):
    r"""Check unitarity of a 4x4 matrix and convert it to :math:`SU(4)` if the determinant is not 1.

    Args:
        U (array[complex]): A matrix, presumed to be :math:`4 \times 4` and unitary.

    Returns:
        array[complex]: A :math:`4 \times 4` matrix in :math:`SU(4)` that is
        equivalent to U up to a global phase.
    """
    # Check unitarity
    if not math.allclose(
            math.dot(U, math.T(math.conj(U))), math.eye(4), atol=1e-7):
        raise ValueError("Operator must be unitary.")

    # Compute the determinant
    det = math.linalg.det(U)

    # Convert to SU(4) if it's not close to 1
    if not math.allclose(det, 1.0):
        exp_angle = -1j * math.cast_like(math.angle(det), 1j) / 4
        U = math.cast_like(U, det) * math.exp(exp_angle)

    return U
Example #27
0
    def test_ones_like_explicit_dtype(self, t):
        """Test that the ones like function creates the correct
        shape and type tensor."""
        res = fn.ones_like(t, dtype=np.float16)

        if isinstance(t, (list, tuple)):
            t = onp.asarray(t)

        assert res.shape == t.shape
        assert fn.get_interface(res) == fn.get_interface(t)
        assert fn.allclose(res, np.ones(t.shape))

        # if tensorflow or pytorch, extract view of underlying data
        if hasattr(res, "numpy"):
            res = res.numpy()
            t = t.numpy()

        assert onp.asarray(res).dtype.type is np.float16
Example #28
0
    def test_measure(self):
        """Tests the measure method."""

        drawer = MPLDrawer(1, 1)
        drawer.measure(0, 0)

        box = drawer.ax.patches[0]
        assert box.get_xy() == (-0.4, -0.4)
        assert box.get_width() == 0.8
        assert box.get_height() == 0.8

        arc = drawer.ax.patches[1]
        assert arc.center == (0, 0.05)
        assert arc.theta1 == 180
        assert arc.theta2 == 0
        assert allclose(arc.height, 0.44)
        assert arc.width == 0.48

        arrow = drawer.ax.patches[2]
        assert isinstance(arrow, FancyArrow)

        plt.close()
Example #29
0
    def test_measure(self):
        """Tests the measure method."""

        drawer = MPLDrawer(1, 1)
        drawer.measure(0, 0)

        box = drawer.ax.patches[0]
        assert box.get_x() == -drawer._box_length / 2.0 + drawer._pad
        assert box.get_y() == -drawer._box_length / 2.0 + drawer._pad
        assert box.get_width() == drawer._box_length - 2 * drawer._pad

        arc = drawer.ax.patches[1]
        assert arc.center == (0, drawer._box_length / 16)
        assert arc.theta1 == 180
        assert arc.theta2 == 0
        assert allclose(arc.height, 0.55 * drawer._box_length)
        assert arc.width == 0.6 * drawer._box_length

        arrow = drawer.ax.patches[2]
        assert isinstance(arrow, FancyArrow)

        plt.close()
Example #30
0
def _decomposition_2_cnots(U, wires):
    r"""If 2 CNOTs are required, we can write the circuit as
     -╭U- = -A--╭X--RZ(d)--╭X--C-
     -╰U- = -B--╰C--RX(p)--╰C--D-
    We need to find the angles for the Z and X rotations such that the inner
    part has the same spectrum as U, and then we can recover A, B, C, D.
    """
    # Compute the rotation angles
    u = math.dot(Edag, math.dot(U, E))
    gammaU = math.dot(u, math.T(u))
    evs, _ = math.linalg.eig(gammaU)

    # These choices are based on Proposition III.3 of
    # https://arxiv.org/abs/quant-ph/0308045
    # There is, however, a special case where the circuit has the form
    # -╭U- = -A--╭C--╭X--C-
    # -╰U- = -B--╰X--╰C--D-
    #
    # or some variant of this, where the two CNOTs are adjacent.
    #
    # What happens here is that the set of evs is -1, -1, 1, 1 and we can write
    # -╭U- = -A--╭X--SZ--╭X--C-
    # -╰U- = -B--╰C--SX--╰C--D-
    # where SZ and SX are square roots of Z and X respectively. (This
    # decomposition comes from using Hadamards to flip the direction of the
    # first CNOT, and then decomposing them and merging single-qubit gates.) For
    # some reason this case is not handled properly with the full algorithm, so
    # we treat it separately.

    sorted_evs = math.sort(math.real(evs))

    if math.allclose(sorted_evs, [-1, -1, 1, 1]):
        interior_decomp = [
            qml.CNOT(wires=[wires[1], wires[0]]),
            qml.S(wires=wires[0]),
            qml.SX(wires=wires[1]),
            qml.CNOT(wires=[wires[1], wires[0]]),
        ]

        # S \otimes SX
        inner_matrix = S_SX
    else:
        # For the non-special case, the eigenvalues come in conjugate pairs.
        # We need to find two non-conjugate eigenvalues to extract the angles.
        x = math.angle(evs[0])
        y = math.angle(evs[1])

        # If it was the conjugate, grab a different eigenvalue.
        if math.allclose(x, -y):
            y = math.angle(evs[2])

        delta = (x + y) / 2
        phi = (x - y) / 2

        interior_decomp = [
            qml.CNOT(wires=[wires[1], wires[0]]),
            qml.RZ(delta, wires=wires[0]),
            qml.RX(phi, wires=wires[1]),
            qml.CNOT(wires=[wires[1], wires[0]]),
        ]

        RZd = qml.RZ(math.cast_like(delta, 1j), wires=0).matrix
        RXp = qml.RX(phi, wires=0).matrix
        inner_matrix = math.kron(RZd, RXp)

    # We need the matrix representation of this interior part, V, in order to
    # decompose U = (A \otimes B) V (C \otimes D)
    V = math.dot(math.cast_like(CNOT10, U),
                 math.dot(inner_matrix, math.cast_like(CNOT10, U)))

    # Now we find the A, B, C, D in SU(2), and return the decomposition
    A, B, C, D = _extract_su2su2_prefactors(U, V)

    A_ops = zyz_decomposition(A, wires[0])
    B_ops = zyz_decomposition(B, wires[1])
    C_ops = zyz_decomposition(C, wires[0])
    D_ops = zyz_decomposition(D, wires[1])

    return C_ops + D_ops + interior_decomp + A_ops + B_ops