Пример #1
0
 def get_continuations(
     self,
     beam_hypotheses: List[DuoRATHypothesis],
     step: int,
     memory: torch.Tensor,
     question_position_map: Dict[Any, Deque[Any]],
     columns_position_map: Dict[Any, Deque[Any]],
     tables_position_map: Dict[Any, Deque[Any]],
     grammar_constrained_inference: bool,
 ):
     device = next(self.parameters()).device
     # we have to make copies of the builders here so that the additions of the mask actions
     # are confined to the for loop:
     decoder_batch = duo_rat_decoder_batch(items=[
         hypothesis.beam_builder.add_action_token(
             action_token=ActionToken(
                 key=MaskAction(),
                 value=MaskAction(),
                 scope=AttentionScope(scope_name=AttentionScopeName.TARGET),
             ),
             copy=True,
         ).build(device=device) for hypothesis in beam_hypotheses
     ])
     expanded_memory = memory.expand(len(beam_hypotheses), -1, -1)
     output = self._decode(memory=expanded_memory, batch=decoder_batch)
     p_copy_gen_logprobs = self.copy_logprob(output)
     (batch_size, seq_len) = decoder_batch.target.shape
     assert p_copy_gen_logprobs.shape == (batch_size, seq_len, 2)
     assert not torch.isnan(p_copy_gen_logprobs).any()
     copy_logits = self.pointer_network(query=output, keys=expanded_memory)
     gen_logits = self.out_proj(output)
     # For each hypothesis, record all possible continuations
     continuations = []
     for hypothesis_id, hypothesis in enumerate(beam_hypotheses):
         assert isinstance(hypothesis.beam_builder.parsing_result, Partial)
         continuations += self.get_hyp_continuations(
             decoder_batch=decoder_batch,
             copy_logits=copy_logits,
             gen_logits=gen_logits,
             p_copy_gen_logprobs=p_copy_gen_logprobs,
             hypothesis_id=hypothesis_id,
             hypothesis=hypothesis,
             step=step,
             question_position_map=question_position_map,
             columns_position_map=columns_position_map,
             tables_position_map=tables_position_map,
             grammar_constrained_inference=grammar_constrained_inference,
         )
     return continuations
Пример #2
0
 def items_to_duo_rat_batch(
         self, preproc_items: List[RATPreprocItem]) -> DuoRATBatch:
     items = self.preproc_items_to_duorat_items(preproc_items)
     decoder_batch = mask_duo_rat_decoder_batch(
         batch=duo_rat_decoder_batch(items=tuple(item.decoder_item
                                                 for item in items)),
         action_relation_types=self.target_relation_types,
         memory_relation_types=self.memory_relation_types,
         mask_sampling_config=self.mask_sampling_config,
         mask_value=self.preproc.target_vocab[MaskAction()],
     )
     duo_rat_batch = DuoRATBatch(
         encoder_batch=duo_rat_encoder_batch(items=tuple(
             item.encoder_item for item in items)),
         decoder_batch=decoder_batch,
     )
     return duo_rat_batch
Пример #3
0
    def compute_loss(self,
                     preproc_items: List[RATPreprocItem],
                     debug=False) -> torch.Tensor:

        items = self.preproc_items_to_duorat_items(preproc_items)
        decoder_batch = duo_rat_decoder_batch(items=tuple(item.decoder_item
                                                          for item in items))
        memory, output = self.forward(batch=DuoRATBatch(
            encoder_batch=duo_rat_encoder_batch(items=tuple(
                item.encoder_item for item in items)),
            decoder_batch=decoder_batch,
        ))
        assert not torch.isnan(memory).any()
        assert not torch.isnan(output).any()
        return self._compute_loss(
            memory=memory,
            output=output,
            target_key_padding_mask=decoder_batch.target_key_padding_mask,
            valid_copy_mask=decoder_batch.valid_copy_mask,
            copy_target_mask=decoder_batch.copy_target_mask,
            valid_actions_mask=decoder_batch.valid_actions_mask,
            target=decoder_batch.target,
        ).mean()