コード例 #1
0
def send_rpc(route: str, in_msg: message.Message, res_class: message.Message, server_id: str='') -> message.Message:
    """ sends a rpc to other pitaya server """
    if not issubclass(type(in_msg), message.Message) or not issubclass(res_class, message.Message):
        raise TypeError
    msg_bytes = in_msg.SerializeToString()
    msg_len = len(msg_bytes)
    c_bytes = (c_char * msg_len)(*msg_bytes)
    ret_ptr = POINTER(MemoryBuffer)()
    err = PitayaError()
    res = LIB.tfg_pitc_RPC(server_id.encode(
        'utf-8'), route.encode('utf-8'), addressof(c_bytes), msg_len, byref(ret_ptr), byref(err))
    if not res:
        exception_msg = "code: {} msg: {}".format(err.code, err.msg)
        LIB.tfg_pitc_FreePitayaError(err)
        raise Exception(exception_msg)
    ret_bytes = (
        c_char * ret_ptr.contents.size).from_address(ret_ptr.contents.data)
    response = Response()
    response.MergeFromString(ret_bytes.value)
    res = res_class()
    res.MergeFromString(response.data)
    LIB.tfg_pitc_FreeMemoryBuffer(ret_ptr)
    return res
コード例 #2
0
def get_contained_resource(
        contained_resource: message.Message) -> message.Message:
    """Returns the resource instance contained within `contained_resource`.

  Args:
    contained_resource: The containing `ContainedResource` instance.

  Returns:
    The resource contained by `contained_resource`.

  Raises:
    TypeError: In the event that `contained_resource` is not of type
    `ContainedResource`.
    ValueError: In the event that the oneof on `contained_resource` is not set.
  """
    # TODO: Use an annotation here.
    if contained_resource.DESCRIPTOR.name != 'ContainedResource':
        raise TypeError('Expected `ContainedResource` but got: '
                        f'{type(contained_resource)}.')
    oneof_field = contained_resource.WhichOneof('oneof_resource')
    if oneof_field is None:
        raise ValueError('`ContainedResource` oneof not set.')
    return proto_utils.get_value_at_field(contained_resource, oneof_field)
コード例 #3
0
  def _GetMethodUrlAndPathParamsNames(
      self,
      handler_name: str,
      args: message.Message,
  ) -> Tuple[reflection_pb2.ApiMethod, str, Iterable[str]]:
    path_params = {}  # Dict[str, Union[int, str]]
    if args:
      for field, value in args.ListFields():
        if self.handlers_map.is_endpoint_expecting(handler_name, field.name):
          path_params[field.name] = self._CoerceValueToQueryStringType(
              field, value)

    url = self.urls.build(handler_name, path_params, force_external=True)

    method = None
    for rule in self.handlers_map.iter_rules():
      if rule.endpoint == handler_name:
        method = [m for m in rule.methods if m != "HEAD"][0]

    if not method:
      raise RuntimeError("Can't find method for %s" % handler_name)

    return method, url, list(path_params.keys())
コード例 #4
0
ファイル: _json_printer.py プロジェクト: anniyanvr/fhir
    def _print_reference(self, reference: message.Message) -> None:
        """Standardizes and prints the provided reference.

    Note that "standardization" in the case of PURE FHIR JSON refers to
    un-typing the typed-reference prior to printing.

    Args:
      reference: The reference to print.
    """
        set_oneof = reference.WhichOneof('reference')
        if (self.json_format == _FhirJsonFormat.PURE and set_oneof is not None
                and set_oneof != 'uri'):
            # In pure FHIR mode, we have to serialize structured references
            # into FHIR uri strings.
            standardized_reference = copy.copy(reference)

            # Setting the new URI field will overwrite the original oneof
            new_uri = proto_utils.get_value_at_field(standardized_reference,
                                                     'uri')
            proto_utils.set_value_at_field(
                new_uri, 'value', references.reference_to_string(reference))
            self._print_message(standardized_reference)
        else:
            self._print_message(reference)
