Beispiel #1
0
    def test_compute_scores(self):
        # TODO(halilakin): Verify behaviour in batch mode
        test_args = test_utils.ModelParamsDict()
        _, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)
        task = tasks.PytorchTranslateTask(test_args, src_dict, tgt_dict)
        model = task.build_model(test_args)

        with patch(
                "pytorch_translate.utils.load_diverse_ensemble_for_inference",
                return_value=([model], test_args, task),
        ):
            scorer = SimpleModelScorer(test_args, "/tmp/model_path.txt")
            tgt_tokens = torch.tensor([[2, 11, 22, 0], [2, 33, 44, 55]])
            logprobs = torch.zeros(tgt_tokens.shape[0], tgt_tokens.shape[1],
                                   len(tgt_dict))
            logprobs[0, 0, 11] = 0.5
            logprobs[0, 1, 22] = 1.5
            logprobs[0, 3, :] = 5

            logprobs[1, 0, 33] = 0.5
            logprobs[1, 1, 44] = 1.5
            logprobs[1, 2, 55] = 2.5

            hypos_scores = scorer.compute_scores(tgt_tokens, logprobs)
            assert hypos_scores[0] == 2.0
            assert hypos_scores[1] == 4.5
Beispiel #2
0
    def __init__(self, args):
        self.args = args

        assert (
            args.l2r_model_path is not None
        ), "Rescoring needs --l2r-model-path which generated given hypotheses"
        self.l2r_model_scorer = SimpleModelScorer(args, args.l2r_model_path)
        self.forward_task = self.l2r_model_scorer.task

        self.r2l_model_scorer = None
        if args.r2l_model_path:
            self.r2l_model_scorer = R2LModelScorer(args, args.r2l_model_path)

        self.reverse_model_scorer = None
        if args.reverse_model_path:
            self.reverse_model_scorer = ReverseModelScorer(
                args, args.reverse_model_path, self.forward_task)

        self.lm_scorer = None
        if args.lm_model_path:
            self.lm_scorer = LMScorer(args, args.lm_model_path,
                                      self.forward_task)

        self.cloze_transformer_scorer = None
        if args.cloze_transformer_path:
            self.cloze_transformer_scorer = SimpleModelScorer(
                args, args.cloze_transformer_path)
    def test_convert_hypos_to_tgt_tokens(self):
        with patch(
                "pytorch_translate.utils.load_diverse_ensemble_for_inference",
                return_value=([self.model], self.args, self.task),
        ):
            scorer = SimpleModelScorer(self.args, "/tmp/model_path.txt")
            hypos = [
                {
                    "tokens": torch.Tensor([1, 2, 3, 4, 5])
                },
                {
                    "tokens": torch.Tensor([1, 2, 3, 4])
                },
                {
                    "tokens": torch.Tensor([1, 2, 3])
                },
                {
                    "tokens": torch.Tensor([1, 2])
                },
                {
                    "tokens": torch.Tensor([1])
                },
            ]
            tgt_tokens = scorer.convert_hypos_to_tgt_tokens(hypos)

            pad = self.task.tgt_dict.pad()
            eos = self.task.tgt_dict.eos()
            expected_tgt_tokens = torch.Tensor([
                [eos, 1, 2, 3, 4, 5],
                [eos, 1, 2, 3, 4, pad],
                [eos, 1, 2, 3, pad, pad],
                [eos, 1, 2, pad, pad, pad],
                [eos, 1, pad, pad, pad, pad],
            ]).type_as(tgt_tokens)
            assert torch.equal(tgt_tokens, expected_tgt_tokens)
Beispiel #4
0
    def test_simple_scorer_prepare_inputs(self):
        pad = self.task.tgt_dict.pad()
        eos = self.task.tgt_dict.eos()

        src_tokens = torch.tensor([[6, 7, 8]], dtype=torch.int)
        hypos = [
            {"tokens": torch.tensor([12, 13, 14, eos], dtype=torch.int)},
            {"tokens": torch.tensor([22, 23, eos], dtype=torch.int)},
        ]

        with patch(
            "pytorch_translate.utils.load_diverse_ensemble_for_inference",
            return_value=([self.model], self.args, self.task),
        ):
            scorer = SimpleModelScorer(
                self.args, "/tmp/model_path.txt", None, self.task
            )
            (encoder_inputs, tgt_tokens) = scorer.prepare_inputs(src_tokens, hypos)

            # Test encoder inputs
            assert torch.equal(
                encoder_inputs[0], torch.tensor([[6, 7, 8], [6, 7, 8]], dtype=torch.int)
            ), "Encoder inputs are not as expected"
            assert torch.equal(
                encoder_inputs[1], torch.tensor([3, 3], dtype=torch.int)
            ), "Src lengths are not as expected"

            # Test target tokens
            assert torch.equal(
                tgt_tokens,
                torch.tensor(
                    [[eos, 12, 13, 14, eos], [eos, 22, 23, eos, pad]], dtype=torch.int
                ),
            ), "Target tokens are not as expected"
