def test_binding_symbols_to_circuit_binds_them_to_wavefunction_operation(self):
        alpha, beta = sympy.symbols("alpha, beta")
        circuit = Circuit(
            [RX(alpha)(0), MultiPhaseOperation((beta, 0.5)), RX(beta)(1)]
        ).bind({beta: 0.3})

        assert circuit.free_symbols == [alpha]
        assert circuit.operations == [
            RX(alpha)(0),
            MultiPhaseOperation((0.3, 0.5)),
            RX(0.3)(1),
        ]
    def test_binding_some_params_leaves_free_params(self):
        theta1, theta2, theta3 = sympy.symbols("theta1:4")
        circuit = Circuit(
            [
                RX(theta1)(0),
                RY(theta2)(1),
                RZ(theta3)(0),
                RX(theta2)(0),
            ]
        )

        bound_circuit = circuit.bind({theta1: 0.5, theta3: 3.14})
        assert bound_circuit.free_symbols == [theta2]
    def test_binding_all_params_leaves_no_free_symbols(self):
        alpha, beta, gamma = sympy.symbols("alpha,beta,gamma")
        circuit = Circuit(
            [
                RX(alpha)(0),
                RY(beta)(1),
                RZ(gamma)(0),
                RX(gamma)(0),
            ]
        )
        bound_circuit = circuit.bind({alpha: 0.5, beta: 3.14, gamma: 0})

        assert not bound_circuit.free_symbols
    def test_binding_excessive_params_binds_only_the_existing_ones(self):
        theta1, theta2, theta3 = sympy.symbols("theta1:4")
        other_param = sympy.symbols("other_param")
        circuit = Circuit(
            [
                RX(theta1)(0),
                RY(theta2)(1),
                RZ(theta3)(0),
                RX(theta2)(0),
            ]
        )

        bound_circuit = circuit.bind({theta1: -np.pi, other_param: 42})
        assert bound_circuit.free_symbols == [theta2, theta3]
    def test_symbols_of_wavefunction_operations_are_present_in_circuits_free_symbols(
        self,
    ):
        alpha, beta = sympy.symbols("alpha, beta")
        circuit = Circuit([RX(alpha)(0), MultiPhaseOperation((beta, 0.5))])

        assert circuit.free_symbols == [alpha, beta]
def test_splitting_circuits_partitions_it_into_expected_chunks():
    def _predicate(operation):
        return isinstance(operation, GateOperation) and operation.gate.name in (
            "RX",
            "RY",
            "RZ",
        )

    circuit = Circuit(
        [RX(np.pi)(0), RZ(np.pi / 2)(1), CNOT(2, 3), RY(np.pi / 4)(2), X(0), Y(1)]
    )

    expected_partition = [
        (True, Circuit([RX(np.pi)(0), RZ(np.pi / 2)(1)], n_qubits=4)),
        (False, Circuit([CNOT(2, 3)], n_qubits=4)),
        (True, Circuit([RY(np.pi / 4)(2)], n_qubits=4)),
        (False, Circuit([X(0), Y(1)], n_qubits=4)),
    ]

    assert list(split_circuit(circuit, _predicate)) == expected_partition
    def test_circuit_bound_with_all_params_contains_bound_gates(self):
        theta1, theta2, theta3 = sympy.symbols("theta1:4")
        symbols_map = {theta1: 0.5, theta2: 3.14, theta3: 0}

        circuit = Circuit(
            [
                RX(theta1)(0),
                RY(theta2)(1),
                RZ(theta3)(0),
                RX(theta3)(0),
            ]
        )
        bound_circuit = circuit.bind(symbols_map)

        expected_circuit = Circuit(
            [
                RX(theta1).bind(symbols_map)(0),
                RY(theta2).bind(symbols_map)(1),
                RZ(theta3).bind(symbols_map)(0),
                RX(theta3).bind(symbols_map)(0),
            ]
        )

        assert bound_circuit == expected_circuit
class TestGates:
    @pytest.fixture
    def simulator(self) -> SymbolicSimulator:
        return SymbolicSimulator()

    @pytest.mark.parametrize(
        "circuit, expected_wavefunction",
        [
            (
                Circuit([RX(Symbol("theta"))(0)]),
                Wavefunction([
                    1.0 * cos(Symbol("theta") / 2),
                    -1j * sin(Symbol("theta") / 2)
                ]),
            ),
            (
                Circuit([X(0), RY(Symbol("theta"))(0)]),
                Wavefunction([
                    -1.0 * sin(Symbol("theta") / 2),
                    1.0 * cos(Symbol("theta") / 2),
                ]),
            ),
            (
                Circuit([
                    H(0),
                    U3(Symbol("theta"), Symbol("phi"), Symbol("lambda"))(0)
                ]),
                Wavefunction([
                    cos(Symbol("theta") / 2) / sqrt(2) +
                    -exp(I * Symbol("lambda")) * sin(Symbol("theta") / 2) /
                    sqrt(2),
                    exp(I * Symbol("phi")) * sin(Symbol("theta") / 2) / sqrt(2)
                    + exp(I * (Symbol("lambda") + Symbol("phi"))) *
                    cos(Symbol("theta") / 2) / sqrt(2),
                ]),
            ),
        ],
    )
    def test_wavefunction_works_as_expected_with_symbolic_circuits(
        self,
        simulator: SymbolicSimulator,
        circuit: Circuit,
        expected_wavefunction: Wavefunction,
    ):
        returned_wavefunction = simulator.get_wavefunction(circuit)

        assert returned_wavefunction == expected_wavefunction