コード例 #1
0
def get_socket(
    wire: WireSpec, message_classes: Dict[str, Type[Message]], config: Message
) -> Optional[Socket]:
    if wire.optional:
        if not config.HasField(wire.name):
            return None
    else:
        assert config.HasField(wire.name), f"Config unset: {wire.name}"

    wire_type = message_classes[wire.type]
    socket_type_fd = next(
        (
            fd
            for fd in wire_type.DESCRIPTOR.fields
            if fd.type == FieldDescriptor.TYPE_MESSAGE
            and fd.message_type.full_name == net_pb2.Socket.DESCRIPTOR.full_name
        ),
        None,
    )
    if socket_type_fd is None:
        return None

    protocol = wire_pb2.Mark.TCP
    for _, option in socket_type_fd.GetOptions().ListFields():
        if isinstance(option, wire_pb2.Mark):
            protocol = option.protocol

    wire_value = getattr(config, wire.name)
    socket_value = getattr(wire_value, socket_type_fd.name)
    return Socket(host=socket_value.host, port=socket_value.port, _protocol=protocol)
コード例 #2
0
    def _from_proto(self, proto: message.Message) -> "BaseModelCardField":
        """Convert proto to this class object."""
        if not isinstance(proto, self._proto_type):
            raise TypeError("%s is expected. However %s is provided." %
                            (self._proto_type, type(proto)))

        for field_descriptor in proto.DESCRIPTOR.fields:
            field_name = field_descriptor.name
            if not hasattr(self, field_name):
                raise ValueError("%s has no such field named '%s.'" %
                                 (self, field_name))

            # Process Message type.
            if field_descriptor.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
                if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
                    # Clean the list first.
                    setattr(self, field_name, [])
                    for p in getattr(proto, field_name):
                        # To get the type hint of a list is not easy.
                        field = self.__annotations__[field_name].__args__[0]()  # pytype: disable=attribute-error
                        field._from_proto(p)  # pylint: disable=protected-access
                        getattr(self, field_name).append(field)

                elif proto.HasField(field_name):
                    getattr(self,
                            field_name)._from_proto(getattr(proto, field_name))  # pylint: disable=protected-access

            # Process Non-Message type
            else:
                if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
                    setattr(self, field_name, getattr(proto, field_name)[:])
                elif proto.HasField(field_name):
                    setattr(self, field_name, getattr(proto, field_name))

        return self
コード例 #3
0
ファイル: model_base.py プロジェクト: wanderer6994/noisicaa
    def set_value(
            self, instance: 'ObjectBase', pb: protobuf.Message, pool: 'Pool', value: OBJECT
    ) -> None:
        if value is not None:
            _checktype(value, self.otype)
        elif not self.allow_none:
            raise ValueError("None not allowed for '%s'." % self.name)

        if pb.HasField(self.name):
            old_id = getattr(pb, self.name)
        else:
            old_id = None
        new_id = value.id if value is not None else None
        if new_id == old_id:
            return

        old_value = pool[old_id] if old_id is not None else None

        if value is None:
            pb.ClearField(self.name)
        else:
            setattr(pb, self.name, value.id)

        if not instance.in_setup:
            instance.property_changed(PropertyValueChange(instance, self.name, old_value, value))
コード例 #4
0
ファイル: model_base.py プロジェクト: wanderer6994/noisicaa
 def get_value(self, instance: 'ObjectBase', pb: protobuf.Message, pool: 'Pool') -> PROTO:
     if pb.HasField(self.name):
         return getattr(pb, self.name)
     if self.default is not None:
         return self.default
     if self.allow_none:
         return None
     raise ValueNotSetError("Value '%s' has not been set." % self.name)
コード例 #5
0
def _is_field_present(msg: message.Message,
                      field_name: str,
                      error_msgs: List[str],
                      error_prefix: str = '') -> bool:
  """Checks whether a message field in a proto is present."""
  if not msg.HasField(field_name):
    error_msgs.append('{}{} field is missing'.format(error_prefix, field_name))
    return False
  return True
コード例 #6
0
ファイル: model_base.py プロジェクト: wanderer6994/noisicaa
 def get_value(self, instance: 'ObjectBase', pb: protobuf.Message, pool: 'Pool') -> OBJECT:
     if pb.HasField(self.name):
         obj_id = getattr(pb, self.name)
         try:
             return cast(OBJECT, pool[obj_id])
         except KeyError:
             raise InvalidReferenceError("%s.%s" % (type(instance).__name__, self.name))
     if self.allow_none:
         return None
     raise ValueNotSetError("Value '%s' has not been set." % self.name)
コード例 #7
0
ファイル: _validate.py プロジェクト: vmagamedov/harness
 def dispatch(self, rules: Message) -> None:
     for field in self.rule_descriptor.fields:
         if field.label == FieldDescriptor.LABEL_REPEATED:
             rule_value = getattr(rules, field.name)
             if not rule_value:
                 continue
         else:
             if not rules.HasField(field.name):
                 continue
             rule_value = getattr(rules, field.name)
         try:
             visit_fn = getattr(self, f"visit_{field.name}")
         except AttributeError:
             raise NotImplementedError(field.full_name)
         visit_fn(rule_value)
