def test_from_proto(self, val_type, val, arg_value): """Test from proto under many cases.""" deserializer = op_deserializer.GateOpDeserializer( serialized_gate_id='my_gate', gate_constructor=GateWithAttribute, args=[ op_deserializer.DeserializingArg( serialized_name='my_val', constructor_arg_name='val', ) ]) serialized = op_proto({ 'gate': { 'id': 'my_gate' }, 'args': { 'my_val': arg_value }, 'qubits': [{ 'id': '1_2' }] }) q = cirq.GridQubit(1, 2) result = deserializer.from_proto(serialized, arg_function_language='linear') self.assertEqual(result, GateWithAttribute(val)(q))
def _fsim_gate_deserializer(): """Make standard deserializer for fsim gate.""" def _scalar_combiner(theta, theta_scalar, phi, phi_scalar, control_qubits, control_values): """This is a workaround to support symbol scalar multiplication. See `_eigen_gate_deserializer` for details. """ return _optional_control_promote( cirq.FSimGate(theta=_round(theta) * _round(theta_scalar), phi=_round(phi) * _round(phi_scalar)), control_qubits, control_values) args = [ op_deserializer.DeserializingArg(serialized_name="theta", constructor_arg_name="theta"), op_deserializer.DeserializingArg(serialized_name="phi", constructor_arg_name="phi"), op_deserializer.DeserializingArg(serialized_name="theta_scalar", constructor_arg_name="theta_scalar"), op_deserializer.DeserializingArg(serialized_name="phi_scalar", constructor_arg_name="phi_scalar"), op_deserializer.DeserializingArg(serialized_name="control_qubits", constructor_arg_name="control_qubits"), op_deserializer.DeserializingArg(serialized_name="control_values", constructor_arg_name="control_values") ] return op_deserializer.GateOpDeserializer(serialized_gate_id="FSIM", gate_constructor=_scalar_combiner, args=args)
def test_from_proto_required_arg_not_assigned(self): """Error if required arg isn't assigned.""" deserializer = op_deserializer.GateOpDeserializer( serialized_gate_id='my_gate', gate_constructor=GateWithAttribute, args=[ op_deserializer.DeserializingArg( serialized_name='my_val', constructor_arg_name='val', ), op_deserializer.DeserializingArg( serialized_name='not_req', constructor_arg_name='not_req', required=False) ]) serialized = op_proto({ 'gate': { 'id': 'my_gate' }, 'args': { 'my_val': {} }, 'qubits': [{ 'id': '1_2' }] }) with self.assertRaises(ValueError): deserializer.from_proto(serialized)
def test_defaults(self): """Ensure default values still deserialize.""" deserializer = op_deserializer.GateOpDeserializer( serialized_gate_id='my_gate', gate_constructor=GateWithAttribute, args=[ op_deserializer.DeserializingArg(serialized_name='my_val', constructor_arg_name='val', default=1.0), op_deserializer.DeserializingArg( serialized_name='not_req', constructor_arg_name='not_req', default='hello', required=False) ]) serialized = op_proto({ 'gate': { 'id': 'my_gate' }, 'args': {}, 'qubits': [{ 'id': '1_2' }] }) g = GateWithAttribute(1.0) g.not_req = 'hello' self.assertEqual(deserializer.from_proto(serialized), g(cirq.GridQubit(1, 2)))
def test_from_proto_not_required_ok(self): """Deserialization succeeds for missing not required fields.""" deserializer = op_deserializer.GateOpDeserializer( serialized_gate_id='my_gate', gate_constructor=GateWithAttribute, args=[ op_deserializer.DeserializingArg( serialized_name='my_val', constructor_arg_name='val', ), op_deserializer.DeserializingArg( serialized_name='not_req', constructor_arg_name='not_req', required=False) ]) serialized = op_proto({ 'gate': { 'id': 'my_gate' }, 'args': { 'my_val': { 'arg_value': { 'float_value': 0.125 } } }, 'qubits': [{ 'id': '1_2' }] }) q = cirq.GridQubit(1, 2) result = deserializer.from_proto(serialized) self.assertEqual(result, GateWithAttribute(0.125)(q))
def test_from_proto_missing_required_arg(self): """Error raised when required field is missing.""" deserializer = op_deserializer.GateOpDeserializer( serialized_gate_id='my_gate', gate_constructor=GateWithAttribute, args=[ op_deserializer.DeserializingArg( serialized_name='my_val', constructor_arg_name='val', ), op_deserializer.DeserializingArg( serialized_name='not_req', constructor_arg_name='not_req', required=False) ]) serialized = op_proto({ 'gate': { 'id': 'my_gate' }, 'args': { 'not_req': { 'arg_value': { 'float_value': 0.125 } } }, 'qubits': [{ 'id': '1_2' }] }) with self.assertRaises(ValueError): deserializer.from_proto(serialized)
def test_from_proto_value_func(self): """Test value func deserialization in simple case.""" deserializer = op_deserializer.GateOpDeserializer( serialized_gate_id='my_gate', gate_constructor=GateWithAttribute, args=[ op_deserializer.DeserializingArg(serialized_name='my_val', constructor_arg_name='val', value_func=lambda x: x + 1) ]) serialized = op_proto({ 'gate': { 'id': 'my_gate' }, 'args': { 'my_val': { 'arg_value': { 'float_value': 0.125 } } }, 'qubits': [{ 'id': '1_2' }] }) q = cirq.GridQubit(1, 2) result = deserializer.from_proto(serialized) self.assertEqual(result, GateWithAttribute(1.125)(q))
def test_from_proto_value_type_not_recognized(self): """Ensure unrecognized value type errors.""" deserializer = op_deserializer.GateOpDeserializer( serialized_gate_id='my_gate', gate_constructor=GateWithAttribute, args=[ op_deserializer.DeserializingArg( serialized_name='my_val', constructor_arg_name='val', ) ]) serialized = op_proto({ 'gate': { 'id': 'my_gate' }, 'args': { 'my_val': { 'arg_value': {}, } }, 'qubits': [{ 'id': '1_2' }] }) with self.assertRaisesRegex(ValueError, expected_regex='Unrecognized value type'): _ = deserializer.from_proto(serialized)
def test_from_proto_required_missing(self): """Test error raised when required is missing.""" deserializer = op_deserializer.GateOpDeserializer( serialized_gate_id='my_gate', gate_constructor=GateWithAttribute, args=[ op_deserializer.DeserializingArg( serialized_name='my_val', constructor_arg_name='val', ) ]) serialized = op_proto({ 'gate': { 'id': 'my_gate' }, 'args': { 'not_my_val': { 'arg_value': { 'float_value': 0.125 } } }, 'qubits': [{ 'id': '1_2' }] }) with self.assertRaisesRegex(Exception, expected_regex='my_val'): deserializer.from_proto(serialized)
def _bit_flip_channel_deserializer(): """Make standard deserializer for BitFlip channel.""" args = [ op_deserializer.DeserializingArg(serialized_name="p", constructor_arg_name="p") ] return op_deserializer.GateOpDeserializer( serialized_gate_id="BF", gate_constructor=cirq.BitFlipChannel, args=args)
def _phase_damp_channel_deserializer(): """Make standard deserializer for PhaseDamp channel.""" args = [ op_deserializer.DeserializingArg(serialized_name="gamma", constructor_arg_name="gamma") ] return op_deserializer.GateOpDeserializer( serialized_gate_id="PD", gate_constructor=cirq.PhaseDampingChannel, args=args)
def _amplitude_damp_channel_deserializer(): """Make standard deserializer for depolarization channel.""" args = [ op_deserializer.DeserializingArg(serialized_name="gamma", constructor_arg_name="gamma") ] return op_deserializer.GateOpDeserializer( serialized_gate_id="AD", gate_constructor=cirq.AmplitudeDampingChannel, args=args)
def _depolarize_channel_deserializer(): """Make standard deserializer for depolarization channel.""" args = [ op_deserializer.DeserializingArg(serialized_name="p", constructor_arg_name="p") ] return op_deserializer.GateOpDeserializer( serialized_gate_id="DP", gate_constructor=cirq.DepolarizingChannel, args=args)
def _gad_channel_deserializer(): """Make standard deserializer for GeneralizedAmplitudeDamping.""" args = [ op_deserializer.DeserializingArg(serialized_name="p", constructor_arg_name="p"), op_deserializer.DeserializingArg(serialized_name="gamma", constructor_arg_name="gamma") ] return op_deserializer.GateOpDeserializer( serialized_gate_id="GAD", gate_constructor=cirq.GeneralizedAmplitudeDampingChannel, args=args)
def _phased_eigen_gate_deserializer(gate_type, serialized_id): """Make a standard deserializer for phased eigen gates.""" def _scalar_combiner(exponent, global_shift, exponent_scalar, phase_exponent, phase_exponent_scalar, control_qubits, control_values): """This is a workaround to support symbol scalar multiplication. In the future we should likely get rid of this in favor of proper expression parsing once cirq supports it. See cirq.op_serializer and cirq's program protobuf for details. This is needed for things like cirq.rx('alpha'). """ exponent = _round(exponent) phase_exponent = _round(phase_exponent) exponent = exponent if exponent_scalar == 1.0 \ else exponent * _round(exponent_scalar) phase_exponent = phase_exponent if phase_exponent_scalar == 1.0 \ else phase_exponent * _round(phase_exponent_scalar) if global_shift != 0: # needed in case this specific phasedeigengate doesn't # have a global_phase in constructor. return _optional_control_promote( gate_type(exponent=exponent, global_shift=_round(global_shift), phase_exponent=phase_exponent), control_qubits, control_values) return _optional_control_promote( gate_type(exponent=exponent, phase_exponent=phase_exponent), control_qubits, control_values) args = [ op_deserializer.DeserializingArg(serialized_name="phase_exponent", constructor_arg_name="phase_exponent"), op_deserializer.DeserializingArg( serialized_name="phase_exponent_scalar", constructor_arg_name="phase_exponent_scalar"), op_deserializer.DeserializingArg(serialized_name="exponent", constructor_arg_name="exponent"), op_deserializer.DeserializingArg( serialized_name="exponent_scalar", constructor_arg_name="exponent_scalar"), op_deserializer.DeserializingArg(serialized_name="global_shift", constructor_arg_name="global_shift"), op_deserializer.DeserializingArg(serialized_name="control_qubits", constructor_arg_name="control_qubits"), op_deserializer.DeserializingArg(serialized_name="control_values", constructor_arg_name="control_values") ] return op_deserializer.GateOpDeserializer(serialized_gate_id=serialized_id, gate_constructor=_scalar_combiner, args=args)
def _asymmetric_depolarize_deserializer(): """Make standard deserializer for asymmetric depolarization channel.""" args = [ op_deserializer.DeserializingArg(serialized_name="p_x", constructor_arg_name="p_x"), op_deserializer.DeserializingArg(serialized_name="p_y", constructor_arg_name="p_y"), op_deserializer.DeserializingArg(serialized_name="p_z", constructor_arg_name="p_z") ] return op_deserializer.GateOpDeserializer( serialized_gate_id="ADP", gate_constructor=cirq.AsymmetricDepolarizingChannel, args=args)
def _identity_gate_deserializer(): """Make a standard deserializer for the single qubit identity.""" args = [ op_deserializer.DeserializingArg(serialized_name="unused", constructor_arg_name="unused"), op_deserializer.DeserializingArg(serialized_name="control_qubits", constructor_arg_name="control_qubits"), op_deserializer.DeserializingArg(serialized_name="control_values", constructor_arg_name="control_values") ] def _cirq_i_workaround(unused, control_qubits, control_values): return _optional_control_promote(cirq.I, control_qubits, control_values) return op_deserializer.GateOpDeserializer( serialized_gate_id="I", gate_constructor=_cirq_i_workaround, args=args)
def test_from_proto_unknown_function(self): """Unknown function throws error when deserializing.""" deserializer = op_deserializer.GateOpDeserializer( serialized_gate_id='my_gate', gate_constructor=GateWithAttribute, args=[ op_deserializer.DeserializingArg( serialized_name='my_val', constructor_arg_name='val', ) ]) serialized = op_proto({ 'gate': { 'id': 'my_gate' }, 'args': { 'my_val': { 'func': { 'type': 'UNKNOWN_OPERATION', 'args': [ { 'symbol': 'x' }, { 'arg_value': { 'float_value': -1.0 } }, ] } } }, 'qubits': [{ 'id': '1_2' }] }) with self.assertRaisesRegex( ValueError, expected_regex='Unrecognized function type'): _ = deserializer.from_proto(serialized)
def test_from_proto_function_argument_not_set(self): """Ensure unset function arguments error when deserializing.""" deserializer = op_deserializer.GateOpDeserializer( serialized_gate_id='my_gate', gate_constructor=GateWithAttribute, args=[ op_deserializer.DeserializingArg( serialized_name='my_val', constructor_arg_name='val', ) ]) serialized = op_proto({ 'gate': { 'id': 'my_gate' }, 'args': { 'my_val': { 'func': { 'type': 'mul', 'args': [ { 'symbol': 'x' }, {}, ] } } }, 'qubits': [{ 'id': '1_2' }] }) with self.assertRaisesRegex( ValueError, expected_regex='A multiplication argument is missing'): _ = deserializer.from_proto(serialized, arg_function_language='linear')
gate_type=cirq.XPowGate, serialized_gate_id='x_pow', args=[ op_serializer.SerializingArg( serialized_name='half_turns', serialized_type=float, op_getter='exponent', ) ], ) X_DESERIALIZER = op_deserializer.GateOpDeserializer( serialized_gate_id='x_pow', gate_constructor=cirq.XPowGate, args=[ op_deserializer.DeserializingArg( serialized_name='half_turns', constructor_arg_name='exponent', ) ], ) Y_SERIALIZER = op_serializer.GateOpSerializer( gate_type=cirq.YPowGate, serialized_gate_id='y_pow', args=[ op_serializer.SerializingArg( serialized_name='half_turns', serialized_type=float, op_getter='exponent', ) ],
def _reset_channel_deserializer(): """Make standard deserializer for reset channel.""" args = [] return op_deserializer.GateOpDeserializer( serialized_gate_id="RST", gate_constructor=cirq.ResetChannel, args=args)