def build_generator(self, args): # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, 'sampling', False) sampling_topk = getattr(args, 'sampling_topk', -1) sampling_topp = getattr(args, 'sampling_topp', -1.0) diverse_beam_groups = getattr(args, 'diverse_beam_groups', -1) diverse_beam_strength = getattr(args, 'diverse_beam_strength', 0.5), match_source_len = getattr(args, 'match_source_len', False) diversity_rate = getattr(args, 'diversity_rate', -1) if ( sum( int(cond) for cond in [ sampling, diverse_beam_groups > 0, match_source_len, diversity_rate > 0, ] ) > 1 ): raise ValueError('Provided Search parameters are mutually exclusive.') assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling' if sampling: search_strategy = search.Sampling(self.target_dictionary, sampling_topk, sampling_topp) elif diverse_beam_groups > 0: search_strategy = search.DiverseBeamSearch( self.target_dictionary, diverse_beam_groups, diverse_beam_strength) elif match_source_len: # this is useful for tagging applications where the output # length should match the input length, so we hardcode the # length constraints for simplicity search_strategy = search.LengthConstrainedBeamSearch( self.target_dictionary, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) elif diversity_rate > -1: search_strategy = search.DiverseSiblingsSearch(self.target_dictionary, diversity_rate) else: search_strategy = search.BeamSearch(self.target_dictionary) if args.context_type == 'src': seq_cls = AudioContextAwareSequenceGenerator else: seq_cls = TargetContextAwareSequenceGenerator return seq_cls( self.target_dictionary, beam_size=getattr(args, 'beam', 5), max_len_a=getattr(args, 'max_len_a', 0), max_len_b=getattr(args, 'max_len_b', 200), min_len=getattr(args, 'min_len', 1), normalize_scores=(not getattr(args, 'unnormalized', False)), len_penalty=getattr(args, 'lenpen', 1), unk_penalty=getattr(args, 'unkpen', 0), temperature=getattr(args, 'temperature', 1.), match_source_len=getattr(args, 'match_source_len', False), no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), search_strategy=search_strategy, )
def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if args.encoder_layers_to_keep: args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) if args.decoder_layers_to_keep: args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) if getattr(args, "max_source_positions", None) is None: args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if getattr(args, "max_target_positions", None) is None: args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS src_dict, tgt_dict = task.source_dictionary, task.target_dictionary generator = None if args.use_sentence_level_oracles: from fairseq.sequence_generator import SequenceGenerator import fairseq.search as search search_strategy = search.LengthConstrainedBeamSearch(tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) generator = SequenceGenerator(tgt_dict, beam_size=args.oracle_search_beam_size, match_source_len=False, max_len_a=1, max_len_b=100, search_strategy=search_strategy) if args.share_all_embeddings: if src_dict != tgt_dict: raise ValueError("--share-all-embeddings requires a joined dictionary") if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" ) if args.decoder_embed_path and ( args.decoder_embed_path != args.encoder_embed_path ): raise ValueError( "--share-all-embeddings not compatible with --decoder-embed-path" ) encoder_embed_tokens = cls.build_embedding( args, src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: encoder_embed_tokens = cls.build_embedding( args, src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = cls.build_embedding( args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path ) encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) return cls(args, encoder, decoder, generator)
def build_generator(task, models, args): if getattr(args, "score_reference", False): from fairseq.sequence_scorer import SequenceScorer return SequenceScorer( task.target_dictionary, compute_alignment=getattr(args, "print_alignment", False), ) # from fairseq.sequence_generator import ( # SequenceGenerator, # SequenceGeneratorWithAlignment, # ) # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, "sampling", False) sampling_topk = getattr(args, "sampling_topk", -1) sampling_topp = getattr(args, "sampling_topp", -1.0) diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) match_source_len = getattr(args, "match_source_len", False) diversity_rate = getattr(args, "diversity_rate", -1) if (sum( int(cond) for cond in [ sampling, diverse_beam_groups > 0, match_source_len, diversity_rate > 0, ]) > 1): raise ValueError("Provided Search parameters are mutually exclusive.") assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" if sampling: search_strategy = search.Sampling(task.target_dictionary, sampling_topk, sampling_topp) elif diverse_beam_groups > 0: search_strategy = search.DiverseBeamSearch(task.target_dictionary, diverse_beam_groups, diverse_beam_strength) elif match_source_len: # this is useful for tagging applications where the output # length should match the input length, so we hardcode the # length constraints for simplicity search_strategy = search.LengthConstrainedBeamSearch( task.target_dictionary, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) elif diversity_rate > -1: search_strategy = search.DiverseSiblingsSearch(task.target_dictionary, diversity_rate) else: search_strategy = BeamSearch(task.target_dictionary, getattr(args, 'mc', None)) if getattr(args, "print_alignment", False): seq_gen_cls = SequenceGeneratorWithAlignment else: seq_gen_cls = SequenceGenerator return seq_gen_cls( models, task.target_dictionary, beam_size=getattr(args, "beam", 5), max_len_a=getattr(args, "max_len_a", 0), max_len_b=getattr(args, "max_len_b", 200), min_len=getattr(args, "min_len", 1), normalize_scores=(not getattr(args, "unnormalized", False)), len_penalty=getattr(args, "lenpen", 1), unk_penalty=getattr(args, "unkpen", 0), temperature=getattr(args, "temperature", 1.0), match_source_len=getattr(args, "match_source_len", False), no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), search_strategy=search_strategy, )
def __init__( self, tgt_dict, beam_size=1, max_len_a=0, max_len_b=200, min_len=1, stop_early=True, normalize_scores=True, len_penalty=1., unk_penalty=0., retain_dropout=False, sampling=False, sampling_topk=-1, sampling_temperature=1., diverse_beam_groups=-1, diverse_beam_strength=0.5, match_source_len=False, no_repeat_ngram_size=0, ): """Generates translations of a given source sentence. Args: tgt_dict (~fairseq.data.Dictionary): target dictionary beam_size (int, optional): beam width (default: 1) max_len_a/b (int, optional): generate sequences of maximum length ax + b, where x is the source length min_len (int, optional): the minimum length of the generated output (not including end-of-sentence) stop_early (bool, optional): stop generation immediately after we finalize beam_size hypotheses, even though longer hypotheses might have better normalized scores (default: True) normalize_scores (bool, optional): normalize scores by the length of the output (default: True) len_penalty (float, optional): length penalty, where <1.0 favors shorter, >1.0 favors longer sentences (default: 1.0) unk_penalty (float, optional): unknown word penalty, where <0 produces more unks, >0 produces fewer (default: 0.0) retain_dropout (bool, optional): use dropout when generating (default: False) sampling (bool, optional): sample outputs instead of beam search (default: False) sampling_topk (int, optional): only sample among the top-k choices at each step (default: -1) sampling_temperature (float, optional): temperature for sampling, where values >1.0 produces more uniform sampling and values <1.0 produces sharper sampling (default: 1.0) diverse_beam_groups/strength (float, optional): parameters for Diverse Beam Search sampling match_source_len (bool, optional): outputs should match the source length (default: False) """ self.pad = tgt_dict.pad() self.unk = tgt_dict.unk() self.eos = tgt_dict.eos() self.vocab_size = len(tgt_dict) self.beam_size = beam_size # the max beam size is the dictionary size - 1, since we never select pad self.beam_size = min(beam_size, self.vocab_size - 1) self.max_len_a = max_len_a self.max_len_b = max_len_b self.min_len = min_len self.stop_early = stop_early self.normalize_scores = normalize_scores self.len_penalty = len_penalty self.unk_penalty = unk_penalty self.retain_dropout = retain_dropout self.match_source_len = match_source_len self.no_repeat_ngram_size = no_repeat_ngram_size assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' if sampling: self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature) elif diverse_beam_groups > 0: self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength) elif match_source_len: self.search = search.LengthConstrainedBeamSearch( tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) else: self.search = search.BeamSearch(tgt_dict)
def build_generator(self, models, args): from examples.waitk.generators.waitk_sequence_generator import WaitkSequenceGenerator # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, "sampling", False) sampling_topk = getattr(args, "sampling_topk", -1) sampling_topp = getattr(args, "sampling_topp", -1.0) diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) match_source_len = getattr(args, "match_source_len", False) diversity_rate = getattr(args, "diversity_rate", -1) if (sum( int(cond) for cond in [ sampling, diverse_beam_groups > 0, match_source_len, diversity_rate > 0, ]) > 1): raise ValueError( "Provided Search parameters are mutually exclusive.") assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" if sampling: search_strategy = search.Sampling(self.target_dictionary, sampling_topk, sampling_topp) elif diverse_beam_groups > 0: search_strategy = search.DiverseBeamSearch(self.target_dictionary, diverse_beam_groups, diverse_beam_strength) elif match_source_len: # this is useful for tagging applications where the output # length should match the input length, so we hardcode the # length constraints for simplicity search_strategy = search.LengthConstrainedBeamSearch( self.target_dictionary, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) elif diversity_rate > -1: search_strategy = search.DiverseSiblingsSearch( self.target_dictionary, diversity_rate) else: search_strategy = search.BeamSearch(self.target_dictionary) return WaitkSequenceGenerator( models, self.target_dictionary, beam_size=getattr(args, "beam", 5), max_len_a=getattr(args, "max_len_a", 0), max_len_b=getattr(args, "max_len_b", 200), min_len=getattr(args, "min_len", 1), normalize_scores=(not getattr(args, "unnormalized", False)), len_penalty=getattr(args, "lenpen", 1), unk_penalty=getattr(args, "unkpen", 0), temperature=getattr(args, "temperature", 1.0), match_source_len=getattr(args, "match_source_len", False), no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), search_strategy=search_strategy, waitk=args.eval_waitk)
def build_generator( self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None, ): """ Build a :class:`~fairseq.SequenceGenerator` instance for this task. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models args (fairseq.dataclass.configs.GenerationConfig): configuration object (dataclass) for generation extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass through to SequenceGenerator prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): If provided, this function constrains the beam search to allowed tokens only at each step. The provided function should take 2 arguments: the batch ID (`batch_id: int`) and a unidimensional tensor of token ids (`inputs_ids: torch.Tensor`). It has to return a `List[int]` with the allowed tokens for the next generation step conditioned on the previously generated tokens (`inputs_ids`) and the batch ID (`batch_id`). This argument is useful for constrained generation conditioned on the prefix, as described in "Autoregressive Entity Retrieval" (https://arxiv.org/abs/2010.00904) and https://github.com/facebookresearch/GENRE. """ if getattr(args, "score_reference", False): from fairseq.sequence_scorer import SequenceScorer return SequenceScorer( self.target_dictionary, compute_alignment=getattr(args, "print_alignment", False), ) from fairseq.sequence_generator import ( SequenceGenerator, SequenceGeneratorWithAlignment, ) # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, "sampling", False) sampling_topk = getattr(args, "sampling_topk", -1) sampling_topp = getattr(args, "sampling_topp", -1.0) diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) match_source_len = getattr(args, "match_source_len", False) diversity_rate = getattr(args, "diversity_rate", -1) constrained = getattr(args, "constraints", False) if prefix_allowed_tokens_fn is None: prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) if ( sum( int(cond) for cond in [ sampling, diverse_beam_groups > 0, match_source_len, diversity_rate > 0, ] ) > 1 ): raise ValueError("Provided Search parameters are mutually exclusive.") assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" if sampling: search_strategy = search.Sampling( self.target_dictionary, sampling_topk, sampling_topp ) elif diverse_beam_groups > 0: search_strategy = search.DiverseBeamSearch( self.target_dictionary, diverse_beam_groups, diverse_beam_strength ) elif match_source_len: # this is useful for tagging applications where the output # length should match the input length, so we hardcode the # length constraints for simplicity search_strategy = search.LengthConstrainedBeamSearch( self.target_dictionary, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) elif diversity_rate > -1: search_strategy = search.DiverseSiblingsSearch( self.target_dictionary, diversity_rate ) elif constrained: search_strategy = search.LexicallyConstrainedBeamSearch( self.target_dictionary, args.constraints ) elif prefix_allowed_tokens_fn: search_strategy = search.PrefixConstrainedBeamSearch( self.target_dictionary, prefix_allowed_tokens_fn ) else: search_strategy = search.BeamSearch(self.target_dictionary) extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} if seq_gen_cls is None: if getattr(args, "print_alignment", False): seq_gen_cls = SequenceGeneratorWithAlignment extra_gen_cls_kwargs["print_alignment"] = args.print_alignment else: seq_gen_cls = SequenceGenerator return seq_gen_cls( models, self.target_dictionary, beam_size=getattr(args, "beam", 5), max_len_a=getattr(args, "max_len_a", 0), max_len_b=getattr(args, "max_len_b", 200), min_len=getattr(args, "min_len", 1), normalize_scores=(not getattr(args, "unnormalized", False)), len_penalty=getattr(args, "lenpen", 1), unk_penalty=getattr(args, "unkpen", 0), temperature=getattr(args, "temperature", 1.0), match_source_len=getattr(args, "match_source_len", False), no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), search_strategy=search_strategy, **extra_gen_cls_kwargs, )
def build_generator(self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None): if getattr(args, "score_reference", False): args.score_reference = False logger.warning( "--score-reference is not applicable to speech recognition, ignoring it." ) from fairseq.sequence_generator import SequenceGenerator # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, "sampling", False) sampling_topk = getattr(args, "sampling_topk", -1) sampling_topp = getattr(args, "sampling_topp", -1.0) diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) match_source_len = getattr(args, "match_source_len", False) diversity_rate = getattr(args, "diversity_rate", -1) if (sum( int(cond) for cond in [ sampling, diverse_beam_groups > 0, match_source_len, diversity_rate > 0, ]) > 1): raise ValueError( "Provided Search parameters are mutually exclusive.") assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" if sampling: search_strategy = search.Sampling(self.target_dictionary, sampling_topk, sampling_topp) elif diverse_beam_groups > 0: search_strategy = search.DiverseBeamSearch(self.target_dictionary, diverse_beam_groups, diverse_beam_strength) elif match_source_len: # this is useful for tagging applications where the output # length should match the input length, so we hardcode the # length constraints for simplicity search_strategy = search.LengthConstrainedBeamSearch( self.target_dictionary, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) elif diversity_rate > -1: search_strategy = search.DiverseSiblingsSearch( self.target_dictionary, diversity_rate) else: search_strategy = search.BeamSearch(self.target_dictionary) if seq_gen_cls is None: seq_gen_cls = SequenceGenerator extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} extra_gen_cls_kwargs["lm_weight"] = getattr(args, "lm_weight", 0.0) extra_gen_cls_kwargs["eos_factor"] = getattr(args, "eos_factor", None) return seq_gen_cls( models, self.target_dictionary, beam_size=getattr(args, "beam", 5), max_len_a=getattr(args, "max_len_a", 0), max_len_b=getattr(args, "max_len_b", 200), min_len=getattr(args, "min_len", 1), normalize_scores=(not getattr(args, "unnormalized", False)), len_penalty=getattr(args, "lenpen", 1), unk_penalty=getattr(args, "unkpen", 0), temperature=getattr(args, "temperature", 1.), match_source_len=getattr(args, "match_source_len", False), no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), search_strategy=search_strategy, **extra_gen_cls_kwargs, )
def __init__(self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True, normalize_scores=True, len_penalty=1., unk_penalty=0., retain_dropout=False, sampling=False, sampling_topk=-1, sampling_temperature=1., diverse_beam_groups=-1, diverse_beam_strength=0.5, match_source_len=False, no_repeat_ngram_size=0): """Generates translations of a given source sentence. Args: beam_size (int, optional): beam width (default: 1) min/maxlen (int, optional): the length of the generated output will be bounded by minlen and maxlen (not including end-of-sentence) stop_early (bool, optional): stop generation immediately after we finalize beam_size hypotheses, even though longer hypotheses might have better normalized scores (default: True) normalize_scores (bool, optional): normalize scores by the length of the output (default: True) len_penalty (float, optional): length penalty, where <1.0 favors shorter, >1.0 favors longer sentences (default: 1.0) unk_penalty (float, optional): unknown word penalty, where <0 produces more unks, >0 produces fewer (default: 0.0) retain_dropout (bool, optional): use dropout when generating (default: False) sampling (bool, optional): sample outputs instead of beam search (default: False) sampling_topk (int, optional): only sample among the top-k choices at each step (default: -1) sampling_temperature (float, optional): temperature for sampling, where values >1.0 produces more uniform sampling and values <1.0 produces sharper sampling (default: 1.0) diverse_beam_groups/strength (float, optional): parameters for Diverse Beam Search sampling match_source_len (bool, optional): outputs should match the source length (default: False) """ self.models = models self.pad = tgt_dict.pad() self.unk = tgt_dict.unk() self.eos = tgt_dict.eos() self.vocab_size = len(tgt_dict) self.beam_size = beam_size self.minlen = minlen max_decoder_len = min(m.max_decoder_positions() for m in self.models) max_decoder_len -= 1 # we define maxlen not including the EOS marker self.maxlen = max_decoder_len if maxlen is None else min( maxlen, max_decoder_len) self.stop_early = stop_early self.normalize_scores = normalize_scores self.len_penalty = len_penalty self.unk_penalty = unk_penalty self.retain_dropout = retain_dropout self.match_source_len = match_source_len self.no_repeat_ngram_size = no_repeat_ngram_size self.js = {} assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' if sampling: self.search = search.Sampling(tgt_dict, sampling_topk, sampling_temperature) elif diverse_beam_groups > 0: self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength) elif match_source_len: self.search = search.LengthConstrainedBeamSearch( tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) else: self.search = search.BeamSearch(tgt_dict)
def build_generator(self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None): if getattr(args, "score_reference", False): from fairseq.sequence_scorer import SequenceScorer return SequenceScorer( self.target_dictionary, compute_alignment=getattr(args, "print_alignment", False), ) from fairseq.sequence_generator import ( SequenceGenerator, SequenceGeneratorWithAlignment, ) try: from .generator import TextRecognitionGenerator except: from generator import TextRecognitionGenerator try: from fairseq.fb_sequence_generator import FBSequenceGenerator except ModuleNotFoundError: pass # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, "sampling", False) sampling_topk = getattr(args, "sampling_topk", -1) sampling_topp = getattr(args, "sampling_topp", -1.0) diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) match_source_len = getattr(args, "match_source_len", False) diversity_rate = getattr(args, "diversity_rate", -1) constrained = getattr(args, "constraints", False) prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) if (sum( int(cond) for cond in [ sampling, diverse_beam_groups > 0, match_source_len, diversity_rate > 0, ]) > 1): raise ValueError( "Provided Search parameters are mutually exclusive.") assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" if sampling: search_strategy = search.Sampling(self.target_dictionary, sampling_topk, sampling_topp) elif diverse_beam_groups > 0: search_strategy = search.DiverseBeamSearch(self.target_dictionary, diverse_beam_groups, diverse_beam_strength) elif match_source_len: # this is useful for tagging applications where the output # length should match the input length, so we hardcode the # length constraints for simplicity search_strategy = search.LengthConstrainedBeamSearch( self.target_dictionary, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) elif diversity_rate > -1: search_strategy = search.DiverseSiblingsSearch( self.target_dictionary, diversity_rate) elif constrained: search_strategy = search.LexicallyConstrainedBeamSearch( self.target_dictionary, args.constraints) elif prefix_allowed_tokens_fn: search_strategy = search.PrefixConstrainedBeamSearch( self.target_dictionary, prefix_allowed_tokens_fn) else: search_strategy = search.BeamSearch(self.target_dictionary) extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} if seq_gen_cls is None: if getattr(args, "print_alignment", False): seq_gen_cls = SequenceGeneratorWithAlignment extra_gen_cls_kwargs["print_alignment"] = args.print_alignment elif getattr(args, "fb_seq_gen", False): seq_gen_cls = FBSequenceGenerator else: seq_gen_cls = TextRecognitionGenerator return seq_gen_cls( models, self.target_dictionary, beam_size=getattr(args, "beam", 5), max_len_a=getattr(args, "max_len_a", 0), max_len_b=getattr(args, "max_len_b", 200), min_len=getattr(args, "min_len", 1), normalize_scores=(not getattr(args, "unnormalized", False)), len_penalty=getattr(args, "lenpen", 1), unk_penalty=getattr(args, "unkpen", 0), temperature=getattr(args, "temperature", 1.0), match_source_len=getattr(args, "match_source_len", False), no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), search_strategy=search_strategy, **extra_gen_cls_kwargs, )
def __init__( self, tgt_dict, beam_size=1, max_len_a=0, max_len_b=200, min_len=1, normalize_scores=True, len_penalty=1., unk_penalty=0., retain_dropout=False, sampling=False, sampling_topk=-1, sampling_topp=-1.0, temperature=1., diverse_beam_groups=-1, diverse_beam_strength=0.5, match_source_len=False, no_repeat_ngram_size=0, coverage_weight=0.0, eos_factor=None, ): """Generates translations of a given source sentence. Args: tgt_dict (~fairseq.data.Dictionary): target dictionary beam_size (int, optional): beam width (default: 1) max_len_a/b (int, optional): generate sequences of maximum length ax + b, where x is the source length min_len (int, optional): the minimum length of the generated output (not including end-of-sentence) normalize_scores (bool, optional): normalize scores by the length of the output (default: True) len_penalty (float, optional): length penalty, where <1.0 favors shorter, >1.0 favors longer sentences (default: 1.0) unk_penalty (float, optional): unknown word penalty, where <0 produces more unks, >0 produces fewer (default: 0.0) retain_dropout (bool, optional): use dropout when generating (default: False) sampling (bool, optional): sample outputs instead of beam search (default: False) sampling_topk (int, optional): only sample among the top-k choices at each step (default: -1) sampling_topp (float, optional): only sample among the smallest set of words whose cumulative probability mass exceeds p at each step (default: -1.0) temperature (float, optional): temperature, where values >1.0 produce more uniform samples and values <1.0 produce sharper samples (default: 1.0) diverse_beam_groups/strength (float, optional): parameters for Diverse Beam Search sampling match_source_len (bool, optional): outputs should match the source length (default: False) """ self.pad = tgt_dict.pad() self.unk = tgt_dict.unk() self.eos = tgt_dict.eos() self.vocab_size = len(tgt_dict) self.beam_size = beam_size # the max beam size is the dictionary size - 1, since we never select pad self.beam_size = min(beam_size, self.vocab_size - 1) self.max_len_a = max_len_a self.max_len_b = max_len_b self.min_len = min_len self.normalize_scores = normalize_scores self.len_penalty = len_penalty self.unk_penalty = unk_penalty self.retain_dropout = retain_dropout self.temperature = temperature self.match_source_len = match_source_len self.no_repeat_ngram_size = no_repeat_ngram_size self.coverage_weight = coverage_weight self.eos_factor = eos_factor assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling' assert temperature > 0, '--temperature must be greater than 0' assert eos_factor is None or eos_factor >= 1.0, '--eos-factor must be >= 1.0 if set' if sampling: self.search = search.Sampling(tgt_dict, sampling_topk, sampling_topp) elif diverse_beam_groups > 0: self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength) elif match_source_len: self.search = search.LengthConstrainedBeamSearch( tgt_dict, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) else: self.search = search.BeamSearch(tgt_dict)
def build_generator(self, args): if getattr(args, 'score_reference', False): from fairseq.sequence_scorer import SequenceScorer return SequenceScorer( self.target_dictionary, compute_alignment=getattr(args, 'print_alignment', False), ) from fairseq.sequence_generator import SequenceGenerator, SequenceGeneratorWithAlignment # Choose search strategy. Defaults to Beam Search. #print("ARGS is",args) sampling = getattr(args, 'sampling', False) #print("sampling is",sampling) sampling_topk = getattr(args, 'sampling_topk', -1) sampling_topp = getattr(args, 'sampling_topp', -1.0) diverse_beam_groups = getattr(args, 'diverse_beam_groups', -1) diverse_beam_strength = getattr(args, 'diverse_beam_strength', 0.5), match_source_len = getattr(args, 'match_source_len', False) diversity_rate = getattr(args, 'diversity_rate', -1) dedup = getattr(args, 'dedup', False) verb_idxs = getattr(args, 'verb_idxs', []) banned_toks = getattr(args, 'banned_toks', []) coef_trainer = getattr(args, 'coef_trainer', None) coefs = getattr(args, 'coefs', []) learn = getattr(args, 'learn', False) learn_every_token = getattr(args, 'learn_every_token', False) if (sum( int(cond) for cond in [ sampling, diverse_beam_groups > 0, match_source_len, diversity_rate > 0, ]) > 1): raise ValueError( 'Provided Search parameters are mutually exclusive.') assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling' if sampling: search_strategy = search.Sampling(self.target_dictionary, sampling_topk, sampling_topp) elif diverse_beam_groups > 0: search_strategy = search.DiverseBeamSearch(self.target_dictionary, diverse_beam_groups, diverse_beam_strength) elif match_source_len: # this is useful for tagging applications where the output # length should match the input length, so we hardcode the # length constraints for simplicity search_strategy = search.LengthConstrainedBeamSearch( self.target_dictionary, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) elif diversity_rate > -1: search_strategy = search.DiverseSiblingsSearch( self.target_dictionary, diversity_rate) else: search_strategy = search.BeamSearch(self.target_dictionary) if getattr(args, 'print_alignment', False): seq_gen_cls = SequenceGeneratorWithAlignment else: seq_gen_cls = SequenceGenerator return seq_gen_cls( self.target_dictionary, beam_size=getattr(args, 'beam', 5), max_len_a=getattr(args, 'max_len_a', 0), max_len_b=getattr(args, 'max_len_b', 200), min_len=getattr(args, 'min_len', 1), normalize_scores=(not getattr(args, 'unnormalized', False)), len_penalty=getattr(args, 'lenpen', 1), unk_penalty=getattr(args, 'unkpen', 0), temperature=getattr(args, 'temperature', 1.), match_source_len=getattr(args, 'match_source_len', False), no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), search_strategy=search_strategy ) #dedup=dedup,verb=verb_idxs,banned_toks=banned_toks,coef_trainer=coef_trainer,coefs=coefs,learn=learn,learn_every_token=learn_every_token)
def build_generator(self, args): if args.score_reference: args.score_reference = False logger.warning( '--score-reference is not applicable to speech recognition, ignoring it.' ) from fairseq.sequence_generator import SequenceGenerator # Choose search strategy. Defaults to Beam Search. sampling = getattr(args, 'sampling', False) sampling_topk = getattr(args, 'sampling_topk', -1) sampling_topp = getattr(args, 'sampling_topp', -1.0) diverse_beam_groups = getattr(args, 'diverse_beam_groups', -1) diverse_beam_strength = getattr(args, 'diverse_beam_strength', 0.5), match_source_len = getattr(args, 'match_source_len', False) diversity_rate = getattr(args, 'diversity_rate', -1) if (sum( int(cond) for cond in [ sampling, diverse_beam_groups > 0, match_source_len, diversity_rate > 0, ]) > 1): raise ValueError( 'Provided Search parameters are mutually exclusive.') assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling' assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling' if sampling: search_strategy = search.Sampling(self.target_dictionary, sampling_topk, sampling_topp) elif diverse_beam_groups > 0: search_strategy = search.DiverseBeamSearch(self.target_dictionary, diverse_beam_groups, diverse_beam_strength) elif match_source_len: # this is useful for tagging applications where the output # length should match the input length, so we hardcode the # length constraints for simplicity search_strategy = search.LengthConstrainedBeamSearch( self.target_dictionary, min_len_a=1, min_len_b=0, max_len_a=1, max_len_b=0, ) elif diversity_rate > -1: search_strategy = search.DiverseSiblingsSearch( self.target_dictionary, diversity_rate) else: search_strategy = search.BeamSearch(self.target_dictionary) return SequenceGenerator( self.target_dictionary, beam_size=getattr(args, 'beam', 5), max_len_a=getattr(args, 'max_len_a', 0), max_len_b=getattr(args, 'max_len_b', 200), min_len=getattr(args, 'min_len', 1), normalize_scores=(not getattr(args, 'unnormalized', False)), len_penalty=getattr(args, 'lenpen', 1), unk_penalty=getattr(args, 'unkpen', 0), temperature=getattr(args, 'temperature', 1.), match_source_len=getattr(args, 'match_source_len', False), no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), search_strategy=search_strategy, eos_factor=getattr(args, 'eos_factor', None), )