def add_input_token(self, input_token: Token[InputId, str], copy: bool = False) -> "DuoRATInputSegmentBuilder": builder = deepcopy(self) if copy is True else self positioned_input_token = replace( input_token, position=Pos(builder.input_token_max_position_pointer + input_token.position), ) builder.input_token_max_position_pointer = ( builder.input_token_max_position_pointer + input_token.position + 1) builder.input_a_builder.add_token(token=replace( positioned_input_token, value=self.input_a_str_to_id(positioned_input_token.value), )) builder.input_b_builder.add_token(token=replace( positioned_input_token, value=self.input_b_str_to_id(positioned_input_token.value), )) builder.input_key_padding_mask_builder.add_token( token=positioned_input_token) builder.input_attention_mask_builder.add_token( token=positioned_input_token) builder.input_token_type_ids_builder.add_token( token=positioned_input_token) builder.input_position_ids_builder.add_token( token=positioned_input_token) builder.input_to_source_gather_index_builder.add_input_token( input_token=positioned_input_token) return builder
def add_source_token(self, source_token: Token[InputId, str], copy: bool = False) -> "DuoRATEncoderItemBuilder": builder = deepcopy(self) if copy is True else self positioned_source_token = replace( source_token, position=Pos(builder.source_token_max_position_pointer + source_token.position), ) builder.positioned_source_tokens.append(positioned_source_token) builder.source_token_max_position_pointer = ( builder.source_token_max_position_pointer + source_token.position + 1) for _, input_segment_builder in builder.input_segment_builders.items(): input_segment_builder.add_positioned_source_token( positioned_source_token=positioned_source_token) builder.input_to_source_gather_index_builder.add_source_token( source_token=positioned_source_token) builder.source_key_padding_mask_builder.add_token( token=positioned_source_token) builder.source_attention_mask_builder.add_token( token=positioned_source_token) builder.source_relations_builder.add_source_token( source_token=positioned_source_token) return builder
class QuestionToken(Token[QuestionTokenId, VT]): key: QuestionTokenId value: VT raw_value: VT scope: AttentionScope position: Pos = Pos(0) match_tags: Tuple[MatchTag, ...] = tuple()
def add_input_token( self, input_token: Token[InputId, str], copy: bool = False ) -> "DuoRATEncoderItemBuilder": builder = deepcopy(self) if copy is True else self builder.input_segment_builders[input_token.scope].add_input_token( input_token=input_token ) if ( builder.max_supported_input_length is not None and builder.input_token_max_position_pointer + input_token.position + 1 > builder.max_supported_input_length ): logger.warning( "input token tensor has been truncated to {} tokens, " "original length was {} tokens".format( builder.max_supported_input_length, builder.input_token_max_position_pointer + input_token.position + 1, ) ) return builder else: positioned_input_token = replace( input_token, position=Pos( builder.input_token_max_position_pointer + input_token.position ), ) builder.input_token_max_position_pointer = ( builder.input_token_max_position_pointer + input_token.position + 1 ) builder.input_a_builder.add_token( token=replace( positioned_input_token, value=self.input_a_str_to_id(positioned_input_token.value), ) ) builder.input_b_builder.add_token( token=replace( positioned_input_token, value=self.input_b_str_to_id(positioned_input_token.value), ) ) builder.input_key_padding_mask_builder.add_token( token=positioned_input_token ) builder.input_attention_mask_builder.add_token(token=positioned_input_token) builder.input_token_type_ids_builder.add_token(token=positioned_input_token) builder.input_position_ids_builder.add_token(token=positioned_input_token) builder.input_to_source_gather_index_builder.add_input_token( input_token=positioned_input_token ) builder.source_relations_builder.add_input_token( input_token=positioned_input_token ) return builder
def add_action_token( self, action_token: ActionToken[Action], copy: bool = False, ) -> "DuoRATDecoderItemBuilder": builder = deepcopy(self) if copy is True else self if isinstance(builder.parsing_result, Done): raise ValueError("A complete action sequence cannot be continued") elif isinstance(builder.parsing_result, Partial): positioned_action_token = replace( action_token, position=Pos(builder.action_token_max_position_pointer + action_token.position), ) builder.action_token_max_position_pointer = ( builder.action_token_max_position_pointer + action_token.position + 1) action_info = ActionInfo( action=positioned_action_token.value, parent_pos=builder.parsing_result.parent_pos, frontier_field=builder.parsing_result.frontier_field, ) # Don't try to parse a mask action if positioned_action_token.value != MaskAction(): builder.parsing_result = builder.parsing_result.cont( positioned_action_token.position, positioned_action_token.value) positioned_action_info_token = ActionInfoToken( key=action_info, value=action_info, position=positioned_action_token.position, scope=positioned_action_token.scope, ) builder._add_positioned_action_info_token( positioned_action_info_token=positioned_action_info_token) return builder else: raise ValueError("Invalid parsing state: {}".format( builder.parsing_result))
class ActionInfoToken(Token[ActionInfo, VT]): key: ActionInfo value: VT scope: AttentionScope position: Pos = Pos(0)
class TableToken(Token[TableId, VT]): key: TableId value: VT scope: AttentionScope position: Pos = Pos(0)
class ColumnToken(Token[ColumnId, VT]): key: ColumnId value: VT scope: AttentionScope position: Pos = Pos(0)