Beispiel #5
0
    def test_padding(self):
        """Same sentence should produce the same score with or without padding
        """
        eos = self.task.tgt_dict.eos()

        src_tokens = torch.tensor([[6, 7, 8]], dtype=torch.long)
        hypos_with_padding = [
            {"tokens": torch.tensor([12, 13, 14, 15, 16, eos], dtype=torch.long)},
            {"tokens": torch.tensor([22, 23, 24, 25, eos], dtype=torch.long)},
            {"tokens": torch.tensor([32, 33, eos], dtype=torch.long)},
        ]
        hypos_without_padding = [
            {"tokens": torch.tensor([22, 23, 24, 25, eos], dtype=torch.long)},
            {"tokens": torch.tensor([32, 33, eos], dtype=torch.long)},
        ]

        with patch(
            "pytorch_translate.utils.load_diverse_ensemble_for_inference",
            return_value=([self.model], self.args, self.task),
        ):
            scorer = SimpleModelScorer(
                self.args, "/tmp/model_path.txt", None, self.task
            )
            scores_with_padding = scorer.score(src_tokens, hypos_with_padding)
            scores_without_padding = scorer.score(src_tokens, hypos_without_padding)
            assert (
                scores_with_padding[1] == scores_without_padding[0]
                and scores_with_padding[2] == scores_without_padding[1]
            ), "Scores with and without padding should not be different"
    def test_reverse_model_scorer(self):
        """Verify that reverse model is working correctly, by having one
        forward and one backward scorers, and asserting that we get the same
        scores from two scorers when source and targets are reversed

        """
        eos = self.task.tgt_dict.eos()

        src_tokens = torch.tensor([[6, 7, 8]], dtype=torch.long)
        hypos = [
            {
                "tokens": torch.tensor([12, 13, 14, 15, eos], dtype=torch.long)
            },
            {
                "tokens": torch.tensor([22, 23, 24, eos], dtype=torch.long)
            },
            {
                "tokens": torch.tensor([32, 33, eos], dtype=torch.long)
            },
        ]

        reverse_src_tokens_0 = torch.tensor([[12, 13, 14, 15]],
                                            dtype=torch.long)
        reverse_src_tokens_1 = torch.tensor([[22, 23, 24]], dtype=torch.long)
        reverse_src_tokens_2 = torch.tensor([[32, 33]], dtype=torch.long)
        reverse_hypos = [{
            "tokens":
            torch.tensor([6, 7, 8, eos], dtype=torch.long)
        }]

        with patch(
                "pytorch_translate.utils.load_diverse_ensemble_for_inference",
                return_value=([self.model], self.args, self.task),
        ):
            scorer = SimpleModelScorer(self.args, "/tmp/model_path.txt", None,
                                       self.task)
            reverse_scorer = ReverseModelScorer(self.args,
                                                "/tmp/model_path.txt", None,
                                                self.task)
            forward_scores = scorer.score(src_tokens, hypos)
            reverse_score_0 = reverse_scorer.score(reverse_src_tokens_0,
                                                   reverse_hypos)
            reverse_score_1 = reverse_scorer.score(reverse_src_tokens_1,
                                                   reverse_hypos)
            reverse_score_2 = reverse_scorer.score(reverse_src_tokens_2,
                                                   reverse_hypos)

            assert forward_scores[0] == reverse_score_0[0]
            assert forward_scores[1] == reverse_score_1[0]
            assert forward_scores[2] == reverse_score_2[0]
    def test_convert_hypos_to_tgt_tokens(self):
        test_args = test_utils.ModelParamsDict()
        _, src_dict, tgt_dict = test_utils.prepare_inputs(test_args)
        task = tasks.PytorchTranslateTask(test_args, src_dict, tgt_dict)
        model = task.build_model(test_args)

        with patch(
                "pytorch_translate.utils.load_diverse_ensemble_for_inference",
                return_value=([model], test_args, task),
        ):
            scorer = SimpleModelScorer(test_args, None)

            hypos = [
                {
                    "tokens": torch.Tensor([1, 2, 3, 4, 5])
                },
                {
                    "tokens": torch.Tensor([1, 2, 3, 4])
                },
                {
                    "tokens": torch.Tensor([1, 2, 3])
                },
                {
                    "tokens": torch.Tensor([1, 2])
                },
                {
                    "tokens": torch.Tensor([1])
                },
            ]
            tgt_tokens = scorer.convert_hypos_to_tgt_tokens(hypos)

            pad = task.tgt_dict.pad()
            eos = task.tgt_dict.eos()
            expected_tgt_tokens = torch.Tensor([
                [eos, 1, 2, 3, 4, 5],
                [eos, 1, 2, 3, 4, pad],
                [eos, 1, 2, 3, pad, pad],
                [eos, 1, 2, pad, pad, pad],
                [eos, 1, pad, pad, pad, pad],
            ]).type_as(tgt_tokens)
            assert torch.equal(tgt_tokens, expected_tgt_tokens)
    def test_compute_scores(self):
        # TODO(halilakin): Verify behaviour in batch mode
        with patch(
                "pytorch_translate.utils.load_diverse_ensemble_for_inference",
                return_value=([self.model], self.args, self.task),
        ):
            scorer = SimpleModelScorer(self.args, "/tmp/model_path.txt")
            tgt_tokens = torch.tensor([[2, 11, 22, 0], [2, 33, 44, 55]])
            logprobs = torch.zeros(tgt_tokens.shape[0], tgt_tokens.shape[1],
                                   len(self.task.tgt_dict))
            logprobs[0, 0, 11] = 0.5
            logprobs[0, 1, 22] = 1.5
            logprobs[0, 3, :] = 5

            logprobs[1, 0, 33] = 0.5
            logprobs[1, 1, 44] = 1.5
            logprobs[1, 2, 55] = 2.5

            hypos_scores = scorer.compute_scores(tgt_tokens, logprobs)
            assert hypos_scores[0] == 2.0
            assert hypos_scores[1] == 4.5
