Exemplo n.º 1
0
    def run_batch(self, batch: Batch, max_output_length: int, beam_size: int,
                  beam_alpha: float) -> (np.array, np.array):
        """
        Get outputs and attentions scores for a given batch

        :param batch: batch to generate hypotheses for
        :param max_output_length: maximum length of hypotheses
        :param beam_size: size of the beam for beam search, if 0 use greedy
        :param beam_alpha: alpha value for beam search
        :return: stacked_output: hypotheses for batch,
            stacked_attention_scores: attention scores for batch
        """
        encoder_output, encoder_hidden = self.encode(batch.src,
                                                     batch.src_lengths,
                                                     batch.src_mask,
                                                     self.encoder)
        if self.encoder_2:
            encoder_output_2, encoder_hidden_2 = self.encode(
                src=batch.src_prev,
                src_length=batch.src_prev_lengths,
                src_mask=batch.src_prev_mask,
                encoder=self.encoder_2)
            x = self.last_layer(encoder_output, batch.src_mask,
                                encoder_output_2, batch.src_prev_mask)

            encoder_output, encoder_hidden = self.last_layer_norm(x), None

        # if maximum output length is not globally specified, adapt to src len
        if max_output_length is None:
            max_output_length = int(max(batch.src_lengths.cpu().numpy()) * 1.5)

        # greedy decoding
        if beam_size < 2:
            stacked_output, stacked_attention_scores = greedy(
                encoder_hidden=encoder_hidden,
                encoder_output=encoder_output,
                eos_index=self.eos_index,
                src_mask=batch.src_mask,
                embed=self.trg_embed,
                bos_index=self.bos_index,
                decoder=self.decoder,
                max_output_length=max_output_length)
            # batch, time, max_src_length
        else:  # beam size
            stacked_output, stacked_attention_scores = \
                    beam_search(
                        size=beam_size, encoder_output=encoder_output,
                        encoder_hidden=encoder_hidden,
                        src_mask=batch.src_mask, embed=self.trg_embed,
                        max_output_length=max_output_length,
                        alpha=beam_alpha, eos_index=self.eos_index,
                        pad_index=self.pad_index,
                        bos_index=self.bos_index,
                        decoder=self.decoder)

        return stacked_output, stacked_attention_scores
Exemplo n.º 2
0
    def run_batch(self, batch: Batch, max_output_length: int, beam_size: int,
                  beam_alpha: float, return_logp: bool = False) \
            -> (np.array, np.array, Optional[np.array]):
        """
        Get outputs and attentions scores for a given batch

        :param batch: batch to generate hypotheses for
        :param max_output_length: maximum length of hypotheses
        :param beam_size: size of the beam for beam search, if 0 use greedy
        :param beam_alpha: alpha value for beam search
        :param return_logp: keep track of log probabilities as well
        :return:
            - stacked_output: hypotheses for batch,
            - stacked_attention_scores: attention scores for batch
            - log_probs: log probabilities for batch hypotheses
        """
        encoder_output, encoder_hidden = self.encode(batch.src,
                                                     batch.src_lengths,
                                                     batch.src_mask)

        # if maximum output length is not globally specified, adapt to src len
        if max_output_length is None:
            max_output_length = int(max(batch.src_lengths.cpu().numpy()) * 1.5)

        # greedy decoding
        if beam_size == 0:
            stacked_output, stacked_attention_scores, logprobs = greedy(
                encoder_hidden=encoder_hidden,
                encoder_output=encoder_output,
                src_mask=batch.src_mask,
                embed=self.trg_embed,
                bos_index=self.bos_index,
                decoder=self.decoder,
                max_output_length=max_output_length,
                eos_index=self.eos_index,
                return_logp=return_logp)
            # batch, time, max_src_length
        else:  # beam size > 0
            stacked_output, stacked_attention_scores, logprobs = \
                beam_search(size=beam_size, encoder_output=encoder_output,
                            encoder_hidden=encoder_hidden,
                            src_mask=batch.src_mask, embed=self.trg_embed,
                            max_output_length=max_output_length,
                            alpha=beam_alpha, eos_index=self.eos_index,
                            pad_index=self.pad_index, bos_index=self.bos_index,
                            decoder=self.decoder, return_logp=return_logp)

        return stacked_output, stacked_attention_scores, logprobs
