def test_multiple_serializers(self):
     """Compound serialization."""
     serializer1 = op_serializer.GateOpSerializer(
         gate_type=cirq.XPowGate,
         serialized_gate_id='x_pow',
         args=[
             op_serializer.SerializingArg(serialized_name='half_turns',
                                          serialized_type=float,
                                          op_getter='exponent')
         ],
         can_serialize_predicate=lambda x: x.gate.exponent != 1)
     serializer2 = op_serializer.GateOpSerializer(
         gate_type=cirq.XPowGate,
         serialized_gate_id='x',
         args=[
             op_serializer.SerializingArg(serialized_name='half_turns',
                                          serialized_type=float,
                                          op_getter='exponent')
         ],
         can_serialize_predicate=lambda x: x.gate.exponent == 1)
     gate_set = serializable_gate_set.SerializableGateSet(
         gate_set_name='my_gate_set',
         serializers=[serializer1, serializer2],
         deserializers=[])
     q0 = cirq.GridQubit(1, 1)
     self.assertEqual(gate_set.serialize_op(cirq.X(q0)).gate.id, 'x')
     self.assertEqual(
         gate_set.serialize_op(cirq.X(q0)**0.5).gate.id, 'x_pow')
Example #2
0
def _identity_gate_serializer():
    """Make a standard serializer for the single qubit identity."""

    def _identity_check(x):
        if x.gate.num_qubits() != 1:
            raise ValueError("Multi-Qubit identity gate not supported."
                             "Given: {}. To work around this, use "
                             "cirq.I.on_each instead.".format(str(x)))
        return True

    # Here `args` is used for two reasons. 1. GateOpSerializer doesn't work well
    # with empty arg lists. 2. It is a nice way to check identity gate size.
    args = [
        op_serializer.SerializingArg(serialized_name="unused",
                                     serialized_type=bool,
                                     op_getter=_identity_check),
        op_serializer.SerializingArg(
            serialized_name="control_qubits",
            serialized_type=str,
            op_getter=lambda x: _serialize_controls(x)),
        op_serializer.SerializingArg(
            serialized_name="control_values",
            serialized_type=str,
            op_getter=lambda x: _serialize_control_vals(x))
    ]
    return op_serializer.GateOpSerializer(
        gate_type=cirq.IdentityGate,
        serialized_gate_id="I",
        args=args,
        can_serialize_predicate=_CONSTANT_TRUE)
Example #3
0
def _fsim_gate_serializer():
    """Make standard serializer for fsim gate."""

    args = [
        op_serializer.SerializingArg(
            serialized_name="theta",
            serialized_type=float,
            op_getter=lambda x: _symbol_extractor(x.gate.theta)),
        op_serializer.SerializingArg(
            serialized_name="phi",
            serialized_type=float,
            op_getter=lambda x: _symbol_extractor(x.gate.phi)),
        op_serializer.SerializingArg(
            serialized_name="theta_scalar",
            serialized_type=float,
            op_getter=lambda x: _scalar_extractor(x.gate.theta)),
        op_serializer.SerializingArg(
            serialized_name="phi_scalar",
            serialized_type=float,
            op_getter=lambda x: _scalar_extractor(x.gate.phi)),
        op_serializer.SerializingArg(
            serialized_name="control_qubits",
            serialized_type=str,
            op_getter=lambda x: _serialize_controls(x)),
        op_serializer.SerializingArg(
            serialized_name="control_values",
            serialized_type=str,
            op_getter=lambda x: _serialize_control_vals(x))
    ]
    return op_serializer.GateOpSerializer(
        gate_type=cirq.FSimGate,
        serialized_gate_id="FSIM",
        args=args,
        can_serialize_predicate=_CONSTANT_TRUE)