Beispiel #9
0
    def __init__(self, args, forward_task=None, models=None):
        """models = {'l2r_model': {'model': model, 'task': task}, ...}"""
        self.args = args
        if models is None:
            models = {}
        self.l2r_model_scorer = None
        if args.l2r_model_path or models.get("l2r_model", None):
            self.l2r_model_scorer = SimpleModelScorer(
                args, args.l2r_model_path, models.get("l2r_model", None), forward_task
            )

        self.r2l_model_scorer = None
        if args.r2l_model_path or models.get("r2l_model", None):
            self.r2l_model_scorer = R2LModelScorer(
                args, args.r2l_model_path, models.get("r2l_model", None), forward_task
            )

        self.reverse_model_scorer = None
        if args.reverse_model_path or models.get("reverse_model", None):
            self.reverse_model_scorer = ReverseModelScorer(
                args,
                args.reverse_model_path,
                models.get("reverse_model", None),
                forward_task,
            )

        self.lm_scorer = None
        if args.lm_model_path or models.get("lm_model", None):
            self.lm_scorer = LMScorer(
                args, args.lm_model_path, models.get("lm_model", None), forward_task
            )

        self.cloze_transformer_scorer = None
        if args.cloze_transformer_path or models.get("cloze_model", None):
            self.cloze_transformer_scorer = SimpleModelScorer(
                args,
                args.cloze_transformer_path,
                models.get("cloze_model", None),
                forward_task,
            )
Beispiel #10
0
class Rescorer:
    """Reranks n-best hypotheses based on extra models and parameters"""

    def __init__(self, args, forward_task=None, models=None):
        """models = {'l2r_model': {'model': model, 'task': task}, ...}"""
        self.args = args
        if models is None:
            models = {}
        self.l2r_model_scorer = None
        if args.l2r_model_path or models.get("l2r_model", None):
            self.l2r_model_scorer = SimpleModelScorer(
                args, args.l2r_model_path, models.get("l2r_model", None), forward_task
            )

        self.r2l_model_scorer = None
        if args.r2l_model_path or models.get("r2l_model", None):
            self.r2l_model_scorer = R2LModelScorer(
                args, args.r2l_model_path, models.get("r2l_model", None), forward_task
            )

        self.reverse_model_scorer = None
        if args.reverse_model_path or models.get("reverse_model", None):
            self.reverse_model_scorer = ReverseModelScorer(
                args,
                args.reverse_model_path,
                models.get("reverse_model", None),
                forward_task,
            )

        self.lm_scorer = None
        if args.lm_model_path or models.get("lm_model", None):
            self.lm_scorer = LMScorer(
                args, args.lm_model_path, models.get("lm_model", None), forward_task
            )

        self.cloze_transformer_scorer = None
        if args.cloze_transformer_path or models.get("cloze_model", None):
            self.cloze_transformer_scorer = SimpleModelScorer(
                args,
                args.cloze_transformer_path,
                models.get("cloze_model", None),
                forward_task,
            )

    def score(self, src_tokens, hypos):
        """run models and compute scores based on p(y), p(x|y) etc."""

        scores = torch.zeros((len(hypos), len(FeatureList)), dtype=torch.float)

        self.compute_l2r_model_scores(src_tokens, hypos, scores)
        self.compute_r2l_model_scores(src_tokens, hypos, scores)
        self.compute_reverse_model_scores(src_tokens, hypos, scores)
        self.compute_lm_scores(src_tokens, hypos, scores)
        self.compute_cloze_transformer_scores(src_tokens, hypos, scores)

        return scores

    def compute_l2r_model_scores(self, src_tokens, hypos, scores):
        if not self.l2r_model_scorer:
            return
        l2r_scores = self.l2r_model_scorer.score(src_tokens, hypos)
        scores[:, FeatureList.L2R_MODEL_SCORE.value] = l2r_scores[:]

    def compute_r2l_model_scores(self, src_tokens, hypos, scores):
        if not self.r2l_model_scorer:
            return
        r2l_scores = self.r2l_model_scorer.score(src_tokens, hypos)
        scores[:, FeatureList.R2L_MODEL_SCORE.value] = r2l_scores[:]

    def compute_reverse_model_scores(self, src_tokens, hypos, scores):
        """computes p(x|y) for each hypothesis. """
        if not self.reverse_model_scorer:
            return

        scores[
            :, FeatureList.REVERSE_MODEL_SCORE.value
        ] = self.reverse_model_scorer.score(src_tokens, hypos)

    def compute_lm_scores(self, src_tokens, hypos, scores):
        """computes p(x|y) for each hypothesis. """
        if not self.lm_scorer:
            return

        lm_scores = self.lm_scorer.score(src_tokens, hypos)
        scores[:, FeatureList.LM_SCORE.value] = lm_scores[:]

    def compute_cloze_transformer_scores(self, src_tokens, hypos, scores):
        if not self.cloze_transformer_scorer:
            return

        cloze_scores = self.cloze_transformer_scorer.score(src_tokens, hypos)
        scores[:, FeatureList.CLOZE_SCORE.value] = cloze_scores[:]