Exemplo n.º 3
0
    def run_batch(self, batch: Batch, max_output_length: int, beam_size: int,
                  beam_alpha: float):
        """
        Get outputs and attentions scores for a given batch

        :param batch:
        :param max_output_length:
        :param beam_size:
        :param beam_alpha:
        :return:
        """
        encoder_output, encoder_hidden = self.encode(batch.src,
                                                     batch.src_lengths,
                                                     batch.src_mask)

        # if maximum output length is not globally specified, adapt to src len
        if max_output_length is None:
            max_output_length = int(max(batch.src_lengths.cpu().numpy()) * 1.5)

        # greedy decoding
        if beam_size == 0:
            stacked_output, stacked_attention_scores = greedy(
                encoder_hidden=encoder_hidden,
                encoder_output=encoder_output,
                src_mask=batch.src_mask,
                embed=self.trg_embed,
                bos_index=self.bos_index,
                decoder=self.decoder,
                max_output_length=max_output_length)
            # batch, time, max_src_length
        else:  # beam size
            stacked_output, stacked_attention_scores = \
                beam_search(size=beam_size, encoder_output=encoder_output,
                            encoder_hidden=encoder_hidden,
                            src_mask=batch.src_mask, embed=self.trg_embed,
                            max_output_length=max_output_length,
                            alpha=beam_alpha, eos_index=self.eos_index,
                            pad_index=self.pad_index, bos_index=self.bos_index,
                            decoder=self.decoder)

        return stacked_output, stacked_attention_scores
Exemplo n.º 4
0
    def run_batch(self, batch: Batch, max_output_length: int, beam_size: int,
                  beam_alpha: float) -> (np.array, np.array):
        """
        Get outputs and attentions scores for a given batch

        :param batch: batch to generate hypotheses for
        :param max_output_length: maximum length of hypotheses
        :param beam_size: size of the beam for beam search, if 0 use greedy
        :param beam_alpha: alpha value for beam search
        :return: 
            stacked_output: hypotheses for batch,
            stacked_attention_scores: attention scores for batch
        """

        encoder_output, encoder_hidden = self.encode(batch.src,
                                                     batch.src_lengths,
                                                     batch.src_mask)

        # if maximum output length is not globally specified, adapt to src len
        if max_output_length is None:
            max_output_length = int(max(batch.src_lengths.cpu().numpy()) * 1.5)

        if hasattr(batch, "kbsrc"):
            # B x KB x EMB; B x KB; B x KB
            kb_keys, kb_values, kb_values_embed, kb_trv, kb_mask = self.preprocess_batch_kb(
                batch, kbattdims=self.kb_att_dims)
            if kb_keys is None:
                knowledgebase = None
            else:
                knowledgebase = (kb_keys, kb_values, kb_values_embed, kb_mask)
        else:
            knowledgebase = None

        # greedy decoding
        if beam_size == 0:
            stacked_output, stacked_attention_scores, stacked_kb_att_scores, _ = greedy(
                encoder_hidden=encoder_hidden,
                encoder_output=encoder_output,
                src_mask=batch.src_mask,
                embed=self.trg_embed,
                bos_index=self.bos_index,
                decoder=self.decoder,
                generator=self.generator,
                max_output_length=max_output_length,
                knowledgebase=knowledgebase)
            # batch, time, max_src_length
        else:  # beam size
            stacked_output, stacked_attention_scores, stacked_kb_att_scores = \
                    beam_search(
                        decoder=self.decoder,
                        generator=self.generator,
                        size=beam_size, encoder_output=encoder_output,
                        encoder_hidden=encoder_hidden,
                        src_mask=batch.src_mask, embed=self.trg_embed,
                        max_output_length=max_output_length,
                        alpha=beam_alpha, eos_index=self.eos_index,
                        pad_index=self.pad_index,
                        bos_index=self.bos_index,
                        knowledgebase = knowledgebase)

        if knowledgebase != None and self.do_postproc:
            with self.Timer("postprocessing hypotheses"):
                # replace kb value tokens with actual values in hypotheses, e.g.
                # ['your','@event','is','at','@meeting_time'] => ['your', 'conference', 'is', 'at', '7pm']
                # assert kb_values.shape[1] == 1, kb_values.shape
                stacked_output = self.postprocess_batch_hypotheses(
                    stacked_output, stacked_kb_att_scores, kb_values, kb_trv)

            print(
                f"proc_batch: Hypotheses: {self.trv_vocab.arrays_to_sentences(stacked_output)}"
            )
        else:
            print(
                f"proc_batch: Hypotheses: {self.trg_vocab.arrays_to_sentences(stacked_output)}"
            )

        return stacked_output, stacked_attention_scores, stacked_kb_att_scores
