コード例 #1
0
def table_input_tokens(
        sql_schema: SQLSchema, table_id: TableId,
        scoping: Scoping) -> Generator[TableToken[str], None, None]:
    if isinstance(scoping, NoScoping):
        make_scope = lambda _table_id: AttentionScope(scope_name=
                                                      AttentionScopeName.INPUT)
    elif isinstance(scoping, CoarseScoping):
        make_scope = lambda _table_id: AttentionScope(
            scope_name=AttentionScopeName.SCHEMA)
    elif isinstance(scoping, FineScoping):
        make_scope = lambda table_id: AttentionScope(
            scope_name=AttentionScopeName.TABLE, scope_extension=table_id)
    else:
        raise NotImplementedError
    return (TableToken(key=table_id, value=s, scope=make_scope(table_id))
            for s in sql_schema.tokenized_table_names[table_id])
コード例 #2
0
def column_input_tokens(
        sql_schema: SQLSchema, column_id: ColumnId,
        scoping: Scoping) -> Generator[ColumnToken[str], None, None]:
    if isinstance(scoping, NoScoping):
        make_scope = lambda _column_id: AttentionScope(
            scope_name=AttentionScopeName.INPUT)
    elif isinstance(scoping, CoarseScoping):
        make_scope = lambda _column_id: AttentionScope(
            scope_name=AttentionScopeName.SCHEMA)
    elif isinstance(scoping, FineScoping):
        make_scope = lambda column_id: AttentionScope(
            scope_name=AttentionScopeName.COLUMN, scope_extension=column_id)
    else:
        raise NotImplementedError
    return (ColumnToken(key=column_id, value=s, scope=make_scope(column_id))
            for s in sql_schema.tokenized_column_names[column_id])
コード例 #3
0
def table_source_tokens(sql_schema: SQLSchema, table_id: TableId,
                        scoping: Scoping) -> Tuple[TableToken[TableId]]:
    if isinstance(scoping, NoScoping):
        make_scope = lambda _table_id: AttentionScope(
            scope_name=AttentionScopeName.SOURCE)
    elif isinstance(scoping, CoarseScoping):
        make_scope = lambda _table_id: AttentionScope(
            scope_name=AttentionScopeName.SCHEMA)
    elif isinstance(scoping, FineScoping):
        make_scope = lambda table_id: AttentionScope(
            scope_name=AttentionScopeName.TABLE, scope_extension=table_id)
    else:
        raise NotImplementedError
    return (TableToken(key=table_id,
                       value=table_id,
                       scope=make_scope(table_id)), )
コード例 #4
0
def column_source_tokens(sql_schema: SQLSchema, column_id: ColumnId,
                         scoping: Scoping) -> Tuple[ColumnToken[ColumnId]]:
    if isinstance(scoping, NoScoping):
        make_scope = lambda _column_id: AttentionScope(
            scope_name=AttentionScopeName.SOURCE)
    elif isinstance(scoping, CoarseScoping):
        make_scope = lambda _column_id: AttentionScope(
            scope_name=AttentionScopeName.SCHEMA)
    elif isinstance(scoping, FineScoping):
        make_scope = lambda column_id: AttentionScope(
            scope_name=AttentionScopeName.COLUMN, scope_extension=column_id)
    else:
        raise NotImplementedError
    return (ColumnToken(key=column_id,
                        value=column_id,
                        scope=make_scope(column_id)), )
コード例 #5
0
ファイル: tokens.py プロジェクト: zhaoxlpku/duorat
def action_tokens(
    actions: Iterable[Action], scoping: Scoping
) -> Generator[ActionToken[Action], None, None]:
    if isinstance(scoping, NoScoping):
        scope = AttentionScope(scope_name=AttentionScopeName.TARGET)
    else:
        raise NotImplementedError
    return (ActionToken(key=action, value=action, scope=scope) for action in actions)
コード例 #6
0
def question_input_tokens(
        question: Iterable[PreprocQuestionToken],
        scoping: Scoping) -> Generator[QuestionToken[str], None, None]:
    if isinstance(scoping, NoScoping):
        scope = AttentionScope(scope_name=AttentionScopeName.INPUT)
    elif isinstance(scoping, CoarseScoping) or isinstance(
            scoping, FineScoping):
        scope = AttentionScope(scope_name=AttentionScopeName.QUESTION)
    else:
        raise NotImplementedError
    return (QuestionToken(
        key=token.key,
        value=token.value,
        scope=scope,
        raw_value=token.value,
        match_tags=token.match_tags,
    ) for token in question)
