Ejemplo n.º 1
0
 def build_generator(self, args):
     from fairseq.iterative_refinement_generator import IterativeRefinementGenerator
     return IterativeRefinementGenerator(
         self.target_dictionary,
         eos_penalty=getattr(args, 'iter_decode_eos_penalty', 0.0),
         max_iter=getattr(args, 'iter_decode_max_iter', 10),
         decoding_format=getattr(args, 'decoding_format', None),
         adaptive=not getattr(args, 'iter_decode_force_max_iter', False))
Ejemplo n.º 2
0
    def build_generator(self, models, args):
        # add models input to match the API for SequenceGenerator
        from fairseq.iterative_refinement_generator import IterativeRefinementGenerator
        gen = IterativeRefinementGenerator(
            self.target_dictionary,
            eos_penalty=getattr(args, 'iter_decode_eos_penalty', 0.0),
            max_iter=getattr(args, 'iter_decode_max_iter', 10),
            beam_size=getattr(args, 'iter_decode_with_beam', 1),
            reranking=getattr(args, 'iter_decode_with_external_reranker',
                              False),
            decoding_format=getattr(args, 'decoding_format', None),
            adaptive=not getattr(args, 'iter_decode_force_max_iter', False),
            retain_history=getattr(args, 'retain_iter_history', False))

        if self.args.use_lang_token:
            gen.eos = self.target_dictionary.index('[{}]'.format(
                self.args.target_lang))
        return gen
Ejemplo n.º 3
0
 def build_generator(self, models, args):
     from fairseq.iterative_refinement_generator import IterativeRefinementGenerator
     return IterativeRefinementGenerator(
         self.target_dictionary,
         eos_penalty=getattr(args, 'iter_decode_eos_penalty', 0.0),
         max_iter=getattr(args, 'iter_decode_max_iter', 10),
         beam_size=getattr(args, 'iter_decode_with_beam', 1),
         reranking=getattr(args, 'iter_decode_with_external_reranker', False),
         decoding_format=getattr(args, 'decoding_format', None),
         adaptive=not getattr(args, 'iter_decode_force_max_iter', False),
         retain_history=getattr(args, 'retain_iter_history', False))
Ejemplo n.º 4
0
    def build_generator(self, models, args, **unused):
        # add models input to match the API for SequenceGenerator
        from fairseq.iterative_refinement_generator import IterativeRefinementGenerator

        return IterativeRefinementGenerator(
            self.target_dictionary,
            eos_penalty=getattr(args, "iter_decode_eos_penalty", 0.0),
            max_iter=getattr(args, "iter_decode_max_iter", 10),
            beam_size=getattr(args, "iter_decode_with_beam", 1),
            reranking=getattr(args, "iter_decode_with_external_reranker",
                              False),
            decoding_format=getattr(args, "decoding_format", None),
            adaptive=not getattr(args, "iter_decode_force_max_iter", False),
            retain_history=getattr(args, "retain_iter_history", False),
        )
Ejemplo n.º 5
0
 def build_generator(self, models, args):
     # add models input to match the API for SequenceGenerator
     from fairseq.iterative_refinement_generator import IterativeRefinementGenerator
     init_tokens = None if args.init_output == 'blank' else args.init_output
     return IterativeRefinementGenerator(
         self.target_dictionary,
         eos_penalty=getattr(args, 'iter_decode_eos_penalty', 0.0),
         max_iter=getattr(args, 'iter_decode_max_iter', 10),
         beam_size=getattr(args, 'iter_decode_with_beam', 1),
         reranking=getattr(args, 'iter_decode_with_external_reranker', False),
         decoding_format=getattr(args, 'decoding_format', None),
         adaptive=not getattr(args, 'iter_decode_force_max_iter', False),
         retain_history=getattr(args, 'retain_iter_history', False),
         init_tokens=init_tokens
     )