Ejemplo n.º 1
0
    def test_multi_circuit_batch(self):
        """Test that a batch of circuits works."""
        a_symbol = sympy.Symbol('alpha')
        some_values = np.array([[0.5], [3.5]])
        circuit = cirq.Circuit(cirq.H(cirq.GridQubit(0, 0))**a_symbol)
        results = unitary.Unitary()(util.convert_to_tensor([circuit, circuit]),
                                    symbol_names=[a_symbol],
                                    symbol_values=some_values)
        u_1 = cirq.unitary(
            cirq.resolve_parameters(circuit, {a_symbol: some_values[0][0]}))
        u_2 = cirq.unitary(
            cirq.resolve_parameters(circuit, {a_symbol: some_values[1][0]}))

        self.assertAllClose(results, [u_1, u_2])
Ejemplo n.º 2
0
    def test_single_circuit_batch_inputs(self):
        """Test that a single circuit with multiple parameters works."""
        a_symbol = sympy.Symbol('alpha')
        some_values = np.array([[0.5], [3.5]])
        circuit = cirq.Circuit(cirq.H(cirq.GridQubit(0, 0))**a_symbol)
        results = unitary.Unitary()(circuit,
                                    symbol_names=[a_symbol],
                                    symbol_values=some_values)
        u_1 = cirq.unitary(
            cirq.resolve_parameters(circuit, {a_symbol: some_values[0][0]}))
        u_2 = cirq.unitary(
            cirq.resolve_parameters(circuit, {a_symbol: some_values[1][0]}))

        self.assertAllClose(results, [u_1, u_2])
Ejemplo n.º 3
0
    def test_op_errors(self):
        """Test that op errors can be hit."""
        u_calc = unitary.Unitary()
        symbol = sympy.Symbol('alpha')
        circuit = cirq.Circuit(cirq.H(cirq.GridQubit(0, 0))**symbol)
        with self.assertRaisesRegex(Exception, expected_regex=""):
            # wrong symbol name.
            u_calc([circuit],
                   symbol_names=['alphaaaa'],
                   symbol_values=[[2.0], [3.0]])

        with self.assertRaisesRegex(Exception, expected_regex=""):
            # too many symbol values provided.
            u_calc(circuit,
                   symbol_names=['alpha'],
                   symbol_values=[[2.0, 4.0], [3.0, 5.0]])
Ejemplo n.º 4
0
    def test_input_errors(self):
        """Test that bad inputs caught input_check.py."""
        u_calc = unitary.Unitary()
        symbol = sympy.Symbol('alpha')
        circuit = cirq.Circuit(cirq.H(cirq.GridQubit(0, 0))**symbol)
        with self.assertRaisesRegex(Exception, expected_regex=""):
            # no value provided.
            u_calc([circuit, circuit], symbol_names=[symbol], repetitions=5)

        with self.assertRaisesRegex(Exception, expected_regex=""):
            # no name provided.
            u_calc([circuit, circuit],
                   symbol_names=[],
                   symbol_values=[[2.0], [3.0]])

        with self.assertRaisesRegex(Exception, expected_regex=""):
            # deceptive, but the circuit shouldn't be in a list. otherwise fine.
            u_calc([circuit],
                   symbol_names=['alpha'],
                   symbol_values=[[2.0], [3.0]])
Ejemplo n.º 5
0
 def test_basic_inputs_fixed(self):
     """Test that State layer outputs work on hand case."""
     simple_circuit = cirq.Circuit(cirq.X(cirq.GridQubit(0, 0)))
     true_u = np.array([[0, 1], [1, 0]], dtype=np.complex64)
     tfq_u = unitary.Unitary()(simple_circuit)
     self.assertAllClose(tfq_u, [true_u])
Ejemplo n.º 6
0
 def test_basic_inputs(self):
     """Test that State layer outputs work end to end."""
     simple_circuit = cirq.Circuit(cirq.H(cirq.GridQubit(0, 0)))
     cirq_u = cirq.unitary(simple_circuit)
     tfq_u = unitary.Unitary()(simple_circuit)
     self.assertAllClose(tfq_u, [cirq_u])
Ejemplo n.º 7
0
 def test_unitary_create(self):
     """Test that State layers can be created."""
     _ = unitary.Unitary()