Esempio n. 1
0
    def _args_from_proto(
            self, proto: v2.program_pb2.Operation, *,
            arg_function_language: str) -> Dict[str, arg_func_langs.ARG_LIKE]:
        return_args = {}
        for arg in self._args:
            if arg.serialized_name not in proto.args:
                if arg.default:
                    return_args[arg.constructor_arg_name] = arg.default
                    continue
                elif arg.required:
                    raise ValueError(
                        f'Argument {arg.serialized_name} '
                        'not in deserializing args, but is required.')

            value = arg_func_langs.arg_from_proto(
                proto.args[arg.serialized_name],
                arg_function_language=arg_function_language,
                required_arg_name=None
                if not arg.required else arg.serialized_name,
            )

            if arg.value_func is not None:
                value = arg.value_func(value)

            if value is not None:
                return_args[arg.constructor_arg_name] = value
        return return_args
Esempio n. 2
0
def test_missing_required_arg():
    with pytest.raises(ValueError, match='blah is missing'):
        _ = float_arg_from_proto(v2.program_pb2.FloatArg(),
                                 arg_function_language='exp',
                                 required_arg_name='blah')
    with pytest.raises(ValueError, match='unrecognized argument type'):
        _ = arg_from_proto(v2.program_pb2.Arg(),
                           arg_function_language='exp',
                           required_arg_name='blah')
    with pytest.raises(ValueError, match='Unrecognized function type '):
        _ = arg_from_proto(
            v2.program_pb2.Arg(func=v2.program_pb2.ArgFunction(type='magic')),
            arg_function_language='exp',
            required_arg_name='blah',
        )
    assert arg_from_proto(v2.program_pb2.Arg(),
                          arg_function_language='exp') is None
Esempio n. 3
0
def test_double_value():
    """Note: due to backwards compatibility, double_val conversion is one-way.
    double_val can be converted to python float,
    but a python float is converted into a float_val not a double_val.
    """
    msg = v2.program_pb2.Arg()
    msg.arg_value.double_value = 1.0
    parsed = arg_from_proto(msg, arg_function_language='')
    assert parsed == 1
Esempio n. 4
0
def test_correspondence(min_lang: str, value: ARG_LIKE,
                        proto: v2.program_pb2.Arg):
    msg = v2.program_pb2.Arg()
    json_format.ParseDict(proto, msg)
    min_i = LANGUAGE_ORDER.index(min_lang)
    for i, lang in enumerate(LANGUAGE_ORDER):
        if i < min_i:
            with pytest.raises(ValueError,
                               match='not supported by arg_function_language'):
                _ = arg_to_proto(value, arg_function_language=lang)
            with pytest.raises(ValueError, match='Unrecognized function type'):
                _ = arg_from_proto(msg, arg_function_language=lang)
        else:
            parsed = arg_from_proto(msg, arg_function_language=lang)
            packed = json_format.MessageToDict(
                arg_to_proto(value, arg_function_language=lang),
                including_default_value_fields=True,
                preserving_proto_field_name=True,
                use_integers_for_enums=True,
            )

            assert parsed == value
            assert packed == proto
Esempio n. 5
0
def test_unsupported_function_language():
    with pytest.raises(ValueError, match='Unrecognized arg_function_language'):
        _ = arg_to_proto(sympy.Symbol('a') + sympy.Symbol('b'),
                         arg_function_language='NEVER GONNAH APPEN')
    with pytest.raises(ValueError, match='Unrecognized arg_function_language'):
        _ = arg_to_proto(3 * sympy.Symbol('b'),
                         arg_function_language='NEVER GONNAH APPEN')
    with pytest.raises(ValueError, match='Unrecognized arg_function_language'):
        _ = arg_from_proto(
            v2.program_pb2.Arg(func=v2.program_pb2.ArgFunction(
                type='add',
                args=[
                    v2.program_pb2.Arg(symbol='a'),
                    v2.program_pb2.Arg(symbol='b')
                ],
            )),
            arg_function_language='NEVER GONNAH APPEN',
        )
