Beispiel #1
0
 def test_diverse_beam_search(self):
     generator = SequenceGenerator(
         self.tgt_dict,
         beam_size=2,
         diverse_beam_groups=2,
         diverse_beam_strength=0.,
     )
     sample = {
         'net_input': {
             'src_tokens': self.src_tokens,
             'src_lengths': self.src_lengths
         }
     }
     hypos = generator.generate([self.model], sample)
     eos, w1, w2 = self.eos, self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
     self.assertHypoScore(hypos[0][0], [0.9, 0.6, 1.0])
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w1, w1, eos])
     self.assertHypoScore(hypos[0][1], [0.9, 0.6, 1.0])
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
     self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.9])
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.9])
 def test_diverse_beam_search(self):
     search_strategy = search.DiverseSiblingsSearch(self.tgt_dict,
                                                    diversity_rate=0.5)
     generator = SequenceGenerator(self.tgt_dict,
                                   beam_size=2,
                                   search_strategy=search_strategy)
     sample = {
         "net_input": {
             "src_tokens": self.src_tokens,
             "src_lengths": self.src_lengths,
         }
     }
     hypos = generator.generate([self.model], sample)
     eos, w1, w2 = self.eos, self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
     self.assertHypoScore(hypos[0][0], [0.9, 0.6, 1.0], [0, 1, 1], 0.5)
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w1, w2, eos])
     self.assertHypoScore(hypos[0][1], [0.9, 0.4, 1.0], [0, 2, 1], 0.5)
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
     self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.9], [0, 1, 1], 0.5)
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w1, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.35, 0.9], [0, 2, 1], 0.5)
Beispiel #3
0
 def test_topp_sampling_search_low_prob(self):
     # Given a prob low enough to top-P sampling, we expect only the top
     # 1 token to be sampled, which always results in the same output.
     low_sampling_topp = self.min_top1_prob / 2.0
     generator = SequenceGenerator(self.tgt_dict,
                                   beam_size=2,
                                   sampling=True,
                                   sampling_topp=low_sampling_topp)
     sample = {
         'net_input': {
             'src_tokens': self.src_tokens,
             'src_lengths': self.src_lengths
         }
     }
     hypos = generator.generate([self.model], sample)
     eos, w1 = self.eos, self.w1
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
     self.assertHypoScore(hypos[0][0], [1.0, 0.4, 1.0])
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w1, w1, eos])
     self.assertHypoScore(hypos[0][1], [1.0, 0.4, 1.0])
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w1, w1, eos])
     self.assertHypoScore(hypos[1][0], [1.0, 0.4, 1.0])
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w1, eos])
     self.assertHypoScore(hypos[1][1], [1.0, 0.4, 1.0])
Beispiel #4
0
    def generate(self, src_sents):
        self._model.split_to_gpus(n_gpus=1)
        self._model.eval()

        src_text = src_sents[0]
        generator = SequenceGenerator(tgt_dict=self._model.dictionary,
                                      max_len_b=200,
                                      min_len=50,
                                      beam_size=10,
                                      len_penalty=2.,
                                      no_repeat_ngram_size=3)

        src_tokens = self._model.encode(src_text)
        if src_tokens.shape[0] > SRC_MAX_LEN:
            src_tokens = src_tokens[-SRC_MAX_LEN:]

        outputs = generator.generate(
            models=[self._model.model],
            sample={
                'net_input': {
                    'src_tokens': src_tokens.unsqueeze(0).to('cuda'),
                    'src_lengths': torch.tensor([len(src_tokens)]).to('cuda')
                }
            },
            bos_token=self._model.dictionary.bos())

        return [self._model.decode(outputs[0][0]['tokens'].cpu())]
Beispiel #5
0
def main():
    parser = options.get_parser('Generation')
    parser.add_argument('--path', metavar='FILE', required=True, action='append',
                        help='path(s) to model file(s)')
    options.add_dataset_args(parser)
    options.add_generation_args(parser)

    args = parser.parse_args()
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load ensemble
    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, model_args = utils.load_ensemble_for_inference(args.path, data_dir=args.data)
    src_dict, dst_dict = models[0].src_dict, models[0].dst_dict

    print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict)))
    print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict)))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)

    # Initialize generator
    translator = SequenceGenerator(
        models, beam_size=args.beam, stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
        unk_penalty=args.unkpen)
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    print('| Type the input sentence and press return:')
    for src_str in sys.stdin:
        src_str = src_str.strip()
        src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
        if use_cuda:
            src_tokens = src_tokens.cuda()
        translations = translator.generate(Variable(src_tokens.view(1, -1)))
        hypos = translations[0]
        print('O\t{}'.format(src_str))

        # Process top predictions
        for hypo in hypos[:min(len(hypos), args.nbest)]:
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu(),
                align_dict=align_dict,
                dst_dict=dst_dict,
                remove_bpe=args.remove_bpe)
            print('H\t{}\t{}'.format(hypo['score'], hypo_str))
            print('A\t{}'.format(' '.join(map(str, alignment))))
Beispiel #6
0
def main(args):
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load ensemble
    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, model_args = utils.load_ensemble_for_inference(args.path, data_dir=args.data)
    src_dict, dst_dict = models[0].src_dict, models[0].dst_dict

    print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict)))
    print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict)))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
        )

    # Initialize generator
    translator = SequenceGenerator(
        models, beam_size=args.beam, stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
        unk_penalty=args.unkpen)
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    print('| Type the input sentence and press return:')
    for src_str in sys.stdin:
        src_str = src_str.strip()
        src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
        if use_cuda:
            src_tokens = src_tokens.cuda()
        src_lengths = src_tokens.new([src_tokens.numel()])
        translations = translator.generate(
            Variable(src_tokens.view(1, -1)),
            Variable(src_lengths.view(-1)),
        )
        hypos = translations[0]
        print('O\t{}'.format(src_str))

        # Process top predictions
        for hypo in hypos[:min(len(hypos), args.nbest)]:
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu(),
                align_dict=align_dict,
                dst_dict=dst_dict,
                remove_bpe=args.remove_bpe,
            )
            print('H\t{}\t{}'.format(hypo['score'], hypo_str))
            print('A\t{}'.format(' '.join(map(str, alignment))))
Beispiel #7
0
 def test_encoder_with_different_output_len(self):
     generator = SequenceGenerator(self.tgt_dict, beam_size=2, max_len_b=2)
     args = self.model.encoder.args
     task = test_utils.TestTranslationTask.setup_task(
         args, self.tgt_dict, self.tgt_dict)
     reshaping_model = test_utils.TestReshapingModel.build_model(args, task)
     hypos = generator.generate([reshaping_model], self.sample)
     for sent in [0, 1]:
         for beam in [0, 1]:
             assert hypos[sent][beam]['attention'] is not None
Beispiel #8
0
    def test_topp_sampling_search_high_prob(self):
        # Given a prob high enough to top-P sampling, any of the top 2
        # tokens could be sampled. This can cause different outputs.
        high_sampling_topp = (self.min_top1_prob + self.min_top2_prob) / 2.0
        generator = SequenceGenerator(
            self.tgt_dict, beam_size=2, sampling=True, sampling_topp=high_sampling_topp
        )
        sample = {
            "net_input": {
                "src_tokens": self.src_tokens,
                "src_lengths": self.src_lengths,
            }
        }
        hypos = generator.generate([self.model], sample)
        eos, w1, w2 = self.eos, self.w1, self.w2
        # sentence 1, beam 1
        self.assertTrue(
            self.hypoTokens(hypos[0][0], [w1, w1, eos])
            or self.hypoTokens(hypos[0][0], [w1, w2, eos])
        )
        self.assertTrue(
            self.hypoScore(hypos[0][0], [1.0, 0.4, 1.0])
            or self.hypoScore(hypos[0][0], [1.0, 0.35, 1.0])
        )

        # sentence 1, beam 2
        self.assertTrue(
            self.hypoTokens(hypos[0][1], [w1, w1, eos])
            or self.hypoTokens(hypos[0][1], [w1, w2, eos])
        )
        self.assertTrue(
            self.hypoScore(hypos[0][1], [1.0, 0.4, 1.0])
            or self.hypoScore(hypos[0][1], [1.0, 0.35, 1.0])
        )

        # sentence 2, beam 1
        self.assertTrue(
            self.hypoTokens(hypos[1][0], [w1, w1, eos])
            or self.hypoTokens(hypos[1][0], [w1, w2, eos])
        )
        self.assertTrue(
            self.hypoScore(hypos[1][0], [1.0, 0.4, 1.0])
            or self.hypoScore(hypos[1][0], [1.0, 0.35, 1.0])
        )

        # sentence 2, beam 2
        self.assertTrue(
            self.hypoTokens(hypos[1][1], [w1, w1, eos])
            or self.hypoTokens(hypos[1][1], [w1, w2, eos])
        )
        self.assertTrue(
            self.hypoScore(hypos[1][1], [1.0, 0.4, 1.0])
            or self.hypoScore(hypos[1][1], [1.0, 0.35, 1.0])
        )
 def test_no_stop_early(self):
     generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False)
     hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
     eos, w1, w2 = self.eos, self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, eos])
     self.assertHypoScore(hypos[0][0], [0.9, 1.0])
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
     self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0])
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w2, w2, w2, w2, eos])
     self.assertHypoScore(hypos[1][0], [0.3, 0.9, 0.99, 0.4, 1.0])
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0])
 def test_maxlen(self):
     generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2)
     hypos = generator.generate(self.encoder_input, beam_size=2)
     eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, eos])
     self.assertHypoScore(hypos[0][0], [0.9, 1.0])
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w2, w2, eos])
     self.assertHypoScore(hypos[0][1], [0.1, 0.1, 0.6])
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
     self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6])
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w2, w2, eos])
     self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01])
 def test_with_normalization(self):
     generator = SequenceGenerator([self.model], self.tgt_dict)
     hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
     eos, w1, w2 = self.eos, self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, eos])
     self.assertHypoScore(hypos[0][0], [0.9, 1.0])
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
     self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0])
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w1, w2, w1, eos])
     self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.4, 1.0])
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6])
Beispiel #12
0
 def test_with_normalization(self):
     generator = SequenceGenerator(self.tgt_dict, beam_size=2)
     hypos = generator.generate([self.model], self.sample)
     eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, eos])
     self.assertHypoScore(hypos[0][0], [0.9, 1.0])
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
     self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0])
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w1, w2, w1, eos])
     self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.4, 1.0])
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6])
Beispiel #13
0
 def test_with_normalization(self):
     generator = SequenceGenerator([self.model])
     hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
     eos, w1, w2 = self.eos, self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, eos])
     self.assertHypoScore(hypos[0][0], [0.9, 1.0])
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
     self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0])
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w1, w2, w1, eos])
     self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.4, 1.0])
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6])
Beispiel #14
0
 def test_no_stop_early(self):
     generator = SequenceGenerator([self.model], stop_early=False)
     hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
     eos, w1, w2 = self.eos, self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, eos])
     self.assertHypoScore(hypos[0][0], [0.9, 1.0])
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
     self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0])
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w2, w2, w2, w2, eos])
     self.assertHypoScore(hypos[1][0], [0.3, 0.9, 0.99, 0.4, 1.0])
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0])
Beispiel #15
0
 def test_with_lenpen_favoring_long_hypos(self):
     lenpen = 5.0
     generator = SequenceGenerator([self.model], len_penalty=lenpen)
     hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
     eos, w1, w2 = self.eos, self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
     self.assertHypoScore(hypos[0][0], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen)
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w1, eos])
     self.assertHypoScore(hypos[0][1], [0.9, 1.0], lenpen=lenpen)
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w1, w2, w1, eos])
     self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.4, 1.0], lenpen=lenpen)
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
Beispiel #16
0
 def test_with_lenpen_favoring_short_hypos(self):
     lenpen = 0.6
     generator = SequenceGenerator(self.tgt_dict, beam_size=2, len_penalty=lenpen)
     hypos = generator.generate([self.model], self.sample)
     eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, eos])
     self.assertHypoScore(hypos[0][0], [0.9, 1.0], lenpen=lenpen)
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
     self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen)
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
     self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6], lenpen=lenpen)
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0], lenpen=lenpen)
 def test_with_lenpen_favoring_long_hypos(self):
     lenpen = 5.0
     generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
     hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
     eos, w1, w2 = self.eos, self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
     self.assertHypoScore(hypos[0][0], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen)
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w1, eos])
     self.assertHypoScore(hypos[0][1], [0.9, 1.0], lenpen=lenpen)
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w1, w2, w1, eos])
     self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.4, 1.0], lenpen=lenpen)
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
Beispiel #18
0
 def _generate(self, opt, src_tokens):
     translator = SequenceGenerator([self.trainer.get_model()],
                                    self.fairseq_dict,
                                    beam_size=opt.beam,
                                    stop_early=(not opt.no_early_stop),
                                    normalize_scores=(not opt.unnormalized),
                                    len_penalty=opt.lenpen)
     translator.cuda()
     tokens = src_tokens
     translations = translator.generate(
         Variable(tokens), Variable(self._positions_for_tokens(tokens)))
     results = [t[0] for t in translations]
     output_lines = [[] for _ in range(len(results))]
     for i in range(len(results)):
         output_lines[i] = ' '.join(self.fairseq_dict[idx]
                                    for idx in results[i]['tokens'][:-1])
     return output_lines
 def test_without_normalization(self):
     # Sentence 1: unchanged from the normalized case
     # Sentence 2: beams swap order
     generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False)
     hypos = generator.generate(self.encoder_input, beam_size=2)
     eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, eos])
     self.assertHypoScore(hypos[0][0], [0.9, 1.0], normalized=False)
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
     self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0], normalized=False)
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
     self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6], normalized=False)
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0], normalized=False)
Beispiel #20
0
 def _generate(self, opt, src_tokens):
     translator = SequenceGenerator(
         [self.trainer.get_model()],
         self.fairseq_dict,
         beam_size=opt.beam,
         stop_early=(not opt.no_early_stop),
         normalize_scores=(not opt.unnormalized),
         len_penalty=opt.lenpen)
     translator.cuda()
     tokens = src_tokens
     translations = translator.generate(
         Variable(tokens), Variable(self._positions_for_tokens(tokens)))
     results = [t[0] for t in translations]
     output_lines = [[] for _ in range(len(results))]
     for i in range(len(results)):
         output_lines[i] = ' '.join(self.fairseq_dict[idx]
                                    for idx in results[i]['tokens'][:-1])
     return output_lines
 def test_without_normalization(self):
     # Sentence 1: unchanged from the normalized case
     # Sentence 2: beams swap order
     generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False)
     hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
     eos, w1, w2 = self.eos, self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, eos])
     self.assertHypoScore(hypos[0][0], [0.9, 1.0], normalized=False)
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
     self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0], normalized=False)
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
     self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.6], normalized=False)
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0], normalized=False)
Beispiel #22
0
 def test_no_stop_early(self):
     generator = SequenceGenerator(self.tgt_dict,
                                   stop_early=False,
                                   beam_size=2)
     hypos = generator.generate([self.model], self.sample)
     eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, eos])
     self.assertHypoScore(hypos[0][0], [0.9, 1.0])
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w2, w1, w2, eos])
     self.assertHypoScore(hypos[0][1], [0.1, 0.9, 0.9, 1.0])
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w2, w2, w2, w2, eos])
     self.assertHypoScore(hypos[1][0], [0.3, 0.9, 0.99, 0.4, 1.0])
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0])
 def test_diverse_beam_search(self):
     generator = SequenceGenerator(
         [self.model], self.tgt_dict,
         beam_size=2, diverse_beam_groups=2, diverse_beam_strength=0.,
     )
     hypos = generator.generate(self.src_tokens, self.src_lengths)
     eos, w1, w2 = self.eos, self.w1, self.w2
     # sentence 1, beam 1
     self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
     self.assertHypoScore(hypos[0][0], [0.9, 0.6, 1.0])
     # sentence 1, beam 2
     self.assertHypoTokens(hypos[0][1], [w1, w1, eos])
     self.assertHypoScore(hypos[0][1], [0.9, 0.6, 1.0])
     # sentence 2, beam 1
     self.assertHypoTokens(hypos[1][0], [w1, w2, eos])
     self.assertHypoScore(hypos[1][0], [0.7, 0.4, 0.9])
     # sentence 2, beam 2
     self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
     self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.9])