コード例 #7
0
ファイル: duorat.py プロジェクト: zhaoxlpku/duorat
 def get_continuations(
     self,
     beam_hypotheses: List[DuoRATHypothesis],
     step: int,
     memory: torch.Tensor,
     question_position_map: Dict[Any, Deque[Any]],
     columns_position_map: Dict[Any, Deque[Any]],
     tables_position_map: Dict[Any, Deque[Any]],
     grammar_constrained_inference: bool,
 ):
     device = next(self.parameters()).device
     # we have to make copies of the builders here so that the additions of the mask actions
     # are confined to the for loop:
     decoder_batch = duo_rat_decoder_batch(items=[
         hypothesis.beam_builder.add_action_token(
             action_token=ActionToken(
                 key=MaskAction(),
                 value=MaskAction(),
                 scope=AttentionScope(scope_name=AttentionScopeName.TARGET),
             ),
             copy=True,
         ).build(device=device) for hypothesis in beam_hypotheses
     ])
     expanded_memory = memory.expand(len(beam_hypotheses), -1, -1)
     output = self._decode(memory=expanded_memory, batch=decoder_batch)
     p_copy_gen_logprobs = self.copy_logprob(output)
     (batch_size, seq_len) = decoder_batch.target.shape
     assert p_copy_gen_logprobs.shape == (batch_size, seq_len, 2)
     assert not torch.isnan(p_copy_gen_logprobs).any()
     copy_logits = self.pointer_network(query=output, keys=expanded_memory)
     gen_logits = self.out_proj(output)
     # For each hypothesis, record all possible continuations
     continuations = []
     for hypothesis_id, hypothesis in enumerate(beam_hypotheses):
         assert isinstance(hypothesis.beam_builder.parsing_result, Partial)
         continuations += self.get_hyp_continuations(
             decoder_batch=decoder_batch,
             copy_logits=copy_logits,
             gen_logits=gen_logits,
             p_copy_gen_logprobs=p_copy_gen_logprobs,
             hypothesis_id=hypothesis_id,
             hypothesis=hypothesis,
             step=step,
             question_position_map=question_position_map,
             columns_position_map=columns_position_map,
             tables_position_map=tables_position_map,
             grammar_constrained_inference=grammar_constrained_inference,
         )
     return continuations
コード例 #8
0
ファイル: duorat.py プロジェクト: zhaoxlpku/duorat
    def parse_decode(
        self,
        encoder_item_builder: DuoRATEncoderItemBuilder,
        memory: torch.Tensor,
        beam_size: int,
        decode_max_time_step: int,
        grammar_constrained_inference: bool,
    ) -> List[FinishedBeam]:
        question_position_map: Dict[Any, Deque[Any]] = defaultdict(deque)
        columns_position_map: Dict[Any, Deque[Any]] = defaultdict(deque)
        tables_position_map: Dict[Any, Deque[Any]] = defaultdict(deque)
        for positioned_source_token in encoder_item_builder.positioned_source_tokens:
            if isinstance(positioned_source_token, QuestionToken):
                question_position_map[
                    positioned_source_token.raw_value].append(
                        positioned_source_token.position)
            elif isinstance(positioned_source_token, ColumnToken):
                columns_position_map[positioned_source_token.value].append(
                    positioned_source_token.position)
            elif isinstance(positioned_source_token, TableToken):
                tables_position_map[positioned_source_token.value].append(
                    positioned_source_token.position)
            else:
                raise ValueError("Unsupported token type: {}".format(
                    positioned_source_token.__repr__()))

        initial_hypothesis = DuoRATHypothesis(
            beam_builder=DuoRATDecoderItemBuilder(
                positioned_source_tokens=encoder_item_builder.
                positioned_source_tokens,
                target_vocab=self.preproc.target_vocab,
                transition_system=self.preproc.transition_system,
                allow_unk=False,
                source_attention_scoping=self.source_attention_scoping,
                target_attention_scoping=self.target_attention_scoping,
                target_relation_types=self.target_relation_types,
                memory_relation_types=self.memory_relation_types,
            ),
            scores=[],
            tokens=[],
        )
        res = beam_search(
            initial_hypothesis,
            beam_size,
            decode_max_time_step,
            get_new_hypothesis=lambda candidate: DuoRATHypothesis(
                beam_builder=candidate.prev_hypothesis.beam_builder.
                add_action_token(
                    action_token=ActionToken(
                        key=candidate.token,
                        value=candidate.token,
                        scope=AttentionScope(scope_name=AttentionScopeName.
                                             TARGET),
                    ),
                    copy=True,
                ),
                score=candidate.score,
                tokens=candidate.prev_hypothesis.tokens + [candidate.token],
                scores=candidate.prev_hypothesis.scores + [candidate.score],
            ),
            get_continuations=partial(
                self.get_continuations,
                memory=memory,
                question_position_map=question_position_map,
                columns_position_map=columns_position_map,
                tables_position_map=tables_position_map,
                grammar_constrained_inference=grammar_constrained_inference,
            ),
        )
        return [
            FinishedBeam(ast=hypothesis.beam_builder.parsing_result.res,
                         score=hypothesis.score) for hypothesis in res
        ]