Esempio n. 6
0
    def _deserialize_gate_op(
        self,
        operation_proto: v2.program_pb2.Operation,
        *,
        arg_function_language: str = '',
        constants: Optional[List[v2.program_pb2.Constant]] = None,
        deserialized_constants: Optional[List[Any]] = None,
    ) -> cirq.Operation:
        """Deserialize an Operation from a cirq_google.api.v2.Operation.

        Args:
            operation_proto: A dictionary representing a
                cirq.google.api.v2.Operation proto.
            arg_function_language: The `arg_function_language` field from
                `Program.Language`.
            constants: The list of Constant protos referenced by constant
                table indices in `proto`.
            deserialized_constants: The deserialized contents of `constants`.
                cirq_google.api.v2.Operation proto.

        Returns:
            The deserialized Operation.

        Raises:
            ValueError: If the operation cannot be deserialized.
        """
        if deserialized_constants is not None:
            qubits = [deserialized_constants[q] for q in operation_proto.qubit_constant_index]
        else:
            qubits = []
        for q in operation_proto.qubits:
            # Preserve previous functionality in case
            # constants table was not used
            qubits.append(v2.qubit_from_proto_id(q.id))

        which_gate_type = operation_proto.WhichOneof('gate_value')

        if which_gate_type == 'xpowgate':
            op = cirq.XPowGate(
                exponent=arg_func_langs.float_arg_from_proto(
                    operation_proto.xpowgate.exponent,
                    arg_function_language=arg_function_language,
                    required_arg_name=None,
                )
            )(*qubits)
        elif which_gate_type == 'ypowgate':
            op = cirq.YPowGate(
                exponent=arg_func_langs.float_arg_from_proto(
                    operation_proto.ypowgate.exponent,
                    arg_function_language=arg_function_language,
                    required_arg_name=None,
                )
            )(*qubits)
        elif which_gate_type == 'zpowgate':
            op = cirq.ZPowGate(
                exponent=arg_func_langs.float_arg_from_proto(
                    operation_proto.zpowgate.exponent,
                    arg_function_language=arg_function_language,
                    required_arg_name=None,
                )
            )(*qubits)
            if operation_proto.zpowgate.is_physical_z:
                op = op.with_tags(PhysicalZTag())
        elif which_gate_type == 'phasedxpowgate':
            exponent = arg_func_langs.float_arg_from_proto(
                operation_proto.phasedxpowgate.exponent,
                arg_function_language=arg_function_language,
                required_arg_name=None,
            )
            phase_exponent = arg_func_langs.float_arg_from_proto(
                operation_proto.phasedxpowgate.phase_exponent,
                arg_function_language=arg_function_language,
                required_arg_name=None,
            )
            op = cirq.PhasedXPowGate(exponent=exponent, phase_exponent=phase_exponent)(*qubits)
        elif which_gate_type == 'phasedxzgate':
            x_exponent = arg_func_langs.float_arg_from_proto(
                operation_proto.phasedxzgate.x_exponent,
                arg_function_language=arg_function_language,
                required_arg_name=None,
            )
            z_exponent = arg_func_langs.float_arg_from_proto(
                operation_proto.phasedxzgate.z_exponent,
                arg_function_language=arg_function_language,
                required_arg_name=None,
            )
            axis_phase_exponent = arg_func_langs.float_arg_from_proto(
                operation_proto.phasedxzgate.axis_phase_exponent,
                arg_function_language=arg_function_language,
                required_arg_name=None,
            )
            op = cirq.PhasedXZGate(
                x_exponent=x_exponent,
                z_exponent=z_exponent,
                axis_phase_exponent=axis_phase_exponent,
            )(*qubits)
        elif which_gate_type == 'czpowgate':
            op = cirq.CZPowGate(
                exponent=arg_func_langs.float_arg_from_proto(
                    operation_proto.czpowgate.exponent,
                    arg_function_language=arg_function_language,
                    required_arg_name=None,
                )
            )(*qubits)
        elif which_gate_type == 'iswappowgate':
            op = cirq.ISwapPowGate(
                exponent=arg_func_langs.float_arg_from_proto(
                    operation_proto.iswappowgate.exponent,
                    arg_function_language=arg_function_language,
                    required_arg_name=None,
                )
            )(*qubits)
        elif which_gate_type == 'fsimgate':
            theta = arg_func_langs.float_arg_from_proto(
                operation_proto.fsimgate.theta,
                arg_function_language=arg_function_language,
                required_arg_name=None,
            )
            phi = arg_func_langs.float_arg_from_proto(
                operation_proto.fsimgate.phi,
                arg_function_language=arg_function_language,
                required_arg_name=None,
            )
            if isinstance(theta, (int, float, sympy.Basic)) and isinstance(
                phi, (int, float, sympy.Basic)
            ):
                op = cirq.FSimGate(theta=theta, phi=phi)(*qubits)
            else:
                raise ValueError('theta and phi must be specified for FSimGate')
        elif which_gate_type == 'measurementgate':
            key = arg_func_langs.arg_from_proto(
                operation_proto.measurementgate.key,
                arg_function_language=arg_function_language,
                required_arg_name=None,
            )
            invert_mask = arg_func_langs.arg_from_proto(
                operation_proto.measurementgate.invert_mask,
                arg_function_language=arg_function_language,
                required_arg_name=None,
            )
            if isinstance(invert_mask, list) and isinstance(key, str):
                op = cirq.MeasurementGate(
                    num_qubits=len(qubits), key=key, invert_mask=tuple(invert_mask)
                )(*qubits)
            else:
                raise ValueError(f'Incorrect types for measurement gate {invert_mask} {key}')

        elif which_gate_type == 'waitgate':
            total_nanos = arg_func_langs.float_arg_from_proto(
                operation_proto.waitgate.duration_nanos,
                arg_function_language=arg_function_language,
                required_arg_name=None,
            )
            op = cirq.WaitGate(duration=cirq.Duration(nanos=total_nanos))(*qubits)
        else:
            raise ValueError(
                f'Unsupported serialized gate with type "{which_gate_type}".'
                f'\n\noperation_proto:\n{operation_proto}'
            )

        which = operation_proto.WhichOneof('token')
        if which == 'token_constant_index':
            if not constants:
                raise ValueError(
                    'Proto has references to constants table '
                    'but none was passed in, value ='
                    f'{operation_proto}'
                )
            op = op.with_tags(
                CalibrationTag(constants[operation_proto.token_constant_index].string_value)
            )
        elif which == 'token_value':
            op = op.with_tags(CalibrationTag(operation_proto.token_value))

        return op