Beispiel #24
0
    def generate(self, cond, top_k, top_p):
        self._model.split_to_gpus(1)
        self._model.eval()

        generator = SequenceGenerator(tgt_dict=self._model.dictionary,
                                      max_len_b=BART_MAX_LEN,
                                      sampling=True,
                                      sampling_topk=top_k,
                                      sampling_topp=top_p)

        src_tokens = self._model.encode(cond)[:BART_MAX_LEN]

        outputs = generator.generate(
            models=[self._model.model],
            sample={
                'net_input': {
                    'src_tokens': src_tokens.unsqueeze(0).to('cuda'),
                    'src_lengths': torch.tensor([len(src_tokens)]).to('cuda')
                }
            })

        return self._model.decode(outputs[0][0]['tokens'].cpu())
def debug_generate(model, loader, vocab, visdom):
    from fairseq.sequence_generator import SequenceGenerator
    closure = lambda s: visdom.log("gen-output", "text-replace", s)
    seq_gen = SequenceGenerator([model], vocab, beam_size=5)
    #pbar = tqdm_progress_bar(loader, epoch=epoch)
    for src, src_lens, _, tgt, tgt_lens, _ in loader:
        src = src.to(device)
        encoder_input = {"src_tokens": src, "src_lengths": src_lens}
        samples = seq_gen.generate(encoder_input, maxlen=20)
        all_lines = []
        for i, sample in enumerate(samples):
            src_str = vocab.string(src[i, :])
            tgt_str = vocab.string(tgt[i, :])
            pred_str = vocab.string(sample[0]['tokens'])
            lines = [
                "> {}".format(src_str), "< {}".format(pred_str),
                "= {}".format(tgt_str), ""
            ]
            all_lines.extend(lines)
    model.train()

    txt_dump = '<br>'.join(all_lines[:100])
    #print('\n'.join(all_lines))
    closure(txt_dump)
Beispiel #26
0
    def _backtranslation_dataset_helper(
        self,
        remove_eos_from_input_src,
        remove_eos_from_output_src,
    ):
        tgt_dataset = LanguagePairDataset(
            src=self.tgt_dataset,
            src_sizes=self.tgt_dataset.sizes,
            src_dict=self.tgt_dict,
            tgt=None,
            tgt_sizes=None,
            tgt_dict=None,
        )

        generator = SequenceGenerator(
            [self.model],
            tgt_dict=self.tgt_dict,
            max_len_a=0,
            max_len_b=200,
            beam_size=2,
            unk_penalty=0,
        )

        backtranslation_dataset = BacktranslationDataset(
            tgt_dataset=TransformEosDataset(
                dataset=tgt_dataset,
                eos=self.tgt_dict.eos(),
                # remove eos from the input src
                remove_eos_from_src=remove_eos_from_input_src,
            ),
            src_dict=self.tgt_dict,
            backtranslation_fn=(
                lambda sample: generator.generate([self.model], sample)),
            output_collater=TransformEosDataset(
                dataset=tgt_dataset,
                eos=self.tgt_dict.eos(),
                # if we remove eos from the input src, then we need to add it
                # back to the output tgt
                append_eos_to_tgt=remove_eos_from_input_src,
                remove_eos_from_src=remove_eos_from_output_src,
            ).collater,
            cuda=self.cuda,
        )
        dataloader = torch.utils.data.DataLoader(
            backtranslation_dataset,
            batch_size=2,
            collate_fn=backtranslation_dataset.collater,
        )
        backtranslation_batch_result = next(iter(dataloader))

        eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(
        ), self.w1, self.w2

        # Note that we sort by src_lengths and add left padding, so actually
        # ids will look like: [1, 0]
        expected_src = torch.LongTensor([[w1, w2, w1, eos],
                                         [pad, pad, w1, eos]])
        if remove_eos_from_output_src:
            expected_src = expected_src[:, :-1]
        expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
        generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
        tgt_tokens = backtranslation_batch_result["target"]

        self.assertTensorEqual(expected_src, generated_src)
        self.assertTensorEqual(expected_tgt, tgt_tokens)
Beispiel #27
0
class FairseqAgent(Agent):
    """Agent which takes an input sequence and produces an output sequence.

    For more information, see Convolutional Sequence to Sequence Learning
     `(Gehring et al. 2017) <https://arxiv.org/abs/1705.03122>`_.
    """

    @staticmethod
    def add_cmdline_args(argparser):
        """Add command-line arguments specifically for this agent."""
        DictionaryAgent.add_cmdline_args(argparser)
        agent = argparser.add_argument_group('Fairseq Arguments')
        agent.add_argument(
            '-tr', '--truncate',
            type=int, default=-1,
            help='truncate input & output lengths to speed up training (may '
                 'reduce accuracy). This fixes all input and output to have a '
                 'maximum length. This reduces the total amount of padding in '
                 'the batches.')
        agent.add_argument(
            '--max-positions',
            default=1024,
            type=int,
            metavar='N',
            help='max number of tokens in the sequence')
        agent.add_argument(
            '--seed',
            default=1,
            type=int,
            metavar='N',
            help='pseudo random number generator seed')
        options.add_optimization_args(argparser)
        options.add_generation_args(argparser)
        options.add_model_args(argparser)

    def __init__(self, opt, shared=None):
        # initialize defaults first
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full
            # initialization. if shared is set, only set up shared members.
            saved_state = None
            if opt.get('model_file') and os.path.isfile(opt['model_file']):
                # load model parameters if available
                print('Loading existing model params from ' +
                      opt['model_file'])
                new_opt, saved_state = self.load(opt['model_file'])
                # override options with stored ones
                opt = self._override_opt(new_opt)

            self.args = OptWrapper(opt)
            self.fairseq_dict = _make_fairseq_dict(DictionaryAgent(opt))
            self.id = 'Fairseq'
            self.truncate = opt['truncate'] if opt['truncate'] > 0 else None

            self.EOS = self.fairseq_dict[self.fairseq_dict.eos()]
            self.EOS_TENSOR = (torch.LongTensor(1, 1)
                               .fill_(self.fairseq_dict.eos()))
            self.NULL_IDX = self.fairseq_dict.pad()

            encoder = fconv.Encoder(
                self.fairseq_dict,
                embed_dim=self.args.encoder_embed_dim,
                convolutions=eval(self.args.encoder_layers),
                dropout=self.args.dropout,
                max_positions=self.args.max_positions)
            decoder = fconv.Decoder(
                self.fairseq_dict,
                embed_dim=self.args.decoder_embed_dim,
                convolutions=eval(self.args.decoder_layers),
                out_embed_dim=self.args.decoder_out_embed_dim,
                attention=eval(self.args.decoder_attention),
                dropout=self.args.dropout,
                max_positions=self.args.max_positions)
            self.model = fconv.FConvModel(encoder, decoder)

            # from fairseq's build_criterion()
            if self.args.label_smoothing > 0:
                self.criterion = criterions.LabelSmoothedCrossEntropyCriterion(
                    self.args.label_smoothing, self.NULL_IDX)
            else:
                self.criterion = criterions.CrossEntropyCriterion(
                    self.NULL_IDX)

            self.trainer = MultiprocessingTrainer(self.args, self.model, self.criterion)
            if saved_state is not None:
                self.set_states(saved_state)
        self.reset()

    def _override_opt(self, new_opt):
        """Set overridable opts from loaded opt file.

        Print out each added key and each overriden key.
        Only override args specific to the model.
        """
        model_args = {
            'arch',
            'encoder-embed-dim',
            'encoder-layers',
            'decoder-embed-dim',
            'decoder-layers',
            'decoder-out-embed-dim',
            'decoder-attention',
        }

        for k, v in new_opt.items():
            if k not in model_args:
                # skip non-model args
                continue
            if k not in self.opt:
                print('Adding new option [ {k}: {v} ]'.format(k=k, v=v))
            elif self.opt[k] != v:
                print('Overriding option [ {k}: {old} => {v}]'.format(
                    k=k, old=self.opt[k], v=v))
            self.opt[k] = v
        return self.opt

    def reset(self):
        """Reset observation and episode_done."""
        self.observation = None
        self.episode_done = True

    def observe(self, observation):
        # shallow copy observation (deep copy can be expensive)
        observation = observation.copy()
        if not self.episode_done:
            # if the last example wasn't the end of an episode, then we need to
            # recall what was said in that example
            prev_dialogue = self.observation['text']
            observation['text'] = prev_dialogue + '\n' + observation['text']
        self.observation = observation
        self.episode_done = observation['episode_done']
        return observation

    def act(self):
        # call batch_act with this batch of one
        return self.batch_act([self.observation])[0]

    def batch_act(self, observations):
        bsz = len(observations)
        # initialize a table of replies with this agent's id
        batch_reply = [{'id': self.getID()} for _ in range(bsz)]

        # convert the observations into batches of inputs and targets
        # valid_inds tells us the indices of all valid examples
        # e.g. for input [{}, {'text': 'hello'}, {}, {}], valid_inds is [1]
        # since the other three elements had no 'text' field

        # also, split observations into sub-batches based on number of gpus
        obs_split = np.array_split(observations, self.trainer.num_replicas)
        samples = [self.batchify(obs) for obs in obs_split]
        samples = [s for s in samples if s[0] is not None]
        any_valid = any(len(s[0]) > 0 for s in samples)

        if not any_valid:
            # no valid examples, just return the empty responses we set up
            return batch_reply

        # produce predictions if testing; otherwise, train
        has_targets = any(s[1] is not None for s in samples)
        if not has_targets:
            offset = 0
            for s in samples:
                xs = s[0]
                valid_inds = s[2]

                predictions = self._generate(self.args, xs)
                for i in range(len(predictions)):
                    # map the predictions back to non-empty examples in the batch
                    batch_reply[valid_inds[i] + offset]['text'] = predictions[i]
                    if i == 0:
                        print('prediction:', predictions[i])
                offset += len(valid_inds)
        else:
            loss = self._train(samples)

            batch_reply[0]['metrics'] = {}
            for k, v in loss.items():
                batch_reply[0]['metrics'][k] = v * bsz
                if k == 'loss':
                    batch_reply[0]['metrics']['perplexity'] = 2 ** v * bsz

        return batch_reply

    def parse(self, string):
        return [self.fairseq_dict.index(word) for word in string.split(' ')]

    def batchify(self, observations):
        """Convert a list of observations into input & target tensors."""
        # valid examples
        exs = [ex for ex in observations if 'text' in ex]
        # the indices of the valid (non-empty) tensors
        valid_inds = [i for i, ex in enumerate(observations) if 'text' in ex]

        # set up the input tensors
        batchsize = len(exs)
        if batchsize == 0:
            return None, None, None
        # tokenize the text
        parsed_x = [deque(maxlen=self.truncate) for _ in exs]
        for dq, ex in zip(parsed_x, exs):
            dq += self.parse(ex['text'])
        # parsed = [self.parse(ex['text']) for ex in exs]
        max_x_len = max((len(x) for x in parsed_x))
        for x in parsed_x:
            # left pad with zeros
            x.extendleft([self.fairseq_dict.pad()] * (max_x_len - len(x)))
        xs = torch.LongTensor(parsed_x)

        # set up the target tensors
        ys = None
        if 'labels' in exs[0]:
            # randomly select one of the labels to update on, if multiple
            labels = [random.choice(ex.get('labels', [''])) for ex in exs]
            parsed_y = [deque(maxlen=self.truncate) for _ in labels]
            for dq, y in zip(parsed_y, labels):
                dq.extendleft(reversed(self.parse(y)))
            for y in parsed_y:
                y.append(self.fairseq_dict.eos())
            # append EOS to each label
            max_y_len = max(len(y) for y in parsed_y)
            for y in parsed_y:
                y += [self.fairseq_dict.pad()] * (max_y_len - len(y))
            ys = torch.LongTensor(parsed_y)
        return xs, ys, valid_inds

    def _positions_for_tokens(self, tokens):
        size = tokens.size()
        not_pad = tokens.ne(self.fairseq_dict.pad()).long()
        new_pos = tokens.new(size).fill_(self.fairseq_dict.pad())
        new_pos += not_pad
        for i in range(1, size[1]):
            new_pos[:, i] += new_pos[:, i-1] - 1
        return new_pos

    def _right_shifted_ys(self, ys):
        result = torch.LongTensor(ys.size())
        result[:, 0] = self.fairseq_dict.index(self.EOS)
        result[:, 1:] = ys[:, :-1]
        return result

    def _generate(self, opt, src_tokens):
        if not hasattr(self, 'translator'):
            self.translator = SequenceGenerator(
                [self.trainer.get_model()],
                beam_size=opt.beam,
                stop_early=(not opt.no_early_stop),
                normalize_scores=(not opt.unnormalized),
                len_penalty=opt.lenpen)
            self.translator.cuda()
        tokens = src_tokens.cuda(async=True)
        token_pos = Variable(self._positions_for_tokens(tokens).cuda())
        translations = self.translator.generate(Variable(tokens), token_pos)
        results = [t[0] for t in translations]
        output_lines = [[] for _ in range(len(results))]
        for i in range(len(results)):
            output_lines[i] = ' '.join(self.fairseq_dict[idx]
                                       for idx in results[i]['tokens'][:-1])
        return output_lines

    def _train(self, samples):
        """Update the model using the targets."""
        for i, sample in enumerate(samples):
            # add extra info to samples
            sample = {
                'src_tokens': sample[0],
                'input_tokens': self._right_shifted_ys(sample[1]),
                'target': sample[1],
                'id': None
            }
            sample['ntokens'] = sum(len(t) for t in sample['target'])
            sample['src_positions'] = self._positions_for_tokens(
                sample['src_tokens'])
            sample['input_positions'] = self._positions_for_tokens(
                sample['input_tokens'])
            samples[i] = sample
        return self.trainer.train_step(samples)

    def save(self, path=None):
        path = self.opt.get('model_file', None) if path is None else path
        if path and hasattr(self, 'trainer'):
            model = {}
            model['state_dict'] = self.trainer.get_model().state_dict()
            model['opt'] = self.opt
            with open(path, 'wb') as write:
                torch.save(model, write)

    def shutdown(self):
        """Save the state of the model when shutdown."""
        path = self.opt.get('model_file', None)
        if path is not None:
            self.save(path + '.shutdown_state')
        super().shutdown()

    def load(self, path):
        """Return opt and model states."""
        with open(path, 'rb') as read:
            model = torch.load(read)
        return model['opt'], model['state_dict']

    def set_states(self, state_dict):
        """Set the state dict of the model from saved states."""
        self.trainer.get_model().load_state_dict(state_dict)
