def get_translator(self, model):
     """Get lazy singleton translator instance."""
     if self.translator is None:
         args_clone = copy.copy(self.args)
         if self.args.loss_beam:  # Override beam size if necessary
             args_clone.beam = self.args.loss_beam
         self.translator = generate.build_sequence_generator(args_clone, [model])
     return self.translator
Beispiel #2
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        src_tokens = sample["net_input"]["src_tokens"]
        beam_size = self.args.rl_num_trajectory
        bsz, srclen = src_tokens.size()
        encoder_input = {
            "src_tokens": sample["net_input"]["src_tokens"],
            "src_lengths": sample["net_input"]["src_lengths"],
        }

        # 1) Generate hypos
        translator = generate.build_sequence_generator(self.args, self.task, [model])
        with torch.no_grad():
            seq_hypos = translator.generate(
                encoder_input,
                beam_size,
                maxlen=int(self.args.max_len_a * srclen + self.args.max_len_b),
            )

        word_hypos = [[] for j in range(bsz)]
        for k in range(bsz):
            word_hypos[k] = [{"tokens": sample["target"][k]}]

        ## Mix sequence, word-level hypos
        hypos = [seq_hypos[j] + word_hypos[j] for j in range(bsz)]
        hypos = [hypo for _ in hypos for hypo in _]
        hypos_len = (
            torch.tensor([len(hypo["tokens"]) for hypo in hypos])
            .type_as(src_tokens)
            .float()
        )
        # mask index for word-level hypos, e.g., target sentence
        mask_index = torch.arange(beam_size, (beam_size + 1) * bsz, beam_size + 1).view(
            -1
        )

        # 2) Compute (log)-probs via forward models
        self.self_rescorer.model = model
        self.self_rescorer.task = self.task
        model.train()
        assert self.self_rescorer.model.training, "model should be in training phase"

        hypo_encoder_inputs, hypo_tokens = self.self_rescorer.prepare_inputs(
            src_tokens, hypos
        )
        hypo_logprobs, hypo_encoder_outs, forward_logprobs = self.self_rescorer.score_tokens(
            hypo_encoder_inputs, hypo_tokens
        )
        hypo_logprobs /= hypos_len ** self.args.rescore_length_penalty

        # 3) Sequence level
        seq_loss = torch.zeros(1).type_as(hypo_logprobs)
        if self.args.rl_weight > 0.0:
            ## 3.1) Compute seq-level rewards
            with torch.no_grad():
                rescorer = Rescorer(self.args, self.task, self.rescore_models)
                scores = rescorer.score(src_tokens, hypos)
                rewards = self.combine_score(src_tokens, hypos, hypos_len, scores)
            assert not rewards.requires_grad, "no grads flow back to generation"
            ## 3.2) Compute Policy Gradient loss
            rewards = rewards.type_as(hypo_logprobs)
            seq_mask = hypo_logprobs.new_ones(hypo_logprobs.size())
            seq_mask[mask_index] = 0.0
            seq_loss = -1.0 * (seq_mask * hypo_logprobs * rewards).sum()

        # 4) Word-level
        word_loss = torch.zeros(1).type_as(hypo_logprobs)
        if self.args.word_weight > 0.0:
            ## 4.1) Compute word-level rewards from a left-right rescoring model
            with torch.no_grad():
                teacher_model = self.rescore_models[self.args.word_model]
                teacher = SimpleModelScorer(self.args, None, teacher_model, self.task)
                _, _, teacher_logprobs = teacher.score_tokens(
                    hypo_encoder_inputs, hypo_tokens
                )
            ## 4.2) Compute word-level loss
            f_logprob, f_index = forward_logprobs.topk(self.args.topk_words)
            word_mask = f_logprob.new_zeros(f_logprob.size())
            word_mask[mask_index, :, :] = 1.0
            ## KL(p_s || p_t) = \sum p_s log p_s - \sum p_s log p_t, aka RL + maxEnt
            word_loss = (
                word_mask
                * f_logprob.exp()
                * (f_logprob - 1.0 * teacher_logprobs.gather(-1, f_index))
            ).sum()

        # 5) Compute Cross-entropy loss
        eos = self.task.target_dictionary.eos()
        target_tokens = torch.cat(
            (
                torch.zeros(bsz, 1).fill_(eos).type_as(sample["target"]),
                sample["target"],
            ),
            dim=1,
        )
        target_encoder_inputs = (
            encoder_input["src_tokens"],
            [encoder_input["src_lengths"][0].item()],
        )
        target_logprobs, target_encoder_out, _ = self.self_rescorer.score_tokens(
            target_encoder_inputs, target_tokens
        )
        nll_loss = -1.0 * target_logprobs.sum()

        # 6) Gather losses
        loss = (
            self.args.rl_weight * seq_loss
            + self.args.word_weight * word_loss
            + nll_loss
        )

        # Logging
        sample_size = (
            sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
        )
        logging_output = {
            "loss": utils.item(loss.data) if reduce else loss.data,
            "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }
        return loss, sample_size, logging_output
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        src_tokens = sample["net_input"]["src_tokens"]
        beam_size = self.args.rl_num_trajectory
        bsz, srclen = src_tokens.size()
        encoder_input = {
            "src_tokens": sample["net_input"]["src_tokens"],
            "src_lengths": sample["net_input"]["src_lengths"],
        }

        # 1) Generate hypos
        translator = generate.build_sequence_generator(self.args, self.task,
                                                       [model])
        with torch.no_grad():
            hypos = translator.generate(
                encoder_input,
                beam_size,
                maxlen=int(self.args.max_len_a * srclen + self.args.max_len_b),
            )
        ## flatten nested list
        hypos = [hypo for _ in hypos
                 for hypo in _]  # with length of bsz * beam_size
        hypos_len = (torch.tensor([len(hypo["tokens"]) for hypo in hypos
                                   ]).type_as(src_tokens).float())

        # 2) Compute (log)-probs via forward models
        self.self_rescorer.model = model
        self.self_rescorer.task = self.task
        model.train()
        assert self.self_rescorer.model.training, "model should be in training phase"

        hypo_encoder_inputs, hypo_tokens = self.self_rescorer.prepare_inputs(
            src_tokens, hypos)
        hypo_logprobs, hypo_encoder_outs, _ = self.self_rescorer.score_tokens(
            hypo_encoder_inputs, hypo_tokens)
        hypo_logprobs /= hypos_len**self.args.rescore_length_penalty

        # 3) Compute rewards from rescoring models
        with torch.no_grad():
            rescorer = Rescorer(self.args, self.task, self.rescore_models)
            scores = rescorer.score(src_tokens, hypos)
            rewards = self.combine_score(src_tokens, hypos, hypos_len, scores)
        assert not rewards.requires_grad, "no grads flow back to generation"

        # 4) Compute Policy Gradient loss
        rewards = rewards.type_as(hypo_logprobs)
        rl_loss = -1.0 * (hypo_logprobs * rewards).sum()

        # 5) Compute Cross-entropy loss
        eos = self.task.target_dictionary.eos()
        target_tokens = torch.cat(
            (
                torch.zeros(bsz, 1).fill_(eos).type_as(sample["target"]),
                sample["target"],
            ),
            dim=1,
        )
        target_encoder_inputs = (
            encoder_input["src_tokens"],
            [encoder_input["src_lengths"][0].item()],
        )
        target_logprobs, target_encoder_out, _ = self.self_rescorer.score_tokens(
            target_encoder_inputs, target_tokens)
        nll_loss = -1.0 * target_logprobs.sum()

        # 6) Gather losses
        loss = self.args.rl_weight * rl_loss + nll_loss

        # Logging
        sample_size = (sample["target"].size(0)
                       if self.args.sentence_avg else sample["ntokens"])
        logging_output = {
            "loss": utils.item(loss.data) if reduce else loss.data,
            "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }
        return loss, sample_size, logging_output