Esempio n. 7
0
    def from_proto(
        self,
        proto: v2.program_pb2.CircuitOperation,
        *,
        arg_function_language: str = '',
        constants: List[v2.program_pb2.Constant] = None,
        deserialized_constants: List[Any] = None,
    ) -> cirq.CircuitOperation:
        """Turns a cirq.google.api.v2.CircuitOperation proto into a CircuitOperation.

        Args:
            proto: The proto object to be deserialized.
            arg_function_language: The `arg_function_language` field from
                `Program.Language`.
            constants: The list of Constant protos referenced by constant
                table indices in `proto`. This list should already have been
                parsed to produce 'deserialized_constants'.
            deserialized_constants: The deserialized contents of `constants`.

        Returns:
            The deserialized CircuitOperation represented by `proto`.

        Raises:
            ValueError: If the circuit operatio proto cannot be deserialied because it is malformed.
        """
        if constants is None or deserialized_constants is None:
            raise ValueError(
                'CircuitOp deserialization requires a constants list and a corresponding list of '
                'post-deserialization values (deserialized_constants).')
        if len(deserialized_constants) <= proto.circuit_constant_index:
            raise ValueError(
                f'Constant index {proto.circuit_constant_index} in CircuitOperation '
                'does not appear in the deserialized_constants list '
                f'(length {len(deserialized_constants)}).')
        circuit = deserialized_constants[proto.circuit_constant_index]
        if not isinstance(circuit, cirq.FrozenCircuit):
            raise ValueError(
                f'Constant at index {proto.circuit_constant_index} was expected to be a circuit, '
                f'but it has type {type(circuit)} in the deserialized_constants list.'
            )

        which_rep_spec = proto.repetition_specification.WhichOneof(
            'repetition_value')
        if which_rep_spec == 'repetition_count':
            rep_ids = None
            repetitions = proto.repetition_specification.repetition_count
        elif which_rep_spec == 'repetition_ids':
            rep_ids = proto.repetition_specification.repetition_ids.ids
            repetitions = len(rep_ids)
        else:
            rep_ids = None
            repetitions = 1

        qubit_map = {
            v2.qubit_from_proto_id(entry.key.id):
            v2.qubit_from_proto_id(entry.value.id)
            for entry in proto.qubit_map.entries
        }
        measurement_key_map = {
            entry.key.string_key: entry.value.string_key
            for entry in proto.measurement_key_map.entries
        }
        arg_map = {
            arg_func_langs.arg_from_proto(
                entry.key, arg_function_language=arg_function_language):
            arg_func_langs.arg_from_proto(
                entry.value, arg_function_language=arg_function_language)
            for entry in proto.arg_map.entries
        }

        for arg in arg_map.keys():
            if not isinstance(arg, (str, sympy.Symbol)):
                raise ValueError(
                    'Invalid key parameter type in deserialized CircuitOperation. '
                    f'Expected str or sympy.Symbol, found {type(arg)}.'
                    f'\nFull arg: {arg}')

        for arg in arg_map.values():
            if not isinstance(arg, (str, sympy.Symbol, float, int)):
                raise ValueError(
                    'Invalid value parameter type in deserialized CircuitOperation. '
                    f'Expected str, sympy.Symbol, or number; found {type(arg)}.'
                    f'\nFull arg: {arg}')

        return cirq.CircuitOperation(
            circuit,
            repetitions,
            qubit_map,
            measurement_key_map,
            arg_map,  # type: ignore
            rep_ids,
        )