def get_continuations( self, beam_hypotheses: List[DuoRATHypothesis], step: int, memory: torch.Tensor, question_position_map: Dict[Any, Deque[Any]], columns_position_map: Dict[Any, Deque[Any]], tables_position_map: Dict[Any, Deque[Any]], grammar_constrained_inference: bool, ): device = next(self.parameters()).device # we have to make copies of the builders here so that the additions of the mask actions # are confined to the for loop: decoder_batch = duo_rat_decoder_batch(items=[ hypothesis.beam_builder.add_action_token( action_token=ActionToken( key=MaskAction(), value=MaskAction(), scope=AttentionScope(scope_name=AttentionScopeName.TARGET), ), copy=True, ).build(device=device) for hypothesis in beam_hypotheses ]) expanded_memory = memory.expand(len(beam_hypotheses), -1, -1) output = self._decode(memory=expanded_memory, batch=decoder_batch) p_copy_gen_logprobs = self.copy_logprob(output) (batch_size, seq_len) = decoder_batch.target.shape assert p_copy_gen_logprobs.shape == (batch_size, seq_len, 2) assert not torch.isnan(p_copy_gen_logprobs).any() copy_logits = self.pointer_network(query=output, keys=expanded_memory) gen_logits = self.out_proj(output) # For each hypothesis, record all possible continuations continuations = [] for hypothesis_id, hypothesis in enumerate(beam_hypotheses): assert isinstance(hypothesis.beam_builder.parsing_result, Partial) continuations += self.get_hyp_continuations( decoder_batch=decoder_batch, copy_logits=copy_logits, gen_logits=gen_logits, p_copy_gen_logprobs=p_copy_gen_logprobs, hypothesis_id=hypothesis_id, hypothesis=hypothesis, step=step, question_position_map=question_position_map, columns_position_map=columns_position_map, tables_position_map=tables_position_map, grammar_constrained_inference=grammar_constrained_inference, ) return continuations
def items_to_duo_rat_batch( self, preproc_items: List[RATPreprocItem]) -> DuoRATBatch: items = self.preproc_items_to_duorat_items(preproc_items) decoder_batch = mask_duo_rat_decoder_batch( batch=duo_rat_decoder_batch(items=tuple(item.decoder_item for item in items)), action_relation_types=self.target_relation_types, memory_relation_types=self.memory_relation_types, mask_sampling_config=self.mask_sampling_config, mask_value=self.preproc.target_vocab[MaskAction()], ) duo_rat_batch = DuoRATBatch( encoder_batch=duo_rat_encoder_batch(items=tuple( item.encoder_item for item in items)), decoder_batch=decoder_batch, ) return duo_rat_batch
def compute_loss(self, preproc_items: List[RATPreprocItem], debug=False) -> torch.Tensor: items = self.preproc_items_to_duorat_items(preproc_items) decoder_batch = duo_rat_decoder_batch(items=tuple(item.decoder_item for item in items)) memory, output = self.forward(batch=DuoRATBatch( encoder_batch=duo_rat_encoder_batch(items=tuple( item.encoder_item for item in items)), decoder_batch=decoder_batch, )) assert not torch.isnan(memory).any() assert not torch.isnan(output).any() return self._compute_loss( memory=memory, output=output, target_key_padding_mask=decoder_batch.target_key_padding_mask, valid_copy_mask=decoder_batch.valid_copy_mask, copy_target_mask=decoder_batch.copy_target_mask, valid_actions_mask=decoder_batch.valid_actions_mask, target=decoder_batch.target, ).mean()