Beispiel #28
0
class TranslationStructuredPredictionTask(translation.TranslationTask):
    """
    Translate from one (source) language to another (target) language.

    Compared to :class:`TranslationTask`, this version performs
    generation during training and computes sequence-level losses.

    Args:
        src_dict (Dictionary): dictionary for the source language
        tgt_dict (Dictionary): dictionary for the target language

    .. note::

        The translation task is compatible with :mod:`train.py <train>`,
        :mod:`generate.py <generate>` and :mod:`interactive.py <interactive>`.

    The translation task provides the following additional command-line
    arguments:

    .. argparse::
        :ref: fairseq.tasks.translation_parser
        :prog:
    """
    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        translation.TranslationTask.add_args(parser)
        parser.add_argument('--seq-beam',
                            default=5,
                            type=int,
                            metavar='N',
                            help='beam size for sequence training')
        parser.add_argument('--seq-keep-reference',
                            default=False,
                            action='store_true',
                            help='retain the reference in the list of hypos')
        parser.add_argument(
            '--seq-scorer',
            default='bleu',
            metavar='SCORER',
            choices=['bleu', 'simile', 'mixed', 'cl-simile'],
            help='optimization metric for sequence level training')

        parser.add_argument('--seq-gen-with-dropout',
                            default=False,
                            action='store_true',
                            help='use dropout to generate hypos')
        parser.add_argument(
            '--seq-max-len-a',
            default=0,
            type=float,
            metavar='N',
            help='generate sequences of maximum length ax + b, '
            'where x is the source length')
        parser.add_argument(
            '--seq-max-len-b',
            default=200,
            type=int,
            metavar='N',
            help='generate sequences of maximum length ax + b, '
            'where x is the source length')
        parser.add_argument('--seq-remove-bpe',
                            nargs='?',
                            const='@@ ',
                            default=None,
                            help='remove BPE tokens before scoring')
        parser.add_argument('--seq-sampling',
                            default=False,
                            action='store_true',
                            help='use sampling instead of beam search')
        parser.add_argument(
            '--seq-unkpen',
            default=0,
            type=float,
            help='unknown word penalty to be used in seq generation')
        parser.add_argument(
            '--simile-lenpen',
            default=0.25,
            type=float,
            help='unknown word penalty to be used in seq generation')
        parser.add_argument(
            '--mixed-ratio',
            default=0.5,
            type=float,
            help='unknown word penalty to be used in seq generation')
        parser.add_argument(
            '--cl-ratio',
            default=0.0,
            type=float,
            help='unknown word penalty to be used in seq generation')
        parser.add_argument(
            '--cl-file',
            default="all",
            choices=["all", "wmt"],
            help='unknown word penalty to be used in seq generation')

    def __init__(self, args, src_dict, tgt_dict):
        super().__init__(args, src_dict, tgt_dict)
        self.args = args
        self._generator = None
        self._scorers = {}

    @classmethod
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        return super(TranslationStructuredPredictionTask,
                     cls).setup_task(args, **kwargs)

    def build_criterion(self, args):
        """
        Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
        this task.

        Args:
            args (argparse.Namespace): parsed command-line arguments

        Returns:
            a :class:`~fairseq.criterions.FairseqCriterion` instance
        """
        from fairseq import criterions
        criterion = criterions.build_criterion(args, self)
        assert isinstance(criterion, criterions.FairseqSequenceCriterion)
        return criterion

    def train_step(self,
                   sample,
                   model,
                   criterion,
                   optimizer,
                   ignore_grad=False):
        """
        Do forward and backward, and return the loss as computed by *criterion*
        for the given *model* and *sample*.

        Args:
            sample (dict): the mini-batch. The format is defined by the
                :class:`~fairseq.data.FairseqDataset`.
            model (~fairseq.models.BaseFairseqModel): the model
            criterion (~fairseq.criterions.FairseqCriterion): the criterion
            optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
            ignore_grad (bool): multiply loss by 0 if this is set to True

        Returns:
            tuple:
                - the loss
                - the sample size, which is used as the denominator for the
                  gradient
                - logging outputs to display while training
        """
        # control dropout during generation
        model.train(self.args.seq_gen_with_dropout)

        # generate hypotheses
        self._generate_hypotheses(model, sample)

        return super().train_step(
            sample=sample,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            ignore_grad=ignore_grad,
        )

    def valid_step(self, sample, model, criterion):
        model.eval()
        self._generate_hypotheses(model, sample)
        return super().valid_step(sample=sample,
                                  model=model,
                                  criterion=criterion)

    def _generate_hypotheses(self, model, sample):
        # initialize generator
        if self._generator is None:
            self._generator = SequenceGenerator(
                self.target_dictionary,
                beam_size=self.args.seq_beam,
                max_len_a=self.args.seq_max_len_a,
                max_len_b=self.args.seq_max_len_b,
                unk_penalty=self.args.seq_unkpen,
                sampling=self.args.seq_sampling,
            )

        # generate hypotheses
        sample['hypos'] = self._generator.generate(
            [model],
            sample,
        )

        # add reference to the set of hypotheses
        if self.args.seq_keep_reference:
            self.add_reference_to_hypotheses(sample)

    def add_reference_to_hypotheses_(self, sample):
        """
        Add the reference translation to the set of hypotheses. This can be
        called from the criterion's forward.
        """
        if 'includes_reference' in sample:
            return
        sample['includes_reference'] = True
        target = sample['target']
        pad_idx = self.target_dictionary.pad()
        for i, hypos_i in enumerate(sample['hypos']):
            # insert reference as first hypothesis
            ref = utils.strip_pad(target[i, :], pad_idx)
            hypos_i.insert(0, {
                'tokens': ref,
                'score': None,
            })

    def get_new_sample_for_hypotheses(self, orig_sample):
        """
        Extract hypotheses from *orig_sample* and return a new collated sample.
        """
        ids = orig_sample['id'].tolist()
        pad_idx = self.source_dictionary.pad()
        samples = [{
            'id':
            ids[i],
            'source':
            utils.strip_pad(orig_sample['net_input']['src_tokens'][i, :],
                            pad_idx),
            'target':
            hypo['tokens'],
        } for i, hypos_i in enumerate(orig_sample['hypos'])
                   for hypo in hypos_i]
        return language_pair_dataset.collate(
            samples,
            pad_idx=pad_idx,
            eos_idx=self.source_dictionary.eos(),
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            sort=False,
        )

    def get_sequence_scorer(self, scorer):
        if scorer not in self._scorers:
            tgt_dict = self.target_dictionary
            src_dict = self.source_dictionary
            if scorer == 'bleu':
                self._scorers[scorer] = BleuScorer(
                    tgt_dict,
                    bpe_symbol=self.args.seq_remove_bpe,
                )
            elif scorer == 'simile':
                self._scorers[scorer] = SimileScorer(
                    tgt_dict,
                    bpe_symbol=self.args.seq_remove_bpe,
                    args=self.args,
                )
            elif scorer == 'cl-simile':
                self._scorers[scorer] = CrossLingualSimileScorer(
                    tgt_dict,
                    src_dict,
                    self.args.cl_ratio,
                    bpe_symbol=self.args.seq_remove_bpe,
                    args=self.args,
                )
            else:
                raise ValueError('Unknown sequence scorer {}'.format(scorer))
        return self._scorers[scorer]

    def get_costs(self, sample, scorer=None):
        """Get costs for hypotheses using the specified *scorer*."""
        if scorer is None:
            scorer = self.get_sequence_scorer(self.args.seq_scorer)

        bsz = len(sample['hypos'])
        nhypos = len(sample['hypos'][0])
        target = sample['target'].int()
        source = sample['net_input']['src_tokens'].int()

        pad_idx = self.target_dictionary.pad()
        assert pad_idx == self.source_dictionary.pad()

        costs = torch.zeros(bsz, nhypos).to(sample['target'].device)

        if self.args.seq_scorer == "simile":
            for i, hypos_i in enumerate(sample['hypos']):
                ref = utils.strip_pad(target[i, :], pad_idx).cpu()
                ref = scorer.preprocess_ref(ref)
                ref_len = len(ref.split())
                hypos = []
                hypo_lens = []

                for j, hypo in enumerate(hypos_i):
                    hyp = scorer.preprocess_hypo(hypo)
                    hypos.append(hyp)
                    hypo_lens.append(len(hyp.split()))

                _costs = scorer.get_costs(ref, hypos)

                for j, _ in enumerate(hypos_i):
                    lp = np.exp(1 - max(ref_len, hypo_lens[j]) /
                                float(min(ref_len, hypo_lens[j])))
                    costs[i, j] = 1 - lp**self.args.simile_lenpen * _costs[
                        j].item()
        elif self.args.seq_scorer == "cl-simile":
            for i, hypos_i in enumerate(sample['hypos']):
                ref = utils.strip_pad(target[i, :], pad_idx).cpu()
                ref = scorer.preprocess_ref(ref)
                src = utils.strip_pad(source[i, :], pad_idx).cpu()
                src = scorer.preprocess_src(src)

                ref_len = len(ref.split())
                hypos = []
                hypo_lens = []

                for j, hypo in enumerate(hypos_i):
                    hyp = scorer.preprocess_hypo(hypo)
                    hypos.append(hyp)
                    hypo_lens.append(len(hyp.split()))

                _costs = scorer.get_costs(ref, hypos, src)

                for j, _ in enumerate(hypos_i):
                    lp = np.exp(1 - max(ref_len, hypo_lens[j]) /
                                float(min(ref_len, hypo_lens[j])))
                    costs[i, j] = 1 - lp**self.args.simile_lenpen * _costs[
                        j].item()
        else:
            for i, hypos_i in enumerate(sample['hypos']):
                ref = utils.strip_pad(target[i, :], pad_idx).cpu()
                ref = scorer.preprocess_ref(ref)
                for j, hypo in enumerate(hypos_i):
                    costs[i, j] = scorer.get_cost(ref,
                                                  scorer.preprocess_hypo(hypo))
        return scorer.postprocess_costs(costs)
