def run_batch(self, batch: Batch, max_output_length: int, beam_size: int, beam_alpha: float) -> (np.array, np.array): """ Get outputs and attentions scores for a given batch :param batch: batch to generate hypotheses for :param max_output_length: maximum length of hypotheses :param beam_size: size of the beam for beam search, if 0 use greedy :param beam_alpha: alpha value for beam search :return: stacked_output: hypotheses for batch, stacked_attention_scores: attention scores for batch """ encoder_output, encoder_hidden = self.encode(batch.src, batch.src_lengths, batch.src_mask, self.encoder) if self.encoder_2: encoder_output_2, encoder_hidden_2 = self.encode( src=batch.src_prev, src_length=batch.src_prev_lengths, src_mask=batch.src_prev_mask, encoder=self.encoder_2) x = self.last_layer(encoder_output, batch.src_mask, encoder_output_2, batch.src_prev_mask) encoder_output, encoder_hidden = self.last_layer_norm(x), None # if maximum output length is not globally specified, adapt to src len if max_output_length is None: max_output_length = int(max(batch.src_lengths.cpu().numpy()) * 1.5) # greedy decoding if beam_size < 2: stacked_output, stacked_attention_scores = greedy( encoder_hidden=encoder_hidden, encoder_output=encoder_output, eos_index=self.eos_index, src_mask=batch.src_mask, embed=self.trg_embed, bos_index=self.bos_index, decoder=self.decoder, max_output_length=max_output_length) # batch, time, max_src_length else: # beam size stacked_output, stacked_attention_scores = \ beam_search( size=beam_size, encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=batch.src_mask, embed=self.trg_embed, max_output_length=max_output_length, alpha=beam_alpha, eos_index=self.eos_index, pad_index=self.pad_index, bos_index=self.bos_index, decoder=self.decoder) return stacked_output, stacked_attention_scores
def run_batch(self, batch: Batch, max_output_length: int, beam_size: int, beam_alpha: float, return_logp: bool = False) \ -> (np.array, np.array, Optional[np.array]): """ Get outputs and attentions scores for a given batch :param batch: batch to generate hypotheses for :param max_output_length: maximum length of hypotheses :param beam_size: size of the beam for beam search, if 0 use greedy :param beam_alpha: alpha value for beam search :param return_logp: keep track of log probabilities as well :return: - stacked_output: hypotheses for batch, - stacked_attention_scores: attention scores for batch - log_probs: log probabilities for batch hypotheses """ encoder_output, encoder_hidden = self.encode(batch.src, batch.src_lengths, batch.src_mask) # if maximum output length is not globally specified, adapt to src len if max_output_length is None: max_output_length = int(max(batch.src_lengths.cpu().numpy()) * 1.5) # greedy decoding if beam_size == 0: stacked_output, stacked_attention_scores, logprobs = greedy( encoder_hidden=encoder_hidden, encoder_output=encoder_output, src_mask=batch.src_mask, embed=self.trg_embed, bos_index=self.bos_index, decoder=self.decoder, max_output_length=max_output_length, eos_index=self.eos_index, return_logp=return_logp) # batch, time, max_src_length else: # beam size > 0 stacked_output, stacked_attention_scores, logprobs = \ beam_search(size=beam_size, encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=batch.src_mask, embed=self.trg_embed, max_output_length=max_output_length, alpha=beam_alpha, eos_index=self.eos_index, pad_index=self.pad_index, bos_index=self.bos_index, decoder=self.decoder, return_logp=return_logp) return stacked_output, stacked_attention_scores, logprobs
def run_batch(self, batch: Batch, max_output_length: int, beam_size: int, beam_alpha: float): """ Get outputs and attentions scores for a given batch :param batch: :param max_output_length: :param beam_size: :param beam_alpha: :return: """ encoder_output, encoder_hidden = self.encode(batch.src, batch.src_lengths, batch.src_mask) # if maximum output length is not globally specified, adapt to src len if max_output_length is None: max_output_length = int(max(batch.src_lengths.cpu().numpy()) * 1.5) # greedy decoding if beam_size == 0: stacked_output, stacked_attention_scores = greedy( encoder_hidden=encoder_hidden, encoder_output=encoder_output, src_mask=batch.src_mask, embed=self.trg_embed, bos_index=self.bos_index, decoder=self.decoder, max_output_length=max_output_length) # batch, time, max_src_length else: # beam size stacked_output, stacked_attention_scores = \ beam_search(size=beam_size, encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=batch.src_mask, embed=self.trg_embed, max_output_length=max_output_length, alpha=beam_alpha, eos_index=self.eos_index, pad_index=self.pad_index, bos_index=self.bos_index, decoder=self.decoder) return stacked_output, stacked_attention_scores
def run_batch(self, batch: Batch, max_output_length: int, beam_size: int, beam_alpha: float) -> (np.array, np.array): """ Get outputs and attentions scores for a given batch :param batch: batch to generate hypotheses for :param max_output_length: maximum length of hypotheses :param beam_size: size of the beam for beam search, if 0 use greedy :param beam_alpha: alpha value for beam search :return: stacked_output: hypotheses for batch, stacked_attention_scores: attention scores for batch """ encoder_output, encoder_hidden = self.encode(batch.src, batch.src_lengths, batch.src_mask) # if maximum output length is not globally specified, adapt to src len if max_output_length is None: max_output_length = int(max(batch.src_lengths.cpu().numpy()) * 1.5) if hasattr(batch, "kbsrc"): # B x KB x EMB; B x KB; B x KB kb_keys, kb_values, kb_values_embed, kb_trv, kb_mask = self.preprocess_batch_kb( batch, kbattdims=self.kb_att_dims) if kb_keys is None: knowledgebase = None else: knowledgebase = (kb_keys, kb_values, kb_values_embed, kb_mask) else: knowledgebase = None # greedy decoding if beam_size == 0: stacked_output, stacked_attention_scores, stacked_kb_att_scores, _ = greedy( encoder_hidden=encoder_hidden, encoder_output=encoder_output, src_mask=batch.src_mask, embed=self.trg_embed, bos_index=self.bos_index, decoder=self.decoder, generator=self.generator, max_output_length=max_output_length, knowledgebase=knowledgebase) # batch, time, max_src_length else: # beam size stacked_output, stacked_attention_scores, stacked_kb_att_scores = \ beam_search( decoder=self.decoder, generator=self.generator, size=beam_size, encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=batch.src_mask, embed=self.trg_embed, max_output_length=max_output_length, alpha=beam_alpha, eos_index=self.eos_index, pad_index=self.pad_index, bos_index=self.bos_index, knowledgebase = knowledgebase) if knowledgebase != None and self.do_postproc: with self.Timer("postprocessing hypotheses"): # replace kb value tokens with actual values in hypotheses, e.g. # ['your','@event','is','at','@meeting_time'] => ['your', 'conference', 'is', 'at', '7pm'] # assert kb_values.shape[1] == 1, kb_values.shape stacked_output = self.postprocess_batch_hypotheses( stacked_output, stacked_kb_att_scores, kb_values, kb_trv) print( f"proc_batch: Hypotheses: {self.trv_vocab.arrays_to_sentences(stacked_output)}" ) else: print( f"proc_batch: Hypotheses: {self.trg_vocab.arrays_to_sentences(stacked_output)}" ) return stacked_output, stacked_attention_scores, stacked_kb_att_scores
def get_loss_for_batch(self, batch: Batch, loss_function: nn.Module, max_output_length: int = None, e_i: float = 1., greedy_threshold: float = 0.9) -> Tensor: """ Compute non-normalized loss and number of tokens for a batch :param batch: batch to compute loss for :param loss_function: loss function, computes for input and target a scalar loss for the complete batch :param max_output_length: maximum length of hypotheses :param e_i: scheduled sampling probability of taking true label vs model generation at every decoding step (https://arxiv.org/abs/1506.03099 Section 2.4) :param greedy_threshold: only actually do greedy search once e_i is below this threshold :return: batch_loss: sum of losses over non-pad elements in the batch """ print(f"\n{'-'*10}GET LOSS FWD PASS: START current batch{'-'*10}\n") assert 0. <= e_i <= 1., f"e_i={e_i} should be a probability" do_teacher_force = e_i >= greedy_threshold # prefer to still do teacher forcing when e_i="label taking probability" is high in scheduled sampling trg, trg_input, trg_mask = batch.trg, batch.trg_input, batch.trg_mask batch_size = trg.size(0) if hasattr(batch, "kbsrc"): kb_keys, kb_values, kb_values_embed, _, kb_mask = self.preprocess_batch_kb( batch, kbattdims=self.kb_att_dims) else: kb_keys = None log_probs = None if kb_keys is not None: # kb task assert batch.kbsrc != None, batch.kbsrc # FIXME hardcoded attribute name if hasattr(batch, "trgcanon"): # get loss on canonized target data during validation, see joeynmt.prediction.validate_on_data # batch size sanity check assert batch.trgcanon.shape[0] == batch.trg.shape[0], [ t.shape for t in [batch.trg, batch.trgcanon] ] # reassign these variables for loss calculation trg, trg_input, trg_mask = batch.trgcanon, batch.trgcanon_input, batch.trgcanon_mask if not do_teacher_force: # scheduled sampling # only use true labels with probability 0 <= e_i < 1; otherwise take previous model generation; # => do a greedy search (autoregressive training as hinted at in Eric et al) with self.Timer("model training: KB Task: do greedy search"): encoder_output, encoder_hidden = self.encode( batch.src, batch.src_lengths, batch.src_mask) # if maximum output length is not globally specified, adapt to src len if max_output_length is None: max_output_length = int( max(batch.src_lengths.cpu().numpy()) * 1.5) print(f"in model.glfb; kb_keys are {kb_keys}") stacked_output, stacked_attention_scores, stacked_kb_att_scores, log_probs = greedy( encoder_hidden=encoder_hidden, encoder_output=encoder_output, src_mask=batch.src_mask, embed=self.trg_embed, bos_index=self.bos_index, decoder=self.decoder, generator=self.generator, max_output_length=trg.size(-1), knowledgebase=(kb_keys, kb_values, kb_values_embed, kb_mask), trg_input=trg_input, e_i=e_i, ) else: # take true label at every step => just do fwd pass (normal teacher forcing training) with self.Timer("model training: KB Task: model fwd pass"): hidden, att_probs, out, kb_probs, _, _ = self.forward( src=batch.src, trg_input=trg_input, src_mask=batch.src_mask, src_lengths=batch.src_lengths, trg_mask=trg_mask, kb_keys=kb_keys, kb_mask=kb_mask, kb_values_embed=kb_values_embed) else: # vanilla, not kb task if not do_teacher_force: raise NotImplementedError( "scheduled sampling only works for KB task atm") hidden, att_probs, out, kb_probs, _, _ = self.forward( src=batch.src, trg_input=trg_input, src_mask=batch.src_mask, src_lengths=batch.src_lengths, trg_mask=trg_mask) kb_values = None if log_probs is None: # same generator fwd pass for KB task and no KB task if teacher forcing # pass output through Generator and add biases for KB entries in vocab indexes of kb values log_probs = self.generator(out, kb_probs=kb_probs, kb_values=kb_values) if hasattr(batch, "trgcanon"): # only calculate loss on this field of the batch during validation loss calculation assert not log_probs.requires_grad, "using batch.trgcanon shouldnt happen / be done during training (canonized data is used in the 'trg' field there)" # check number of classes equals prediction distribution support # (can otherwise lead to nasty CUDA device side asserts that dont give a traceback to here) assert log_probs.size(-1) == self.generator.output_size, ( log_probs.shape, self.generator.output_size) # compute batch loss try: batch_loss = loss_function(log_probs, trg) except Exception as e: print(f"batch_size: {batch_size}") print(f"log_probs= {log_probs.shape}") print(f"trg = {trg.shape}") print(f"") print(f"") raise e # confirm trg is actually canonical: # input(f"loss is calculated on these sequences: {self.trv_vocab.arrays_to_sentences(trg.cpu().numpy())}") with self.Timer("debugging: greedy hypothesis:"): mle_tokens = argmax(log_probs, dim=-1) # torch argmax mle_tokens = mle_tokens.cpu().numpy() print( f"proc_batch: Hypothesis: {self.trg_vocab.arrays_to_sentences(mle_tokens)[-1]}" ) print(f"\n{'-'*10}GET LOSS FWD PASS: END current batch{'-'*10}\n") # batch loss = sum xent over all elements in batch that are not pad return batch_loss