Exemplo n.º 5
0
    def get_loss_for_batch(self,
                           batch: Batch,
                           loss_function: nn.Module,
                           max_output_length: int = None,
                           e_i: float = 1.,
                           greedy_threshold: float = 0.9) -> Tensor:
        """
        Compute non-normalized loss and number of tokens for a batch

        :param batch: batch to compute loss for
        :param loss_function: loss function, computes for input and target
            a scalar loss for the complete batch
        :param max_output_length: maximum length of hypotheses
        :param e_i: scheduled sampling probability of taking true label vs model generation at every decoding step
        (https://arxiv.org/abs/1506.03099 Section 2.4)
        :param greedy_threshold: only actually do greedy search once e_i is below this threshold
        :return: batch_loss: sum of losses over non-pad elements in the batch
        """

        print(f"\n{'-'*10}GET LOSS FWD PASS: START current batch{'-'*10}\n")

        assert 0. <= e_i <= 1., f"e_i={e_i} should be a probability"
        do_teacher_force = e_i >= greedy_threshold  # prefer to still do teacher forcing when e_i="label taking probability" is high in scheduled sampling

        trg, trg_input, trg_mask = batch.trg, batch.trg_input, batch.trg_mask
        batch_size = trg.size(0)

        if hasattr(batch, "kbsrc"):
            kb_keys, kb_values, kb_values_embed, _, kb_mask = self.preprocess_batch_kb(
                batch, kbattdims=self.kb_att_dims)
        else:
            kb_keys = None

        log_probs = None

        if kb_keys is not None:  # kb task
            assert batch.kbsrc != None, batch.kbsrc

            # FIXME hardcoded attribute name
            if hasattr(batch, "trgcanon"):
                # get loss on canonized target data during validation, see joeynmt.prediction.validate_on_data
                # batch size sanity check
                assert batch.trgcanon.shape[0] == batch.trg.shape[0], [
                    t.shape for t in [batch.trg, batch.trgcanon]
                ]
                # reassign these variables for loss calculation
                trg, trg_input, trg_mask = batch.trgcanon, batch.trgcanon_input, batch.trgcanon_mask

            if not do_teacher_force:  # scheduled sampling
                # only use true labels with probability 0 <= e_i < 1; otherwise take previous model generation;
                # => do a greedy search (autoregressive training as hinted at in Eric et al)
                with self.Timer("model training: KB Task: do greedy search"):

                    encoder_output, encoder_hidden = self.encode(
                        batch.src, batch.src_lengths, batch.src_mask)

                    # if maximum output length is not globally specified, adapt to src len
                    if max_output_length is None:
                        max_output_length = int(
                            max(batch.src_lengths.cpu().numpy()) * 1.5)

                    print(f"in model.glfb; kb_keys are {kb_keys}")
                    stacked_output, stacked_attention_scores, stacked_kb_att_scores, log_probs = greedy(
                        encoder_hidden=encoder_hidden,
                        encoder_output=encoder_output,
                        src_mask=batch.src_mask,
                        embed=self.trg_embed,
                        bos_index=self.bos_index,
                        decoder=self.decoder,
                        generator=self.generator,
                        max_output_length=trg.size(-1),
                        knowledgebase=(kb_keys, kb_values, kb_values_embed,
                                       kb_mask),
                        trg_input=trg_input,
                        e_i=e_i,
                    )
            else:  # take true label at every step => just do fwd pass (normal teacher forcing training)
                with self.Timer("model training: KB Task: model fwd pass"):

                    hidden, att_probs, out, kb_probs, _, _ = self.forward(
                        src=batch.src,
                        trg_input=trg_input,
                        src_mask=batch.src_mask,
                        src_lengths=batch.src_lengths,
                        trg_mask=trg_mask,
                        kb_keys=kb_keys,
                        kb_mask=kb_mask,
                        kb_values_embed=kb_values_embed)

        else:  # vanilla, not kb task
            if not do_teacher_force:
                raise NotImplementedError(
                    "scheduled sampling only works for KB task atm")

            hidden, att_probs, out, kb_probs, _, _ = self.forward(
                src=batch.src,
                trg_input=trg_input,
                src_mask=batch.src_mask,
                src_lengths=batch.src_lengths,
                trg_mask=trg_mask)
            kb_values = None

        if log_probs is None:
            # same generator fwd pass for KB task and no KB task if teacher forcing
            # pass output through Generator and add biases for KB entries in vocab indexes of kb values
            log_probs = self.generator(out,
                                       kb_probs=kb_probs,
                                       kb_values=kb_values)

        if hasattr(batch, "trgcanon"):
            # only calculate loss on this field of the batch during validation loss calculation
            assert not log_probs.requires_grad, "using batch.trgcanon shouldnt happen / be done during training (canonized data is used in the 'trg' field there)"

        # check number of classes equals prediction distribution support
        # (can otherwise lead to nasty CUDA device side asserts that dont give a traceback to here)
        assert log_probs.size(-1) == self.generator.output_size, (
            log_probs.shape, self.generator.output_size)

        # compute batch loss
        try:
            batch_loss = loss_function(log_probs, trg)
        except Exception as e:
            print(f"batch_size: {batch_size}")
            print(f"log_probs= {log_probs.shape}")
            print(f"trg = {trg.shape}")
            print(f"")
            print(f"")
            raise e

        # confirm trg is actually canonical:
        # input(f"loss is calculated on these sequences: {self.trv_vocab.arrays_to_sentences(trg.cpu().numpy())}")

        with self.Timer("debugging: greedy hypothesis:"):
            mle_tokens = argmax(log_probs, dim=-1)  # torch argmax
            mle_tokens = mle_tokens.cpu().numpy()

            print(
                f"proc_batch: Hypothesis: {self.trg_vocab.arrays_to_sentences(mle_tokens)[-1]}"
            )

        print(f"\n{'-'*10}GET LOSS FWD PASS: END current batch{'-'*10}\n")

        # batch loss = sum xent over all elements in batch that are not pad
        return batch_loss