コード例 #5
0
def frame_encode(msg: Message) -> bytes:
    pb_data = msg.SerializeToString()
    return cobs.encode(pb_data)
コード例 #6
0
ファイル: proto.py プロジェクト: namely/cos-python-sample
def get_field(msg: Message, field_name):
    '''return field by name or None'''
    if msg.HasField(field_name):
        return msg.__getattribute__(field_name)
    return None
コード例 #7
0
ファイル: pipeline_state.py プロジェクト: jay90099/tfx
def _base64_encode(msg: message.Message) -> str:
    return base64.b64encode(msg.SerializeToString()).decode('utf-8')
コード例 #8
0
ファイル: py_converters.py プロジェクト: kokizzu/CompilerGym
 def __call__(self, message: Message) -> Any:
     if message.HasField("type_id"):
         return self.conversion_map[message.type_id](message)
     else:
         return self.default_converter(message)
コード例 #9
0
 def encode(self, message: Message) -> bool:
     return message.SerializeToString()
コード例 #10
0
def write_tfrecord_file(file_name: Text, proto: Message) -> None:
  """Writes a serialized tfrecord to file."""

  tf.gfile.MakeDirs(os.path.dirname(file_name))
  with tf.python_io.TFRecordWriter(file_name) as writer:
    writer.write(proto.SerializeToString())
コード例 #11
0
ファイル: proto.py プロジェクト: eaugeas/hilo-tfx
def serialize(stream: io.RawIOBase, message: Message):
    stream.write(message.SerializeToString())
コード例 #12
0
async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None:
    size = pbmsg.ByteSize()
    await write_unsigned_varint(stream, size)
    msg_bytes: bytes = pbmsg.SerializeToString()
    stream.write(msg_bytes)
コード例 #13
0
async def read_pbmsg_safe(stream: anyio.abc.SocketStream, pbmsg: PBMessage) -> None:
    len_msg_bytes = await read_unsigned_varint(stream)
    msg_bytes = await stream.receive_exactly(len_msg_bytes)
    pbmsg.ParseFromString(msg_bytes)
コード例 #14
0
async def write_pbmsg(stream: anyio.abc.SocketStream, pbmsg: PBMessage) -> None:
    size = pbmsg.ByteSize()
    await write_unsigned_varint(stream, size)
    msg_bytes: bytes = pbmsg.SerializeToString()
    await stream.send_all(msg_bytes)
コード例 #15
0
 def write(self, message: Message):
     _bytes = message.SerializeToString()
     logger.info("sending: %r", _bytes)
     self._raw_write(_bytes)
コード例 #16
0
def _attempt_parse(obj: ProtobufMessage, data: bytes) -> None:
    try:
        obj.ParseFromString(data)
    except ProtobufDecodeError:
        raise ParseError("Incorrect protobuf message")
