def table_input_tokens( sql_schema: SQLSchema, table_id: TableId, scoping: Scoping) -> Generator[TableToken[str], None, None]: if isinstance(scoping, NoScoping): make_scope = lambda _table_id: AttentionScope(scope_name= AttentionScopeName.INPUT) elif isinstance(scoping, CoarseScoping): make_scope = lambda _table_id: AttentionScope( scope_name=AttentionScopeName.SCHEMA) elif isinstance(scoping, FineScoping): make_scope = lambda table_id: AttentionScope( scope_name=AttentionScopeName.TABLE, scope_extension=table_id) else: raise NotImplementedError return (TableToken(key=table_id, value=s, scope=make_scope(table_id)) for s in sql_schema.tokenized_table_names[table_id])
def column_input_tokens( sql_schema: SQLSchema, column_id: ColumnId, scoping: Scoping) -> Generator[ColumnToken[str], None, None]: if isinstance(scoping, NoScoping): make_scope = lambda _column_id: AttentionScope( scope_name=AttentionScopeName.INPUT) elif isinstance(scoping, CoarseScoping): make_scope = lambda _column_id: AttentionScope( scope_name=AttentionScopeName.SCHEMA) elif isinstance(scoping, FineScoping): make_scope = lambda column_id: AttentionScope( scope_name=AttentionScopeName.COLUMN, scope_extension=column_id) else: raise NotImplementedError return (ColumnToken(key=column_id, value=s, scope=make_scope(column_id)) for s in sql_schema.tokenized_column_names[column_id])
def table_source_tokens(sql_schema: SQLSchema, table_id: TableId, scoping: Scoping) -> Tuple[TableToken[TableId]]: if isinstance(scoping, NoScoping): make_scope = lambda _table_id: AttentionScope( scope_name=AttentionScopeName.SOURCE) elif isinstance(scoping, CoarseScoping): make_scope = lambda _table_id: AttentionScope( scope_name=AttentionScopeName.SCHEMA) elif isinstance(scoping, FineScoping): make_scope = lambda table_id: AttentionScope( scope_name=AttentionScopeName.TABLE, scope_extension=table_id) else: raise NotImplementedError return (TableToken(key=table_id, value=table_id, scope=make_scope(table_id)), )
def column_source_tokens(sql_schema: SQLSchema, column_id: ColumnId, scoping: Scoping) -> Tuple[ColumnToken[ColumnId]]: if isinstance(scoping, NoScoping): make_scope = lambda _column_id: AttentionScope( scope_name=AttentionScopeName.SOURCE) elif isinstance(scoping, CoarseScoping): make_scope = lambda _column_id: AttentionScope( scope_name=AttentionScopeName.SCHEMA) elif isinstance(scoping, FineScoping): make_scope = lambda column_id: AttentionScope( scope_name=AttentionScopeName.COLUMN, scope_extension=column_id) else: raise NotImplementedError return (ColumnToken(key=column_id, value=column_id, scope=make_scope(column_id)), )
def action_tokens( actions: Iterable[Action], scoping: Scoping ) -> Generator[ActionToken[Action], None, None]: if isinstance(scoping, NoScoping): scope = AttentionScope(scope_name=AttentionScopeName.TARGET) else: raise NotImplementedError return (ActionToken(key=action, value=action, scope=scope) for action in actions)
def question_input_tokens( question: Iterable[PreprocQuestionToken], scoping: Scoping) -> Generator[QuestionToken[str], None, None]: if isinstance(scoping, NoScoping): scope = AttentionScope(scope_name=AttentionScopeName.INPUT) elif isinstance(scoping, CoarseScoping) or isinstance( scoping, FineScoping): scope = AttentionScope(scope_name=AttentionScopeName.QUESTION) else: raise NotImplementedError return (QuestionToken( key=token.key, value=token.value, scope=scope, raw_value=token.value, match_tags=token.match_tags, ) for token in question)
def get_continuations( self, beam_hypotheses: List[DuoRATHypothesis], step: int, memory: torch.Tensor, question_position_map: Dict[Any, Deque[Any]], columns_position_map: Dict[Any, Deque[Any]], tables_position_map: Dict[Any, Deque[Any]], grammar_constrained_inference: bool, ): device = next(self.parameters()).device # we have to make copies of the builders here so that the additions of the mask actions # are confined to the for loop: decoder_batch = duo_rat_decoder_batch(items=[ hypothesis.beam_builder.add_action_token( action_token=ActionToken( key=MaskAction(), value=MaskAction(), scope=AttentionScope(scope_name=AttentionScopeName.TARGET), ), copy=True, ).build(device=device) for hypothesis in beam_hypotheses ]) expanded_memory = memory.expand(len(beam_hypotheses), -1, -1) output = self._decode(memory=expanded_memory, batch=decoder_batch) p_copy_gen_logprobs = self.copy_logprob(output) (batch_size, seq_len) = decoder_batch.target.shape assert p_copy_gen_logprobs.shape == (batch_size, seq_len, 2) assert not torch.isnan(p_copy_gen_logprobs).any() copy_logits = self.pointer_network(query=output, keys=expanded_memory) gen_logits = self.out_proj(output) # For each hypothesis, record all possible continuations continuations = [] for hypothesis_id, hypothesis in enumerate(beam_hypotheses): assert isinstance(hypothesis.beam_builder.parsing_result, Partial) continuations += self.get_hyp_continuations( decoder_batch=decoder_batch, copy_logits=copy_logits, gen_logits=gen_logits, p_copy_gen_logprobs=p_copy_gen_logprobs, hypothesis_id=hypothesis_id, hypothesis=hypothesis, step=step, question_position_map=question_position_map, columns_position_map=columns_position_map, tables_position_map=tables_position_map, grammar_constrained_inference=grammar_constrained_inference, ) return continuations
def parse_decode( self, encoder_item_builder: DuoRATEncoderItemBuilder, memory: torch.Tensor, beam_size: int, decode_max_time_step: int, grammar_constrained_inference: bool, ) -> List[FinishedBeam]: question_position_map: Dict[Any, Deque[Any]] = defaultdict(deque) columns_position_map: Dict[Any, Deque[Any]] = defaultdict(deque) tables_position_map: Dict[Any, Deque[Any]] = defaultdict(deque) for positioned_source_token in encoder_item_builder.positioned_source_tokens: if isinstance(positioned_source_token, QuestionToken): question_position_map[ positioned_source_token.raw_value].append( positioned_source_token.position) elif isinstance(positioned_source_token, ColumnToken): columns_position_map[positioned_source_token.value].append( positioned_source_token.position) elif isinstance(positioned_source_token, TableToken): tables_position_map[positioned_source_token.value].append( positioned_source_token.position) else: raise ValueError("Unsupported token type: {}".format( positioned_source_token.__repr__())) initial_hypothesis = DuoRATHypothesis( beam_builder=DuoRATDecoderItemBuilder( positioned_source_tokens=encoder_item_builder. positioned_source_tokens, target_vocab=self.preproc.target_vocab, transition_system=self.preproc.transition_system, allow_unk=False, source_attention_scoping=self.source_attention_scoping, target_attention_scoping=self.target_attention_scoping, target_relation_types=self.target_relation_types, memory_relation_types=self.memory_relation_types, ), scores=[], tokens=[], ) res = beam_search( initial_hypothesis, beam_size, decode_max_time_step, get_new_hypothesis=lambda candidate: DuoRATHypothesis( beam_builder=candidate.prev_hypothesis.beam_builder. add_action_token( action_token=ActionToken( key=candidate.token, value=candidate.token, scope=AttentionScope(scope_name=AttentionScopeName. TARGET), ), copy=True, ), score=candidate.score, tokens=candidate.prev_hypothesis.tokens + [candidate.token], scores=candidate.prev_hypothesis.scores + [candidate.score], ), get_continuations=partial( self.get_continuations, memory=memory, question_position_map=question_position_map, columns_position_map=columns_position_map, tables_position_map=tables_position_map, grammar_constrained_inference=grammar_constrained_inference, ), ) return [ FinishedBeam(ast=hypothesis.beam_builder.parsing_result.res, score=hypothesis.score) for hypothesis in res ]