def test_parse_map_indexing_step(self): map_field_name, map_key = parse_map_indexing_step("my_map[some_key]") self.assertEqual("my_map", map_field_name) self.assertEqual("some_key", map_key) map_field_name, map_key = parse_map_indexing_step("my_map[]") self.assertEqual("my_map", map_field_name) self.assertEqual("", map_key) map_field_name, map_key = parse_map_indexing_step("my_map[[.]") self.assertEqual("my_map", map_field_name) self.assertEqual("[.", map_key)
def _get_map_child( parent: Union[_ProtoChildExpression, _ProtoRootExpression], desc: descriptor.Descriptor, field_name: ProtoFieldName, backing_str_tensor: Optional[tf.Tensor], ) -> Optional[Union[_ProtoLeafExpression, _ProtoChildExpression]]: """Gets the child given a map field.""" [map_field_name, _] = path.parse_map_indexing_step(field_name) map_field_desc = desc.fields_by_name.get(map_field_name) if map_field_desc is None: return None if not _is_map_field_desc(map_field_desc): return None map_message_desc = map_field_desc.message_type if map_message_desc is None: # Note: I don't know if this is reachable. Theoretically, _is_map_field_desc # should have already returned false. return None value_field_desc = map_message_desc.fields_by_name.get("value") if value_field_desc is None: # Note: I don't know if this is reachable. Theoretically, _is_map_field_desc # should have already returned false. return None # This relies on the fact that the value is an optional field. return _get_child_helper(parent, value_field_desc, field_name, backing_str_tensor)
def _get_field_names_to_parse(desc, needed_field_names): """Gets the field names to parse from the original protobuf.""" result = set() # Set[ProtoFieldName] for x in needed_field_names: if path.is_map_indexing_step(x): map_field_name, _ = path.parse_map_indexing_step(x) result.add(map_field_name) elif path.is_extension(x) and is_any_descriptor(desc): result.add("type_url") result.add("value") else: result.add(x) return list(result)
def _get_map_parsed_fields( desc: descriptor.Descriptor, regular_fields: Mapping[StrStep, struct2tensor_ops._ParsedField], field_names: Set[StrStep], backing_str_tensor: Optional[tf.Tensor] = None ) -> Mapping[StrStep, struct2tensor_ops._ParsedField]: """Gets the map proto ParsedFields. field_names includes all the fields: map fields, any fields, and regular fields. Args: desc: the descriptor of the parent proto. regular_fields: the fields that are parsed directly from the proto. field_names: all fields needed: map fields, any fields, and regular fields. backing_str_tensor: a string tensor representing the root serialized proto. This is passed to keep string_views of the tensor valid for all children of the root expression Returns: A map from field names to ParsedFields, only for the field names of the form foo[bar]. """ maps_to_parse = collections.defaultdict(dict) for x in field_names: if path.is_map_indexing_step(x): map_field_name, key = path.parse_map_indexing_step(x) maps_to_parse[map_field_name][key] = x result_as_list = [] for map_field_name, v in maps_to_parse.items(): parsed_map_field = regular_fields[map_field_name] keys_needed = list(v.keys()) map_field_value = parsed_map_field.value map_field_index = parsed_map_field.index map_field_desc = desc.fields_by_name[map_field_name].message_type values_and_parent_indices = struct2tensor_ops.parse_proto_map( map_field_value, map_field_index, map_field_desc, keys_needed, backing_str_tensor) for map_key, [value, parent_index] in zip(keys_needed, values_and_parent_indices): result_as_list.append( struct2tensor_ops._ParsedField(field_name=v[map_key], field_descriptor=None, index=parent_index, value=value)) return {x.field_name: x for x in result_as_list}
def _get_map_child(parent, desc, field_name): """Gets the child given a map field.""" [map_field_name, _] = path.parse_map_indexing_step(field_name) map_field_desc = desc.fields_by_name.get(map_field_name) if map_field_desc is None: return None if not _is_map_field_desc(map_field_desc): return None map_message_desc = map_field_desc.message_type if map_message_desc is None: # Note: I don't know if this is reachable. Theoretically, _is_map_field_desc # should have already returned false. return None value_field_desc = map_message_desc.fields_by_name.get("value") if value_field_desc is None: # Note: I don't know if this is reachable. Theoretically, _is_map_field_desc # should have already returned false. return None # This relies on the fact that the value is an optional field. return _get_child_helper(parent, value_field_desc, field_name)
def _get_map_parsed_fields(desc, regular_fields, field_names): """Gets the map proto ParsedFields. field_names includes all the fields: map fields, any fields, and regular fields. Args: desc: the descriptor of the parent proto. regular_fields: the fields that are parsed directly from the proto. field_names: all fields needed: map fields, any fields, and regular fields. Returns: A map from field names to ParsedFields, only for the field names of the form foo[bar]. """ maps_to_parse = collections.defaultdict(dict) for x in field_names: if path.is_map_indexing_step(x): map_field_name, key = path.parse_map_indexing_step(x) maps_to_parse[map_field_name][key] = x result_as_list = [] for map_field_name, v in maps_to_parse.items(): parsed_map_field = regular_fields[map_field_name] keys_needed = list(v.keys()) map_field_value = parsed_map_field.value map_field_index = parsed_map_field.index map_field_desc = desc.fields_by_name[map_field_name].message_type values_and_parent_indices = struct2tensor_ops.parse_proto_map( map_field_value, map_field_index, map_field_desc, keys_needed) for map_key, [value, parent_index] in zip(keys_needed, values_and_parent_indices): result_as_list.append( struct2tensor_ops._ParsedField(field_name=v[map_key], field_descriptor=None, index=parent_index, value=value)) return {x.field_name: x for x in result_as_list}