コード例 #17
0
def substitute_runtime_parameter(
    msg: message.Message, parameter_bindings: Mapping[str, types.Property]
) -> Mapping[str, types.Property]:
    """Utility function to substitute runtime parameter placeholders with values.

  Args:
    msg: The original message to change. Only messages defined under
      pipeline_pb2 will be supported. Other types will result in no-op.
    parameter_bindings: A dict of parameter keys to parameter values that will
      be used to substitute the runtime parameter placeholder.

  Returns:
    A dict of all runtime parameters to their populated values
  """
    if not isinstance(msg, message.Message):
        return {}

    parameters = {}
    # If the message is a pipeline_pb2.Value instance, try to find an substitute
    # with runtime parameter bindings.
    if isinstance(msg, pipeline_pb2.Value):
        value = cast(pipeline_pb2.Value, msg)
        which = value.WhichOneof('value')
        if which == 'runtime_parameter':
            real_value = _get_runtime_parameter_value(value.runtime_parameter,
                                                      parameter_bindings)
            parameters[value.runtime_parameter.name] = real_value
            if real_value is None:
                return parameters
            value.Clear()
            data_types_utils.set_metadata_value(
                metadata_value=value.field_value, value=real_value)
        if which == 'structural_runtime_parameter':
            real_value = _get_structural_runtime_parameter_value(
                value.structural_runtime_parameter, parameter_bindings,
                parameters)
            if real_value is None:
                return parameters
            value.Clear()
            data_types_utils.set_metadata_value(
                metadata_value=value.field_value, value=real_value)

        return parameters

    # For other cases, recursively call into sub-messages if any.
    for field, sub_message in msg.ListFields():
        # No-op for non-message types.
        if field.type != descriptor.FieldDescriptor.TYPE_MESSAGE:
            continue
        # Evaluates every map values in a map.
        elif (field.message_type.has_options
              and field.message_type.GetOptions().map_entry):
            for key in sub_message:
                parameters.update(
                    substitute_runtime_parameter(sub_message[key],
                                                 parameter_bindings))
        # Evaluates every entry in a list.
        elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
            for element in sub_message:
                parameters.update(
                    substitute_runtime_parameter(element, parameter_bindings))
        # Evaluates sub-message.
        else:
            parameters.update(
                substitute_runtime_parameter(sub_message, parameter_bindings))
    return parameters
コード例 #18
0
 def do_nothing(self, request, context):
     return Message()
コード例 #19
0
 def make_command(self, message: Message) -> Command:
     """Make a command instance from the given message."""
     return Command(name=type(message).__name__,
                    data=message.SerializeToString())
コード例 #20
0
async def read_pbmsg_safe(stream: asyncio.StreamReader,
                          pbmsg: PBMessage) -> None:
    len_msg_bytes = await read_unsigned_varint(stream)
    msg_bytes = await stream.readexactly(len_msg_bytes)
    pbmsg.ParseFromString(msg_bytes)
コード例 #21
0
def test_unhandled_exception_rpc():
    @rpc.unhandled_exception_rpc(Message)
    def do_nothing(self, request, context):
        return Message()

    assert isinstance(do_nothing(MagicMock(), Message(), MagicMock()), Message)
コード例 #22
0
def copy_code(source: message.Message, target: message.Message):
    """Adds all fields from source to target.

  Args:
    source: The FHIR Code instance to copy from.
    target: The target FHIR Code instance to copy to.
  """
    if not fhir_types.is_type_or_profile_of_code(source.DESCRIPTOR):
        raise fhir_errors.InvalidFhirError(
            f'Source: {source.DESCRIPTOR.full_name} '
            'is not type or profile of Code.')

    if not fhir_types.is_type_or_profile_of_code(target.DESCRIPTOR):
        raise fhir_errors.InvalidFhirError(
            f'Target: {target.DESCRIPTOR.full_name} '
            'is not type or profile of Code.')

    if proto_utils.are_same_message_type(source.DESCRIPTOR, target.DESCRIPTOR):
        target.CopyFrom(source)
        return

    source_value_field = source.DESCRIPTOR.fields_by_name.get('value')
    target_value_field = target.DESCRIPTOR.fields_by_name.get('value')
    if source_value_field is None or target_value_field is None:
        raise fhir_errors.InvalidFhirError(
            'Unable to copy code from '
            f'{source.DESCRIPTOR.full_name} '
            f'to {target.DESCRIPTOR.full_name}.')

    proto_utils.copy_common_field(source, target, 'id')
    proto_utils.copy_common_field(source, target, 'extension')

    # Handle specialized codes
    if (source_value_field.type not in _CODE_TYPES
            or target_value_field.type not in _CODE_TYPES):
        raise ValueError(
            f'Unable to copy from {source.DESCRIPTOR.full_name} '
            f'to {target.DESCRIPTOR.full_name}. Must have a field '
            'of TYPE_ENUM or TYPE_STRING.')

    source_value = proto_utils.get_value_at_field(source, source_value_field)
    if source_value_field.type == target_value_field.type:
        # Perform a simple assignment if value_field types are equivalent
        proto_utils.set_value_at_field(target, target_value_field,
                                       source_value)
    else:
        # Otherwise, we need to transform the value prior to assignment...
        if source_value_field.type == descriptor.FieldDescriptor.TYPE_STRING:
            source_enum_value = code_string_to_enum_value_descriptor(
                source_value, target_value_field.enum_type)
            proto_utils.set_value_at_field(target, target_value_field,
                                           source_enum_value.number)
        elif source_value_field.type == descriptor.FieldDescriptor.TYPE_ENUM:
            source_string_value = enum_value_descriptor_to_code_string(
                source_value_field.enum_type.values_by_number[source_value])
            proto_utils.set_value_at_field(target, target_value_field,
                                           source_string_value)
        else:  # Should never hit
            raise ValueError('Unexpected generic value field type: '
                             f'{source_value_field.type}. Must be a field of '
                             'TYPE_ENUM or TYPE_STRING in order to copy.')
