Example #1
0
 def items_to_duo_rat_batch(
         self, preproc_items: List[RATPreprocItem]) -> DuoRATBatch:
     items = self.preproc_items_to_duorat_items(preproc_items)
     decoder_batch = mask_duo_rat_decoder_batch(
         batch=duo_rat_decoder_batch(items=tuple(item.decoder_item
                                                 for item in items)),
         action_relation_types=self.target_relation_types,
         memory_relation_types=self.memory_relation_types,
         mask_sampling_config=self.mask_sampling_config,
         mask_value=self.preproc.target_vocab[MaskAction()],
     )
     duo_rat_batch = DuoRATBatch(
         encoder_batch=duo_rat_encoder_batch(items=tuple(
             item.encoder_item for item in items)),
         decoder_batch=decoder_batch,
     )
     return duo_rat_batch
Example #2
0
 def parse(
     self,
     preproc_items: List[RATPreprocItem],
     decode_max_time_step: int,
     beam_size: int,
 ) -> List[FinishedBeam]:
     assert len(preproc_items) == 1
     if not self.grammar_constrained_inference:
         assert beam_size == 1
     preproc_item = preproc_items[0]
     encoder_item, encoder_item_builder = self._get_encoder_item(
         preproc_item=preproc_item)
     memory = self._encode(batch=duo_rat_encoder_batch(
         items=[encoder_item]))
     return self.parse_decode(
         encoder_item_builder=encoder_item_builder,
         memory=memory,
         beam_size=beam_size,
         decode_max_time_step=decode_max_time_step,
         grammar_constrained_inference=self.grammar_constrained_inference,
     )
Example #3
0
    def compute_loss(self,
                     preproc_items: List[RATPreprocItem],
                     debug=False) -> torch.Tensor:

        items = self.preproc_items_to_duorat_items(preproc_items)
        decoder_batch = duo_rat_decoder_batch(items=tuple(item.decoder_item
                                                          for item in items))
        memory, output = self.forward(batch=DuoRATBatch(
            encoder_batch=duo_rat_encoder_batch(items=tuple(
                item.encoder_item for item in items)),
            decoder_batch=decoder_batch,
        ))
        assert not torch.isnan(memory).any()
        assert not torch.isnan(output).any()
        return self._compute_loss(
            memory=memory,
            output=output,
            target_key_padding_mask=decoder_batch.target_key_padding_mask,
            valid_copy_mask=decoder_batch.valid_copy_mask,
            copy_target_mask=decoder_batch.copy_target_mask,
            valid_actions_mask=decoder_batch.valid_actions_mask,
            target=decoder_batch.target,
        ).mean()