Example #4
0
def _eigen_gate_serializer(gate_type, serialized_id):
    """Make standard serializer for eigen gates."""

    args = [
        op_serializer.SerializingArg(
            serialized_name="exponent",
            serialized_type=float,
            op_getter=lambda x: _symbol_extractor(x.gate.exponent)),
        op_serializer.SerializingArg(
            serialized_name="exponent_scalar",
            serialized_type=float,
            op_getter=lambda x: _scalar_extractor(x.gate.exponent)),
        op_serializer.SerializingArg(
            serialized_name="global_shift",
            serialized_type=float,
            op_getter=lambda x: float(x.gate._global_shift)),
        op_serializer.SerializingArg(
            serialized_name="control_qubits",
            serialized_type=str,
            op_getter=lambda x: _serialize_controls(x)),
        op_serializer.SerializingArg(
            serialized_name="control_values",
            serialized_type=str,
            op_getter=lambda x: _serialize_control_vals(x))
    ]
    return op_serializer.GateOpSerializer(
        gate_type=gate_type,
        serialized_gate_id=serialized_id,
        args=args,
        can_serialize_predicate=_CONSTANT_TRUE)
Example #5
0
    def test_to_proto_not_required_ok(self):
        """Test non require arg absense succeeds."""
        serializer = op_serializer.GateOpSerializer(
            gate_type=GateWithProperty,
            serialized_gate_id='my_gate',
            args=[
                op_serializer.SerializingArg(serialized_name='my_val',
                                             serialized_type=float,
                                             op_getter='val'),
                op_serializer.SerializingArg(serialized_name='not_req',
                                             serialized_type=float,
                                             op_getter='not_req',
                                             required=False)
            ])
        expected = op_proto({
            'gate': {
                'id': 'my_gate'
            },
            'args': {
                'my_val': {
                    'arg_value': {
                        'float_value': 0.125
                    }
                }
            },
            'qubits': [{
                'id': '1_2'
            }]
        })

        q = cirq.GridQubit(1, 2)
        self.assertEqual(serializer.to_proto(GateWithProperty(0.125)(q)),
                         expected)
Example #6
0
 def test_to_proto_callable(self, val_type, val, arg_value):
     """Test callable serialization works."""
     serializer = op_serializer.GateOpSerializer(
         gate_type=GateWithMethod,
         serialized_gate_id='my_gate',
         args=[
             op_serializer.SerializingArg(serialized_name='my_val',
                                          serialized_type=val_type,
                                          op_getter=get_val)
         ])
     q = cirq.GridQubit(1, 2)
     result = serializer.to_proto(GateWithMethod(val)(q),
                                  arg_function_language='linear')
     expected = op_proto({
         'gate': {
             'id': 'my_gate'
         },
         'args': {
             'my_val': arg_value
         },
         'qubits': [{
             'id': '1_2'
         }]
     })
     self.assertEqual(result, expected)
Example #7
0
def _asymmetric_depolarize_serializer():
    """Make standard serializer for asymmetric depolarization channel."""
    args = [
        # cirq channels can't contain symbols.
        op_serializer.SerializingArg(serialized_name="p_x",
                                     serialized_type=float,
                                     op_getter=lambda x: x.gate.p_x),
        op_serializer.SerializingArg(serialized_name="p_y",
                                     serialized_type=float,
                                     op_getter=lambda x: x.gate.p_y),
        op_serializer.SerializingArg(serialized_name="p_z",
                                     serialized_type=float,
                                     op_getter=lambda x: x.gate.p_z),
        op_serializer.SerializingArg(serialized_name="control_qubits",
                                     serialized_type=str,
                                     op_getter=lambda x: ''),
        op_serializer.SerializingArg(serialized_name="control_values",
                                     serialized_type=str,
                                     op_getter=lambda x: '')
    ]
    return op_serializer.GateOpSerializer(
        gate_type=cirq.AsymmetricDepolarizingChannel,
        serialized_gate_id="ADP",
        args=args,
        can_serialize_predicate=_CONSTANT_TRUE)
Example #8
0
 def test_to_proto_unsupported_type(self):
     """Test proto unsupported types errors."""
     serializer = op_serializer.GateOpSerializer(
         gate_type=GateWithProperty,
         serialized_gate_id='my_gate',
         args=[
             op_serializer.SerializingArg(serialized_name='my_val',
                                          serialized_type=bytes,
                                          op_getter='val')
         ])
     q = cirq.GridQubit(1, 2)
     with self.assertRaisesRegex(ValueError, expected_regex='bytes'):
         serializer.to_proto(GateWithProperty(b's')(q))