Beispiel #29
0
class RewardCrossEntropyCriterion(FairseqCriterion):
    def __init__(self, args, task):
        super().__init__(args, task)
        self.eps = args.label_smoothing
        from fairseq.sequence_generator import SequenceGenerator
        self.gen = SequenceGenerator(task.target_dictionary,
                                     beam_size=args.beam_size)
        if args.reward == "bleurt":
            from fairseq.distributed_utils import get_rank
            sys.argv = sys.argv[:1]
            my_rank = 0 if torch.cuda.device_count() <= 1 else get_rank()
            os.environ["CUDA_VISIBLE_DEVICES"] = str(my_rank % 4)
            from bleurt import score
            from transformers import cached_path
            import tensorflow as tf
            gpus = tf.config.experimental.list_physical_devices('GPU')
            if gpus:
                this_gpu = gpus[my_rank % 4]
                tf.config.set_visible_devices([this_gpu], 'GPU')
                try:
                    tf.config.experimental.set_memory_growth(this_gpu, True)
                    tf.config.experimental.set_virtual_device_configuration(
                        this_gpu, [
                            tf.config.experimental.VirtualDeviceConfiguration(
                                memory_limit=2048)
                        ])
                    logical_devices = tf.config.list_logical_devices('GPU')
                    self.logical_device = tf.device(logical_devices[0].name)
                    print("num of logical gpus", len(logical_devices))
                except RuntimeError as e:
                    print(e)
            with self.logical_device:
                self.bleurt_scorer = score.BleurtScorer(
                    os.path.join(
                        cached_path(
                            "https://storage.googleapis.com/bleurt-oss/bleurt-base-128.zip",
                            extract_compressed_file=True), "bleurt-base-128"))

    @staticmethod
    def add_args(parser):
        """Add criterion-specific arguments to the parser."""
        # fmt: off
        parser.add_argument(
            '--label-smoothing',
            default=0.,
            type=float,
            metavar='D',
            help='epsilon for label smoothing, 0 means no label smoothing')
        parser.add_argument('--proxyloss2', action="store_true")
        parser.add_argument('--bleurt-scale', action="store_true")
        parser.add_argument('--contrastive', action="store_true")
        parser.add_argument('--m', default=10, type=float)
        parser.add_argument('--reward', default="sbleu", type=str)
        parser.add_argument('--beam-size', type=int, default=4)
        # fmt: on

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        # >>>> Sample for reward >>>
        is_training = model.training
        model.eval()
        B = sample["target"].shape[0]
        gen_results = []
        for shard_i in range(math.ceil(float(B) / SHARD_SIZE)):
            start = shard_i * SHARD_SIZE
            end = (shard_i + 1) * SHARD_SIZE
            sub_sample = {
                "net_input": {
                    "src_tokens":
                    sample["net_input"]["src_tokens"][start:end],
                    "src_lengths":
                    sample["net_input"]["src_lengths"][start:end],
                    "prev_output_tokens":
                    sample["net_input"]["prev_output_tokens"][start:end]
                }
            }
            sub_results = [[
                p["tokens"][:60] for p in results
            ] for results in self.gen.generate([model], sub_sample)]
            gen_results.extend(sub_results)
        targets = sample["target"] * torch.gt(sample["target"], 1)
        if self.args.reward == "sbleu":
            rewards = []
            for batch_i in range(len(gen_results)):
                batch_rewards = []
                for seq in gen_results[batch_i]:
                    batch_rewards.append(
                        self.compute_reward(seq, targets[batch_i]))
                rewards.append(batch_rewards)
            rewards = torch.tensor(rewards)
        elif self.args.reward == "bleurt":
            hyps = []
            tgts = []
            for batch_i in range(len(gen_results)):
                for seq in gen_results[batch_i]:
                    hyps.append(
                        self.task.tgt_dict.string(seq, bpe_symbol="@@ "))
                    tgts.append(
                        self.task.tgt_dict.string(
                            targets[batch_i]
                            [:torch.gt(targets[batch_i], 0).sum()],
                            bpe_symbol="@@ "))
            with self.logical_device:
                scores = torch.tensor(self.bleurt_scorer.score(tgts, hyps))
                if self.args.bleurt_scale:
                    rewards = scores * 100.
                else:
                    rewards = torch.exp(scores) * 100.
            rewards = rewards.view(B, -1)
        best_idx = rewards.argmax(1)
        # idxp = np.random.randint(rewards.shape[1], size=(rewards.shape[0], 2))
        # idxp_tensor = torch.tensor(idxp)
        # selected_rewards = torch.cat([
        #     rewards[torch.arange(rewards.size(0)), idxp_tensor[:, 0]][:, None],
        #     rewards[torch.arange(rewards.size(0)), idxp_tensor[:, 1]][:, None]
        # ], 1)
        # valid_mask = selected_rewards[:, 0] > selected_rewards[:, 1]
        # reversed_selected_rewards = torch.cat([selected_rewards[:, 1][:, None], selected_rewards[:, 0][:, None]], 1)
        # selected_rewards = selected_rewards * valid_mask[:, None] + reversed_selected_rewards * valid_mask.logical_not()[:, None]
        # reversed_idxp_tensor = torch.cat([idxp_tensor[:, 1][:, None], idxp_tensor[:, 0][:, None]], 1)
        # idxp_tensor = idxp_tensor * valid_mask[:, None] + reversed_idxp_tensor * valid_mask.logical_not()[:, None]
        # best_results = [res[idx] for res, idx in zip(gen_results, idxp_tensor[:, 0])]
        best_results = [res[idx] for res, idx in zip(gen_results, best_idx)]
        if not self.args.proxyloss2:
            maxlen = max([len(r) for r in best_results])
            new_target = targets.new_ones(targets.shape[0], maxlen)
            for i, seq in enumerate(best_results):
                new_target[i, :seq.shape[0]] = seq
            first_col = new_target.new_ones(new_target.shape[0]) * 2
            new_decoder_input = torch.cat(
                [first_col[:, None], new_target[:, :-1]], 1)
        else:
            # argmax_results = [res[0] for res in gen_results]
            # worst_results = [res[idx] for res, idx in zip(gen_results, idxp_tensor[:, 1])]
            worst_results = [
                res[idx] for res, idx in zip(gen_results, rewards.argmin(1))
            ]
            merged_results = best_results + worst_results
            maxlen = max([len(r) for r in merged_results])
            new_target = targets.new_ones(len(merged_results), maxlen)
            for i, seq in enumerate(merged_results):
                new_target[i, :seq.shape[0]] = seq
            first_col = new_target.new_ones(new_target.shape[0]) * 2
            new_decoder_input = torch.cat(
                [first_col[:, None], new_target[:, :-1]], 1)
        sample["net_input"]["prev_output_tokens"] = new_decoder_input
        sample["target"] = new_target
        best_reward = rewards[torch.arange(rewards.shape[0]), best_idx].cuda()
        worst_reward = rewards[torch.arange(rewards.shape[0]),
                               rewards.argmin(1)].cuda()
        # best_reward = selected_rewards[:, 0].cuda()
        # worst_reward = selected_rewards[:, 1].cuda()
        argmax_reward = rewards[:, 0].cuda()
        mean_reward = rewards.mean(1).cuda()
        if is_training:
            model.train()
        # >>>>
        if not self.args.proxyloss2:
            decoder_out = model.forward(sample['net_input']["src_tokens"],
                                        sample['net_input']["src_lengths"],
                                        new_decoder_input)
            loss, nll_loss = self.compute_loss(model,
                                               decoder_out,
                                               sample,
                                               reduce=reduce)
        else:
            # repeated_encoder_out = {"encoder_out": torch.cat([encoder_out["encoder_out"], encoder_out["encoder_out"]]),
            #                         "encoder_padding_mask":
            #                             torch.cat([encoder_out["encoder_padding_mask"], encoder_out["encoder_padding_mask"]])
            #                             if encoder_out["encoder_padding_mask"] is not None else None
            #                         }
            repeated_src_tokens = torch.cat([
                sample['net_input']["src_tokens"],
                sample['net_input']["src_tokens"]
            ])
            repeated_src_lengths = torch.cat([
                sample['net_input']["src_lengths"],
                sample['net_input']["src_lengths"]
            ])
            decoder_out = model.forward(repeated_src_tokens,
                                        repeated_src_lengths,
                                        new_decoder_input)
            loss, nll = self.compute_loss(model,
                                          decoder_out, {"target": new_target},
                                          reduce=False,
                                          return_full_mat=True)
            token_mask = torch.ne(new_target, self.padding_idx)
            loss = (loss.view(new_target.shape) *
                    token_mask).sum(1) / token_mask.sum(1)
            nll = (nll.view(new_target.shape) *
                   token_mask).sum(1) / token_mask.sum(1)
            if self.args.contrastive:
                loss = (loss[:B] - loss[-B:]) * 10.
            else:
                loss = (loss[:B] - loss[-B:]) * 10. + self.args.m * (
                    (best_reward - worst_reward) / 100.)
                loss = torch.gt(loss, 0) * loss
            # loss = ((best_reward - worst_reward) / 100.) * (loss[:B] - loss[-B:]) * 10.
            nll_loss = (nll[:B] - nll[-B:]) * 10.
            if reduce:
                loss = loss.sum()
                nll_loss = nll_loss.sum()
        sample_size = B if self.args.sentence_avg or self.args.proxyloss2 else sample[
            'ntokens']
        logging_output = {
            'loss':
            utils.item(loss.data) if reduce else loss.data,
            'nll_loss':
            utils.item(nll_loss.data) if reduce else nll_loss.data,
            'best_r':
            utils.item(best_reward.sum().data) if reduce else best_reward.data,
            'argmax_r':
            utils.item(argmax_reward.sum().data)
            if reduce else argmax_reward.data,
            'avg_r':
            utils.item(mean_reward.sum().data) if reduce else mean_reward.data,
            'ntokens':
            sample['ntokens'],
            'nsentences':
            B,
            'sample_size':
            sample_size,
        }
        return loss, sample_size, logging_output

    def compute_loss(self,
                     model,
                     net_output,
                     sample,
                     reduce=True,
                     return_full_mat=False):
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        target = model.get_targets(sample, net_output).view(-1, 1)
        if return_full_mat:
            ignore_index = None
            loss = -lprobs.gather(dim=1, index=target).flatten()
            nll_loss = loss
        else:
            ignore_index = self.padding_idx
            loss, nll_loss = label_smoothed_nll_loss(
                lprobs,
                target,
                self.eps,
                ignore_index=ignore_index,
                reduce=reduce,
            )
        return loss, nll_loss

    @staticmethod
    def aggregate_logging_outputs(logging_outputs):
        """Aggregate logging outputs from data parallel training."""
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
        return {
            'loss':
            sum(log.get('loss', 0) for log in logging_outputs) / sample_size /
            math.log(2) if sample_size > 0 else 0.,
            'nll_loss':
            sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens /
            math.log(2) if ntokens > 0 else 0.,
            'best_r':
            sum(log.get('best_r', 0) for log in logging_outputs) / nsentences /
            math.log(2) if nsentences > 0 else 0.,
            'argmax_r':
            sum(log.get('argmax_r', 0) for log in logging_outputs) /
            nsentences / math.log(2) if nsentences > 0 else 0.,
            'avg_r':
            sum(log.get('avg_r', 0) for log in logging_outputs) / nsentences /
            math.log(2) if nsentences > 0 else 0.,
            'ntokens':
            ntokens,
            'nsentences':
            nsentences,
            'sample_size':
            sample_size,
        }

    def compute_reward(self, yhat, target):
        return self._sbleu(yhat, target)

    def _sbleu(self, yhat, target):
        tgt_seq = target.int().cpu().numpy()
        sampled_tokens = yhat.int().cpu().numpy()
        tgt_mask = np.greater(tgt_seq, 0)
        yhat_mask = np.greater(sampled_tokens, 0)
        target_len = int(tgt_mask.sum())
        yhat_len = int(yhat_mask.sum())
        ref_tokens = tgt_seq[:target_len]
        out_tokens = list(sampled_tokens[:yhat_len])
        ref_tokens = self.task.tgt_dict.string(ref_tokens).replace("@@ ",
                                                                   "").split()
        out_tokens = self.task.tgt_dict.string(out_tokens).replace("@@ ",
                                                                   "").split()
        return smoothed_bleu(out_tokens, ref_tokens)