コード例 #8
0
def field_content_length(msg: message.Message,
                         field: Union[descriptor.FieldDescriptor, str]) -> int:
    """Returns the size of the field.

  Args:
    msg: The Message whose fields to examine.
    field: The FieldDescriptor or name of the field to examine.

  Returns:
    The number of elements at the provided field. If field describes a singular
    protobuf field, this will return 1. If the field is not set, returns 0.
  """
    if isinstance(field, str):
        field = _field_descriptor_for_name(msg, field)

    if field_is_repeated(field):
        return len(getattr(msg, field.name))
    return 1 if msg.HasField(field.name) else 0
コード例 #9
0
ファイル: model_base.py プロジェクト: wanderer6994/noisicaa
    def set_value(
            self, instance: 'ObjectBase', pb: protobuf.Message, pool: 'Pool', value: VALUE
    ) -> None:
        if pb.HasField(self.name):
            old_value = getattr(pb, self.name)
        else:
            old_value = self.default

        if value is None:
            if self.allow_none:
                pb.ClearField(self.name)
            else:
                raise ValueError("None not allowed for '%s'." % self.name)
        else:
            setattr(pb, self.name, value)

        if value != old_value and not instance.in_setup:
            instance.property_changed(PropertyValueChange(instance, self.name, old_value, value))
コード例 #10
0
ファイル: model_base.py プロジェクト: wanderer6994/noisicaa
    def set_value(
            self, instance: 'ObjectBase', pb: protobuf.Message, pool: 'Pool', value: PROTOVAL
    ) -> None:
        if pb.HasField(self.name):
            old_value = self.ptype.from_proto(getattr(pb, self.name))
        else:
            old_value = self.default

        if value is None:
            if self.allow_none:
                pb.ClearField(self.name)
            else:
                raise ValueError("None not allowed for '%s'." % self.name)
        else:
            _checktype(value, self.ptype)
            getattr(pb, self.name).CopyFrom(value.to_proto())

        if value != old_value and not instance.in_setup:
            instance.property_changed(PropertyValueChange(instance, self.name, old_value, value))
コード例 #11
0
ファイル: proto.py プロジェクト: namely/cos-python-sample
def get_field(msg: Message, field_name):
    '''return field by name or None'''
    if msg.HasField(field_name):
        return msg.__getattribute__(field_name)
    return None
コード例 #12
0
ファイル: py_converters.py プロジェクト: kokizzu/CompilerGym
 def __call__(self, message: Message) -> Any:
     if message.HasField("type_id"):
         return self.conversion_map[message.type_id](message)
     else:
         return self.default_converter(message)
def _recurse_validate(
    message: Message,
    name: str,
    validators: List[AbstractArgumentValidator],
    leading_parts_name: str = None,
    is_optional: bool = False,
):
    errors = []
    field_name_raw, *remaining_fields = name.split(".")
    field_name = field_name_raw.rstrip("[]")

    remaining_fields = [f for f in remaining_fields if f != ""]

    if leading_parts_name is None and field_name == "":
        field_value = message
        field_descriptor: FieldDescriptor = message.DESCRIPTOR  # type: ignore
        full_name = message.DESCRIPTOR.name
    else:
        field_descriptor = message.DESCRIPTOR.fields_by_name[field_name]

        full_name = field_name if leading_parts_name is None else f"{leading_parts_name}.{field_name}"
        if (field_descriptor.label != FieldDescriptor.LABEL_REPEATED
                and field_descriptor.type == FieldDescriptor.TYPE_MESSAGE
                and not message.HasField(field_name)):
            if is_optional:
                return []
            return [f"request must have {full_name}"]

        field_value = getattr(message, field_name)

    if remaining_fields:
        if field_descriptor.label == FieldDescriptor.LABEL_REPEATED:
            for i, elem in enumerate(field_value):  # type: ignore
                errors.extend(
                    _recurse_validate(
                        message=elem,
                        name=".".join(remaining_fields),
                        leading_parts_name=f"{full_name}[{i}]",
                        validators=validators,
                        is_optional=is_optional,
                    ))
        else:
            errors.extend(
                _recurse_validate(
                    message=field_value,
                    name=".".join(remaining_fields),
                    leading_parts_name=full_name,
                    validators=validators,
                    is_optional=is_optional,
                ))
    else:
        for v in validators:
            if field_name_raw.endswith(
                    "[]"
            ) and field_descriptor.label == FieldDescriptor.LABEL_REPEATED:
                for i, field_value_elem in enumerate(
                        field_value):  # type: ignore
                    validation_result = v.check(f"{full_name}[{i}]",
                                                field_value_elem,
                                                field_descriptor)
                    if not validation_result.valid:
                        errors.append(validation_result.invalid_reason)
            else:
                validation_result = v.check(full_name, field_value,
                                            field_descriptor)
                if not validation_result.valid:
                    errors.append(validation_result.invalid_reason)
    return errors