Example #9
0
 def test_to_proto_type_mismatch(self, val_type, val):
     """Test type mismatch fails."""
     serializer = op_serializer.GateOpSerializer(
         gate_type=GateWithProperty,
         serialized_gate_id='my_gate',
         args=[
             op_serializer.SerializingArg(serialized_name='my_val',
                                          serialized_type=val_type,
                                          op_getter='val')
         ])
     q = cirq.GridQubit(1, 2)
     with self.assertRaisesRegex(ValueError, expected_regex=str(type(val))):
         serializer.to_proto(GateWithProperty(val)(q))
Example #10
0
 def test_to_proto_required_but_not_present(self):
     """Test required and missing args errors."""
     serializer = op_serializer.GateOpSerializer(
         gate_type=GateWithProperty,
         serialized_gate_id='my_gate',
         args=[
             op_serializer.SerializingArg(serialized_name='my_val',
                                          serialized_type=float,
                                          op_getter=lambda x: None)
         ])
     q = cirq.GridQubit(1, 2)
     with self.assertRaisesRegex(ValueError, expected_regex='required'):
         serializer.to_proto(GateWithProperty(1.0)(q))
Example #11
0
 def test_to_proto_no_getattr(self):
     """Test no op getter fails."""
     serializer = op_serializer.GateOpSerializer(
         gate_type=GateWithProperty,
         serialized_gate_id='my_gate',
         args=[
             op_serializer.SerializingArg(serialized_name='my_val',
                                          serialized_type=float,
                                          op_getter='nope')
         ])
     q = cirq.GridQubit(1, 2)
     with self.assertRaisesRegex(ValueError,
                                 expected_regex='does not have'):
         serializer.to_proto(GateWithProperty(1.0)(q))
Example #12
0
 def test_can_serialize_operation_subclass(self):
     """Test can serialize subclass."""
     serializer = op_serializer.GateOpSerializer(
         gate_type=GateWithAttribute,
         serialized_gate_id='my_gate',
         args=[
             op_serializer.SerializingArg(serialized_name='my_val',
                                          serialized_type=float,
                                          op_getter='val')
         ],
         can_serialize_predicate=lambda x: x.gate.val == 1)
     q = cirq.GridQubit(1, 1)
     self.assertTrue(serializer.can_serialize_operation(SubclassGate(1)(q)))
     self.assertFalse(serializer.can_serialize_operation(
         SubclassGate(0)(q)))
Example #13
0
 def test_to_proto_gate_mismatch(self):
     """Test proto gate mismatch errors."""
     serializer = op_serializer.GateOpSerializer(
         gate_type=GateWithProperty,
         serialized_gate_id='my_gate',
         args=[
             op_serializer.SerializingArg(serialized_name='my_val',
                                          serialized_type=float,
                                          op_getter='val')
         ])
     q = cirq.GridQubit(1, 2)
     with self.assertRaisesRegex(
             ValueError,
             expected_regex='GateWithAttribute.*GateWithProperty'):
         serializer.to_proto(GateWithAttribute(1.0)(q))
Example #14
0
def _reset_channel_serializer():
    """Make standard serializer for reset channel."""

    args = [
        # cirq channels can't contain symbols.
        op_serializer.SerializingArg(serialized_name="control_qubits",
                                     serialized_type=str,
                                     op_getter=lambda x: ''),
        op_serializer.SerializingArg(serialized_name="control_values",
                                     serialized_type=str,
                                     op_getter=lambda x: '')
    ]
    return op_serializer.GateOpSerializer(
        gate_type=cirq.ResetChannel,
        serialized_gate_id="RST",
        args=args,
        can_serialize_predicate=_CONSTANT_TRUE)
Example #15
0
def _phase_damp_channel_serializer():
    """Make standard serializer for PhaseDamp channel."""
    args = [
        # cirq channels can't contain symbols.
        op_serializer.SerializingArg(serialized_name="gamma",
                                     serialized_type=float,
                                     op_getter=lambda x: x.gate.gamma),
        op_serializer.SerializingArg(serialized_name="control_qubits",
                                     serialized_type=str,
                                     op_getter=lambda x: ''),
        op_serializer.SerializingArg(serialized_name="control_values",
                                     serialized_type=str,
                                     op_getter=lambda x: '')
    ]
    return op_serializer.GateOpSerializer(
        gate_type=cirq.PhaseDampingChannel,
        serialized_gate_id="PD",
        args=args,
        can_serialize_predicate=_CONSTANT_TRUE)
