Beispiel #1
0
 def _build_executable_spec(
     self, node_id: str,
     spec: any_pb2.Any) -> local_deployment_config_pb2.ExecutableSpec:
   """Builds ExecutableSpec given the any proto from IntermediateDeploymentConfig."""
   result = local_deployment_config_pb2.ExecutableSpec()
   if spec.Is(result.python_class_executable_spec.DESCRIPTOR):
     spec.Unpack(result.python_class_executable_spec)
   elif spec.Is(result.container_executable_spec.DESCRIPTOR):
     spec.Unpack(result.container_executable_spec)
   else:
     raise ValueError(
         'executor spec of {} is expected to be of one of the '
         'types of tfx.orchestration.deployment_config.ExecutableSpec.spec '
         'but got type {}'.format(node_id, spec.type_url))
   return result
Beispiel #2
0
def extract_mlmd_connection(
        connection_config: any_pb2.Any) -> metadata_pb2.MLMDConnectionConfig:
    result = metadata_pb2.MLMDConnectionConfig()
    for name in metadata_pb2.MLMDConnectionConfig.DESCRIPTOR.fields_by_name:
        if connection_config.Unpack(getattr(result, name)):
            break
    return result
Beispiel #3
0
def UnpackAny(
        proto_any: any_pb2.Any) -> Union[UnknownProtobuf, message.Message]:
    try:
        proto = TypeUrlToMessage(proto_any.type_url)
    except ProtobufTypeNotFound as e:
        return UnknownProtobuf(str(e), proto_any)

    proto_any.Unpack(proto)
    return proto
Beispiel #4
0
 def __call__(self, msg: any_pb2.Any) -> Message:
     message_cls = self.type_str_to_class_map[msg.TypeName()]
     unpacked_message = message_cls()
     status = msg.Unpack(unpacked_message)
     if not status:
         raise ValueError(
             f'Failed unpacking prtobuf Any message with type url "{msg.TypeName()}".'
         )
     return unpacked_message
Beispiel #5
0
def _build_local_platform_config(
        node_id: str,
        spec: any_pb2.Any) -> local_deployment_config_pb2.LocalPlatformConfig:
    """Builds LocalPlatformConfig given the any proto from IntermediateDeploymentConfig."""
    result = local_deployment_config_pb2.LocalPlatformConfig()
    if spec.Is(result.docker_platform_config.DESCRIPTOR):
        spec.Unpack(result.docker_platform_config)
    else:
        raise ValueError(
            'Platform config of {} is expected to be of one of the types of '
            'tfx.orchestration.deployment_config.LocalPlatformConfig.config '
            'but got type {}'.format(node_id, spec.type_url))
    return result
Beispiel #6
0
def decode_tensor_node(graph: tf.Graph,
                       encoded_tensor_node: any_pb2.Any) -> types.TensorType:
  """Decode an encoded Tensor node encoded with encode_tensor_node.

  Decodes the encoded Tensor "reference", and returns the node in the given
  graph corresponding to that Tensor.

  Args:
    graph: Graph the Tensor
    encoded_tensor_node: Encoded Tensor.

  Returns:
    Decoded Tensor.
  """
  tensor_info = meta_graph_pb2.TensorInfo()
  encoded_tensor_node.Unpack(tensor_info)
  return tf.saved_model.utils.get_tensor_from_tensor_info(tensor_info, graph)
Beispiel #7
0
def unpack(data: GrpcAny, message: GrpcMessage) -> None:
    """Unpack the serialized protocol buffer message.

    Args:
        data (:obj:`google.protobuf.message.Any`): the serialized protocol buffer message.
        message (:obj:`google.protobuf.message.Message`): the protocol buffer message object
            to which the response data is deserialized.

    Raises:
        ValueError: message is not protocol buffer message object or message's type is not
            matched with the response data type
    """
    if not isinstance(message, GrpcMessage):
        raise ValueError(
            'output message is not protocol buffer message object')
    if not data.Is(message.DESCRIPTOR):
        raise ValueError(
            f'invalid type. serialized message type: {data.type_url}')
    data.Unpack(message)
Beispiel #8
0
def unpack_any(message: any_pb2.Any, out: M) -> M:
    message.Unpack(out)
    return out