コード例 #23
0
def _serialize(message: Message) -> bytes:
    out = BytesIO()
    with gzip.GzipFile(fileobj=out, mode="w") as f:
        f.write(message.SerializeToString())
    return out.getvalue()
コード例 #24
0
def _(message: Message) -> dict:
    exprs = getattr(message, 'exprs', None)
    if exprs is not None:
        return [expr_to_obj(expr) for expr in to_dict(exprs)]
    return {f[0].name: to_dict(f[1]) for f in message.ListFields()}
def _recurse_validate(
    message: Message,
    name: str,
    validators: List[AbstractArgumentValidator],
    leading_parts_name: str = None,
    is_optional: bool = False,
):
    errors = []
    field_name_raw, *remaining_fields = name.split(".")
    field_name = field_name_raw.rstrip("[]")

    remaining_fields = [f for f in remaining_fields if f != ""]

    if leading_parts_name is None and field_name == "":
        field_value = message
        field_descriptor: FieldDescriptor = message.DESCRIPTOR  # type: ignore
        full_name = message.DESCRIPTOR.name
    else:
        field_descriptor = message.DESCRIPTOR.fields_by_name[field_name]

        full_name = field_name if leading_parts_name is None else f"{leading_parts_name}.{field_name}"
        if (field_descriptor.label != FieldDescriptor.LABEL_REPEATED
                and field_descriptor.type == FieldDescriptor.TYPE_MESSAGE
                and not message.HasField(field_name)):
            if is_optional:
                return []
            return [f"request must have {full_name}"]

        field_value = getattr(message, field_name)

    if remaining_fields:
        if field_descriptor.label == FieldDescriptor.LABEL_REPEATED:
            for i, elem in enumerate(field_value):  # type: ignore
                errors.extend(
                    _recurse_validate(
                        message=elem,
                        name=".".join(remaining_fields),
                        leading_parts_name=f"{full_name}[{i}]",
                        validators=validators,
                        is_optional=is_optional,
                    ))
        else:
            errors.extend(
                _recurse_validate(
                    message=field_value,
                    name=".".join(remaining_fields),
                    leading_parts_name=full_name,
                    validators=validators,
                    is_optional=is_optional,
                ))
    else:
        for v in validators:
            if field_name_raw.endswith(
                    "[]"
            ) and field_descriptor.label == FieldDescriptor.LABEL_REPEATED:
                for i, field_value_elem in enumerate(
                        field_value):  # type: ignore
                    validation_result = v.check(f"{full_name}[{i}]",
                                                field_value_elem,
                                                field_descriptor)
                    if not validation_result.valid:
                        errors.append(validation_result.invalid_reason)
            else:
                validation_result = v.check(full_name, field_value,
                                            field_descriptor)
                if not validation_result.valid:
                    errors.append(validation_result.invalid_reason)
    return errors