Beispiel #11
0
class Rescorer:
    """Reranks n-best hypotheses based on extra models and parameters"""
    def __init__(self, args):
        self.args = args

        assert (
            args.l2r_model_path is not None
        ), "Rescoring needs --l2r-model-path which generated given hypotheses"
        self.l2r_model_scorer = SimpleModelScorer(args, args.l2r_model_path)
        self.forward_task = self.l2r_model_scorer.task

        self.r2l_model_scorer = None
        if args.r2l_model_path:
            self.r2l_model_scorer = R2LModelScorer(args, args.r2l_model_path)

        self.reverse_model_scorer = None
        if args.reverse_model_path:
            self.reverse_model_scorer = ReverseModelScorer(
                args, args.reverse_model_path, self.forward_task)

        self.lm_scorer = None
        if args.lm_model_path:
            self.lm_scorer = LMScorer(args, args.lm_model_path,
                                      self.forward_task)

        self.cloze_transformer_scorer = None
        if args.cloze_transformer_path:
            self.cloze_transformer_scorer = SimpleModelScorer(
                args, args.cloze_transformer_path)

    def score(self, src_tokens, hypos):
        """run models and compute scores based on p(y), p(x|y) etc."""
        scores = torch.zeros((len(hypos), len(FeatureList)), dtype=torch.float)

        self.compute_l2r_model_scores(src_tokens, hypos, scores)
        self.compute_r2l_model_scores(src_tokens, hypos, scores)
        self.compute_reverse_model_scores(src_tokens, hypos, scores)
        self.compute_lm_scores(src_tokens, hypos, scores)
        self.compute_cloze_transformer_scores(src_tokens, hypos, scores)

        return scores

    def compute_l2r_model_scores(self, src_tokens, hypos, scores):
        l2r_scores = self.l2r_model_scorer.score(src_tokens, hypos)
        scores[:, FeatureList.L2R_MODEL_SCORE.value] = l2r_scores[:]

    def compute_r2l_model_scores(self, src_tokens, hypos, scores):
        if not self.r2l_model_scorer:
            return
        r2l_scores = self.r2l_model_scorer.score(src_tokens, hypos)
        scores[:, FeatureList.R2L_MODEL_SCORE.value] = r2l_scores[:]

    def compute_reverse_model_scores(self, src_tokens, hypos, scores):
        """computes p(x|y) for each hypothesis. """
        if not self.reverse_model_scorer:
            return

        scores[:, FeatureList.REVERSE_MODEL_SCORE.
               value] = self.reverse_model_scorer.score(src_tokens, hypos)

    def compute_lm_scores(self, src_tokens, hypos, scores):
        """computes p(x|y) for each hypothesis. """
        if not self.lm_scorer:
            return

        lm_scores = self.lm_scorer.score(src_tokens, hypos)
        scores[:, FeatureList.LM_SCORE.value] = lm_scores[:]

    def compute_cloze_transformer_scores(self, src_tokens, hypos, scores):
        if not self.cloze_transformer_scorer:
            return

        cloze_scores = self.cloze_transformer_scorer.score(src_tokens, hypos)
        scores[:, FeatureList.CLOZE_SCORE.value] = cloze_scores[:]
