def _eigen_gate_serializer(gate_type, serialized_id): """Make standard serializer for eigen gates.""" args = [ op_serializer.SerializingArg( serialized_name="exponent", serialized_type=float, op_getter=lambda x: _symbol_extractor(x.gate.exponent)), op_serializer.SerializingArg( serialized_name="exponent_scalar", serialized_type=float, op_getter=lambda x: _scalar_extractor(x.gate.exponent)), op_serializer.SerializingArg( serialized_name="global_shift", serialized_type=float, op_getter=lambda x: float(x.gate._global_shift)), op_serializer.SerializingArg( serialized_name="control_qubits", serialized_type=str, op_getter=lambda x: _serialize_controls(x)), op_serializer.SerializingArg( serialized_name="control_values", serialized_type=str, op_getter=lambda x: _serialize_control_vals(x)) ] return op_serializer.GateOpSerializer( gate_type=gate_type, serialized_gate_id=serialized_id, args=args, can_serialize_predicate=_CONSTANT_TRUE)
def test_multiple_serializers(self): """Compound serialization.""" serializer1 = op_serializer.GateOpSerializer( gate_type=cirq.XPowGate, serialized_gate_id='x_pow', args=[ op_serializer.SerializingArg(serialized_name='half_turns', serialized_type=float, op_getter='exponent') ], can_serialize_predicate=lambda x: x.gate.exponent != 1) serializer2 = op_serializer.GateOpSerializer( gate_type=cirq.XPowGate, serialized_gate_id='x', args=[ op_serializer.SerializingArg(serialized_name='half_turns', serialized_type=float, op_getter='exponent') ], can_serialize_predicate=lambda x: x.gate.exponent == 1) gate_set = serializable_gate_set.SerializableGateSet( gate_set_name='my_gate_set', serializers=[serializer1, serializer2], deserializers=[]) q0 = cirq.GridQubit(1, 1) self.assertEqual(gate_set.serialize_op(cirq.X(q0)).gate.id, 'x') self.assertEqual( gate_set.serialize_op(cirq.X(q0)**0.5).gate.id, 'x_pow')
def _identity_gate_serializer(): """Make a standard serializer for the single qubit identity.""" def _identity_check(x): if x.gate.num_qubits() != 1: raise ValueError("Multi-Qubit identity gate not supported." "Given: {}. To work around this, use " "cirq.I.on_each instead.".format(str(x))) return True # Here `args` is used for two reasons. 1. GateOpSerializer doesn't work well # with empty arg lists. 2. It is a nice way to check identity gate size. args = [ op_serializer.SerializingArg(serialized_name="unused", serialized_type=bool, op_getter=_identity_check), op_serializer.SerializingArg( serialized_name="control_qubits", serialized_type=str, op_getter=lambda x: _serialize_controls(x)), op_serializer.SerializingArg( serialized_name="control_values", serialized_type=str, op_getter=lambda x: _serialize_control_vals(x)) ] return op_serializer.GateOpSerializer( gate_type=cirq.IdentityGate, serialized_gate_id="I", args=args, can_serialize_predicate=_CONSTANT_TRUE)
def test_to_proto_not_required_ok(self): """Test non require arg absense succeeds.""" serializer = op_serializer.GateOpSerializer( gate_type=GateWithProperty, serialized_gate_id='my_gate', args=[ op_serializer.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val'), op_serializer.SerializingArg(serialized_name='not_req', serialized_type=float, op_getter='not_req', required=False) ]) expected = op_proto({ 'gate': { 'id': 'my_gate' }, 'args': { 'my_val': { 'arg_value': { 'float_value': 0.125 } } }, 'qubits': [{ 'id': '1_2' }] }) q = cirq.GridQubit(1, 2) self.assertEqual(serializer.to_proto(GateWithProperty(0.125)(q)), expected)
def _asymmetric_depolarize_serializer(): """Make standard serializer for asymmetric depolarization channel.""" args = [ # cirq channels can't contain symbols. op_serializer.SerializingArg(serialized_name="p_x", serialized_type=float, op_getter=lambda x: x.gate.p_x), op_serializer.SerializingArg(serialized_name="p_y", serialized_type=float, op_getter=lambda x: x.gate.p_y), op_serializer.SerializingArg(serialized_name="p_z", serialized_type=float, op_getter=lambda x: x.gate.p_z), op_serializer.SerializingArg(serialized_name="control_qubits", serialized_type=str, op_getter=lambda x: ''), op_serializer.SerializingArg(serialized_name="control_values", serialized_type=str, op_getter=lambda x: '') ] return op_serializer.GateOpSerializer( gate_type=cirq.AsymmetricDepolarizingChannel, serialized_gate_id="ADP", args=args, can_serialize_predicate=_CONSTANT_TRUE)
def test_to_proto_callable(self, val_type, val, arg_value): """Test callable serialization works.""" serializer = op_serializer.GateOpSerializer( gate_type=GateWithMethod, serialized_gate_id='my_gate', args=[ op_serializer.SerializingArg(serialized_name='my_val', serialized_type=val_type, op_getter=get_val) ]) q = cirq.GridQubit(1, 2) result = serializer.to_proto(GateWithMethod(val)(q), arg_function_language='linear') expected = op_proto({ 'gate': { 'id': 'my_gate' }, 'args': { 'my_val': arg_value }, 'qubits': [{ 'id': '1_2' }] }) self.assertEqual(result, expected)
def _reset_channel_serializer(): """Make standard serializer for reset channel.""" args = [ # cirq channels can't contain symbols. op_serializer.SerializingArg(serialized_name="control_qubits", serialized_type=str, op_getter=lambda x: ''), op_serializer.SerializingArg(serialized_name="control_values", serialized_type=str, op_getter=lambda x: '') ] return op_serializer.GateOpSerializer( gate_type=cirq.ResetChannel, serialized_gate_id="RST", args=args, can_serialize_predicate=_CONSTANT_TRUE)
def _bit_flip_channel_serializer(): """Make standard serializer for BitFlip channel.""" args = [ # cirq channels can't contain symbols. op_serializer.SerializingArg(serialized_name="p", serialized_type=float, op_getter=lambda x: x.gate.p), op_serializer.SerializingArg(serialized_name="control_qubits", serialized_type=str, op_getter=lambda x: ''), op_serializer.SerializingArg(serialized_name="control_values", serialized_type=str, op_getter=lambda x: '') ] return op_serializer.GateOpSerializer( gate_type=cirq.BitFlipChannel, serialized_gate_id="BF", args=args, can_serialize_predicate=_CONSTANT_TRUE)
def _phase_damp_channel_serializer(): """Make standard serializer for PhaseDamp channel.""" args = [ # cirq channels can't contain symbols. op_serializer.SerializingArg(serialized_name="gamma", serialized_type=float, op_getter=lambda x: x.gate.gamma), op_serializer.SerializingArg(serialized_name="control_qubits", serialized_type=str, op_getter=lambda x: ''), op_serializer.SerializingArg(serialized_name="control_values", serialized_type=str, op_getter=lambda x: '') ] return op_serializer.GateOpSerializer( gate_type=cirq.PhaseDampingChannel, serialized_gate_id="PD", args=args, can_serialize_predicate=_CONSTANT_TRUE)
def _fsim_gate_serializer(): """Make standard serializer for fsim gate.""" args = [ op_serializer.SerializingArg( serialized_name="theta", serialized_type=float, op_getter=lambda x: _symbol_extractor(x.gate.theta)), op_serializer.SerializingArg( serialized_name="phi", serialized_type=float, op_getter=lambda x: _symbol_extractor(x.gate.phi)), op_serializer.SerializingArg( serialized_name="theta_scalar", serialized_type=float, op_getter=lambda x: _scalar_extractor(x.gate.theta)), op_serializer.SerializingArg( serialized_name="phi_scalar", serialized_type=float, op_getter=lambda x: _scalar_extractor(x.gate.phi)), op_serializer.SerializingArg( serialized_name="control_qubits", serialized_type=str, op_getter=lambda x: _serialize_controls(x)), op_serializer.SerializingArg( serialized_name="control_values", serialized_type=str, op_getter=lambda x: _serialize_control_vals(x)) ] return op_serializer.GateOpSerializer( gate_type=cirq.FSimGate, serialized_gate_id="FSIM", args=args, can_serialize_predicate=_CONSTANT_TRUE)
def test_to_proto_required_but_not_present(self): """Test required and missing args errors.""" serializer = op_serializer.GateOpSerializer( gate_type=GateWithProperty, serialized_gate_id='my_gate', args=[ op_serializer.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter=lambda x: None) ]) q = cirq.GridQubit(1, 2) with self.assertRaisesRegex(ValueError, expected_regex='required'): serializer.to_proto(GateWithProperty(1.0)(q))
def test_to_proto_unsupported_type(self): """Test proto unsupported types errors.""" serializer = op_serializer.GateOpSerializer( gate_type=GateWithProperty, serialized_gate_id='my_gate', args=[ op_serializer.SerializingArg(serialized_name='my_val', serialized_type=bytes, op_getter='val') ]) q = cirq.GridQubit(1, 2) with self.assertRaisesRegex(ValueError, expected_regex='bytes'): serializer.to_proto(GateWithProperty(b's')(q))
def test_to_proto_type_mismatch(self, val_type, val): """Test type mismatch fails.""" serializer = op_serializer.GateOpSerializer( gate_type=GateWithProperty, serialized_gate_id='my_gate', args=[ op_serializer.SerializingArg(serialized_name='my_val', serialized_type=val_type, op_getter='val') ]) q = cirq.GridQubit(1, 2) with self.assertRaisesRegex(ValueError, expected_regex=str(type(val))): serializer.to_proto(GateWithProperty(val)(q))
def test_to_proto_no_getattr(self): """Test no op getter fails.""" serializer = op_serializer.GateOpSerializer( gate_type=GateWithProperty, serialized_gate_id='my_gate', args=[ op_serializer.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='nope') ]) q = cirq.GridQubit(1, 2) with self.assertRaisesRegex(ValueError, expected_regex='does not have'): serializer.to_proto(GateWithProperty(1.0)(q))
def test_to_proto_gate_mismatch(self): """Test proto gate mismatch errors.""" serializer = op_serializer.GateOpSerializer( gate_type=GateWithProperty, serialized_gate_id='my_gate', args=[ op_serializer.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val') ]) q = cirq.GridQubit(1, 2) with self.assertRaisesRegex( ValueError, expected_regex='GateWithAttribute.*GateWithProperty'): serializer.to_proto(GateWithAttribute(1.0)(q))
def test_can_serialize_operation_subclass(self): """Test can serialize subclass.""" serializer = op_serializer.GateOpSerializer( gate_type=GateWithAttribute, serialized_gate_id='my_gate', args=[ op_serializer.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val') ], can_serialize_predicate=lambda x: x.gate.val == 1) q = cirq.GridQubit(1, 1) self.assertTrue(serializer.can_serialize_operation(SubclassGate(1)(q))) self.assertFalse(serializer.can_serialize_operation( SubclassGate(0)(q)))
def test_is_supported_operation_can_serialize_predicate(self): """Test can_serialize predicate for operations.""" q = cirq.GridQubit(1, 2) serializer = op_serializer.GateOpSerializer( gate_type=cirq.XPowGate, serialized_gate_id='x_pow', args=[ op_serializer.SerializingArg( serialized_name='half_turns', serialized_type=float, op_getter='exponent', ) ], can_serialize_predicate=lambda x: x.gate.exponent == 1.0) gate_set = serializable_gate_set.SerializableGateSet( gate_set_name='my_gate_set', serializers=[serializer], deserializers=[X_DESERIALIZER]) self.assertTrue(gate_set.is_supported_operation(cirq.XPowGate()(q))) self.assertFalse( gate_set.is_supported_operation(cirq.XPowGate()(q)**0.5)) self.assertTrue(gate_set.is_supported_operation(cirq.X(q)))
def test_defaults_not_serialized(self): """Test defaults not serialized.""" serializer = op_serializer.GateOpSerializer( gate_type=GateWithAttribute, serialized_gate_id='my_gate', args=[ op_serializer.SerializingArg(serialized_name='my_val', serialized_type=float, default=1.0, op_getter='val') ]) q = cirq.GridQubit(1, 2) no_default = op_proto({ 'gate': { 'id': 'my_gate' }, 'args': { 'my_val': { 'arg_value': { 'float_value': 0.125 } } }, 'qubits': [{ 'id': '1_2' }] }) self.assertEqual(no_default, serializer.to_proto(GateWithAttribute(0.125)(q))) with_default = op_proto({ 'gate': { 'id': 'my_gate' }, 'qubits': [{ 'id': '1_2' }] }) self.assertEqual(with_default, serializer.to_proto(GateWithAttribute(1.0)(q)))
import tensorflow as tf import cirq from google.protobuf import json_format from tensorflow_quantum.core.serialize import op_serializer, op_deserializer, \ serializable_gate_set from tensorflow_quantum.core.proto import program_pb2 X_SERIALIZER = op_serializer.GateOpSerializer( gate_type=cirq.XPowGate, serialized_gate_id='x_pow', args=[ op_serializer.SerializingArg( serialized_name='half_turns', serialized_type=float, op_getter='exponent', ) ], ) X_DESERIALIZER = op_deserializer.GateOpDeserializer( serialized_gate_id='x_pow', gate_constructor=cirq.XPowGate, args=[ op_deserializer.DeserializingArg( serialized_name='half_turns', constructor_arg_name='exponent', ) ], )