Ejemplo n.º 1
0
 def get_continuations(
     self,
     beam_hypotheses: List[DuoRATHypothesis],
     step: int,
     memory: torch.Tensor,
     question_position_map: Dict[Any, Deque[Any]],
     columns_position_map: Dict[Any, Deque[Any]],
     tables_position_map: Dict[Any, Deque[Any]],
     grammar_constrained_inference: bool,
 ):
     device = next(self.parameters()).device
     # we have to make copies of the builders here so that the additions of the mask actions
     # are confined to the for loop:
     decoder_batch = duo_rat_decoder_batch(items=[
         hypothesis.beam_builder.add_action_token(
             action_token=ActionToken(
                 key=MaskAction(),
                 value=MaskAction(),
                 scope=AttentionScope(scope_name=AttentionScopeName.TARGET),
             ),
             copy=True,
         ).build(device=device) for hypothesis in beam_hypotheses
     ])
     expanded_memory = memory.expand(len(beam_hypotheses), -1, -1)
     output = self._decode(memory=expanded_memory, batch=decoder_batch)
     p_copy_gen_logprobs = self.copy_logprob(output)
     (batch_size, seq_len) = decoder_batch.target.shape
     assert p_copy_gen_logprobs.shape == (batch_size, seq_len, 2)
     assert not torch.isnan(p_copy_gen_logprobs).any()
     copy_logits = self.pointer_network(query=output, keys=expanded_memory)
     gen_logits = self.out_proj(output)
     # For each hypothesis, record all possible continuations
     continuations = []
     for hypothesis_id, hypothesis in enumerate(beam_hypotheses):
         assert isinstance(hypothesis.beam_builder.parsing_result, Partial)
         continuations += self.get_hyp_continuations(
             decoder_batch=decoder_batch,
             copy_logits=copy_logits,
             gen_logits=gen_logits,
             p_copy_gen_logprobs=p_copy_gen_logprobs,
             hypothesis_id=hypothesis_id,
             hypothesis=hypothesis,
             step=step,
             question_position_map=question_position_map,
             columns_position_map=columns_position_map,
             tables_position_map=tables_position_map,
             grammar_constrained_inference=grammar_constrained_inference,
         )
     return continuations
Ejemplo n.º 2
0
 def items_to_duo_rat_batch(
         self, preproc_items: List[RATPreprocItem]) -> DuoRATBatch:
     items = self.preproc_items_to_duorat_items(preproc_items)
     decoder_batch = mask_duo_rat_decoder_batch(
         batch=duo_rat_decoder_batch(items=tuple(item.decoder_item
                                                 for item in items)),
         action_relation_types=self.target_relation_types,
         memory_relation_types=self.memory_relation_types,
         mask_sampling_config=self.mask_sampling_config,
         mask_value=self.preproc.target_vocab[MaskAction()],
     )
     duo_rat_batch = DuoRATBatch(
         encoder_batch=duo_rat_encoder_batch(items=tuple(
             item.encoder_item for item in items)),
         decoder_batch=decoder_batch,
     )
     return duo_rat_batch
Ejemplo n.º 3
0
 def build(self, device: torch.device) -> DuoRATDecoderItem:
     target = self.target_builder.build(device=device)
     shifted_target = torch.cat(
         (
             target.new_full(size=(1, ),
                             fill_value=self.target_vocab[MaskAction()]),
             target[:-1],
         ),
         dim=0,
     )
     memory_relations = self.memory_relations_builder.build(device=device)
     shifted_memory_relations = torch.cat(
         (
             memory_relations.new_full(size=(1, memory_relations.shape[1]),
                                       fill_value=0),
             memory_relations[:-1],
         ),
         dim=0,
     )
     return DuoRATDecoderItem(
         masked_target=target,
         shifted_target=shifted_target,
         frontier_fields=self.frontier_fields_builder.build(device=device),
         frontier_field_types=self.frontier_field_types_builder.build(
             device=device),
         target_relations=self.target_relations_builder.build(
             device=device),
         target_attention_mask=self.target_attention_mask_builder.build(
             device=device),
         target_key_padding_mask=self.target_key_padding_mask_builder.build(
             device=device),
         memory_relations=memory_relations,
         shifted_memory_relations=shifted_memory_relations,
         memory_attention_mask=self.memory_attention_mask_builder.build(
             device=device),
         memory_key_padding_mask=self.memory_key_padding_mask_builder.build(
             device=device),
         valid_copy_mask=self.valid_copy_mask_builder.build(device=device),
         copy_target_mask=self.copy_target_mask_builder.build(
             device=device),
         valid_actions_mask=self.valid_actions_mask_builder.build(
             device=device),
         target=target,
     )
Ejemplo n.º 4
0
    def save(self) -> None:
        self.save_examples()

        # production rules + Reduce + MASK + GenToken tokens that are *not* in the encoder sequence
        for element in itertools.chain(
                map(
                    lambda production: ApplyRuleAction(production=production),
                    self.transition_system.grammar.id2prod.values(),
                ),
            (ReduceAction(), MaskAction()),
        ):
            self.target_vocab_counter[element] = self.min_freq
        self.target_vocab = ActionVocab(
            counter=self.target_vocab_counter,
            max_size=50000,
            min_freq=self.min_freq,
            specials=[ActionVocab.UNK],
        )
        with open(self.target_vocab_path, "wb") as f:
            pickle.dump(self.target_vocab, f)
