コード例 #1
0
 def calculate(
         self,
         sources: Sequence[prensor.NodeTensor],
         destinations: Sequence[expression.Expression],
         options: calculate_options.Options,
         side_info: Optional[prensor.Prensor] = None
 ) -> _ProtoRootNodeTensor:
     if sources:
         raise ValueError("_ProtoRootExpression has no sources")
     size = tf.size(self._tensor_of_protos, out_type=tf.int64)
     needed_fields = _get_needed_fields(destinations)
     backing_str_tensor = None
     if options.use_string_view:
         assert self._message_format == "binary", (
             "`options.use_string_view` is only compatible with 'binary' message "
             "format. Please create the root expression with "
             "message_format='binary'.")
         backing_str_tensor = self._tensor_of_protos
     fields = parse_message_level_ex.parse_message_level_ex(
         self._tensor_of_protos,
         self._descriptor,
         needed_fields,
         message_format=self._message_format,
         backing_str_tensor=backing_str_tensor,
         honor_proto3_optional_semantics=options.
         experimental_honor_proto3_optional_semantics)
     return _ProtoRootNodeTensor(size, fields)
コード例 #2
0
 def calculate_from_parsed_field(self,
                                 parsed_field,
                                 destinations
                                ):
   needed_fields = _get_needed_fields(destinations)
   fields = parse_message_level_ex.parse_message_level_ex(
       parsed_field.value, self._desc, needed_fields)
   return _ProtoChildNodeTensor(parsed_field.index, self.is_repeated, fields)
コード例 #3
0
 def calculate_from_parsed_field(
         self, parsed_field: struct2tensor_ops._ParsedField,
         destinations: Sequence[expression.Expression]
 ) -> prensor.NodeTensor:
     needed_fields = _get_needed_fields(destinations)
     fields = parse_message_level_ex.parse_message_level_ex(
         parsed_field.value, self._desc, needed_fields)
     return _ProtoChildNodeTensor(parsed_field.index, self.is_repeated,
                                  fields)
コード例 #4
0
 def calculate(self, sources,
               destinations,
               options):
   if sources:
     raise ValueError("_ProtoRootExpression has no sources")
   size = tf.size(self._tensor_of_protos, out_type=tf.int64)
   needed_fields = _get_needed_fields(destinations)
   fields = parse_message_level_ex.parse_message_level_ex(
       self._tensor_of_protos, self._descriptor, needed_fields)
   return _ProtoRootNodeTensor(size, fields)
コード例 #5
0
def _run_parse_message_level_ex(proto_list, fields, sess):
    serialized = [x.SerializeToString() for x in proto_list]
    parsed_field_dict = parse_message_level_ex.parse_message_level_ex(
        tf.constant(serialized), proto_list[0].DESCRIPTOR, fields)
    sess_input = {}
    for key, value in parsed_field_dict.items():
        local_dict = {}
        local_dict[_INDEX] = value.index
        local_dict[_VALUE] = value.value
        sess_input[key] = local_dict
    return sess.run(sess_input)
コード例 #6
0
 def calculate_from_parsed_field(
         self, parsed_field: struct2tensor_ops._ParsedField,
         destinations: Sequence[expression.Expression],
         use_string_view: bool) -> prensor.NodeTensor:
     needed_fields = _get_needed_fields(destinations)
     backing_str_tensor = None
     if use_string_view:
         backing_str_tensor = self._backing_str_tensor
     fields = parse_message_level_ex.parse_message_level_ex(
         parsed_field.value,
         self._desc,
         needed_fields,
         backing_str_tensor=backing_str_tensor)
     return _ProtoChildNodeTensor(parsed_field.index, self.is_repeated,
                                  fields)
コード例 #7
0
 def calculate(
         self,
         sources: Sequence[prensor.NodeTensor],
         destinations: Sequence[expression.Expression],
         options: calculate_options.Options,
         side_info: Optional[prensor.Prensor] = None
 ) -> _ProtoRootNodeTensor:
     if sources:
         raise ValueError("_ProtoRootExpression has no sources")
     size = tf.size(self._tensor_of_protos, out_type=tf.int64)
     needed_fields = _get_needed_fields(destinations)
     fields = parse_message_level_ex.parse_message_level_ex(
         self._tensor_of_protos,
         self._descriptor,
         needed_fields,
         message_format=self._message_format)
     return _ProtoRootNodeTensor(size, fields)
コード例 #8
0
def _run_parse_message_level_ex(proto_list, fields, message_format="binary"):
    if message_format == "text":
        serialized = [text_format.MessageToString(x) for x in proto_list]
    elif message_format == "binary":
        serialized = [x.SerializeToString() for x in proto_list]
    else:
        raise ValueError('Message format must be one of "text", "binary"')
    parsed_field_dict = parse_message_level_ex.parse_message_level_ex(
        tf.constant(serialized), proto_list[0].DESCRIPTOR, fields,
        message_format)
    sess_input = {}
    for key, value in parsed_field_dict.items():
        local_dict = {}
        local_dict[_INDEX] = value.index
        local_dict[_VALUE] = value.value
        sess_input[key] = local_dict
    return sess_input
コード例 #9
0
 def calculate_from_parsed_field(
         self,
         parsed_field: struct2tensor_ops._ParsedField,  # pylint:disable=protected-access
         destinations: Sequence[expression.Expression],
         options: calculate_options.Options) -> prensor.NodeTensor:
     needed_fields = _get_needed_fields(destinations)
     backing_str_tensor = None
     if options.use_string_view:
         backing_str_tensor = self._backing_str_tensor
     fields = parse_message_level_ex.parse_message_level_ex(
         parsed_field.value,
         self._desc,
         needed_fields,
         backing_str_tensor=backing_str_tensor,
         honor_proto3_optional_semantics=options.
         experimental_honor_proto3_optional_semantics)
     return _ProtoChildNodeTensor(parsed_field.index, self.is_repeated,
                                  fields)