示例#1
0
 def test_state_create(self):
     """Test that State layers can be created."""
     state.State()
     state.State(backend=cirq.Simulator())
     with self.assertRaisesRegex(TypeError,
                                 expected_regex="junk is invalid"):
         state.State('junk')
示例#2
0
    def test_state_invalid_combinations(self, backend):
        """Test with valid type inputs and valid value, but incorrect combo."""
        state_calc = state.State(backend)
        symbol = sympy.Symbol('alpha')
        circuit = cirq.Circuit(cirq.H(cirq.GridQubit(0, 0))**symbol)
        with self.assertRaisesRegex(Exception, expected_regex=""):
            # no value provided.
            state_calc([circuit, circuit], symbol_names=[symbol], repetitions=5)

        with self.assertRaisesRegex(Exception, expected_regex=""):
            # no name provided.
            state_calc([circuit, circuit],
                       symbol_names=[],
                       symbol_values=[[2.0], [3.0]])

        with self.assertRaisesRegex(Exception, expected_regex=""):
            # deceptive, but the circuit shouldn't be in a list. otherwise fine.
            state_calc([circuit],
                       symbol_names=['alpha'],
                       symbol_values=[[2.0], [3.0]])

        with self.assertRaisesRegex(Exception, expected_regex=""):
            # wrong symbol name.
            state_calc([circuit],
                       symbol_names=['alphaaaa'],
                       symbol_values=[[2.0], [3.0]])

        with self.assertRaisesRegex(Exception, expected_regex=""):
            # too many symbol values provided.
            state_calc(circuit,
                       symbol_names=['alpha'],
                       symbol_values=[[2.0, 4.0], [3.0, 5.0]])
示例#3
0
    def test_state_invalid_type_inputs(self):
        """Test that state rejects bad inputs."""
        state_calc = state.State()
        with self.assertRaisesRegex(TypeError,
                                    expected_regex="circuits cannot be parsed"):
            state_calc('junk_circuit')

        with self.assertRaisesRegex(
                TypeError, expected_regex="symbol_values cannot be parsed"):
            state_calc(cirq.Circuit(), symbol_values='junk')

        with self.assertRaisesRegex(
                TypeError, expected_regex="symbol_names cannot be parsed"):
            state_calc(cirq.Circuit(), symbol_values=[], symbol_names='junk')

        with self.assertRaisesRegex(TypeError, expected_regex="Cannot convert"):
            state_calc(cirq.Circuit(),
                       symbol_values=[['bad']],
                       symbol_names=['name'])

        with self.assertRaisesRegex(TypeError,
                                    expected_regex="must be a string."):
            state_calc(cirq.Circuit(),
                       symbol_values=[[0.5]],
                       symbol_names=[0.33333])

        with self.assertRaisesRegex(ValueError,
                                    expected_regex="must be unique."):
            state_calc(cirq.Circuit(),
                       symbol_values=[[0.5]],
                       symbol_names=['duplicate', 'duplicate'])
示例#4
0
 def test_state_basic_inputs(self):
     """Test that state ingests inputs correctly in simple settings."""
     state_calc = state.State()
     state_calc(cirq.Circuit())
     state_calc([cirq.Circuit()])
     state_calc(cirq.Circuit(), symbol_names=['name'], symbol_values=[[0.5]])
     state_calc(cirq.Circuit(),
                symbol_names=[sympy.Symbol('name')],
                symbol_values=[[0.5]])
示例#5
0
    def test_state_invalid_shape_inputs(self):
        """Test that state rejects bad input shapes."""
        state_calc = state.State()
        with self.assertRaisesRegex(TypeError,
                                    expected_regex="string or sympy.Symbol"):
            state_calc(cirq.Circuit(), symbol_values=[[0.5]], symbol_names=[[]])

        with self.assertRaisesRegex(Exception, expected_regex="rank 1"):
            state_calc(cirq.Circuit(),
                       symbol_values=[0.5],
                       symbol_names=['name'])

        with self.assertRaisesRegex(Exception, expected_regex="rank 2"):
            state_calc([[cirq.Circuit()]],
                       symbol_values=[[0.5]],
                       symbol_names=['name'])
示例#6
0
    def test_state_output(self, backend_output):
        """Check that any output type is as expected.

        This layer only allows for 2 different outputs, depending on whether a
        wavefuntion or density matrix simulator is used. Therefore any pre or
        post processing done inside the layers should not cause output from the
        layer to structurally deviate from what is expected.
        """
        backend = backend_output[0]
        output = backend_output[1]
        state_executor = state.State(backend=backend)
        bits = cirq.GridQubit.rect(1, 2)
        circuit = cirq.Circuit()
        circuit.append(cirq.H.on(bits[0]))
        circuit.append(cirq.CNOT(bits[0], bits[1]))
        programs = util.convert_to_tensor([circuit, circuit])
        layer_output = state_executor(programs).to_list()
        self.assertAllClose(layer_output, [output, output])
示例#7
0
 def test_sample_outputs_simple(self):
     """Test the simplest call where nothing but circuits are provided."""
     state_calc = state.State()
     circuit = cirq.Circuit(cirq.H(cirq.GridQubit(0, 0)))
     output = state_calc([circuit, circuit])
     self.assertShapeEqual(np.empty((2, 2)), output.to_tensor())
示例#8
0
 def test_state_one_circuit(self):
     """Test that State behaves when a single layer is specified."""
     state_calc = state.State()
     state_calc(cirq.Circuit(),
                symbol_values=tf.zeros((5, 0), dtype=tf.dtypes.float32))