def test_apply_unitaries(): a, b, c = cirq.LineQubit.range(3) result = cirq.apply_unitaries( unitary_values=[cirq.H(a), cirq.CNOT(a, b), cirq.H(c).controlled_by(b)], qubits=[a, b, c] ) np.testing.assert_allclose( result.reshape(8), [np.sqrt(0.5), 0, 0, 0, 0, 0, 0.5, 0.5], atol=1e-8 ) # Different order. result = cirq.apply_unitaries( unitary_values=[cirq.H(a), cirq.CNOT(a, b), cirq.H(c).controlled_by(b)], qubits=[a, c, b] ) np.testing.assert_allclose( result.reshape(8), [np.sqrt(0.5), 0, 0, 0, 0, 0.5, 0, 0.5], atol=1e-8 ) # Explicit arguments. result = cirq.apply_unitaries( unitary_values=[cirq.H(a), cirq.CNOT(a, b), cirq.H(c).controlled_by(b)], qubits=[a, b, c], args=cirq.ApplyUnitaryArgs.default(num_qubits=3), ) np.testing.assert_allclose( result.reshape(8), [np.sqrt(0.5), 0, 0, 0, 0, 0, 0.5, 0.5], atol=1e-8 ) # Empty. result = cirq.apply_unitaries(unitary_values=[], qubits=[]) np.testing.assert_allclose(result, [1]) result = cirq.apply_unitaries(unitary_values=[], qubits=[], default=None) np.testing.assert_allclose(result, [1]) # Non-unitary operation. with pytest.raises(TypeError, match='non-unitary'): _ = cirq.apply_unitaries(unitary_values=[cirq.depolarize(0.5).on(a)], qubits=[a]) assert ( cirq.apply_unitaries(unitary_values=[cirq.depolarize(0.5).on(a)], qubits=[a], default=None) is None ) assert ( cirq.apply_unitaries(unitary_values=[cirq.depolarize(0.5).on(a)], qubits=[a], default=1) == 1 ) # Inconsistent arguments. with pytest.raises(ValueError, match='len'): _ = cirq.apply_unitaries( unitary_values=[], qubits=[], args=cirq.ApplyUnitaryArgs.default(1) )
def test_apply_unitaries_mixed_qid_shapes(): class PlusOneMod3Gate(cirq.SingleQubitGate): def _qid_shape_(self): return (3, ) def _unitary_(self): return np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) # yapf: disable class PlusOneMod4Gate(cirq.SingleQubitGate): def _qid_shape_(self): return (4, ) def _unitary_(self): return np.array( [[0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]] ) # yapf: disable a, b = cirq.LineQid.for_qid_shape((3, 4)) result = cirq.apply_unitaries( unitary_values=[ PlusOneMod3Gate().on(a.with_dimension(3)), cirq.X(a.with_dimension(2)), cirq.CNOT(a.with_dimension(2), b.with_dimension(2)), cirq.CNOT(a.with_dimension(2), b.with_dimension(2)), cirq.X(a.with_dimension(2)), PlusOneMod3Gate().on(a.with_dimension(3)), PlusOneMod3Gate().on(a.with_dimension(3)), ], qubits=[a, b], ) np.testing.assert_allclose(result.reshape(12), [1] + [0] * 11, atol=1e-8) result = cirq.apply_unitaries( unitary_values=[ PlusOneMod3Gate().on(a.with_dimension(3)), cirq.X(a.with_dimension(2)), cirq.CNOT(a.with_dimension(2), b.with_dimension(2)), cirq.CNOT(a.with_dimension(2), b.with_dimension(2)), cirq.X(a.with_dimension(2)), PlusOneMod3Gate().on(a.with_dimension(3)), PlusOneMod3Gate().on(a.with_dimension(3)), ], qubits=[a, b], args=cirq.ApplyUnitaryArgs( target_tensor=cirq.eye_tensor((3, 4), dtype=np.complex64), available_buffer=cirq.eye_tensor((3, 4), dtype=np.complex64), axes=(0, 1), ), ) np.testing.assert_allclose(result.reshape(12, 12), np.eye(12), atol=1e-8) result = cirq.apply_unitaries( unitary_values=[ PlusOneMod3Gate().on(a.with_dimension(3)), cirq.X(a.with_dimension(2)), PlusOneMod4Gate().on(b.with_dimension(4)), PlusOneMod4Gate().on(b.with_dimension(4)), cirq.X(b.with_dimension(2)), PlusOneMod4Gate().on(b.with_dimension(4)), PlusOneMod4Gate().on(b.with_dimension(4)), cirq.CNOT(a.with_dimension(2), b.with_dimension(2)), PlusOneMod4Gate().on(b.with_dimension(4)), cirq.X(b.with_dimension(2)), cirq.CNOT(a.with_dimension(2), b.with_dimension(2)), cirq.X(a.with_dimension(2)), PlusOneMod3Gate().on(a.with_dimension(3)), PlusOneMod3Gate().on(a.with_dimension(3)), ], qubits=[a, b], args=cirq.ApplyUnitaryArgs( target_tensor=cirq.eye_tensor((3, 4), dtype=np.complex64), available_buffer=cirq.eye_tensor((3, 4), dtype=np.complex64), axes=(0, 1), ), ) np.testing.assert_allclose( result.reshape(12, 12), np.array([ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], ]), atol=1e-8, )