def test_basis_encoding3(self): """ Input Validation """ x = [1, 0.5] basis = encode.BasisEncoding() with self.assertRaises(AssertionError): basis.circuit(x)
def test_combine1(self): x = np.array([1, 0, 0, 0, 0, 0, 0, 1]) encoder = encode.BasisEncoding() model = TreeTensorNetwork() full_circuit = combine(x, encoder, model) print(full_circuit) self.assertTrue(isinstance(full_circuit, qiskit.QuantumCircuit))
def test_basis_encoding1(self): """ n_qubits """ x = [1, 0, 0, 0, 1, 1, 1] basis = encode.BasisEncoding() n_qubits = basis.n_qubits(x) self.assertEqual(len(x), n_qubits, 'Error in n_qubits')
def test_combine9(self): """ With measurement argument --- ProbabilityThreshold """ x = np.array([1, 0, 0, 1]) encoder = encode.BasisEncoding() model = TreeTensorNetwork() measure = measurement.ProbabilityThreshold(3) full_circuit = combine(x, encoder, model, measure) print(full_circuit) self.assertTrue(isinstance(full_circuit, qiskit.QuantumCircuit))
def test_combine8(self): """ With measurement argument --- Expecation """ x = np.array([1, 0, 0, 1]) encoder = encode.BasisEncoding() model = TreeTensorNetwork() Y_obs = Observable.Y() measure = measurement.Expectation(0, observable=Y_obs) full_circuit = combine(x, encoder, model, measure) print(full_circuit) self.assertTrue(isinstance(full_circuit, qiskit.QuantumCircuit))
def test_run(self): X = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0]]) encoder = encode.BasisEncoding() model = TreeTensorNetwork() measure = measurement.Probability(0) predictions = run(X, encoder, model, measure) right_length = (len(predictions) == len(X)) right_type = isinstance(predictions, np.ndarray) self.assertTrue(right_length and right_type)
def test_basis_encoding2(self): """ Default """ x = [1, 0, 0, 0, 1, 1, 1] basis = encode.BasisEncoding() circuit = basis.circuit(x) counts = get_counts(circuit) key = None value = None for k, v in counts.items(): key = k expected_dirac_label = '1000111' self.assertEqual(len(counts.items()), 1) self.assertEqual(key, expected_dirac_label)