def test_kron_bases_repeat_sanity_checks(basis, repeat): product_basis = cirq.kron_bases(basis, repeat=repeat) assert len(product_basis) == 4**repeat for name1, matrix1 in product_basis.items(): for name2, matrix2 in product_basis.items(): p = cirq.hilbert_schmidt_inner_product(matrix1, matrix2) if name1 != name2: assert p == 0 else: assert abs(p) >= 1
def test_kron_bases(basis1, basis2, expected_kron_basis): kron_basis = cirq.kron_bases(basis1, basis2) assert len(kron_basis) == 16 assert set(kron_basis.keys()) == set(expected_kron_basis.keys()) for name in kron_basis.keys(): assert np.all(kron_basis[name] == expected_kron_basis[name])
for row_outer in range(2) for row_inner in range(2) for col_outer in range(2) for col_inner in range(2) }), )) def test_kron_bases(basis1, basis2, expected_kron_basis): kron_basis = cirq.kron_bases(basis1, basis2) assert len(kron_basis) == 16 assert set(kron_basis.keys()) == set(expected_kron_basis.keys()) for name in kron_basis.keys(): assert np.all(kron_basis[name] == expected_kron_basis[name]) @pytest.mark.parametrize('basis1,basis2', ( (PAULI_BASIS, cirq.kron_bases(PAULI_BASIS)), (STANDARD_BASIS, cirq.kron_bases(STANDARD_BASIS, repeat=1)), (cirq.kron_bases(PAULI_BASIS, PAULI_BASIS), cirq.kron_bases(PAULI_BASIS, repeat=2)), (cirq.kron_bases( cirq.kron_bases(PAULI_BASIS, repeat=2), cirq.kron_bases(PAULI_BASIS, repeat=3), PAULI_BASIS), cirq.kron_bases(PAULI_BASIS, repeat=6)), (cirq.kron_bases( cirq.kron_bases(PAULI_BASIS, STANDARD_BASIS), cirq.kron_bases(PAULI_BASIS, STANDARD_BASIS)), cirq.kron_bases(PAULI_BASIS, STANDARD_BASIS, repeat=2)), )) def test_kron_bases_consistency(basis1, basis2): assert set(basis1.keys()) == set(basis2.keys())