Beispiel #30
0
class FairseqAgent(TorchAgent):
    """Generic wrapper around fairseq for use in ParlAI"""

    metrics = {}

    @classmethod
    def add_cmdline_args(cls, argparser):
        """Add command-line arguments specifically for this agent."""
        # first we need to add the general torch agent operations
        super(FairseqAgent, cls).add_cmdline_args(argparser)

        # let's store any defaults that were overridden
        old_defaults = argparser._defaults
        if 'clip_norm' not in old_defaults:
            # fairseq has a few awful defaults
            old_defaults['clip_norm'] = 1.0
        if 'optimizer' not in old_defaults:
            old_defaults['optimizer'] = 'adam'
            old_defaults['adam_betas'] = '(0.9,0.98)'

        agent = argparser.add_argument_group('Fairseq Arguments')
        agent.add_argument('--fp16',
                           default=False,
                           type='bool',
                           help='Use fp16 training')
        agent.add_argument(
            '--fp16-init-scale',
            default=2**7,
            type=int,
            help='default FP16 loss scale',
        )
        agent.add_argument(
            '--seed',
            default=1,
            type=int,
            metavar='N',
            help='pseudo random number generator seed',
        )
        agent.add_argument(
            '--skip-generation',
            default=False,
            type='bool',
            metavar='BOOL',
            help=
            'Skips test time beam search. Much faster if you only need PPL',
        )

        # Check subargs for generation, optimizers, criterions, archs, etc
        options.add_generation_args(argparser)
        options.add_optimization_args(argparser)
        options.add_checkpoint_args(argparser)

        # restore any user set defaults that fairseq possibly overrode
        argparser.set_defaults(**old_defaults)
        known_args = argparser.parse_known_args(nohelp=True)[0]

        if hasattr(known_args, "optimizer"):
            optimizer = known_args.optimizer
            opt_group = argparser.add_argument_group(
                '{} optimizer arguments'.format(optimizer))
            optim.OPTIMIZER_REGISTRY[optimizer].add_args(opt_group)
        if hasattr(known_args, "lr_scheduler"):
            lr_scheduler = known_args.lr_scheduler
            lr_group = argparser.add_argument_group(
                '{} scheduler arguments'.format(lr_scheduler))
            optim.lr_scheduler.LR_SCHEDULER_REGISTRY[lr_scheduler].add_args(
                lr_group)
        # We need to find out the fairseq model-specific options, so grab the
        # architecture stuff and look up its options
        arch_group = options.add_model_args(argparser)
        # Fairseq marks the arch flag as required, but it may be specified
        # by a saved model cache, so we do some weird stuff to undo that
        for a in arch_group._actions:
            if a.dest == "arch":
                a.required = False
                a.default = None
                break

        # once again restore any user-set defaults
        argparser.set_defaults(**old_defaults)
        known_args = argparser.parse_known_args(nohelp=True)[0]

        if hasattr(known_args, "arch") and known_args.arch is not None:
            arch = known_args.arch
            arch_group = argparser.add_argument_group(
                "{} architecture arguments".format(arch))
            models.ARCH_MODEL_REGISTRY[arch].add_args(arch_group)

        if hasattr(known_args, "criterion"):
            crit_group = argparser.add_argument_group(
                '{} criterion arguments'.format(known_args.criterion))
            criterions.CRITERION_REGISTRY[known_args.criterion].add_args(
                crit_group)

        # one last time, restore any user set defaults
        argparser.set_defaults(**old_defaults)

    @staticmethod
    def dictionary_class():
        # Force use of the Fairseq Dictionary
        return _FairseqDictionary

    def __init__(self, opt, shared=None):
        # In general use a basic TorchAgent wherever possible
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full initialization

            # check early if we're going to be loading the model from a checkpoint
            model_file_exists = self.opt.get('model_file') and os.path.isfile(
                self.opt['model_file'])

            # fairseq expects options to be in argparse format, instead of a dict
            # We also need to do some argument postprocessing and whatnot
            # We'll skip pretrained embeddings if we're going to override them with
            # a model checkpoint anyway
            self.args, self.opt = _fairseq_opt_wrapper(opt, model_file_exists)

            # seed the RNG
            torch.manual_seed(self.args.seed)

            # Just some identifying info
            self.id = "fairseq:{}".format(self.args.arch)

            # We need a placeholder task for fairseq
            self.task = _ParlaiTask(self.dict)

            # meters for keeping track of loss, ppl, etc.
            self.meters = defaultdict(AverageMeter)

            # actually construct the model and generator
            self.model = self.build_model()

            # Construct the generator and scorer
            self.generator = SequenceGenerator(
                [self.model],
                tgt_dict=self.dict,
                beam_size=self.args.beam,
                stop_early=(not self.args.no_early_stop),
                normalize_scores=(not self.args.unnormalized),
                len_penalty=self.args.lenpen,
                unk_penalty=self.args.unkpen,
                sampling=self.args.sampling,
                sampling_topk=self.args.sampling_topk,
                sampling_temperature=self.args.sampling_temperature,
            )
            self.scorer = SequenceScorer([self.model], self.dict)

            # set up the grader and the trainer
            self.criterion = criterions.build_criterion(self.args, self.task)

            # TODO: we might choose to add a --no-fp16 opt in the future to
            # explicitly disable fp16 instead
            if not self.args.fp16 and torch.cuda.get_device_capability(
                    0)[0] >= 7:
                print("Heads up: using --fp16 could be a lot faster!")
            if self.use_cuda:
                self.trainer = trainer.Trainer(self.args, self.task,
                                               self.model, self.criterion,
                                               None)
                self.trainer._build_optimizer()
            else:
                self.trainer = None

            # if the model already existed, let's preload it and the trainer
            if model_file_exists:
                print('Loading existing model params from ' +
                      self.opt['model_file'])
                self.load(self.opt.get('model_file'))

            # move things to the GPU if possible
            if self.use_cuda:
                self.model = self.model.cuda()
                self.generator = self.generator.cuda()
        else:
            self.model = shared['model']
            self.trainer = shared['trainer']
            self.generator = shared['generator']
            self.dict = shared['dict']
            self.args = shared['args']
            self.meters = shared['meters']

        # Start things off clean
        self.reset()

    def _check_opts_unchanged(self, saved_opts, current_opts):
        """Verify that critical options do not differ in command line vs saved model"""
        for k in NON_OVERRIDABLE_ARGS:
            if k not in saved_opts or k not in current_opts:
                # if it's not an option needed by this fairseq model, don't stress
                continue
            if saved_opts[k] != current_opts[k]:
                raise ValueError(
                    '{} cannot be overridden when --model-file is specified'.
                    format(k))

    def build_model(self):
        """
        Construct the actual Fairseq model. Default implementation is to use
        Fairseq's arch builder, but this method may be overridden to build custom
        models.
        """
        model_class = models.ARCH_MODEL_REGISTRY[self.args.arch]
        model = model_class.build_model(self.args, self.task)
        if self.args.embedding_type != 'random':
            self._copy_embeddings(model.encoder.embed_tokens.weight,
                                  self.args.embedding_type)
        return model

    def share(self):
        shared = super().share()
        shared['model'] = self.model
        shared['trainer'] = self.trainer
        shared['generator'] = self.generator
        shared['dict'] = self.dict
        shared['args'] = self.args
        shared['meters'] = self.meters
        return shared

    def save(self, path):
        """Save using fairseq's checkpointing."""
        if not path:
            return
        self.trainer.save_checkpoint(path, {'opt': self.opt, 'epoch': 0})
        # Parlai expects options to also be saved
        with open(path + '.opt', 'w') as handle:
            # overridden options shouldn't be stored, only the main ones
            if 'override' in self.opt:
                del self.opt['override']
            json.dump(self.opt, handle)

        # force save the dict
        self.dict.save(path + '.dict', sort=False)

    def load(self, path):
        """Load using fairseq's checkpointing."""
        if self.trainer:
            old_options = self.trainer.load_checkpoint(
                path, self.args.reset_optimizer)
            self._check_opts_unchanged(old_options, self.opt)
        else:
            load_model_state(path, self.model)

    def shutdown(self):
        if not hasattr(self, 'trainer'):
            # looks like this is a "fake" model that isn't actually used for batch_act.
            # we don't need to save this one.
            return
        super().shutdown()

    def reset(self):
        """Reset observation and episode_done."""
        super().reset()
        self.reset_metrics()

    def is_valid(self, obs):
        """Override from TorchAgent.
        Check if an observation has no tokens in it."""
        return len(obs.get('text_vec', [])) > 0

    def batchify(self, obs_batch):
        """
        Override parent batchify to set requirements for fairseq.

        Fairseq depends on sorted batch inputs for a call to rnn.pad_packed_sequence.
        Fairseq models cannot handle zero length sentences
        """
        return super().batchify(obs_batch, sort=True)

    def _update_metrics(self, metrics, sample):
        if metrics is None:
            # probably got an overflow in fp16 mode. don't count this sample
            return

        bsz = len(sample['target'])
        ntok = sample['ntokens']
        ssize = metrics['sample_size']

        for k, v in metrics.items():
            if k in {'ntokens', 'nsentences', 'sample_size'}:
                # don't need these
                continue
            elif k == "nll_loss":
                # nll loss is always normalized by ntokens
                self.meters[k].update(v, ntok)
            elif k == "loss":
                # loss is explicitly normalized by passed up sample size
                self.meters[k].update(v, ssize)
            else:
                # assume everything else it's averaged over bsz
                self.meters[k].update(v, bsz)

    def train_step(self, batch):
        """Process batch of inputs and targets and train on them.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = True
        sample = self._make_sample(batch)
        self.model.train()
        metrics = self.trainer.train_step([sample])
        self._update_metrics(metrics, sample)

    def eval_step(self, batch):
        """Process batch of inputs.

        If the batch includes labels, calculate validation metrics as well.
        If --skip-generation is not set, return a prediction for each input.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = False
        samples = self._make_sample(batch)
        self.model.eval()
        if batch.label_vec is not None and self.trainer is not None:
            # Interactive mode won't have a gold label
            metrics = self.trainer.valid_step(samples)
            self._update_metrics(metrics, samples)

        # Output placeholders
        reranked_cands = None
        generated_output = None

        # Grade each of the candidate sequences
        if batch.candidate_vecs is not None:
            bsz = len(batch.text_vec)
            reranked_cands = []
            # score the candidates for each item in the batch separately, so that
            # we can support variable number of candidates
            for i in range(bsz):
                cands = batch.candidate_vecs[i]
                if not cands:
                    reranked_cands.append(None)
                    continue
                ncand = len(cands)
                # repeat the input many times
                xs = batch.text_vec[i].unsqueeze(0).expand(ncand, -1)
                # some models crash if there's leading padding on every example
                xs = xs[:, :batch.text_lengths[i]]
                # and appropriately pack the outputs
                ys, _ = padded_tensor(cands, self.NULL_IDX, self.use_cuda)
                s = self._make_sample(xs=xs, ys=ys)
                # perform the actual grading, extract the scores
                scored = list(
                    self.scorer.score_batched_itr([s], cuda=self.use_cuda))
                scores = [s[3][0]['score'].item() for s in scored]
                # intentional hanging comma here; argsort returns a list
                ranked, = argsort(scores, batch.candidates[i], descending=True)
                reranked_cands.append(ranked)

        # Next generate freely to create our response
        if not self.args.skip_generation:
            generated_output = self._generate(samples)
        elif reranked_cands:
            # we're skiping generation, but we're also grading candidates
            # so output the highest ranked candidate
            # In the case of zero candidates, we don't have something to rank,
            # so we may need to pass on that None
            generated_output = [
                ranked and ranked[0] or None for ranked in reranked_cands
            ]
        else:
            # no output at all
            pass

        return Output(generated_output, reranked_cands)

    def _generate(self, samples):
        no_prev_token = {
            k: v
            for k, v in samples['net_input'].items()
            if k != 'prev_output_tokens'
        }
        gens = self.generator.generate(no_prev_token, maxlen=64)
        bsz = samples['net_input']['src_tokens'].size(0)
        responses = []
        for i in range(bsz):
            beams = gens[i]
            selected = max(beams, key=lambda x: x["score"])
            tokens = selected["tokens"]
            start = 0
            end = -1
            for i, t in enumerate(tokens):
                t = t.item()
                if t == self.dict.bos_index:
                    # don't include <s> token
                    start = i + 1
                    continue
                if t == self.dict.eos_index:
                    # stop (and don't include) </s> token
                    end = i
                    break
            responses.append(self.dict.vec2txt(tokens[start:end]))
        return responses

    def report(self):
        """Return metrics calculated by the model."""
        # if we haven't initialized yet, just return a dummy object
        if not hasattr(self, "trainer"):
            return {}

        output = {k: v.avg for k, v in self.meters.items()}

        if "nll_loss" in self.meters:
            # special case, we used sentence averaging so ppl comes from nll_loss
            output["ppl"] = np.exp2(self.meters["nll_loss"].avg)
        else:
            # normal case, just use loss
            output["ppl"] = np.exp2(self.meters["loss"].avg)

        # Fairseq trainer metrics we'll pass up the way
        trainer_metrics = {"ups", "wps", "gnorm", "clip"}
        if self.is_training:
            for k in trainer_metrics:
                output[k] = self.trainer.meters[k].avg

        # for display purposes
        output = {k: round_sigfigs(v, 4) for k, v in output.items()}
        return output

    def reset_metrics(self):
        """Reset metrics calculated by the model back to zero."""
        if not hasattr(self, "trainer"):
            # We haven't set up the trainer yet, so we don't have any metrics
            return
        # We need to reset everything
        self.meters.clear()
        if self.trainer:
            for k in self.trainer.meters:
                self.trainer.meters[k].reset()

    def receive_metrics(self, metrics_dict):
        """Update lr scheduler with validation loss."""
        # TODO: this should be smarter
        self.trainer.lr_step(-1, metrics_dict["loss"])

    # Helper functions
    def _seq_length(self, xs):
        """Compute length of the sequence (non-padded size)."""
        return xs.ne(self.dict.pad_index).long().sum(dim=-1)

    def _right_shifted_ys(self, ys):
        """Replace first token with EOS and shift remaining tokens right 1."""
        result = torch.LongTensor(ys.size())
        result[:, 0] = self.dict.eos_index
        result[:, 1:] = ys[:, :-1]
        return result

    def _make_sample(self, batch=None, xs=None, ys=None):
        """Generate a sample object that Fairseq expects."""
        # add extra info to samples
        if batch is None and xs is None:
            raise ValueError("Must supply either batch or xs")
        if batch is None and ys is None:
            raise ValueError("Must supply either batch or ys")
        if xs is None:
            xs = batch.text_vec
        if ys is None:
            ys = batch.label_vec
        repadded = convert_padding_direction(xs,
                                             self.dict.pad(),
                                             right_to_left=True)
        sample = {}
        sample["id"] = torch.arange(len(xs) - 1)
        sample["net_input"] = {
            "src_tokens": repadded,
            "src_lengths": self._seq_length(xs),
        }
        if ys is not None:
            sample["target"] = ys
            sample["ntokens"] = sum(self._seq_length(ys)).item()
            sample["net_input"]["prev_output_tokens"] = self._right_shifted_ys(
                ys)
        return sample
class MaskDiscriminatorTask(MaskMLETask):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.generator = self.load_pretrained_generator(args[0].generator_path)
        if not args[0].cpu:
            self.generator.cuda()

        self.passed_iters = 0
        self.sequence_generator = SequenceGenerator(self.target_dictionary, beam_size=1)

    @staticmethod
    def add_args(parser):
        MaskMLETask.add_args(parser)

        parser.add_argument('--generator-path', type=str, help='path to trained generator')

    def load_pretrained_generator(self, path, arg_overrides=None):
        model = utils.load_checkpoint_to_cpu(path)
        args = model['args']
        state_dict = model['model']
        if not(arg_overrides is None):
            args = utils.override_model_args(args, arg_overrides)
        src_dict = self.source_dictionary
        tgt_dict = self.target_dictionary
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()

        task = MaskMLETask(args, src_dict, tgt_dict)
        model = task.build_model(args)
        model.upgrade_state_dict(state_dict)
        model.load_state_dict(state_dict, strict=True)
        return model

    def process_sample(self, sample, p):
        mask = torch.distributions.Bernoulli(torch.Tensor([p]))
        target = sample['target'].clone()

        mask_tensor = mask.sample(target.size())[:, :, 0].to(target.device)

        pad_idx = self.target_dictionary.pad()
        mask_idx = self.target_dictionary.index("<MASK>")

        target[(target != pad_idx) & (
            mask_tensor.byte())] = mask_idx
        mask_tensor[(target == pad_idx)] = 0

        sample['net_input']['masked_tgt'] = target
        sample['masks'] = mask_tensor
        return sample

    def get_mask_rate(self):
        return 0.8
        #  return torch.clamp(0.1 + self.passed_iters * 0.01, 0., 1.)

    def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
        """
        Do forward and backward, and return the loss as computed by *criterion*
        for the given *model* and *sample*.
        Args:
            sample (dict): the mini-batch. The format is defined by the
                :class:`~fairseq.data.FairseqDataset`.
            model (~fairseq.models.BaseFairseqModel): the model
            criterion (~fairseq.criterions.FairseqCriterion): the criterion
            optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
            ignore_grad (bool): multiply loss by 0 if this is set to True
        Returns:
            tuple:
                - the loss
                - the sample size, which is used as the denominator for the
                  gradient
                - logging outputs to display while training
        """

        p = self.get_mask_rate()
        sample = self.process_sample(sample, p=p)

        self.generator.eval()
        model.train()

        generated = self.sequence_generator.generate((self.generator, ), sample,
                                                     substitute=True, mask_token=self.target_dictionary.index('<MASK>'))

        max_len = sample['target'].shape[1]
        tokens = [x[0]['tokens'] for x in generated]
        lengths = [min(max_len, x.shape[0]) for x in tokens]
        generated_tokens = torch.stack([torch.cat(
            (
                sample['target'].new_full(
                    (max_len - length,),
                    self.target_dictionary.pad()
                ),
                x[:length],
            )
        ) for x, length in zip(tokens, lengths)])

        sample['generated_tokens'] = generated_tokens

        # print('Target', sample['target'][0])
        # print('Generated', generated_tokens[0])

        loss, sample_size, logging_output = criterion(model, sample)
        if ignore_grad:
            loss *= 0
        optimizer.backward(loss)
        self.passed_iters += 1
        return loss, sample_size, logging_output

    def valid_step(self, sample, model, criterion):
        p = self.get_mask_rate()
        sample = self.process_sample(sample, p=p)
        self.generator.eval()
        model.eval()
        with torch.no_grad():
            generated = self.sequence_generator.generate((self.generator,), sample,
                                                         substitute=True,
                                                         mask_token=self.target_dictionary.index(
                                                             '<MASK>'))
            max_len = sample['target'].shape[1]
            tokens = [x[0]['tokens'] for x in generated]
            lengths = [min(max_len, x.shape[0]) for x in tokens]
            generated_tokens = torch.stack([torch.cat(
                (
                    sample['target'].new_full((max_len - length ,), self.target_dictionary.pad()),
                    x[:length]
                )
            ) for x, length in zip(tokens, lengths)])
            sample['generated_tokens'] = generated_tokens

            loss, sample_size, logging_output = criterion(model, sample)
        return loss, sample_size, logging_output

    def inference_step(self, generator, models, sample, prefix_tokens=None):
        p = self.get_mask_rate()
        sample = self.process_sample(sample, p=p)
        with torch.no_grad():
            return generator.generate(models, sample, prefix_tokens=prefix_tokens,
                                      substitute=True,
                                      mask_token=self.target_dictionary.index(
                                          '<MASK>'))
