from cam.sgnmt.blocks.alignment.nam import align_with_nam from cam.sgnmt.blocks.alignment.nmt import align_with_nmt from cam.sgnmt.blocks.nmt import blocks_get_default_nmt_config from cam.sgnmt.misc.sparse import FileBasedFeatMap from cam.sgnmt.output import CSVAlignmentOutputHandler, \ NPYAlignmentOutputHandler, TextAlignmentOutputHandler from cam.sgnmt.ui import get_align_parser logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s') logging.getLogger().setLevel(logging.INFO) parser = get_align_parser() args = parser.parse_args() # Get configuration configuration = blocks_get_default_nmt_config() for k in dir(args): if k in configuration: configuration[k] = getattr(args, k) if configuration['src_sparse_feat_map']: configuration['src_sparse_feat_map'] = FileBasedFeatMap( configuration['enc_embed'], configuration['src_sparse_feat_map']) if configuration['trg_sparse_feat_map']: configuration['trg_sparse_feat_map'] = FileBasedFeatMap( configuration['dec_embed'], configuration['trg_sparse_feat_map']) logging.info("Model options:\n{}".format(pprint.pformat(configuration))) # Align if args.alignment_model == "nam": alignments = align_with_nam(configuration, args) elif args.alignment_model == "nmt":
def _get_dataset_with_mono(mono_data_integration='exp3s', backtrans_nmt_config='', backtrans_store=True, add_mono_dummy_data=True, min_parallel_data=0.2, backtrans_reload_frequency=0, backtrans_max_same_word=0.3, src_data='', trg_data='', src_mono_data='', trg_mono_data='', src_vocab_size=30000, trg_vocab_size=30000, src_sparse_feat_map='', trg_sparse_feat_map='', saveto='', **kwargs): """Creates a parallel data stream with monolingual data. This is based on the ``ParallelSource`` framework in ``stream``. The arguments to this method are given by the configuration dict. """ src_sens = stream.load_sentences_from_file(src_data, src_vocab_size) trg_sens = stream.load_sentences_from_file(trg_data, trg_vocab_size) trg_mono_sens = stream.load_sentences_from_file(trg_mono_data, trg_vocab_size) backtrans_config = blocks_get_default_nmt_config() if backtrans_nmt_config: for pair in backtrans_nmt_config.split(","): (k, v) = pair.split("=", 1) backtrans_config[k] = type(backtrans_config[k])(v) parallel_src = ShuffledParallelSource(src_sens, trg_sens) dummy_src = None if add_mono_dummy_data: dummy_src = DummyParallelSource(utils.GO_ID, trg_mono_sens) if backtrans_store: backtrans_file = "%s/backtrans.txt" % saveto old_backtrans_src = OldBacktranslatedParallelSource(backtrans_file) backtrans_src = BacktranslatedParallelSource( trg_mono_sens, backtrans_config, backtrans_file, backtrans_max_same_word, backtrans_reload_frequency, old_backtrans_src) else: backtrans_src = BacktranslatedParallelSource( trg_mono_sens, backtrans_config, None, backtrans_max_same_word, backtrans_reload_frequency) if min_parallel_data > 0.0: if add_mono_dummy_data: dummy_src = MergedParallelSource(parallel_src, dummy_src, min_parallel_data) backtrans_src = MergedParallelSource(parallel_src, backtrans_src, min_parallel_data) old_backtrans_src = MergedParallelSource(parallel_src, old_backtrans_src, min_parallel_data) sources = [] sources.append(parallel_src) if add_mono_dummy_data: sources.append(dummy_src) sources.append(backtrans_src) if backtrans_store: sources.append(old_backtrans_src) return ParallelSourceSwitchDataset(sources, src_vocab_size, trg_vocab_size, src_sparse_feat_map=src_sparse_feat_map, trg_sparse_feat_map=trg_sparse_feat_map)
from cam.sgnmt.blocks.alignment.nmt import align_with_nmt from cam.sgnmt.blocks.nmt import blocks_get_default_nmt_config from cam.sgnmt.misc.sparse import FileBasedFeatMap from cam.sgnmt.output import CSVAlignmentOutputHandler, \ NPYAlignmentOutputHandler, TextAlignmentOutputHandler from cam.sgnmt.blocks.nmt import get_blocks_align_parser logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s') logging.getLogger().setLevel(logging.INFO) parser = get_blocks_align_parser() args = parser.parse_args() # Get configuration configuration = blocks_get_default_nmt_config() for k in dir(args): if k in configuration: configuration[k] = getattr(args, k) if configuration['src_sparse_feat_map']: configuration['src_sparse_feat_map'] = FileBasedFeatMap( configuration['enc_embed'], configuration['src_sparse_feat_map']) if configuration['trg_sparse_feat_map']: configuration['trg_sparse_feat_map'] = FileBasedFeatMap( configuration['dec_embed'], configuration['trg_sparse_feat_map']) logging.info("Model options:\n{}".format(pprint.pformat(configuration))) # Align if args.alignment_model == "nam":
def add_predictors(decoder): """Adds all enabled predictors to the ``decoder``. This function makes heavy use of the global ``args`` which contains the SGNMT configuration. Particularly, it reads out ``args.predictors`` and adds appropriate instances to ``decoder``. TODO: Refactor this method as it is waaaay tooooo looong Args: decoder (Decoder): Decoding strategy, see ``create_decoder()``. This method will add predictors to this instance with ``add_predictor()`` """ preds = utils.split_comma(args.predictors) if not preds: logging.fatal("Require at least one predictor! See the --predictors " "argument for more information.") weights = None if args.predictor_weights: weights = utils.split_comma(args.predictor_weights) if len(preds) != len(weights): logging.fatal("Specified %d predictors, but %d weights. Please " "revise the --predictors and --predictor_weights " "arguments" % (len(preds), len(weights))) return pred_weight = 1.0 try: for idx, pred in enumerate(preds): # Add predictors one by one wrappers = [] if '_' in pred: # Handle weights when we have wrapper predictors wrappers = pred.split('_') pred = wrappers[-1] wrappers = wrappers[-2::-1] if weights: wrapper_weights = [ float(w) for w in weights[idx].split('_') ] pred_weight = wrapper_weights[-1] wrapper_weights = wrapper_weights[-2::-1] else: wrapper_weights = [1.0] * len(wrappers) elif weights: pred_weight = float(weights[idx]) # Create predictor instances for the string argument ``pred`` if pred == "nmt": nmt_engine = _get_override_args("nmt_engine") if nmt_engine == 'blocks': nmt_config = _parse_config_param( "nmt_config", blocks_get_default_nmt_config()) p = blocks_get_nmt_predictor( args, _get_override_args("nmt_path"), nmt_config) elif nmt_engine == 'tensorflow': nmt_config = _parse_config_param( "nmt_config", tf_get_default_nmt_config()) p = tf_get_nmt_predictor(args, _get_override_args("nmt_path"), nmt_config) elif nmt_engine != 'none': logging.fatal("NMT engine %s is not supported (yet)!" % nmt_engine) elif pred == "nizza": p = NizzaPredictor(_get_override_args("pred_src_vocab_size"), _get_override_args("pred_trg_vocab_size"), _get_override_args("nizza_model"), _get_override_args("nizza_hparams_set"), _get_override_args("nizza_checkpoint_dir"), single_cpu_thread=args.single_cpu_thread) elif pred == "lexnizza": p = LexNizzaPredictor( _get_override_args("pred_src_vocab_size"), _get_override_args("pred_trg_vocab_size"), _get_override_args("nizza_model"), _get_override_args("nizza_hparams_set"), _get_override_args("nizza_checkpoint_dir"), single_cpu_thread=args.single_cpu_thread, alpha=args.lexnizza_alpha, beta=args.lexnizza_beta, trg2src_model_name=args.lexnizza_trg2src_model, trg2src_hparams_set_name=args.lexnizza_trg2src_hparams_set, trg2src_checkpoint_dir=args. lexnizza_trg2src_checkpoint_dir, shortlist_strategies=args.lexnizza_shortlist_strategies, max_shortlist_length=args.lexnizza_max_shortlist_length, min_id=args.lexnizza_min_id) elif pred == "t2t": p = T2TPredictor(_get_override_args("pred_src_vocab_size"), _get_override_args("pred_trg_vocab_size"), _get_override_args("t2t_model"), _get_override_args("t2t_problem"), _get_override_args("t2t_hparams_set"), args.t2t_usr_dir, _get_override_args("t2t_checkpoint_dir"), single_cpu_thread=args.single_cpu_thread, max_terminal_id=args.syntax_max_terminal_id, pop_id=args.syntax_pop_id) elif pred == "simt2t": p = SimT2TPredictor( _get_override_args("pred_src_vocab_size"), _get_override_args("pred_trg_vocab_size"), _get_override_args("t2t_model"), _get_override_args("t2t_problem"), _get_override_args("t2t_hparams_set"), args.t2t_usr_dir, _get_override_args("t2t_checkpoint_dir"), single_cpu_thread=args.single_cpu_thread, max_terminal_id=args.syntax_max_terminal_id, pop_id=args.syntax_pop_id) elif pred == "simt2tv2": p = SimT2TPredictor_v2( _get_override_args("pred_src_vocab_size"), _get_override_args("pred_trg_vocab_size"), _get_override_args("t2t_model"), _get_override_args("t2t_problem"), _get_override_args("t2t_hparams_set"), args.t2t_usr_dir, _get_override_args("t2t_checkpoint_dir"), single_cpu_thread=args.single_cpu_thread, max_terminal_id=args.syntax_max_terminal_id, pop_id=args.syntax_pop_id) elif pred == "bracket": p = BracketPredictor(args.syntax_max_terminal_id, args.syntax_pop_id, max_depth=args.syntax_max_depth, extlength_path=args.extlength_path) elif pred == "osm": p = OSMPredictor() elif pred == "forcedosm": p = ForcedOSMPredictor(args.trg_test) elif pred == "fst": p = FstPredictor(_get_override_args("fst_path"), args.use_fst_weights, args.normalize_fst_weights, skip_bos_weight=args.fst_skip_bos_weight, to_log=args.fst_to_log) elif pred == "nfst": p = NondeterministicFstPredictor( _get_override_args("fst_path"), args.use_fst_weights, args.normalize_fst_weights, args.fst_skip_bos_weight, to_log=args.fst_to_log) elif pred == "forced": p = ForcedPredictor(args.trg_test) elif pred == "bow": p = BagOfWordsPredictor( args.trg_test, args.bow_accept_subsets, args.bow_accept_duplicates, args.heuristic_scores_file, args.collect_statistics, "consumed" in args.bow_heuristic_strategies, "remaining" in args.bow_heuristic_strategies, args.bow_diversity_heuristic_factor, _get_override_args("pred_trg_vocab_size")) elif pred == "bowsearch": p = BagOfWordsSearchPredictor( decoder, args.hypo_recombination, args.trg_test, args.bow_accept_subsets, args.bow_accept_duplicates, args.heuristic_scores_file, args.collect_statistics, "consumed" in args.bow_heuristic_strategies, "remaining" in args.bow_heuristic_strategies, args.bow_diversity_heuristic_factor, _get_override_args("pred_trg_vocab_size")) elif pred == "forcedlst": feat_name = _get_override_args("forcedlst_sparse_feat") p = ForcedLstPredictor(args.trg_test, args.use_nbest_weights, feat_name if feat_name else None) elif pred == "rtn": p = RtnPredictor(args.rtn_path, args.use_rtn_weights, args.normalize_rtn_weights, to_log=args.fst_to_log, minimize_rtns=args.minimize_rtns, rmeps=args.remove_epsilon_in_rtns) elif pred == "srilm": p = SRILMPredictor(args.srilm_path, args.srilm_order, args.srilm_convert_to_ln) elif pred == "nplm": p = NPLMPredictor(args.nplm_path, args.normalize_nplm_probs) elif pred == "rnnlm": p = tf_get_rnnlm_predictor(_get_override_args("rnnlm_path"), _get_override_args("rnnlm_config"), tf_get_rnnlm_prefix()) elif pred == "wc": p = WordCountPredictor(args.wc_word) elif pred == "ngramc": p = NgramCountPredictor(_get_override_args("ngramc_path"), _get_override_args("ngramc_order"), args.ngramc_discount_factor) elif pred == "unkc": p = UnkCountPredictor( _get_override_args("pred_src_vocab_size"), [float(l) for l in args.unk_count_lambdas.split(',')]) elif pred == "length": length_model_weights = [ float(w) for w in args.length_model_weights.split(',') ] p = NBLengthPredictor(args.src_test_raw, length_model_weights, args.use_length_point_probs, args.length_model_offset) elif pred == "extlength": p = ExternalLengthPredictor(args.extlength_path) elif pred == "lrhiero": fw = None if args.grammar_feature_weights: fw = [ float(w) for w in args.grammar_feature_weights.split(',') ] p = RuleXtractPredictor(args.rules_path, args.use_grammar_weights, fw) elif pred == "vanilla": continue else: logging.fatal("Predictor '%s' not available. Please check " "--predictors for spelling errors." % pred) decoder.remove_predictors() return for _, wrapper in enumerate(wrappers): # Embed predictor ``p`` into wrapper predictors if necessary # TODO: Use wrapper_weights if wrapper == "idxmap": src_path = _get_override_args("src_idxmap") trg_path = _get_override_args("trg_idxmap") if isinstance(p, UnboundedVocabularyPredictor): p = UnboundedIdxmapPredictor(src_path, trg_path, p, 1.0) else: # idxmap predictor for bounded predictors p = IdxmapPredictor(src_path, trg_path, p, 1.0) elif wrapper == "altsrc": src_test = _get_override_args("altsrc_test") if isinstance(p, UnboundedVocabularyPredictor): p = UnboundedAltsrcPredictor(src_test, p) else: # altsrc predictor for bounded predictors p = AltsrcPredictor(src_test, p) elif wrapper == "word2char": map_path = _get_override_args("word2char_map") # word2char always wraps unbounded predictors p = Word2charPredictor(map_path, p) elif wrapper == "skipvocab": # skipvocab always wraps unbounded predictors p = SkipvocabPredictor(args.skipvocab_max_id, args.skipvocab_stop_size, args.beam, p) elif wrapper == "fsttok": fsttok_path = _get_override_args("fsttok_path") # fsttok always wraps unbounded predictors p = FSTTokPredictor(fsttok_path, args.fst_unk_id, args.fsttok_max_pending_score, p) elif wrapper == "ngramize": # ngramize always wraps bounded predictors p = NgramizePredictor(args.min_ngram_order, args.max_ngram_order, args.max_len_factor, p) elif wrapper == "unkvocab": # unkvocab always wraps bounded predictors p = UnkvocabPredictor(args.trg_vocab_size, p) else: logging.fatal("Predictor wrapper '%s' not available. " "Please double-check --predictors for " "spelling errors." % wrapper) decoder.remove_predictors() return decoder.add_predictor(pred, p, pred_weight) logging.info("Initialized predictor {} (weight: {})".format( pred, pred_weight)) except IOError as e: logging.fatal("One of the files required for setting up the " "predictors could not be read: %s" % e) decoder.remove_predictors() except AttributeError as e: logging.fatal("Invalid argument for one of the predictors: %s" % e) decoder.remove_predictors() except NameError as e: logging.fatal("Could not find external library: %s. Please make sure " "that your PYTHONPATH and LD_LIBRARY_PATH contains all " "paths required for the predictors. Stack trace: %s" % (e, traceback.format_exc())) decoder.remove_predictors() except ValueError as e: logging.fatal("A number format error occurred while configuring the " "predictors: %s. Please double-check all integer- or " "float-valued parameters such as --predictor_weights and" " try again. Stack trace: %s" % (e, traceback.format_exc())) decoder.remove_predictors() except Exception as e: logging.fatal("An unexpected %s has occurred while setting up the pre" "dictors: %s Stack trace: %s" % (sys.exc_info()[0], e, traceback.format_exc())) decoder.remove_predictors()
def add_predictors(decoder): """Adds all enabled predictors to the ``decoder``. This function makes heavy use of the global ``args`` which contains the SGNMT configuration. Particularly, it reads out ``args.predictors`` and adds appropriate instances to ``decoder``. TODO: Refactor this method as it is waaaay tooooo looong Args: decoder (Decoder): Decoding strategy, see ``create_decoder()``. This method will add predictors to this instance with ``add_predictor()`` """ preds = utils.split_comma(args.predictors) if not preds: logging.fatal("Require at least one predictor! See the --predictors " "argument for more information.") weights = None if args.predictor_weights: weights = utils.split_comma(args.predictor_weights) if len(preds) != len(weights): logging.fatal("Specified %d predictors, but %d weights. Please " "revise the --predictors and --predictor_weights " "arguments" % (len(preds), len(weights))) return pred_weight = 1.0 try: for idx, pred in enumerate(preds): # Add predictors one by one wrappers = [] if '_' in pred: # Handle weights when we have wrapper predictors wrappers = pred.split('_') pred = wrappers[-1] wrappers = wrappers[-2::-1] if weights: wrapper_weights = [float(w) for w in weights[idx].split('_')] pred_weight = wrapper_weights[-1] wrapper_weights = wrapper_weights[-2::-1] else: wrapper_weights = [1.0] * len(wrappers) elif weights: pred_weight = float(weights[idx]) # Create predictor instances for the string argument ``pred`` if pred == "nmt": # TODO: Clean this up and make a blocks and tfnmt predictor nmt_engine = _get_override_args("nmt_engine") if nmt_engine == 'blocks': nmt_config = _parse_config_param( "nmt_config", blocks_get_default_nmt_config()) p = blocks_get_nmt_predictor( args, _get_override_args("nmt_path"), nmt_config) elif nmt_engine == 'tensorflow': nmt_config = _parse_config_param( "nmt_config", tf_get_default_nmt_config()) p = tf_get_nmt_predictor( args, _get_override_args("nmt_path"), nmt_config) elif nmt_engine != 'none': logging.fatal("NMT engine %s is not supported (yet)!" % nmt_engine) elif pred == "nizza": p = NizzaPredictor(_get_override_args("pred_src_vocab_size"), _get_override_args("pred_trg_vocab_size"), _get_override_args("nizza_model"), _get_override_args("nizza_hparams_set"), _get_override_args("nizza_checkpoint_dir"), n_cpu_threads=args.n_cpu_threads) elif pred == "lexnizza": p = LexNizzaPredictor(_get_override_args("pred_src_vocab_size"), _get_override_args("pred_trg_vocab_size"), _get_override_args("nizza_model"), _get_override_args("nizza_hparams_set"), _get_override_args("nizza_checkpoint_dir"), n_cpu_threads=args.n_cpu_threads, alpha=args.lexnizza_alpha, beta=args.lexnizza_beta, trg2src_model_name= args.lexnizza_trg2src_model, trg2src_hparams_set_name= args.lexnizza_trg2src_hparams_set, trg2src_checkpoint_dir= args.lexnizza_trg2src_checkpoint_dir, shortlist_strategies= args.lexnizza_shortlist_strategies, max_shortlist_length= args.lexnizza_max_shortlist_length, min_id=args.lexnizza_min_id) elif pred == "t2t": p = T2TPredictor(_get_override_args("pred_src_vocab_size"), _get_override_args("pred_trg_vocab_size"), _get_override_args("t2t_model"), _get_override_args("t2t_problem"), _get_override_args("t2t_hparams_set"), args.t2t_usr_dir, _get_override_args("t2t_checkpoint_dir"), t2t_unk_id=_get_override_args("t2t_unk_id"), n_cpu_threads=args.n_cpu_threads, max_terminal_id=args.syntax_max_terminal_id, pop_id=args.syntax_pop_id) elif pred == "segt2t": p = SegT2TPredictor(_get_override_args("pred_src_vocab_size"), _get_override_args("pred_trg_vocab_size"), _get_override_args("t2t_model"), _get_override_args("t2t_problem"), _get_override_args("t2t_hparams_set"), args.t2t_usr_dir, _get_override_args("t2t_checkpoint_dir"), t2t_unk_id=_get_override_args("t2t_unk_id"), n_cpu_threads=args.n_cpu_threads, max_terminal_id=args.syntax_max_terminal_id, pop_id=args.syntax_pop_id) elif pred == "editt2t": p = EditT2TPredictor(_get_override_args("pred_src_vocab_size"), _get_override_args("pred_trg_vocab_size"), _get_override_args("t2t_model"), _get_override_args("t2t_problem"), _get_override_args("t2t_hparams_set"), args.trg_test, args.beam, args.t2t_usr_dir, _get_override_args("t2t_checkpoint_dir"), t2t_unk_id=_get_override_args("t2t_unk_id"), n_cpu_threads=args.n_cpu_threads, max_terminal_id=args.syntax_max_terminal_id, pop_id=args.syntax_pop_id) elif pred == "fertt2t": p = FertilityT2TPredictor( _get_override_args("pred_src_vocab_size"), _get_override_args("pred_trg_vocab_size"), _get_override_args("t2t_model"), _get_override_args("t2t_problem"), _get_override_args("t2t_hparams_set"), args.t2t_usr_dir, _get_override_args("t2t_checkpoint_dir"), n_cpu_threasd=args.n_cpu_threads, max_terminal_id=args.syntax_max_terminal_id, pop_id=args.syntax_pop_id) elif pred == "bracket": p = BracketPredictor(args.syntax_max_terminal_id, args.syntax_pop_id, max_depth=args.syntax_max_depth, extlength_path=args.extlength_path) elif pred == "osm": p = OSMPredictor(args.osm_type) elif pred == "forcedosm": p = ForcedOSMPredictor(args.trg_test) elif pred == "fst": p = FstPredictor(_get_override_args("fst_path"), args.use_fst_weights, args.normalize_fst_weights, skip_bos_weight=args.fst_skip_bos_weight, to_log=args.fst_to_log) elif pred == "nfst": p = NondeterministicFstPredictor(_get_override_args("fst_path"), args.use_fst_weights, args.normalize_fst_weights, args.fst_skip_bos_weight, to_log=args.fst_to_log) elif pred == "forced": p = ForcedPredictor(args.trg_test) elif pred == "bow": p = BagOfWordsPredictor( args.trg_test, args.bow_accept_subsets, args.bow_accept_duplicates, args.heuristic_scores_file, args.collect_statistics, "consumed" in args.bow_heuristic_strategies, "remaining" in args.bow_heuristic_strategies, args.bow_diversity_heuristic_factor, _get_override_args("pred_trg_vocab_size")) elif pred == "bowsearch": p = BagOfWordsSearchPredictor( decoder, args.hypo_recombination, args.trg_test, args.bow_accept_subsets, args.bow_accept_duplicates, args.heuristic_scores_file, args.collect_statistics, "consumed" in args.bow_heuristic_strategies, "remaining" in args.bow_heuristic_strategies, args.bow_diversity_heuristic_factor, _get_override_args("pred_trg_vocab_size")) elif pred == "forcedlst": feat_name = _get_override_args("forcedlst_sparse_feat") p = ForcedLstPredictor(args.trg_test, args.use_nbest_weights, args.forcedlst_match_unk, feat_name if feat_name else None) elif pred == "rtn": p = RtnPredictor(args.rtn_path, args.use_rtn_weights, args.normalize_rtn_weights, to_log=args.fst_to_log, minimize_rtns=args.minimize_rtns, rmeps=args.remove_epsilon_in_rtns) elif pred == "srilm": p = SRILMPredictor(args.lm_path, _get_override_args("ngramc_order"), args.srilm_convert_to_ln) elif pred == "kenlm": p = KenLMPredictor(args.lm_path) elif pred == "nplm": p = NPLMPredictor(args.nplm_path, args.normalize_nplm_probs) elif pred == "rnnlm": p = tf_get_rnnlm_predictor(_get_override_args("rnnlm_path"), _get_override_args("rnnlm_config"), tf_get_rnnlm_prefix()) elif pred == "wc": p = WordCountPredictor(args.wc_word, args.wc_nonterminal_penalty, args.syntax_nonterminal_ids, args.syntax_min_terminal_id, args.syntax_max_terminal_id, args.negative_wc, _get_override_args("pred_trg_vocab_size")) elif pred == "ngramc": p = NgramCountPredictor(_get_override_args("ngramc_path"), _get_override_args("ngramc_order"), args.ngramc_discount_factor) elif pred == "unkc": p = UnkCountPredictor( _get_override_args("pred_src_vocab_size"), utils.split_comma(args.unk_count_lambdas, float)) elif pred == "length": length_model_weights = utils.split_comma( args.length_model_weights, float) p = NBLengthPredictor(args.src_test_raw, length_model_weights, args.use_length_point_probs, args.length_model_offset) elif pred == "extlength": p = ExternalLengthPredictor(args.extlength_path) elif pred == "lrhiero": fw = None if args.grammar_feature_weights: fw = utils.split_comma(args.grammar_feature_weights, float) p = RuleXtractPredictor(args.rules_path, args.use_grammar_weights, fw) elif pred == "vanilla": continue else: logging.fatal("Predictor '%s' not available. Please check " "--predictors for spelling errors." % pred) decoder.remove_predictors() return for _,wrapper in enumerate(wrappers): # Embed predictor ``p`` into wrapper predictors if necessary # TODO: Use wrapper_weights if wrapper == "idxmap": src_path = _get_override_args("src_idxmap") trg_path = _get_override_args("trg_idxmap") if isinstance(p, UnboundedVocabularyPredictor): p = UnboundedIdxmapPredictor(src_path, trg_path, p, 1.0) else: # idxmap predictor for bounded predictors p = IdxmapPredictor(src_path, trg_path, p, 1.0) elif wrapper == "maskvocab": words = utils.split_comma(args.maskvocab_words, int) if isinstance(p, UnboundedVocabularyPredictor): p = UnboundedMaskvocabPredictor(words, p) else: # idxmap predictor for bounded predictors p = MaskvocabPredictor(words, p) elif wrapper == "weightnt": p = WeightNonTerminalPredictor( p, args.syntax_nonterminal_factor, args.syntax_nonterminal_ids, args.syntax_min_terminal_id, args.syntax_max_terminal_id, _get_override_args("pred_trg_vocab_size")) elif wrapper == "parse": if args.parse_tok_grammar: if args.parse_bpe_path: p = BpeParsePredictor( args.syntax_path, args.syntax_bpe_path, p, args.syntax_word_out, args.normalize_fst_weights, norm_alpha=args.syntax_norm_alpha, beam_size=args.syntax_internal_beam, max_internal_len=args.syntax_max_internal_len, allow_early_eos=args.syntax_allow_early_eos, consume_out_of_class=args.syntax_consume_ooc, terminal_restrict=args.syntax_terminal_restrict, internal_only_restrict=args.syntax_internal_only, eow_ids=args.syntax_eow_ids, terminal_ids=args.syntax_terminal_ids) else: p = TokParsePredictor( args.syntax_path, p, args.syntax_word_out, args.normalize_fst_weights, norm_alpha=args.syntax_norm_alpha, beam_size=args.syntax_internal_beam, max_internal_len=args.syntax_max_internal_len, allow_early_eos=args.syntax_allow_early_eos, consume_out_of_class=args.syntax_consume_ooc) else: p = ParsePredictor( p, args.normalize_fst_weights, beam_size=args.syntax_internal_beam, max_internal_len=args.syntax_max_internal_len, nonterminal_ids=args.syntax_nonterminal_ids) elif wrapper == "altsrc": src_test = _get_override_args("altsrc_test") if isinstance(p, UnboundedVocabularyPredictor): p = UnboundedAltsrcPredictor(src_test, p) else: # altsrc predictor for bounded predictors p = AltsrcPredictor(src_test, p) elif wrapper == "rank": if isinstance(p, UnboundedVocabularyPredictor): p = UnboundedRankPredictor(p) else: # rank predictor for bounded predictors p = RankPredictor(p) elif wrapper == "glue": if isinstance(p, UnboundedVocabularyPredictor): p = UnboundedGluePredictor(args.max_len_factor, p) else: # glue predictor for bounded predictors p = GluePredictor(args.max_len_factor, p) elif wrapper == "word2char": map_path = _get_override_args("word2char_map") # word2char always wraps unbounded predictors p = Word2charPredictor(map_path, p) elif wrapper == "skipvocab": # skipvocab always wraps unbounded predictors p = SkipvocabPredictor(args.skipvocab_max_id, args.skipvocab_stop_size, args.beam, p) elif wrapper == "fsttok": fsttok_path = _get_override_args("fsttok_path") # fsttok always wraps unbounded predictors p = FSTTokPredictor(fsttok_path, args.fst_unk_id, args.fsttok_max_pending_score, p) elif wrapper == "ngramize": # ngramize always wraps bounded predictors p = NgramizePredictor(args.min_ngram_order, args.max_ngram_order, args.max_len_factor, p) elif wrapper == "unkvocab": # unkvocab always wraps bounded predictors p = UnkvocabPredictor(args.trg_vocab_size, p) else: logging.fatal("Predictor wrapper '%s' not available. " "Please double-check --predictors for " "spelling errors." % wrapper) decoder.remove_predictors() return decoder.add_predictor(pred, p, pred_weight) logging.info("Initialized predictor {} (weight: {})".format( pred, pred_weight)) except IOError as e: logging.fatal("One of the files required for setting up the " "predictors could not be read: %s" % e) decoder.remove_predictors() except AttributeError as e: logging.fatal("Invalid argument for one of the predictors: %s." "Stack trace: %s" % (e, traceback.format_exc())) decoder.remove_predictors() except NameError as e: logging.fatal("Could not find external library: %s. Please make sure " "that your PYTHONPATH and LD_LIBRARY_PATH contains all " "paths required for the predictors. Stack trace: %s" % (e, traceback.format_exc())) decoder.remove_predictors() except ValueError as e: logging.fatal("A number format error occurred while configuring the " "predictors: %s. Please double-check all integer- or " "float-valued parameters such as --predictor_weights and" " try again. Stack trace: %s" % (e, traceback.format_exc())) decoder.remove_predictors() except Exception as e: logging.fatal("An unexpected %s has occurred while setting up the pre" "dictors: %s Stack trace: %s" % (sys.exc_info()[0], e, traceback.format_exc())) decoder.remove_predictors()
# Print out result for task in sorted(finished_tasks, key=lambda t: t.sen_id): logging.info("Decoded (ID: %d): %s" % ( task.sen_id, ' '.join([str(w) for w in task.get_best_translation()]))) logging.info("Stats (ID: %d): %s" % (task.sen_id, task.get_stats_string())) logging.info("Decoding finished. Time: %.6f" % (stop_time - start_time)) os.system('kill %d' % os.getpid()) # MAIN ENTRY POINT # Get configuration config = blocks_get_default_nmt_config() for k in dir(args): if k in config: config[k] = getattr(args, k) logging.info("Model options:\n{}".format(pprint.pformat(config))) np.show_config() nmt_model = NMTModel(config) nmt_model.set_up() loader = LoadNMTUtils(get_nmt_model_path_best_bleu(config), config['saveto'], nmt_model.search_model) loader.load_weights() src_sentences = load_sentences(args.src_test,