def create_decoder(): """Creates the ``Decoder`` instance. This specifies the search strategy used to traverse the space spanned by the predictors. This method relies on the global ``args`` variable. TODO: Refactor to avoid long argument lists Returns: Decoder. Instance of the search strategy """ # Create decoder instance and add predictors if args.decoder == "greedy": decoder = GreedyDecoder(args) elif args.decoder == "beam": decoder = BeamDecoder(args, args.hypo_recombination, args.beam, args.pure_heuristic_scores, args.decoder_diversity_factor, args.early_stopping) elif args.decoder == "multisegbeam": decoder = MultisegBeamDecoder(args, args.hypo_recombination, args.beam, args.multiseg_tokenizations, args.early_stopping, args.max_word_len) elif args.decoder == "syncbeam": decoder = SyncBeamDecoder(args, args.hypo_recombination, args.beam, args.pure_heuristic_scores, args.decoder_diversity_factor, args.early_stopping, args.sync_symbol, args.max_word_len) elif args.decoder == "dfs": decoder = DFSDecoder(args, args.early_stopping, args.max_node_expansions) elif args.decoder == "restarting": decoder = RestartingDecoder(args, args.hypo_recombination, args.max_node_expansions, args.low_decoder_memory, args.restarting_node_score, args.stochastic_decoder, args.decode_always_single_step) elif args.decoder == "bow": decoder = BOWDecoder(args, args.hypo_recombination, args.max_node_expansions, args.stochastic_decoder, args.early_stopping, args.decode_always_single_step) elif args.decoder == "flip": decoder = FlipDecoder(args, args.trg_test, args.max_node_expansions, args.early_stopping, args.flip_strategy) elif args.decoder == "bigramgreedy": decoder = BigramGreedyDecoder(args, args.trg_test, args.max_node_expansions, args.early_stopping) elif args.decoder == "bucket": decoder = BucketDecoder( args, args.hypo_recombination, args.max_node_expansions, args.low_decoder_memory, args.beam, args.pure_heuristic_scores, args.decoder_diversity_factor, args.early_stopping, args.stochastic_decoder, args.bucket_selector, args.bucket_score_strategy, args.collect_statistics) elif args.decoder == "astar": decoder = AstarDecoder(args, args.beam, args.pure_heuristic_scores, args.early_stopping, max(1, args.nbest)) elif args.decoder == "vanilla": decoder = get_nmt_vanilla_decoder( args, args.nmt_path, _parse_config_param("nmt_config", get_default_nmt_config())) args.predictors = "vanilla" else: logging.fatal("Decoder %s not available. Please double-check the " "--decoder parameter." % args.decoder) add_predictors(decoder) # Add heuristics for search strategies like A* if args.heuristics: add_heuristics(decoder) # Update start sentence id if necessary if args.range: idx, _ = args.range.split(":") if (":" in args.range) else (args.range, 0) decoder.set_start_sen_id(int(idx) - 1) # -1 because indices start with 1 return decoder
def create_decoder(): """Creates the ``Decoder`` instance. This specifies the search strategy used to traverse the space spanned by the predictors. This method relies on the global ``args`` variable. TODO: Refactor to avoid long argument lists Returns: Decoder. Instance of the search strategy """ # Create decoder instance and add predictors decoder = None try: if args.decoder == "greedy": decoder = GreedyDecoder(args) elif args.decoder == "beam": decoder = BeamDecoder(args) elif args.decoder == "multisegbeam": decoder = MultisegBeamDecoder(args, args.hypo_recombination, args.beam, args.multiseg_tokenizations, args.early_stopping, args.max_word_len) elif args.decoder == "syncbeam": decoder = SyncBeamDecoder(args) elif args.decoder == "mbrbeam": decoder = MBRBeamDecoder(args) elif args.decoder == "sepbeam": decoder = SepBeamDecoder(args) elif args.decoder == "syntaxbeam": decoder = SyntaxBeamDecoder(args) elif args.decoder == "combibeam": decoder = CombiBeamDecoder(args) elif args.decoder == "dfs": decoder = DFSDecoder(args) elif args.decoder == "restarting": decoder = RestartingDecoder(args, args.hypo_recombination, args.max_node_expansions, args.low_decoder_memory, args.restarting_node_score, args.stochastic_decoder, args.decode_always_single_step) elif args.decoder == "bow": decoder = BOWDecoder(args) elif args.decoder == "flip": decoder = FlipDecoder(args) elif args.decoder == "bigramgreedy": decoder = BigramGreedyDecoder(args) elif args.decoder == "bucket": decoder = BucketDecoder(args, args.hypo_recombination, args.max_node_expansions, args.low_decoder_memory, args.beam, args.pure_heuristic_scores, args.decoder_diversity_factor, args.early_stopping, args.stochastic_decoder, args.bucket_selector, args.bucket_score_strategy, args.collect_statistics) elif args.decoder == "astar": decoder = AstarDecoder(args) elif args.decoder == "vanilla": decoder = construct_nmt_vanilla_decoder() args.predictors = "vanilla" else: logging.fatal("Decoder %s not available. Please double-check the " "--decoder parameter." % args.decoder) except Exception as e: logging.fatal("An %s has occurred while initializing the decoder: %s" " Stack trace: %s" % (sys.exc_info()[0], e, traceback.format_exc())) if decoder is None: sys.exit("Could not initialize decoder.") add_predictors(decoder) # Add heuristics for search strategies like A* if args.heuristics: add_heuristics(decoder) return decoder