def test_prob_hess(self, method): """Test the probability Hessian using linear combination of unitaries method d^2p0/da^2 = - sin(a)sin(b) / 2 d^2p1/da^2 = sin(a)sin(b) / 2 d^2p0/dadb = cos(a)cos(b) / 2 d^2p1/dadb = - cos(a)cos(b) / 2 """ a = Parameter('a') b = Parameter('b') params = [(a, a), (a, b)] q = QuantumRegister(1) qc = QuantumCircuit(q) qc.h(q) qc.rz(a, q[0]) qc.rx(b, q[0]) op = CircuitStateFn(primitive=qc, coeff=1.) prob_hess = Hessian(hess_method=method).convert(operator=op, params=params) values_dict = [{a: np.pi / 4, b: 0}, {a: np.pi / 4, b: np.pi / 4}, {a: np.pi / 2, b: np.pi}] correct_values = [[[0, 0], [1 / (2 * np.sqrt(2)), - 1 / (2 * np.sqrt(2))]], [[- 1 / 4, 1 / 4], [1 / 4, - 1 / 4]], [[0, 0], [0, 0]]] for i, value_dict in enumerate(values_dict): for j, prob_hess_result in enumerate(prob_hess.assign_parameters(value_dict).eval()): np.testing.assert_array_almost_equal(prob_hess_result, correct_values[i][j], decimal=1)
def test_operator_coefficient_hessian(self, method): """Test the operator coefficient hessian <Z> = Tr( | psi > < psi | Z) = sin(a)sin(b) <X> = Tr( | psi > < psi | X) = cos(a) d<H>/dc_0 = 2 * c_0 * <X> + c_1 * <Z> d<H>/dc_1 = c_0 * <Z> d^2<H>/dc_0^2 = 2 * <X> d^2<H>/dc_0dc_1 = <Z> d^2<H>/dc_1dc_0 = <Z> d^2<H>/dc_1^2 = 0 """ a = Parameter('a') b = Parameter('b') q = QuantumRegister(1) qc = QuantumCircuit(q) qc.h(q) qc.rz(a, q[0]) qc.rx(b, q[0]) coeff_0 = Parameter('c_0') coeff_1 = Parameter('c_1') ham = coeff_0 * coeff_0 * X + coeff_1 * coeff_0 * Z op = ~StateFn(ham) @ CircuitStateFn(primitive=qc, coeff=1.) gradient_coeffs = [(coeff_0, coeff_0), (coeff_0, coeff_1), (coeff_1, coeff_1)] coeff_grad = Hessian(hess_method=method).convert(op, gradient_coeffs) values_dict = [{coeff_0: 0.5, coeff_1: -1, a: np.pi / 4, b: np.pi}, {coeff_0: 0.5, coeff_1: -1, a: np.pi / 4, b: np.pi / 4}] correct_values = [[2 / np.sqrt(2), 0, 0], [2 / np.sqrt(2), 1 / 2, 0]] for i, value_dict in enumerate(values_dict): np.testing.assert_array_almost_equal(coeff_grad.assign_parameters(value_dict).eval(), correct_values[i], decimal=1)
def test_state_hessian(self, method): """Test the state Hessian Tr(|psi><psi|Z) = sin(a)sin(b) Tr(|psi><psi|X) = cos(a) d^2<H>/da^2 = - 0.5 cos(a) + 1 sin(a)sin(b) d^2<H>/dbda = - 1 cos(a)cos(b) d^2<H>/dbda = - 1 cos(a)cos(b) d^2<H>/db^2 = + 1 sin(a)sin(b) """ ham = 0.5 * X - 1 * Z a = Parameter('a') b = Parameter('b') params = [(a, a), (a, b), (b, b)] q = QuantumRegister(1) qc = QuantumCircuit(q) qc.h(q) qc.rz(a, q[0]) qc.rx(b, q[0]) op = ~StateFn(ham) @ CircuitStateFn(primitive=qc, coeff=1.) state_hess = Hessian(hess_method=method).convert(operator=op, params=params) values_dict = [{a: np.pi / 4, b: np.pi}, {a: np.pi / 4, b: np.pi / 4}, {a: np.pi / 2, b: np.pi / 4}] correct_values = [[-0.5 / np.sqrt(2), 1 / np.sqrt(2), 0], [-0.5 / np.sqrt(2) + 0.5, -1 / 2., 0.5], [1 / np.sqrt(2), 0, 1 / np.sqrt(2)]] for i, value_dict in enumerate(values_dict): np.testing.assert_array_almost_equal(state_hess.assign_parameters(value_dict).eval(), correct_values[i], decimal=1)
def test_state_hessian_custom_combo_fn(self, method): """Test the state Hessian with on an operator which includes a user-defined combo_fn. Tr(|psi><psi|Z) = sin(a)sin(b) Tr(|psi><psi|X) = cos(a) d^2<H>/da^2 = - 0.5 cos(a) + 1 sin(a)sin(b) d^2<H>/dbda = - 1 cos(a)cos(b) d^2<H>/dbda = - 1 cos(a)cos(b) d^2<H>/db^2 = + 1 sin(a)sin(b) """ ham = 0.5 * X - 1 * Z a = Parameter('a') b = Parameter('b') params = [(a, a), (a, b), (b, b)] q = QuantumRegister(1) qc = QuantumCircuit(q) qc.h(q) qc.rz(a, q[0]) qc.rx(b, q[0]) op = ListOp([~StateFn(ham) @ CircuitStateFn(primitive=qc, coeff=1.)], combo_fn=lambda x: x[0]**3 + 4 * x[0]) state_hess = Hessian(hess_method=method).convert(operator=op, params=params) values_dict = [{ a: np.pi / 4, b: np.pi }, { a: np.pi / 4, b: np.pi / 4 }, { a: np.pi / 2, b: np.pi / 4 }] correct_values = [[-1.28163104, 2.56326208, 1.06066017], [-0.04495626, -2.40716991, 1.8125], [2.82842712, -1.5, 1.76776695]] for i, value_dict in enumerate(values_dict): np.testing.assert_array_almost_equal( state_hess.assign_parameters(value_dict).eval(), correct_values[i], decimal=1)