コード例 #1
0
ファイル: duorat.py プロジェクト: ThanThoai/Text2SqlVN
 def add_input_token(self,
                     input_token: Token[InputId, str],
                     copy: bool = False) -> "DuoRATInputSegmentBuilder":
     builder = deepcopy(self) if copy is True else self
     positioned_input_token = replace(
         input_token,
         position=Pos(builder.input_token_max_position_pointer +
                      input_token.position),
     )
     builder.input_token_max_position_pointer = (
         builder.input_token_max_position_pointer + input_token.position +
         1)
     builder.input_a_builder.add_token(token=replace(
         positioned_input_token,
         value=self.input_a_str_to_id(positioned_input_token.value),
     ))
     builder.input_b_builder.add_token(token=replace(
         positioned_input_token,
         value=self.input_b_str_to_id(positioned_input_token.value),
     ))
     builder.input_key_padding_mask_builder.add_token(
         token=positioned_input_token)
     builder.input_attention_mask_builder.add_token(
         token=positioned_input_token)
     builder.input_token_type_ids_builder.add_token(
         token=positioned_input_token)
     builder.input_position_ids_builder.add_token(
         token=positioned_input_token)
     builder.input_to_source_gather_index_builder.add_input_token(
         input_token=positioned_input_token)
     return builder
コード例 #2
0
ファイル: duorat.py プロジェクト: ThanThoai/Text2SqlVN
 def add_source_token(self,
                      source_token: Token[InputId, str],
                      copy: bool = False) -> "DuoRATEncoderItemBuilder":
     builder = deepcopy(self) if copy is True else self
     positioned_source_token = replace(
         source_token,
         position=Pos(builder.source_token_max_position_pointer +
                      source_token.position),
     )
     builder.positioned_source_tokens.append(positioned_source_token)
     builder.source_token_max_position_pointer = (
         builder.source_token_max_position_pointer + source_token.position +
         1)
     for _, input_segment_builder in builder.input_segment_builders.items():
         input_segment_builder.add_positioned_source_token(
             positioned_source_token=positioned_source_token)
     builder.input_to_source_gather_index_builder.add_source_token(
         source_token=positioned_source_token)
     builder.source_key_padding_mask_builder.add_token(
         token=positioned_source_token)
     builder.source_attention_mask_builder.add_token(
         token=positioned_source_token)
     builder.source_relations_builder.add_source_token(
         source_token=positioned_source_token)
     return builder
コード例 #3
0
class QuestionToken(Token[QuestionTokenId, VT]):
    key: QuestionTokenId
    value: VT
    raw_value: VT
    scope: AttentionScope
    position: Pos = Pos(0)
    match_tags: Tuple[MatchTag, ...] = tuple()
コード例 #4
0
ファイル: duorat.py プロジェクト: zhaoxlpku/duorat
 def add_input_token(
     self, input_token: Token[InputId, str], copy: bool = False
 ) -> "DuoRATEncoderItemBuilder":
     builder = deepcopy(self) if copy is True else self
     builder.input_segment_builders[input_token.scope].add_input_token(
         input_token=input_token
     )
     if (
         builder.max_supported_input_length is not None
         and builder.input_token_max_position_pointer + input_token.position + 1
         > builder.max_supported_input_length
     ):
         logger.warning(
             "input token tensor has been truncated to {} tokens, "
             "original length was {} tokens".format(
                 builder.max_supported_input_length,
                 builder.input_token_max_position_pointer + input_token.position + 1,
             )
         )
         return builder
     else:
         positioned_input_token = replace(
             input_token,
             position=Pos(
                 builder.input_token_max_position_pointer + input_token.position
             ),
         )
         builder.input_token_max_position_pointer = (
             builder.input_token_max_position_pointer + input_token.position + 1
         )
         builder.input_a_builder.add_token(
             token=replace(
                 positioned_input_token,
                 value=self.input_a_str_to_id(positioned_input_token.value),
             )
         )
         builder.input_b_builder.add_token(
             token=replace(
                 positioned_input_token,
                 value=self.input_b_str_to_id(positioned_input_token.value),
             )
         )
         builder.input_key_padding_mask_builder.add_token(
             token=positioned_input_token
         )
         builder.input_attention_mask_builder.add_token(token=positioned_input_token)
         builder.input_token_type_ids_builder.add_token(token=positioned_input_token)
         builder.input_position_ids_builder.add_token(token=positioned_input_token)
         builder.input_to_source_gather_index_builder.add_input_token(
             input_token=positioned_input_token
         )
         builder.source_relations_builder.add_input_token(
             input_token=positioned_input_token
         )
     return builder
コード例 #5
0
ファイル: duorat.py プロジェクト: ThanThoai/Text2SqlVN
 def add_action_token(
     self,
     action_token: ActionToken[Action],
     copy: bool = False,
 ) -> "DuoRATDecoderItemBuilder":
     builder = deepcopy(self) if copy is True else self
     if isinstance(builder.parsing_result, Done):
         raise ValueError("A complete action sequence cannot be continued")
     elif isinstance(builder.parsing_result, Partial):
         positioned_action_token = replace(
             action_token,
             position=Pos(builder.action_token_max_position_pointer +
                          action_token.position),
         )
         builder.action_token_max_position_pointer = (
             builder.action_token_max_position_pointer +
             action_token.position + 1)
         action_info = ActionInfo(
             action=positioned_action_token.value,
             parent_pos=builder.parsing_result.parent_pos,
             frontier_field=builder.parsing_result.frontier_field,
         )
         # Don't try to parse a mask action
         if positioned_action_token.value != MaskAction():
             builder.parsing_result = builder.parsing_result.cont(
                 positioned_action_token.position,
                 positioned_action_token.value)
         positioned_action_info_token = ActionInfoToken(
             key=action_info,
             value=action_info,
             position=positioned_action_token.position,
             scope=positioned_action_token.scope,
         )
         builder._add_positioned_action_info_token(
             positioned_action_info_token=positioned_action_info_token)
         return builder
     else:
         raise ValueError("Invalid parsing state: {}".format(
             builder.parsing_result))
コード例 #6
0
class ActionInfoToken(Token[ActionInfo, VT]):
    key: ActionInfo
    value: VT
    scope: AttentionScope
    position: Pos = Pos(0)
コード例 #7
0
class TableToken(Token[TableId, VT]):
    key: TableId
    value: VT
    scope: AttentionScope
    position: Pos = Pos(0)
コード例 #8
0
class ColumnToken(Token[ColumnId, VT]):
    key: ColumnId
    value: VT
    scope: AttentionScope
    position: Pos = Pos(0)