Beispiel #32
0
def main():
    args = parser.parse_args()
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    dictionary = Dictionary.load(args.vocab_path)
    dictionary.truncate(args.max_vocab_size)

    test_dataset = SummaryDataset(os.path.join(args.data_path, 'test'),
                                  dictionary=dictionary,
                                  max_article_size=args.max_source_positions,
                                  max_summary_size=args.max_target_positions,
                                  max_elements=10 if args.debug else None)
    test_sampler = SequentialSampler(test_dataset)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, sampler=test_sampler, \
                                 num_workers=args.num_workers,
                                 collate_fn=lambda samples: collate(samples, dictionary.pad_index,
                                                                    dictionary.eos_index))

    summarization_task = SummarizationTask(args, dictionary)
    if args.model == 'transformer':
        args.local_transformer = False
        # transformer.base_architecture(args)
        transformer.transformer_small(args)
        model = transformer.TransformerModel.build_model(
            args, summarization_task).to(args.device)
    elif args.model == 'lstm':
        lstm.base_architecture(args)
        args.criterion = None
        model = lstm.LSTMModel.build_model(args,
                                           summarization_task).to(args.device)
    elif args.model == 'lightconv':
        args.encoder_conv_type = 'lightweight'
        args.decoder_conv_type = 'lightweight'
        args.weight_softmax = True
        lightconv.lightconv_small(args)
        model = lightconv.LightConvModel.build_model(
            args, summarization_task).to(args.device)
    elif args.model == 'localtransformer':
        args.local_transformer = True
        # transformer.base_architecture(args)
        transformer.transformer_small(args)
        model = transformer.TransformerModel.build_model(
            args, summarization_task).to(args.device)
    elif args.model == 'transformer_conv':
        # args.local_transformer = True
        # transformer.base_architecture(args)
        transformer_conv.transformer_conv_small(args)
        model = transformer_conv.TransformerConvModel.build_model(
            args, summarization_task).to(args.device)
    elif args.model == 'transformer_mc':
        # args.local_transformer = True
        # transformer.base_architecture(args)
        transformer_mc.transformer_mc_small(args)
        model = transformer_mc.TransformerMCModel.build_model(
            args, summarization_task).to(args.device)

    if args.model_path:
        model.load_state_dict(torch.load(args.model_path))

    generator = SequenceGenerator(dictionary,
                                  beam_size=args.beam_size,
                                  max_len_b=args.max_target_positions)

    avg_rouge_score = defaultdict(float)

    for batch_idx, batch in enumerate(test_dataloader):
        src_tokens = batch['net_input']['src_tokens'].to(args.device)
        src_lengths = batch['net_input']['src_lengths'].to(args.device)

        references = batch['target']
        references = [
            remove_special_tokens(ref, dictionary) for ref in references
        ]
        references = [dictionary.string(ref) for ref in references]

        # encoder_input = {'src_tokens': src_tokens, 'src_lengths': src_lengths}
        hypos = generator.generate([model], {
            'net_input': {
                'src_tokens': src_tokens,
                'src_lengths': src_lengths
            }
        })

        hypotheses = [hypo[0]['tokens'] for hypo in hypos]
        assert len(hypotheses) == src_tokens.size()[0]  # = size of the batch
        hypotheses = [
            remove_special_tokens(hypo, dictionary) for hypo in hypotheses
        ]
        hypotheses = [dictionary.string(hyp) for hyp in hypotheses]

        if args.verbose:
            print("\nComparison references/hypotheses:")
            for ref, hypo in zip(references, hypotheses):
                print(ref)
                print(hypo)
                print()

        avg_rouge_score_batch = compute_rouge.compute_score(
            references, hypotheses)
        print("rouge for this batch:", avg_rouge_score_batch)

        compute_rouge.update(avg_rouge_score, batch_idx * args.batch_size,
                             avg_rouge_score_batch, len(hypotheses))

    return avg_rouge_score
Beispiel #33
0
    def forward(self, model, sample, reduce=True):
        # sample mode
        #print('!!!RL loss.')
        model.eval()
        # src_dict = self.task.source_dictionary
        tgt_dict = self.task.target_dictionary
        eos_idx = self.task.target_dictionary.eos()
        sample_beam = self.args.sample_beam
        translator = SequenceGenerator(
            [model],
            tgt_dict=tgt_dict,
            sampling=self.args.multinomial_sample_train,
            beam_size=sample_beam,
            minlen=1)
        translator.cuda()
        ct = 0
        translations = []

        s = utils.move_to_cuda(sample)
        input = s['net_input']
        max_len = 200
        with torch.no_grad():
            hypos = translator.generate(
                input['src_tokens'],
                input['src_lengths'],
                beam_size=sample_beam,
                maxlen=max_len,
            )
        for i, id in enumerate(s['id'].data):
            src = input['src_tokens'].data[i, :]
            # remove padding from ref
            ref = utils.strip_pad(
                s['target'].data[i, :],
                tgt_dict.pad()) if s['target'] is not None else None
            translations.append((id, src, ref, hypos[i]))
            ct += 1
        # print("sample batch size:", ct)

        model.train()

        # MLE loss
        mle_net_output = model(**sample['net_input'])
        mle_lprobs = model.get_normalized_probs(mle_net_output, log_probs=True)
        mle_lprobs = mle_lprobs.view(-1, mle_lprobs.size(-1))
        mle_target = model.get_targets(sample, mle_net_output).view(-1)
        mle_loss = F.nll_loss(mle_lprobs,
                              mle_target,
                              size_average=False,
                              ignore_index=self.padding_idx,
                              reduce=reduce)
        mle_tokens = sample['ntokens']
        avg_mle_loss = mle_loss / mle_tokens
        print('avg_mle_loss:', avg_mle_loss)
        # RL loss
        batch_rl_loss = 0
        batch_tokens = 0
        sample_ind = 0
        for sample_id, src_tokens, tgt_tokens, hypos in translations:
            # calculate bleu
            sample_ind += 1
            rewards = torch.Tensor(sample_beam).float().cuda()
            logprobs = torch.Tensor(sample_beam).float().cuda()
            for i in range(sample_beam):
                hypo = hypos[i]
                trans_tokens = hypo['tokens']
                rewards[i] = self.compute_gleu(tgt_tokens.cpu(),
                                               trans_tokens.cpu(),
                                               max_order=self.args.max_order,
                                               gram=self.args.gram).cuda()
                # one_sample loss calculation
                tgt_input_tokens = trans_tokens.new(
                    trans_tokens.shape).fill_(0)
                assert trans_tokens[-1] == eos_idx
                tgt_input_tokens[0] = eos_idx
                tgt_input_tokens[1:] = trans_tokens[:-1]
                train_sample = {
                    'net_input': {
                        'src_tokens':
                        src_tokens.view(1, -1),
                        'src_lengths':
                        torch.LongTensor(src_tokens.numel()).view(1, -1),
                        'prev_output_tokens':
                        tgt_input_tokens.view(1, -1),
                    },
                    'target': trans_tokens.view(1, -1)
                }
                train_sample = utils.move_to_cuda(train_sample)
                net_output = model(**train_sample['net_input'])
                lprobs = model.get_normalized_probs(net_output, log_probs=True)
                lprobs = lprobs.view(-1, lprobs.size(-1))
                target = model.get_targets(train_sample,
                                           net_output).view(-1, 1)
                non_pad_mask = target.ne(tgt_dict.pad())
                lprob = -lprobs.gather(dim=-1, index=target)[non_pad_mask]
                logprobs[i] = torch.sum(lprob)
                ntokens = len(train_sample['target'])
                batch_tokens += ntokens
            rl_loss = torch.sum(logprobs *
                                (rewards - rewards.mean()))  # one sample loss
            batch_rl_loss += rl_loss

        avg_rl_loss = batch_rl_loss / batch_tokens
        print('avg_rl_loss:', avg_rl_loss)
        if self.args.mle_weight:
            assert self.args.rl_weight
            total_loss = self.args.mle_weight * avg_mle_loss + self.args.rl_weight * avg_rl_loss
            total_tokens = batch_tokens + mle_tokens
        else:
            total_loss = avg_rl_loss
            total_tokens = batch_tokens
        logging_output = {
            'loss': utils.item(total_loss.data),
            'ntokens': total_tokens,
            'sample_size': total_tokens,
        }
        print('total: ', total_loss)
        return total_loss, total_tokens, logging_output
Beispiel #34
0
def main():
    parser = options.get_parser('Generation')
    parser.add_argument('--path', metavar='FILE', required=True, action='append',
                        help='path(s) to model file(s)')
    dataset_args = options.add_dataset_args(parser)
    dataset_args.add_argument('-i', '--interactive', action='store_true',
                              help='generate translations in interactive mode')
    dataset_args.add_argument('--batch-size', default=32, type=int, metavar='N',
                              help='batch size')
    dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT',
                              help='data subset to generate (train, valid, test)')
    options.add_generation_args(parser)

    args = parser.parse_args()
    print(args)

    if args.no_progress_bar:
        progress_bar.enabled = False
    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load model and dataset
    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, dataset = utils.load_ensemble_for_inference(args.path, args.data)

    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
    if not args.interactive:
        print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))

    # Optimize model for generation
    for model in models:
        model.make_generation_fast_(not args.no_beamable_mm)

    # Initialize generator
    translator = SequenceGenerator(models, dataset.dst_dict, beam_size=args.beam,
                                   stop_early=(not args.no_early_stop),
                                   normalize_scores=(not args.unnormalized),
                                   len_penalty=args.lenpen)
    align_dict = {}
    if args.unk_replace_dict != '':
        assert args.interactive, "Unkown words replacing requires access to original source and is only" \
                                 "supported in interactive mode"
        with open(args.unk_replace_dict, 'r') as f:
            for line in f:
                l = line.split()
                align_dict[l[0]] = l[1]

    def replace_unk(hypo_str, align_str, src, unk):
        hypo_tokens = hypo_str.split()
        src_tokens = tokenizer.tokenize_line(src)
        align_idx = [int(i) for i in align_str.split()]
        for i, ht in enumerate(hypo_tokens):
            if ht == unk:
                src_token = src_tokens[align_idx[i]]
                if src_token in align_dict:
                    hypo_tokens[i] = align_dict[src_token]
                else:
                    hypo_tokens[i] = src_token
        return ' '.join(hypo_tokens)

    if use_cuda:
        translator.cuda()

    bpe_symbol = '@@ ' if args.remove_bpe else None
    def display_hypotheses(id, src, orig, ref, hypos):
        id_str = '' if id is None else '-{}'.format(id)
        src_str = to_sentence(dataset.src_dict, src, bpe_symbol)
        print('S{}\t{}'.format(id_str, src_str))
        if orig is not None:
            print('O{}\t{}'.format(id_str, orig.strip()))
        if ref is not None:
            print('T{}\t{}'.format(id_str, to_sentence(dataset.dst_dict, ref, bpe_symbol, ref_unk=True)))
        for hypo in hypos:
            hypo_str = to_sentence(dataset.dst_dict, hypo['tokens'], bpe_symbol)
            align_str = ' '.join(map(str, hypo['alignment']))
            if args.unk_replace_dict != '':
                hypo_str = replace_unk(hypo_str, align_str, orig, unk_symbol(dataset.dst_dict))
            print('H{}\t{}\t{}'.format(
                id_str, hypo['score'], hypo_str))
            print('A{}\t{}'.format(id_str, align_str))

    if args.interactive:
        for line in sys.stdin:
            tokens = tokenizer.Tokenizer.tokenize(line, dataset.src_dict, add_if_not_exist=False).long()
            start = dataset.src_dict.pad() + 1
            positions = torch.arange(start, start + len(tokens)).type_as(tokens)
            if use_cuda:
                positions = positions.cuda()
                tokens = tokens.cuda()
            translations = translator.generate(Variable(tokens.view(1, -1)), Variable(positions.view(1, -1)))
            hypos = translations[0]
            display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)])

    else:
        def maybe_remove_bpe(tokens):
            """Helper for removing BPE symbols from a hypothesis."""
            if not args.remove_bpe:
                return tokens
            assert (tokens == dataset.dst_dict.pad()).sum() == 0
            hypo_minus_bpe = to_sentence(dataset.dst_dict, tokens, bpe_symbol)
            return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True)

        # Generate and compute BLEU score
        scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
        itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size, max_positions=args.max_positions)
        num_sentences = 0
        with progress_bar(itr, smoothing=0, leave=False) as t:
            wps_meter = TimeMeter()
            gen_timer = StopwatchMeter()
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
                cuda_device=0 if use_cuda else None, timer=gen_timer)
            for id, src, ref, hypos in translations:
                ref = ref.int().cpu()
                top_hypo = hypos[0]['tokens'].int().cpu()
                scorer.add(maybe_remove_bpe(ref), maybe_remove_bpe(top_hypo))
                display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)])

                wps_meter.update(src.size(0))
                t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg)))
                num_sentences += 1

        print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
            num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