Example #16
0
def _bit_flip_channel_serializer():
    """Make standard serializer for BitFlip channel."""
    args = [
        # cirq channels can't contain symbols.
        op_serializer.SerializingArg(serialized_name="p",
                                     serialized_type=float,
                                     op_getter=lambda x: x.gate.p),
        op_serializer.SerializingArg(serialized_name="control_qubits",
                                     serialized_type=str,
                                     op_getter=lambda x: ''),
        op_serializer.SerializingArg(serialized_name="control_values",
                                     serialized_type=str,
                                     op_getter=lambda x: '')
    ]
    return op_serializer.GateOpSerializer(
        gate_type=cirq.BitFlipChannel,
        serialized_gate_id="BF",
        args=args,
        can_serialize_predicate=_CONSTANT_TRUE)
 def test_is_supported_operation_can_serialize_predicate(self):
     """Test can_serialize predicate for operations."""
     q = cirq.GridQubit(1, 2)
     serializer = op_serializer.GateOpSerializer(
         gate_type=cirq.XPowGate,
         serialized_gate_id='x_pow',
         args=[
             op_serializer.SerializingArg(
                 serialized_name='half_turns',
                 serialized_type=float,
                 op_getter='exponent',
             )
         ],
         can_serialize_predicate=lambda x: x.gate.exponent == 1.0)
     gate_set = serializable_gate_set.SerializableGateSet(
         gate_set_name='my_gate_set',
         serializers=[serializer],
         deserializers=[X_DESERIALIZER])
     self.assertTrue(gate_set.is_supported_operation(cirq.XPowGate()(q)))
     self.assertFalse(
         gate_set.is_supported_operation(cirq.XPowGate()(q)**0.5))
     self.assertTrue(gate_set.is_supported_operation(cirq.X(q)))
Example #18
0
 def test_defaults_not_serialized(self):
     """Test defaults not serialized."""
     serializer = op_serializer.GateOpSerializer(
         gate_type=GateWithAttribute,
         serialized_gate_id='my_gate',
         args=[
             op_serializer.SerializingArg(serialized_name='my_val',
                                          serialized_type=float,
                                          default=1.0,
                                          op_getter='val')
         ])
     q = cirq.GridQubit(1, 2)
     no_default = op_proto({
         'gate': {
             'id': 'my_gate'
         },
         'args': {
             'my_val': {
                 'arg_value': {
                     'float_value': 0.125
                 }
             }
         },
         'qubits': [{
             'id': '1_2'
         }]
     })
     self.assertEqual(no_default,
                      serializer.to_proto(GateWithAttribute(0.125)(q)))
     with_default = op_proto({
         'gate': {
             'id': 'my_gate'
         },
         'qubits': [{
             'id': '1_2'
         }]
     })
     self.assertEqual(with_default,
                      serializer.to_proto(GateWithAttribute(1.0)(q)))
"""Test serializable_gat_set.py functionality."""

import tensorflow as tf

import cirq
from google.protobuf import json_format
from tensorflow_quantum.core.serialize import op_serializer, op_deserializer, \
    serializable_gate_set
from tensorflow_quantum.core.proto import program_pb2

X_SERIALIZER = op_serializer.GateOpSerializer(
    gate_type=cirq.XPowGate,
    serialized_gate_id='x_pow',
    args=[
        op_serializer.SerializingArg(
            serialized_name='half_turns',
            serialized_type=float,
            op_getter='exponent',
        )
    ],
)

X_DESERIALIZER = op_deserializer.GateOpDeserializer(
    serialized_gate_id='x_pow',
    gate_constructor=cirq.XPowGate,
    args=[
        op_deserializer.DeserializingArg(
            serialized_name='half_turns',
            constructor_arg_name='exponent',
        )
    ],