def test_grad_combo_fn_chain_rule(self, method):
        """Test the chain rule for a custom gradient combo function."""
        np.random.seed(2)

        def combo_fn(x):
            amplitudes = x[0].primitive.data
            pdf = np.multiply(amplitudes, np.conj(amplitudes))
            return np.sum(np.log(pdf)) / (-len(amplitudes))

        def grad_combo_fn(x):
            amplitudes = x[0].primitive.data
            pdf = np.multiply(amplitudes, np.conj(amplitudes))
            grad = []
            for prob in pdf:
                grad += [-1 / prob]
            return grad

        qc = RealAmplitudes(2, reps=1)
        grad_op = ListOp([StateFn(qc.decompose())],
                         combo_fn=combo_fn,
                         grad_combo_fn=grad_combo_fn)
        grad = Gradient(grad_method=method).convert(grad_op)
        value_dict = dict(
            zip(qc.ordered_parameters,
                np.random.rand(len(qc.ordered_parameters))))
        correct_values = [
            [(-0.16666259133549044 + 0j)],
            [(-7.244949702732864 + 0j)],
            [(-2.979791752749964 + 0j)],
            [(-5.310186078432614 + 0j)],
        ]
        np.testing.assert_array_almost_equal(
            grad.assign_parameters(value_dict).eval(), correct_values)
    def test_real_amplitudes_circuit_5q(self):
        """Test that for the 5-qubit real amplitudes circuit
        extracting linear functions produces the expected number of linear blocks,
        and synthesizing these blocks produces an expected number of CNOTs.
        """
        ansatz = RealAmplitudes(5, reps=2)
        circuit1 = ansatz.decompose()

        # collect linear functions
        circuit2 = PassManager(CollectLinearFunctions()).run(circuit1)
        self.assertEqual(circuit2.count_ops()["linear_function"], 2)

        # synthesize linear functions
        circuit3 = PassManager(LinearFunctionsSynthesis()).run(circuit2)
        self.assertEqual(circuit3.count_ops()["cx"], 8)
    def test_grad_combo_fn_chain_rule_nat_grad(self):
        """Test the chain rule for a custom gradient combo function."""
        np.random.seed(2)

        def combo_fn(x):
            amplitudes = x[0].primitive.data
            pdf = np.multiply(amplitudes, np.conj(amplitudes))
            return np.sum(np.log(pdf)) / (-len(amplitudes))

        def grad_combo_fn(x):
            amplitudes = x[0].primitive.data
            pdf = np.multiply(amplitudes, np.conj(amplitudes))
            grad = []
            for prob in pdf:
                grad += [-1 / prob]
            return grad

        try:
            qc = RealAmplitudes(2, reps=1)
            grad_op = ListOp([StateFn(qc.decompose())],
                             combo_fn=combo_fn,
                             grad_combo_fn=grad_combo_fn)
            grad = NaturalGradient(grad_method="lin_comb",
                                   regularization="ridge").convert(
                                       grad_op, qc.ordered_parameters)
            value_dict = dict(
                zip(qc.ordered_parameters,
                    np.random.rand(len(qc.ordered_parameters))))
            correct_values = [[0.20777236], [-18.92560338], [-15.89005475],
                              [-10.44002031]]
            np.testing.assert_array_almost_equal(
                grad.assign_parameters(value_dict).eval(),
                correct_values,
                decimal=3)
        except MissingOptionalLibraryError as ex:
            self.skipTest(str(ex))
    return np.sum(np.log(pdf)) / (-len(amplitudes))


def grad_combo_fn(x):
    amplitudes = x[0].primitive.data
    pdf = np.multiply(amplitudes, np.conj(amplitudes))
    grad = []
    for prob in pdf:
        grad += [-1 / prob]
    return grad



qc = RealAmplitudes(2, reps=1)
grad_op = ListOp(
    [StateFn(qc.decompose())], combo_fn=combo_fn, grad_combo_fn=grad_combo_fn
)
grad = NaturalGradient(grad_method="lin_comb", regularization="ridge").convert(
    grad_op, qc.ordered_parameters
)
value_dict = dict(
    zip(qc.ordered_parameters, np.random.rand(len(qc.ordered_parameters)))
)
correct_values = [[0.20777236], [-18.92560338], [-15.89005475], [-10.44002031]]
backend = BasicAer.get_backend("qasm_simulator")
q_instance = QuantumInstance(backend=backend, shots=5000)
sampler = CircuitSampler(backend=q_instance).convert(
    grad, params={k: [v] for k, v in value_dict.items()}
)
print('Sampler ', sampler.eval()[0])
print('Correct Values ', correct_values)