Beispiel #12
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        src_tokens = sample["net_input"]["src_tokens"]
        beam_size = self.args.rl_num_trajectory
        bsz, srclen = src_tokens.size()
        encoder_input = {
            "src_tokens": sample["net_input"]["src_tokens"],
            "src_lengths": sample["net_input"]["src_lengths"],
        }

        # 1) Generate hypos
        translator = generate.build_sequence_generator(self.args, self.task, [model])
        with torch.no_grad():
            seq_hypos = translator.generate(
                encoder_input,
                beam_size,
                maxlen=int(self.args.max_len_a * srclen + self.args.max_len_b),
            )

        word_hypos = [[] for j in range(bsz)]
        for k in range(bsz):
            word_hypos[k] = [{"tokens": sample["target"][k]}]

        ## Mix sequence, word-level hypos
        hypos = [seq_hypos[j] + word_hypos[j] for j in range(bsz)]
        hypos = [hypo for _ in hypos for hypo in _]
        hypos_len = (
            torch.tensor([len(hypo["tokens"]) for hypo in hypos])
            .type_as(src_tokens)
            .float()
        )
        # mask index for word-level hypos, e.g., target sentence
        mask_index = torch.arange(beam_size, (beam_size + 1) * bsz, beam_size + 1).view(
            -1
        )

        # 2) Compute (log)-probs via forward models
        self.self_rescorer.model = model
        self.self_rescorer.task = self.task
        model.train()
        assert self.self_rescorer.model.training, "model should be in training phase"

        hypo_encoder_inputs, hypo_tokens = self.self_rescorer.prepare_inputs(
            src_tokens, hypos
        )
        hypo_logprobs, hypo_encoder_outs, forward_logprobs = self.self_rescorer.score_tokens(
            hypo_encoder_inputs, hypo_tokens
        )
        hypo_logprobs /= hypos_len ** self.args.rescore_length_penalty

        # 3) Sequence level
        seq_loss = torch.zeros(1).type_as(hypo_logprobs)
        if self.args.rl_weight > 0.0:
            ## 3.1) Compute seq-level rewards
            with torch.no_grad():
                rescorer = Rescorer(self.args, self.task, self.rescore_models)
                scores = rescorer.score(src_tokens, hypos)
                rewards = self.combine_score(src_tokens, hypos, hypos_len, scores)
            assert not rewards.requires_grad, "no grads flow back to generation"
            ## 3.2) Compute Policy Gradient loss
            rewards = rewards.type_as(hypo_logprobs)
            seq_mask = hypo_logprobs.new_ones(hypo_logprobs.size())
            seq_mask[mask_index] = 0.0
            seq_loss = -1.0 * (seq_mask * hypo_logprobs * rewards).sum()

        # 4) Word-level
        word_loss = torch.zeros(1).type_as(hypo_logprobs)
        if self.args.word_weight > 0.0:
            ## 4.1) Compute word-level rewards from a left-right rescoring model
            with torch.no_grad():
                teacher_model = self.rescore_models[self.args.word_model]
                teacher = SimpleModelScorer(self.args, None, teacher_model, self.task)
                _, _, teacher_logprobs = teacher.score_tokens(
                    hypo_encoder_inputs, hypo_tokens
                )
            ## 4.2) Compute word-level loss
            f_logprob, f_index = forward_logprobs.topk(self.args.topk_words)
            word_mask = f_logprob.new_zeros(f_logprob.size())
            word_mask[mask_index, :, :] = 1.0
            ## KL(p_s || p_t) = \sum p_s log p_s - \sum p_s log p_t, aka RL + maxEnt
            word_loss = (
                word_mask
                * f_logprob.exp()
                * (f_logprob - 1.0 * teacher_logprobs.gather(-1, f_index))
            ).sum()

        # 5) Compute Cross-entropy loss
        eos = self.task.target_dictionary.eos()
        target_tokens = torch.cat(
            (
                torch.zeros(bsz, 1).fill_(eos).type_as(sample["target"]),
                sample["target"],
            ),
            dim=1,
        )
        target_encoder_inputs = (
            encoder_input["src_tokens"],
            [encoder_input["src_lengths"][0].item()],
        )
        target_logprobs, target_encoder_out, _ = self.self_rescorer.score_tokens(
            target_encoder_inputs, target_tokens
        )
        nll_loss = -1.0 * target_logprobs.sum()

        # 6) Gather losses
        loss = (
            self.args.rl_weight * seq_loss
            + self.args.word_weight * word_loss
            + nll_loss
        )

        # Logging
        sample_size = (
            sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
        )
        logging_output = {
            "loss": utils.item(loss.data) if reduce else loss.data,
            "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }
        return loss, sample_size, logging_output
Beispiel #13
0
 def __init__(self, args, task):
     super().__init__(args, task)
     self.self_rescorer = SimpleModelScorer(args, None, None, task)
     self.rescore_models = self.load_rescore_models(args)
     self.args = args
     self.task = task
Beispiel #14
0
class RescoringCriterion(FairseqCriterion):
    def __init__(self, args, task):
        super().__init__(args, task)
        self.self_rescorer = SimpleModelScorer(args, None, None, task)
        self.rescore_models = self.load_rescore_models(args)
        self.args = args
        self.task = task

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        rescore_add_args(parser)
        parser.add_argument(
            "--rl-weight",
            type=float,
            default=0.1,
            help="trade-off coefficient of rl loss",
        )
        parser.add_argument(
            "--rl-num-trajectory",
            type=int,
            default=3,
            help="num trajectory in rl training",
        )
        parser.add_argument(
            "--topk-words",
            type=int,
            default=8,
            help="match topk words at each time step",
        )
        parser.add_argument(
            "--word-weight", type=float, default=1.0, help="weight for word level"
        )
        parser.add_argument(
            "--word-model",
            type=str,
            default="cloze_model",
            help="word-level teacher model",
        )

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        src_tokens = sample["net_input"]["src_tokens"]
        beam_size = self.args.rl_num_trajectory
        bsz, srclen = src_tokens.size()
        encoder_input = {
            "src_tokens": sample["net_input"]["src_tokens"],
            "src_lengths": sample["net_input"]["src_lengths"],
        }

        # 1) Generate hypos
        translator = generate.build_sequence_generator(self.args, self.task, [model])
        with torch.no_grad():
            seq_hypos = translator.generate(
                encoder_input,
                beam_size,
                maxlen=int(self.args.max_len_a * srclen + self.args.max_len_b),
            )

        word_hypos = [[] for j in range(bsz)]
        for k in range(bsz):
            word_hypos[k] = [{"tokens": sample["target"][k]}]

        ## Mix sequence, word-level hypos
        hypos = [seq_hypos[j] + word_hypos[j] for j in range(bsz)]
        hypos = [hypo for _ in hypos for hypo in _]
        hypos_len = (
            torch.tensor([len(hypo["tokens"]) for hypo in hypos])
            .type_as(src_tokens)
            .float()
        )
        # mask index for word-level hypos, e.g., target sentence
        mask_index = torch.arange(beam_size, (beam_size + 1) * bsz, beam_size + 1).view(
            -1
        )

        # 2) Compute (log)-probs via forward models
        self.self_rescorer.model = model
        self.self_rescorer.task = self.task
        model.train()
        assert self.self_rescorer.model.training, "model should be in training phase"

        hypo_encoder_inputs, hypo_tokens = self.self_rescorer.prepare_inputs(
            src_tokens, hypos
        )
        hypo_logprobs, hypo_encoder_outs, forward_logprobs = self.self_rescorer.score_tokens(
            hypo_encoder_inputs, hypo_tokens
        )
        hypo_logprobs /= hypos_len ** self.args.rescore_length_penalty

        # 3) Sequence level
        seq_loss = torch.zeros(1).type_as(hypo_logprobs)
        if self.args.rl_weight > 0.0:
            ## 3.1) Compute seq-level rewards
            with torch.no_grad():
                rescorer = Rescorer(self.args, self.task, self.rescore_models)
                scores = rescorer.score(src_tokens, hypos)
                rewards = self.combine_score(src_tokens, hypos, hypos_len, scores)
            assert not rewards.requires_grad, "no grads flow back to generation"
            ## 3.2) Compute Policy Gradient loss
            rewards = rewards.type_as(hypo_logprobs)
            seq_mask = hypo_logprobs.new_ones(hypo_logprobs.size())
            seq_mask[mask_index] = 0.0
            seq_loss = -1.0 * (seq_mask * hypo_logprobs * rewards).sum()

        # 4) Word-level
        word_loss = torch.zeros(1).type_as(hypo_logprobs)
        if self.args.word_weight > 0.0:
            ## 4.1) Compute word-level rewards from a left-right rescoring model
            with torch.no_grad():
                teacher_model = self.rescore_models[self.args.word_model]
                teacher = SimpleModelScorer(self.args, None, teacher_model, self.task)
                _, _, teacher_logprobs = teacher.score_tokens(
                    hypo_encoder_inputs, hypo_tokens
                )
            ## 4.2) Compute word-level loss
            f_logprob, f_index = forward_logprobs.topk(self.args.topk_words)
            word_mask = f_logprob.new_zeros(f_logprob.size())
            word_mask[mask_index, :, :] = 1.0
            ## KL(p_s || p_t) = \sum p_s log p_s - \sum p_s log p_t, aka RL + maxEnt
            word_loss = (
                word_mask
                * f_logprob.exp()
                * (f_logprob - 1.0 * teacher_logprobs.gather(-1, f_index))
            ).sum()

        # 5) Compute Cross-entropy loss
        eos = self.task.target_dictionary.eos()
        target_tokens = torch.cat(
            (
                torch.zeros(bsz, 1).fill_(eos).type_as(sample["target"]),
                sample["target"],
            ),
            dim=1,
        )
        target_encoder_inputs = (
            encoder_input["src_tokens"],
            [encoder_input["src_lengths"][0].item()],
        )
        target_logprobs, target_encoder_out, _ = self.self_rescorer.score_tokens(
            target_encoder_inputs, target_tokens
        )
        nll_loss = -1.0 * target_logprobs.sum()

        # 6) Gather losses
        loss = (
            self.args.rl_weight * seq_loss
            + self.args.word_weight * word_loss
            + nll_loss
        )

        # Logging
        sample_size = (
            sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
        )
        logging_output = {
            "loss": utils.item(loss.data) if reduce else loss.data,
            "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }
        return loss, sample_size, logging_output

    @staticmethod
    def aggregate_logging_outputs(logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
        ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
        nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        return {
            "loss": loss_sum / sample_size if sample_size > 0 else 0.0,
            "nll_loss": nll_loss_sum / ntokens / math.log(2),
            "ntokens": ntokens,
            "nsentences": nsentences,
            "sample_size": sample_size,
        }

    def combine_score(self, src_tokens, hypos, hypos_len, scores):
        """ Rescore translations and combine weights to find top hypo tokens
        """
        # Prepare all the weights and call combine weighted scores
        args = self.args
        weights = [
            args.l2r_model_weight,
            args.r2l_model_weight,
            args.reverse_model_weight,
            args.lm_model_weight,
            args.cloze_transformer_weight,
        ]
        bsz, src_len = src_tokens.size()
        hypos_len = hypos_len.type_as(scores)
        combined_scores = combine_weighted_scores(
            scores, weights, src_len, hypos_len, args.length_penalty
        )
        return combined_scores

    def load_rescore_models(self, args):
        """load rescoring models"""
        models = {}
        if args.l2r_model_path:
            l2r_model, _, l2r_task = pytorch_translate_utils.load_diverse_ensemble_for_inference(
                [args.l2r_model_path]
            )
            models["l2r_model"] = {"model": l2r_model[0], "task": l2r_task}
        #
        if args.r2l_model_path:
            r2l_model, _, r2l_task = pytorch_translate_utils.load_diverse_ensemble_for_inference(
                [args.r2l_model_path]
            )
            models["r2l_model"] = {"model": r2l_model[0], "task": r2l_task}
        #
        if args.reverse_model_path:
            reverse_model, _, reverse_task = pytorch_translate_utils.load_diverse_ensemble_for_inference(
                [args.reverse_model_path]
            )
            models["reverse_model"] = {"model": reverse_model[0], "task": reverse_task}
        #
        if args.lm_model_path:
            lm_model, _, lm_task = pytorch_translate_utils.load_diverse_ensemble_for_inference(
                [args.lm_model_path]
            )
            models["lm_model"] = {"model": lm_model[0], "task": lm_task}
        #
        if args.cloze_transformer_path:
            cloze_model, _, cloze_task = pytorch_translate_utils.load_diverse_ensemble_for_inference(
                [args.cloze_transformer_path]
            )
            models["cloze_model"] = {"model": cloze_model[0], "task": cloze_task}
        return models
Beispiel #15
0
class RescoringCriterion(FairseqCriterion):
    def __init__(self, args, task):
        super().__init__(args, task)
        self.self_rescorer = SimpleModelScorer(args, None, None, task)
        self.rescore_models = self.load_rescore_models(args)
        self.args = args
        self.task = task

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        rescore_add_args(parser)
        parser.add_argument(
            "--rl-weight",
            type=float,
            default=0.1,
            help="trade-off coefficient of rl loss",
        )
        parser.add_argument(
            "--rl-num-trajectory",
            type=int,
            default=3,
            help="num trajectory in rl training",
        )

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        src_tokens = sample["net_input"]["src_tokens"]
        beam_size = self.args.rl_num_trajectory
        bsz, srclen = src_tokens.size()
        encoder_input = {
            "src_tokens": sample["net_input"]["src_tokens"],
            "src_lengths": sample["net_input"]["src_lengths"],
        }

        # 1) Generate hypos
        translator = generate.build_sequence_generator(self.args, self.task,
                                                       [model])
        with torch.no_grad():
            hypos = translator.generate(
                encoder_input,
                beam_size,
                maxlen=int(self.args.max_len_a * srclen + self.args.max_len_b),
            )
        ## flatten nested list
        hypos = [hypo for _ in hypos
                 for hypo in _]  # with length of bsz * beam_size
        hypos_len = (torch.tensor([len(hypo["tokens"]) for hypo in hypos
                                   ]).type_as(src_tokens).float())

        # 2) Compute (log)-probs via forward models
        self.self_rescorer.model = model
        self.self_rescorer.task = self.task
        model.train()
        assert self.self_rescorer.model.training, "model should be in training phase"

        hypo_encoder_inputs, hypo_tokens = self.self_rescorer.prepare_inputs(
            src_tokens, hypos)
        hypo_logprobs, hypo_encoder_outs, _ = self.self_rescorer.score_tokens(
            hypo_encoder_inputs, hypo_tokens)
        hypo_logprobs /= hypos_len**self.args.rescore_length_penalty

        # 3) Compute rewards from rescoring models
        with torch.no_grad():
            rescorer = Rescorer(self.args, self.task, self.rescore_models)
            scores = rescorer.score(src_tokens, hypos)
            rewards = self.combine_score(src_tokens, hypos, hypos_len, scores)
        assert not rewards.requires_grad, "no grads flow back to generation"

        # 4) Compute Policy Gradient loss
        rewards = rewards.type_as(hypo_logprobs)
        rl_loss = -1.0 * (hypo_logprobs * rewards).sum()

        # 5) Compute Cross-entropy loss
        eos = self.task.target_dictionary.eos()
        target_tokens = torch.cat(
            (
                torch.zeros(bsz, 1).fill_(eos).type_as(sample["target"]),
                sample["target"],
            ),
            dim=1,
        )
        target_encoder_inputs = (
            encoder_input["src_tokens"],
            [encoder_input["src_lengths"][0].item()],
        )
        target_logprobs, target_encoder_out, _ = self.self_rescorer.score_tokens(
            target_encoder_inputs, target_tokens)
        nll_loss = -1.0 * target_logprobs.sum()

        # 6) Gather losses
        loss = self.args.rl_weight * rl_loss + nll_loss

        # Logging
        sample_size = (sample["target"].size(0)
                       if self.args.sentence_avg else sample["ntokens"])
        logging_output = {
            "loss": utils.item(loss.data) if reduce else loss.data,
            "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }
        return loss, sample_size, logging_output

    @staticmethod
    def aggregate_logging_outputs(logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
        ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
        nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
        return {
            "loss": loss_sum / sample_size if sample_size > 0 else 0.0,
            "nll_loss": nll_loss_sum / ntokens / math.log(2),
            "ntokens": ntokens,
            "nsentences": nsentences,
            "sample_size": sample_size,
        }

    def combine_score(self, src_tokens, hypos, hypos_len, scores):
        """ Rescore translations and combine weights to find top hypo tokens
        """
        # Prepare all the weights and call combine weighted scores
        args = self.args
        weights = [
            args.l2r_model_weight,
            args.r2l_model_weight,
            args.reverse_model_weight,
            args.lm_model_weight,
            args.cloze_transformer_weight,
        ]
        bsz, src_len = src_tokens.size()
        hypos_len = hypos_len.type_as(scores)
        combined_scores = combine_weighted_scores(scores, weights, src_len,
                                                  hypos_len,
                                                  args.length_penalty)
        return combined_scores

    def load_rescore_models(self, args):
        """load rescoring models"""
        models = {}
        if args.l2r_model_path:
            l2r_model, _, l2r_task = pytorch_translate_utils.load_diverse_ensemble_for_inference(
                [args.l2r_model_path])
            models["l2r_model"] = {"model": l2r_model[0], "task": l2r_task}
        #
        if args.r2l_model_path:
            r2l_model, _, r2l_task = pytorch_translate_utils.load_diverse_ensemble_for_inference(
                [args.r2l_model_path])
            models["r2l_model"] = {"model": r2l_model[0], "task": r2l_task}
        #
        if args.reverse_model_path:
            reverse_model, _, reverse_task = pytorch_translate_utils.load_diverse_ensemble_for_inference(
                [args.reverse_model_path])
            models["reverse_model"] = {
                "model": reverse_model[0],
                "task": reverse_task
            }
        #
        if args.lm_model_path:
            lm_model, _, lm_task = pytorch_translate_utils.load_diverse_ensemble_for_inference(
                [args.lm_model_path])
            models["lm_model"] = {"model": lm_model[0], "task": lm_task}
        #
        if args.cloze_transformer_path:
            cloze_model, _, cloze_task = pytorch_translate_utils.load_diverse_ensemble_for_inference(
                [args.cloze_transformer_path])
            models["cloze_model"] = {
                "model": cloze_model[0],
                "task": cloze_task
            }
        return models