Ejemplo n.º 5
0
 def add_action_token(
     self,
     action_token: ActionToken[Action],
     copy: bool = False,
 ) -> "DuoRATDecoderItemBuilder":
     builder = deepcopy(self) if copy is True else self
     if isinstance(builder.parsing_result, Done):
         raise ValueError("A complete action sequence cannot be continued")
     elif isinstance(builder.parsing_result, Partial):
         positioned_action_token = replace(
             action_token,
             position=Pos(builder.action_token_max_position_pointer +
                          action_token.position),
         )
         builder.action_token_max_position_pointer = (
             builder.action_token_max_position_pointer +
             action_token.position + 1)
         action_info = ActionInfo(
             action=positioned_action_token.value,
             parent_pos=builder.parsing_result.parent_pos,
             frontier_field=builder.parsing_result.frontier_field,
         )
         # Don't try to parse a mask action
         if positioned_action_token.value != MaskAction():
             builder.parsing_result = builder.parsing_result.cont(
                 positioned_action_token.position,
                 positioned_action_token.value)
         positioned_action_info_token = ActionInfoToken(
             key=action_info,
             value=action_info,
             position=positioned_action_token.position,
             scope=positioned_action_token.scope,
         )
         builder._add_positioned_action_info_token(
             positioned_action_info_token=positioned_action_info_token)
         return builder
     else:
         raise ValueError("Invalid parsing state: {}".format(
             builder.parsing_result))
Ejemplo n.º 6
0
 def get_hyp_continuations(
     self,
     decoder_batch: DuoRATDecoderBatch,
     copy_logits: torch.Tensor,
     gen_logits: torch.Tensor,
     p_copy_gen_logprobs: torch.Tensor,
     hypothesis_id: int,
     hypothesis: DuoRATHypothesis,
     step: int,
     question_position_map: Dict[Any, Deque[Any]],
     columns_position_map: Dict[Any, Deque[Any]],
     tables_position_map: Dict[Any, Deque[Any]],
     grammar_constrained_inference: bool,
 ):
     device = next(self.parameters()).device
     assert isinstance(hypothesis.beam_builder.parsing_result, Partial)
     continuations = []
     # Copy continuations
     if (hypothesis.beam_builder.parsing_result.frontier_field is None
             or not isinstance(
                 hypothesis.beam_builder.parsing_result.frontier_field.type,
                 ASDLPrimitiveType,
             )):
         pass
     else:
         if grammar_constrained_inference:
             masked_copy_logits = copy_logits.masked_fill(
                 mask=~decoder_batch.valid_copy_mask.to(device=device),
                 value=float("-inf"),
             )
         else:
             masked_copy_logits = copy_logits
         copy_log_probs = F.log_softmax(masked_copy_logits, dim=2)
         for position_map in [
                 question_position_map,
                 columns_position_map,
                 tables_position_map,
         ]:
             for token_value, positions in position_map.items():
                 if (any(decoder_batch.valid_copy_mask[hypothesis_id, -1,
                                                       positions])
                         or not grammar_constrained_inference):
                     score = (torch.logsumexp(
                         copy_log_probs[hypothesis_id, step, positions],
                         dim=0,
                     ) + p_copy_gen_logprobs[hypothesis_id, step, 0])
                     continuations.append(
                         Candidate(
                             token=self.preproc.transition_system.
                             get_gen_token_action(
                                 primitive_type=hypothesis.beam_builder.
                                 parsing_result.frontier_field.type)(
                                     token=token_value),
                             score=hypothesis.score + score.item(),
                             prev_hypothesis=hypothesis,
                         ))
     # Vocab continuations
     if grammar_constrained_inference:
         masked_gen_logits = gen_logits.masked_fill(
             mask=~decoder_batch.valid_actions_mask.to(device=device),
             value=float("-inf"),
         )
     else:
         masked_gen_logits = gen_logits
     gen_log_probs = F.log_softmax(masked_gen_logits, dim=2)
     action_ids = (decoder_batch.valid_actions_mask[hypothesis_id,
                                                    -1].nonzero(
                                                        as_tuple=False)
                   if grammar_constrained_inference else range(
                       decoder_batch.valid_actions_mask.shape[2]))
     for valid_action_id in action_ids:
         # Never continue with a MaskAction.
         if self.preproc.target_vocab.itos[valid_action_id] == MaskAction():
             continue
         score = (gen_log_probs[hypothesis_id, step, valid_action_id] +
                  p_copy_gen_logprobs[hypothesis_id, step, 1])
         continuations.append(
             Candidate(
                 token=self.preproc.target_vocab.itos[valid_action_id],
                 score=hypothesis.score + score.item(),
                 prev_hypothesis=hypothesis,
             ))
     return continuations
Ejemplo n.º 7
0
def mask_action() -> MaskAction:
    return MaskAction()