Beispiel #35
0
class Handler(BaseDynaHandler):
    """Use Fairseq model for translation.
    To use this handler, download one of the Flores pretrained model:

    615M parameters:
        https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz
    175M parameters:
        https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.gz

    and extract the files next to this one.
    Notably there should be a "dict.txt" and a "sentencepiece.bpe.model".
    """
    def initialize(self, context):
        """
        load model and extra files.
        """
        logger.info(
            f"Will initialize with system_properties: {context.system_properties}"
        )
        model_pt_path, model_file_dir, device = self._handler_initialize(
            context)
        config = json.loads(
            (Path(model_file_dir) / "model_generation.json").read_text())
        self.device = device

        translation_cfg = TranslationConfig()
        self.vocab = TranslationTask.load_dictionary("dict.txt")

        self.spm = sentencepiece.SentencePieceProcessor()
        self.spm.Load("sentencepiece.bpe.model")
        logger.info("Loaded sentencepiece.bpe.model")

        if config.get("dummy", False):
            self.sequence_generator = FakeGenerator()
            logger.warning("Will use a FakeGenerator model, only testing BPE")
        else:
            task = TranslationTask(translation_cfg, self.vocab, self.vocab)
            [model], cfg = fairseq.checkpoint_utils.load_model_ensemble(
                [model_pt_path], task=task)
            model.eval().to(self.device)
            logger.info(
                f"Loaded model from {model_pt_path} to device {self.device}")
            logger.info(
                f"Will use the following config: {json.dumps(config, indent=4)}"
            )
            self.sequence_generator = SequenceGenerator(
                [model],
                tgt_dict=self.vocab,
                beam_size=config.get("beam_size", 1),
                max_len_a=config.get("max_len_a", 1.3),
                max_len_b=config.get("max_len_b", 5),
                min_len=config.get("min_len", 5),
            )

        self.taskIO = TaskIO()
        self.initialized = True

    def lang_token(self, lang: str) -> int:
        """Converts the ISO 639-3 language code to MM100 language codes."""
        simple_lang = ISO2M100[lang]
        token = self.vocab.index(f"__{simple_lang}__")
        assert token != self.vocab.unk(
        ), f"Unknown language '{lang}' ({simple_lang})"
        return token

    def tokenize(self, line: str) -> list:
        words = self.spm.EncodeAsPieces(line.strip())
        tokens = [self.vocab.index(word) for word in words]
        return tokens

    def preprocess_one(self, sample) -> dict:
        """
        preprocess data into a format that the model can do inference on
        """
        # TODO: this doesn't seem to produce good results. wrong EOS / BOS ?
        tokens = self.tokenize(sample["sourceText"])
        src_token = self.lang_token(sample["sourceLanguage"])
        tgt_token = self.lang_token(sample["targetLanguage"])
        return {
            "src_tokens": [src_token] + tokens + [self.vocab.eos()],
            "src_length": len(tokens) + 1,
            "tgt_token": tgt_token,
        }
        return sample

    def preprocess(self, samples) -> dict:
        samples = [self.preprocess_one(s) for s in samples]
        prefix_tokens = torch.tensor([[s["tgt_token"]] for s in samples])
        src_lengths = torch.tensor([s["src_length"] for s in samples])
        src_tokens = data_utils.collate_tokens(
            [torch.tensor(s["src_tokens"]) for s in samples],
            self.vocab.pad(),
            self.vocab.eos(),
        )
        return {
            "nsentences": len(samples),
            "ntokens": src_lengths.sum().item(),
            "net_input": {
                "src_tokens": src_tokens.to(self.device),
                "src_lengths": src_lengths.to(self.device),
            },
            "prefix_tokens": prefix_tokens.to(self.device),
        }

    def strip_pad(self, sentence):
        assert sentence.ndim == 1
        return sentence[sentence.ne(self.vocab.pad())]

    @torch.no_grad()
    def inference(self, input_data: dict) -> list:
        generated = self.sequence_generator.generate(
            models=[],
            sample=input_data,
            prefix_tokens=input_data["prefix_tokens"],
        )
        # `generate` returns a list of samples
        # with several hypothesis per sample
        # and a dict per hypothesis.
        # We also need to strip the language token.
        return [hypos[0]["tokens"][1:] for hypos in generated]

    def postprocess(self, inference_output, samples: list) -> list:
        """
        post process inference output into a response.
        response should be a list of json
        the response format will need to pass the validation in
        ```
        dynalab.tasks.flores_small1.TaskIO().verify_response(response)
        ```
        """
        translations = [
            self.vocab.string(self.strip_pad(sentence), "sentencepiece")
            for sentence in inference_output
        ]
        return [
            # Signing required by dynabench, don't remove.
            self.taskIO.sign_response(
                {
                    "id": sample["uid"],
                    "translatedText": translation
                },
                sample,
            ) for translation, sample in zip(translations, samples)
        ]
Beispiel #36
0
class FairseqAgent(TorchAgent):
    """Generic wrapper around fairseq for use in ParlAI"""

    metrics = {}

    # TODO: merge with TorchAgent.add_cmdline_args
    @staticmethod
    def add_cmdline_args(argparser):
        """Add command-line arguments specifically for this agent."""
        # first we need to add the general torch agent operations
        TorchAgent.add_cmdline_args(argparser)

        agent = argparser.add_argument_group('Fairseq Arguments')
        agent.add_argument(
            '--seed',
            default=1,
            type=int,
            metavar='N',
            help='pseudo random number generator seed'
        )
        agent.add_argument(
            '--skip-generation',
            default=False,
            type=bool,
            metavar='BOOL',
            help='Skips test time beam search. Much faster if you only need PPL',
        )

        # Dictionary construction stuff. Using the subclass in case we end up
        # needing any fairseq specific things
        _FairseqDictionary.add_cmdline_args(argparser)

        # Optimization and learning rate schedule specific arguments
        options.add_optimization_args(argparser)
        known_args = argparser.parse_known_args(nohelp=True)[0]
        if hasattr(known_args, "optimizer"):
            optimizer = known_args.optimizer
            opt_group = argparser.add_argument_group(
                '{} optimizer arguments'.format(optimizer)
            )
            optim.OPTIMIZER_REGISTRY[optimizer].add_args(opt_group)
        if hasattr(known_args, "lr_scheduler"):
            lr_scheduler = known_args.lr_scheduler
            lr_group = argparser.add_argument_group(
                '{} scheduler arguments'.format(lr_scheduler)
            )
            optim.lr_scheduler.LR_SCHEDULER_REGISTRY[lr_scheduler].add_args(lr_group)

        # Generation arguments
        options.add_generation_args(argparser)

        # We need to find out the fairseq model-specific options, so grab the
        # architecture stuff and look up its options
        arch_group = options.add_model_args(argparser)
        # Fairseq marks the arch flag as required, but it may be specified
        # by a saved model cache, so we do some weird stuff to undo that
        for a in arch_group._actions:
            if a.dest == "arch":
                a.required = False
                a.default = None
                break
        known_args = argparser.parse_known_args(nohelp=True)[0]
        if hasattr(known_args, "arch") and known_args.arch is not None:
            arch = known_args.arch
            arch_group = argparser.add_argument_group(
                "{} architecture arguments".format(arch)
            )
            models.ARCH_MODEL_REGISTRY[arch].add_args(arch_group)

        # Override a few defaults from within fairseq to more sensible defaults
        argparser.set_defaults(
            clip_norm=0.1,
            adam_betas="(0.9,0.98)"
        )

    def __init__(self, opt, shared=None):
        # In general use a basic TorchAgent wherever possible
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full initialization

            # fairseq expects options to be in argparse format, instead of a dict
            # We also need to do some argument postprocessing and whatnot
            self.args, self.opt = _fairseq_opt_wrapper(opt)

            # seed the RNG
            torch.manual_seed(self.args.seed)

            # Just some identifying info
            self.id = "fairseq:{}".format(self.args.arch)

            # construct dictionaries for parlai frontend and fairseq backend
            self.dict = _FairseqDictionary(self.opt)

            # We need a placeholder task for fairseq
            self.task = _ParlaiTask(self.dict)

            # actually construct the model and generator
            model_class = models.ARCH_MODEL_REGISTRY[self.args.arch]
            self.model = model_class.build_model(self.args, self.task)
            self.generator = SequenceGenerator(
                [self.model],
                tgt_dict=self.dict,
                beam_size=self.args.beam,
                stop_early=(not self.args.no_early_stop),
                normalize_scores=(not self.args.unnormalized),
                len_penalty=self.args.lenpen,
            )
            # set up the grader and the trainer
            # TODO: maybe support label smoothing here
            self.criterion = CrossEntropyCriterion(self.args, self.task)

            if self.args.fp16:
                self.trainer = fp16_trainer.FP16Trainer(
                    self.args, self.task, self.model, self.criterion
                )
            else:
                # TODO: we might choose to add a --no-fp16 opt in the future to
                # explicitly disable fp16 instead
                if torch.cuda.get_device_capability(0)[0] >= 7:
                    print("Heads up: using --fp16 could be a lot faster!")
                self.trainer = trainer.Trainer(
                    self.args, self.task, self.model, self.criterion
                )

            # if the model already existed, let's preload it and the trainer
            if self.opt.get('model_file') and os.path.isfile(self.opt['model_file']):
                print('Loading existing model params from ' + self.opt['model_file'])
                self.load(self.opt.get('model_file'))

            # move things to the GPU if possible
            if self.use_cuda:
                self.model = self.model.cuda()
                self.generator = self.generator.cuda()

        # Start things off clean
        self.reset()

    def _check_opts_unchanged(self, saved_opts, current_opts):
        """Verify that critical options do not differ in command line vs saved model"""
        for k in NON_OVERRIDABLE_ARGS:
            if k not in saved_opts or k not in current_opts:
                # if it's not an option needed by this fairseq model, don't stress
                continue
            if saved_opts[k] != current_opts[k]:
                raise ValueError(
                    '{} cannot be overridden when --model-file is specified'.format(k)
                )

    def save(self, path):
        """Save using fairseq's checkpointing."""
        if not path:
            return
        self.trainer.save_checkpoint(path, {'opt': self.opt, 'epoch': 0})
        # Parlai expects options to also be saved
        with open(path + ".opt", 'wb') as handle:
            # overridden options shouldn't be stored, only the main ones
            if 'override' in self.opt:
                del self.opt['override']
            pickle.dump(self.opt, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def load(self, path):
        """Load using fairseq's checkpointing."""
        old_options = self.trainer.load_checkpoint(path)
        self._check_opts_unchanged(old_options, self.opt)

    def shutdown(self):
        if not hasattr(self, 'trainer'):
            # looks like this is a "fake" model that isn't actually used for batch_act.
            # we don't need to save this one.
            return
        super().shutdown()

    def reset(self):
        """Reset observation and episode_done."""
        super().reset()
        self.reset_metrics()

    def batch_act(self, observations):
        bsz = len(observations)
        # initialize a table of replies with this agent's id
        batch_reply = [{"id": self.getID()} for _ in range(bsz)]

        # torchagent boilerplate
        self.is_training = any(["labels" in obs for obs in observations])
        vec_obs = [self.vectorize(obs) for obs in observations]
        xs, _, ys, _, valid_inds = self.map_valid(vec_obs)
        if xs is None:
            return batch_reply

        # here begins fairseq specific stuff
        samples = self._make_sample(xs, ys)

        if self.is_training:
            self.model.train()
            self.trainer.train_step(samples)
        else:
            # grade the evaluation label
            self.model.eval()
            if ys is not None:
                # Interactive mode won't have a gold label
                self.trainer.valid_step(samples)

            # Grade each of the candidate sequences
            # TODO: grade everything in observations[i]['label_candidates']

            # Next generate freely to create our response
            if self.args.skip_generation:
                # skip the generation step
                for i in valid_inds:
                    batch_reply[i]["text"] = ""
            else:
                # actually do the generation
                for i, response in zip(valid_inds, self._generate(samples)):
                    batch_reply[i]["text"] = response

        return batch_reply

    def _generate(self, samples):
        src_tokens = samples["net_input"]["src_tokens"]
        src_lengths = samples["net_input"]["src_lengths"]
        gens = self.generator.generate(src_tokens, src_lengths, maxlen=64)
        responses = []
        for i in range(len(src_tokens)):
            beams = gens[i]
            selected = max(beams, key=lambda x: x["score"])
            tokens = selected["tokens"]
            start = 0
            end = -1
            for i, t in enumerate(tokens):
                t = t.item()
                if t == self.dict.bos_index:
                    # don't include <s> token
                    start = i + 1
                    continue
                if t == self.dict.eos_index:
                    # stop (and don't include) </s> token
                    end = i
                    break
            responses.append(self.dict.vec2txt(tokens[start:end]))
        return responses

    def report(self):
        # if we haven't initialized yet, just return a dummy object
        if not hasattr(self, "trainer"):
            return {}

        # These are the metrics we'll pass up the way, and their new names
        train_metrics = {"train_loss", "ups", "wps", "gnorm", "clip"}
        valid_metrics = {"valid_loss"}

        metrics = train_metrics if self.is_training else valid_metrics

        m = {k: self.trainer.meters[k].avg for k in metrics}

        # additionally output perplexity. note that fairseq models use base 2
        # in cross_entropy:
        # github.com/pytorch/fairseq/blob/master/fairseq/criterions/cross_entropy.py#L55
        if "train_loss" in m:
            m["train_ppl"] = np.exp2(m["train_loss"])
        if "valid_loss" in m:
            m["ppl"] = np.exp2(m["valid_loss"])

        for k, v in m.items():
            # clean up: rounds to sigfigs and converts tensors to floats
            m[k] = round_sigfigs(v, 4)

        return m

    def reset_metrics(self):
        if not hasattr(self, "trainer"):
            # We haven't initialized the trainer yet, so we don't have any metrics
            return
        # We need to reset everything
        for k in self.trainer.meters:
            self.trainer.meters[k].reset()

    def receive_metrics(self, metrics_dict):
        """Used to update lr scheduler."""
        self.trainer.lr_step(-1, metrics_dict["valid_loss"])

    # Helper functions
    def _seq_length(self, xs):
        """Computes length of the sequence (non-padded size)"""
        return xs.ne(self.dict.pad_index).long().sum(dim=-1)

    def _right_shifted_ys(self, ys):
        """Replaces first token with EOS and shifts the remaining tokens right one."""
        result = torch.LongTensor(ys.size())
        result[:, 0] = self.dict.eos_index
        result[:, 1:] = ys[:, :-1]
        return result

    def _make_sample(self, xs, ys):
        """Generates a sample object that Fairseq expects."""
        # add extra info to samples
        # TODO: should the right/left padding thing be in torch agent?
        repadded = convert_padding_direction(xs, self.dict.pad(), right_to_left=True)
        sample = {}
        sample["net_input"] = {
            "src_tokens": repadded,
            "src_lengths": self._seq_length(xs),
        }
        if ys is not None:
            sample["target"] = ys
            sample["ntokens"] = sum(self._seq_length(ys)).item()
            sample["net_input"]["prev_output_tokens"] = self._right_shifted_ys(ys)
        return sample
Beispiel #37
0
class FairseqAgent(TorchAgent):
    """Generic wrapper around fairseq for use in ParlAI"""

    DEFAULT_OPTIONS = {
        "adam_betas": "(0.9,0.98)",
        "optimizer": "adam",
        "clip_norm": 0.1,
    }

    metrics = {}

    @classmethod
    def add_cmdline_args(cls, argparser):
        """Add command-line arguments specifically for this agent."""
        # first we need to add the general torch agent operations
        TorchAgent.add_cmdline_args(argparser)

        agent = argparser.add_argument_group('Fairseq Arguments')
        agent.add_argument('--fp16',
                           default=False,
                           type=bool,
                           help='Use fp16 training')
        agent.add_argument('--seed',
                           default=1,
                           type=int,
                           metavar='N',
                           help='pseudo random number generator seed')
        agent.add_argument(
            '--skip-generation',
            default=False,
            type=bool,
            metavar='BOOL',
            help=
            'Skips test time beam search. Much faster if you only need PPL',
        )

        # Dictionary construction stuff. Using the subclass in case we end up
        # needing any fairseq specific things
        cls.dictionary_class().add_cmdline_args(argparser)

        # Check subargs for generation, optimizers, criterions, archs, etc
        options.add_generation_args(argparser)
        options.add_optimization_args(argparser)

        # make sure we set defaults according to the model before parsing
        argparser.set_defaults(**cls.DEFAULT_OPTIONS)
        known_args = argparser.parse_known_args(nohelp=True)[0]

        if hasattr(known_args, "optimizer"):
            optimizer = known_args.optimizer
            opt_group = argparser.add_argument_group(
                '{} optimizer arguments'.format(optimizer))
            optim.OPTIMIZER_REGISTRY[optimizer].add_args(opt_group)
        if hasattr(known_args, "lr_scheduler"):
            lr_scheduler = known_args.lr_scheduler
            lr_group = argparser.add_argument_group(
                '{} scheduler arguments'.format(lr_scheduler))
            optim.lr_scheduler.LR_SCHEDULER_REGISTRY[lr_scheduler].add_args(
                lr_group)
        # We need to find out the fairseq model-specific options, so grab the
        # architecture stuff and look up its options
        arch_group = options.add_model_args(argparser)
        # Fairseq marks the arch flag as required, but it may be specified
        # by a saved model cache, so we do some weird stuff to undo that
        for a in arch_group._actions:
            if a.dest == "arch":
                a.required = False
                a.default = None
                break

        # make sure we set defaults according to parlai model before parsing
        argparser.set_defaults(**cls.DEFAULT_OPTIONS)
        known_args = argparser.parse_known_args(nohelp=True)[0]

        if hasattr(known_args, "arch") and known_args.arch is not None:
            arch = known_args.arch
            arch_group = argparser.add_argument_group(
                "{} architecture arguments".format(arch))
            models.ARCH_MODEL_REGISTRY[arch].add_args(arch_group)

        if hasattr(known_args, "criterion"):
            crit_group = argparser.add_argument_group(
                '{} criterion arguments'.format(known_args.criterion))
            criterions.CRITERION_REGISTRY[known_args.criterion].add_args(
                crit_group)

        # As one final check, let's make sure we set defaults correctly
        argparser.set_defaults(**cls.DEFAULT_OPTIONS)

    @staticmethod
    def dictionary_class():
        # Force use of the Fairseq Dictionary
        return _FairseqDictionary

    def __init__(self, opt, shared=None):
        # In general use a basic TorchAgent wherever possible
        super().__init__(opt, shared)
        if not shared:
            # this is not a shared instance of this class, so do full initialization

            # check early if we're going to be loading the model from a checkpoint
            model_file_exists = (self.opt.get('model_file')
                                 and os.path.isfile(self.opt['model_file']))

            # fairseq expects options to be in argparse format, instead of a dict
            # We also need to do some argument postprocessing and whatnot
            # We'll skip pretrained embeddings if we're going to override them with
            # a model checkpoint anyway
            self.args, self.opt = _fairseq_opt_wrapper(opt, model_file_exists)

            # seed the RNG
            torch.manual_seed(self.args.seed)

            # Just some identifying info
            self.id = "fairseq:{}".format(self.args.arch)

            # We need a placeholder task for fairseq
            self.task = _ParlaiTask(self.dict)

            # actually construct the model and generator
            self.model = self.build_model()

            # Construct the generator and scorer
            self.generator = SequenceGenerator(
                [self.model],
                tgt_dict=self.dict,
                beam_size=self.args.beam,
                stop_early=(not self.args.no_early_stop),
                normalize_scores=(not self.args.unnormalized),
                len_penalty=self.args.lenpen,
                unk_penalty=self.args.unkpen,
                sampling=self.args.sampling,
                sampling_topk=self.args.sampling_topk,
                sampling_temperature=self.args.sampling_temperature,
            )
            self.scorer = SequenceScorer([self.model], self.dict)

            # set up the grader and the trainer
            self.criterion = criterions.build_criterion(self.args, self.task)

            if getattr(self.args, 'fp16', None):
                self.trainer = fp16_trainer.FP16Trainer(
                    self.args, self.task, self.model, self.criterion)
            else:
                # TODO: we might choose to add a --no-fp16 opt in the future to
                # explicitly disable fp16 instead
                if torch.cuda.get_device_capability(0)[0] >= 7:
                    print("Heads up: using --fp16 could be a lot faster!")
                self.trainer = trainer.Trainer(self.args, self.task,
                                               self.model, self.criterion)

            # if the model already existed, let's preload it and the trainer
            if model_file_exists:
                print('Loading existing model params from ' +
                      self.opt['model_file'])
                self.load(self.opt.get('model_file'))

            # move things to the GPU if possible
            if self.use_cuda:
                self.model = self.model.cuda()
                self.generator = self.generator.cuda()
        else:
            self.model = shared['model']
            self.trainer = shared['trainer']
            self.generator = shared['generator']
            self.dict = shared['dict']
            self.args = shared['args']

        # Start things off clean
        self.reset()

    def _check_opts_unchanged(self, saved_opts, current_opts):
        """Verify that critical options do not differ in command line vs saved model"""
        for k in NON_OVERRIDABLE_ARGS:
            if k not in saved_opts or k not in current_opts:
                # if it's not an option needed by this fairseq model, don't stress
                continue
            if saved_opts[k] != current_opts[k]:
                raise ValueError(
                    '{} cannot be overridden when --model-file is specified'.
                    format(k))

    def build_model(self):
        """
        Construct the actual Fairseq model. Default implementation is to use
        Fairseq's arch builder, but this method may be overridden to build custom
        models.
        """
        model_class = models.ARCH_MODEL_REGISTRY[self.args.arch]
        return model_class.build_model(self.args, self.task)

    def share(self):
        shared = super().share()
        shared['model'] = self.model
        shared['trainer'] = self.trainer
        shared['generator'] = self.generator
        shared['dict'] = self.dict
        shared['args'] = self.args
        return shared

    def save(self, path):
        """Save using fairseq's checkpointing."""
        if not path:
            return
        self.trainer.save_checkpoint(path, {'opt': self.opt, 'epoch': 0})
        # Parlai expects options to also be saved
        with open(path + ".opt", 'wb') as handle:
            # overridden options shouldn't be stored, only the main ones
            if 'override' in self.opt:
                del self.opt['override']
            pickle.dump(self.opt, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def load(self, path):
        """Load using fairseq's checkpointing."""
        old_options = self.trainer.load_checkpoint(path)
        self._check_opts_unchanged(old_options, self.opt)

    def shutdown(self):
        if not hasattr(self, 'trainer'):
            # looks like this is a "fake" model that isn't actually used for batch_act.
            # we don't need to save this one.
            return
        super().shutdown()

    def reset(self):
        """Reset observation and episode_done."""
        super().reset()
        self.reset_metrics()

    def batchify(self, obs_batch):
        """
        Override parent batchify to set requirements for fairseq.

        Fairseq depends on sorted batch inputs for a call to rnn.pad_packed_sequence.
        Fairseq models cannot handle zero length sentences
        """
        return super().batchify(obs_batch,
                                sort=True,
                                is_valid=_is_nonempty_observation)

    def train_step(self, batch):
        """Process batch of inputs and targets and train on them.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = True
        samples = self._make_sample(batch.text_vec, batch.label_vec)
        self.model.train()
        self.trainer.train_step(samples)

    def eval_step(self, batch):
        """Process batch of inputs.

        If the batch includes labels, calculate validation metrics as well.
        If --skip-generation is not set, return a prediction for each input.

        :param batch: parlai.core.torch_agent.Batch, contains tensorized
                      version of observations.
        """
        if batch.text_vec is None:
            return
        self.is_training = False
        samples = self._make_sample(batch.text_vec, batch.label_vec)
        self.model.eval()
        if batch.label_vec is not None:
            # Interactive mode won't have a gold label
            self.trainer.valid_step(samples)

        # Output placeholders
        reranked_cands = None
        generated_output = None

        # Grade each of the candidate sequences
        if batch.candidate_vecs is not None:
            bsz = len(batch.text_vec)
            reranked_cands = []
            # score the candidates for each item in the batch separately, so that
            # we can support variable number of candidates
            for i in range(bsz):
                cands = batch.candidate_vecs[i]
                if not cands:
                    reranked_cands.append(None)
                    continue
                ncand = len(cands)
                # repeat the input many times
                xs = batch.text_vec[i].unsqueeze(0).expand(ncand, -1)
                # some models crash if there's leading padding on every example
                xs = xs[:, :batch.text_lengths[i]]
                # and appropriately pack the outputs
                ys, _ = padded_tensor(cands, self.NULL_IDX, self.use_cuda)
                s = self._make_sample(xs, ys)
                # perform the actual grading, extract the scores
                scored = list(
                    self.scorer.score_batched_itr([s], cuda=self.use_cuda))
                scores = [s[3][0]['score'].item() for s in scored]
                # intentional hanging comma here; argsort returns a list
                ranked, = argsort(scores, batch.candidates[i], descending=True)
                reranked_cands.append(ranked)

        # Next generate freely to create our response
        if not self.args.skip_generation:
            generated_output = self._generate(samples)
        elif reranked_cands:
            # we're skiping generation, but we're also grading candidates
            # so output the highest ranked candidate
            # In the case of zero candidates, we don't have something to rank,
            # so we may need to pass on that None
            generated_output = [
                ranked and ranked[0] or None for ranked in reranked_cands
            ]
        else:
            # no output at all
            pass

        return Output(generated_output, reranked_cands)

    def _generate(self, samples):
        src_tokens = samples["net_input"]["src_tokens"]
        src_lengths = samples["net_input"]["src_lengths"]
        gens = self.generator.generate(src_tokens, src_lengths, maxlen=64)
        responses = []
        for i in range(len(src_tokens)):
            beams = gens[i]
            selected = max(beams, key=lambda x: x["score"])
            tokens = selected["tokens"]
            start = 0
            end = -1
            for i, t in enumerate(tokens):
                t = t.item()
                if t == self.dict.bos_index:
                    # don't include <s> token
                    start = i + 1
                    continue
                if t == self.dict.eos_index:
                    # stop (and don't include) </s> token
                    end = i
                    break
            responses.append(self.dict.vec2txt(tokens[start:end]))
        return responses

    def report(self):
        """Return metrics calculated by the model."""
        # if we haven't initialized yet, just return a dummy object
        if not hasattr(self, "trainer"):
            return {}

        # These are the metrics we'll pass up the way, and their new names
        train_metrics = {"train_loss", "ups", "wps", "gnorm", "clip"}
        valid_metrics = {"valid_loss"}

        metrics = train_metrics if self.is_training else valid_metrics

        m = {k: self.trainer.meters[k].avg for k in metrics}

        # additionally output perplexity. note that fairseq models use base 2
        # in cross_entropy:
        # github.com/pytorch/fairseq/blob/master/fairseq/criterions/cross_entropy.py#L55
        if "train_loss" in m:
            m["train_ppl"] = np.exp2(m["train_loss"])
        if "valid_loss" in m:
            m["ppl"] = np.exp2(m["valid_loss"])

        for k, v in m.items():
            # clean up: rounds to sigfigs and converts tensors to floats
            m[k] = round_sigfigs(v, 4)

        return m

    def reset_metrics(self):
        """Reset metrics calculated by the model back to zero."""
        if not hasattr(self, "trainer"):
            # We haven't set up the trainer yet, so we don't have any metrics
            return
        # We need to reset everything
        for k in self.trainer.meters:
            self.trainer.meters[k].reset()

    def receive_metrics(self, metrics_dict):
        """Update lr scheduler with validation loss."""
        self.trainer.lr_step(-1, metrics_dict["valid_loss"])

    # Helper functions
    def _seq_length(self, xs):
        """Compute length of the sequence (non-padded size)."""
        return xs.ne(self.dict.pad_index).long().sum(dim=-1)

    def _right_shifted_ys(self, ys):
        """Replace first token with EOS and shift remaining tokens right 1."""
        result = torch.LongTensor(ys.size())
        result[:, 0] = self.dict.eos_index
        result[:, 1:] = ys[:, :-1]
        return result

    def _make_sample(self, xs, ys):
        """Generate a sample object that Fairseq expects."""
        # add extra info to samples
        # TODO: should the right/left padding thing be in torch agent?
        sample = {}
        sample["id"] = torch.arange(len(xs) - 1)
        sample["net_input"] = {
            "src_tokens": xs,
            "src_lengths": self._seq_length(xs),
        }
        if ys is not None:
            sample["target"] = ys
            sample["ntokens"] = sum(self._seq_length(ys)).item()
            sample["net_input"]["prev_output_tokens"] = self._right_shifted_ys(
                ys)
        return sample