Exemplo n.º 1
0
def add_predictors(decoder):
    """Adds all enabled predictors to the ``decoder``. This function 
    makes heavy use of the global ``args`` which contains the
    SGNMT configuration. Particularly, it reads out ``args.predictors``
    and adds appropriate instances to ``decoder``.
    TODO: Refactor this method as it is waaaay tooooo looong
    
    Args:
        decoder (Decoder):  Decoding strategy, see ``create_decoder()``.
            This method will add predictors to this instance with
            ``add_predictor()``
    """
    preds = args.predictors.split(",")
    if not preds:
        logging.fatal("Require at least one predictor! See the --predictors "
                      "argument for more information.")
    weights = None
    if args.predictor_weights:
        weights = args.predictor_weights.strip().split(",")
        if len(preds) != len(weights):
            logging.fatal("Specified %d predictors, but %d weights. Please "
                          "revise the --predictors and --predictor_weights "
                          "arguments" % (len(preds), len(weights)))
            return

    pred_weight = 1.0
    try:
        for idx, pred in enumerate(preds):  # Add predictors one by one
            wrappers = []
            if '_' in pred:
                # Handle weights when we have wrapper predictors
                wrappers = pred.split('_')
                pred = wrappers[-1]
                wrappers = wrappers[-2::-1]
                if weights:
                    wrapper_weights = [
                        float(w) for w in weights[idx].split('_')
                    ]
                    pred_weight = wrapper_weights[-1]
                    wrapper_weights = wrapper_weights[-2::-1]
                else:
                    wrapper_weights = [1.0] * len(wrappers)
            elif weights:
                pred_weight = float(weights[idx])

            # Create predictor instances for the string argument ``pred``
            if pred == "nmt":
                nmt_engine = _get_override_args("nmt_engine")
                if nmt_engine == 'blocks':
                    get_nmt_predictor = blocks_get_nmt_predictor
                    get_default_nmt_config = blocks_get_default_nmt_config
                elif nmt_engine == 'tensorflow':
                    get_nmt_predictor = tf_get_nmt_predictor
                    get_default_nmt_config = tf_get_default_nmt_config
                elif nmt_engine != 'none':
                    logging.fatal("NMT engine %s is not supported (yet)!" %
                                  nmt_engine)
                p = get_nmt_predictor(
                    args, _get_override_args("nmt_path"),
                    _parse_config_param("nmt_config",
                                        get_default_nmt_config()))
            elif pred == "fst":
                p = FstPredictor(_get_override_args("fst_path"),
                                 args.use_fst_weights,
                                 args.normalize_fst_weights,
                                 skip_bos_weight=args.fst_skip_bos_weight,
                                 to_log=args.fst_to_log)
            elif pred == "nfst":
                p = NondeterministicFstPredictor(
                    _get_override_args("fst_path"),
                    args.use_fst_weights,
                    args.normalize_fst_weights,
                    args.fst_skip_bos_weight,
                    to_log=args.fst_to_log)
            elif pred == "forced":
                p = ForcedPredictor(args.trg_test)
            elif pred == "bow":
                p = BagOfWordsPredictor(
                    args.trg_test, args.bow_accept_subsets,
                    args.bow_accept_duplicates, args.heuristic_scores_file,
                    args.collect_statistics, "consumed"
                    in args.bow_heuristic_strategies, "remaining"
                    in args.bow_heuristic_strategies,
                    args.bow_diversity_heuristic_factor,
                    args.bow_equivalence_vocab_size)
            elif pred == "bowsearch":
                p = BagOfWordsSearchPredictor(
                    decoder, args.hypo_recombination, args.trg_test,
                    args.bow_accept_subsets, args.bow_accept_duplicates,
                    args.heuristic_scores_file, args.collect_statistics,
                    "consumed" in args.bow_heuristic_strategies, "remaining"
                    in args.bow_heuristic_strategies,
                    args.bow_diversity_heuristic_factor,
                    args.bow_equivalence_vocab_size)
            elif pred == "forcedlst":
                feat_name = _get_override_args("forcedlst_sparse_feat")
                p = ForcedLstPredictor(args.trg_test, args.use_nbest_weights,
                                       feat_name if feat_name else None)
            elif pred == "rtn":
                p = RtnPredictor(args.rtn_path,
                                 args.use_rtn_weights,
                                 args.normalize_rtn_weights,
                                 to_log=args.fst_to_log,
                                 minimize_rtns=args.minimize_rtns,
                                 rmeps=args.remove_epsilon_in_rtns)
            elif pred == "srilm":
                p = SRILMPredictor(args.srilm_path, args.srilm_order,
                                   args.srilm_convert_to_ln)
            elif pred == "nplm":
                p = NPLMPredictor(args.nplm_path, args.normalize_nplm_probs)
            elif pred == "rnnlm":
                p = tf_get_rnnlm_predictor(_get_override_args("rnnlm_path"),
                                           _get_override_args("rnnlm_config"),
                                           tf_get_rnnlm_prefix())
            elif pred == "lstm":
                p = ChainerLstmPredictor(args.lstm_path)
            elif pred == "wc":
                p = WordCountPredictor(args.wc_word)
            elif pred == "ngramc":
                p = NgramCountPredictor(_get_override_args("ngramc_path"),
                                        _get_override_args("ngramc_order"),
                                        args.ngramc_discount_factor)
            elif pred == "unkc":
                p = UnkCountPredictor(
                    args.unkc_src_vocab_size,
                    [float(l) for l in args.unk_count_lambdas.split(',')])
            elif pred == "length":
                length_model_weights = [
                    float(w) for w in args.length_model_weights.split(',')
                ]
                p = NBLengthPredictor(args.src_test_raw, length_model_weights,
                                      args.use_length_point_probs,
                                      args.length_model_offset)
            elif pred == "extlength":
                p = ExternalLengthPredictor(args.extlength_path)
            elif pred == "lrhiero":
                fw = None
                if args.grammar_feature_weights:
                    fw = [
                        float(w)
                        for w in args.grammar_feature_weights.split(',')
                    ]
                p = RuleXtractPredictor(args.rules_path,
                                        args.use_grammar_weights, fw)
            elif pred == "vanilla":
                continue
            else:
                logging.fatal("Predictor '%s' not available. Please check "
                              "--predictors for spelling errors." % pred)
                decoder.remove_predictors()
                return
            for _, wrapper in enumerate(wrappers):
                # Embed predictor ``p`` into wrapper predictors if necessary
                # TODO: Use wrapper_weights
                if wrapper == "idxmap":
                    src_path = _get_override_args("src_idxmap")
                    trg_path = _get_override_args("trg_idxmap")
                    if isinstance(p, UnboundedVocabularyPredictor):
                        p = UnboundedIdxmapPredictor(src_path, trg_path, p,
                                                     1.0)
                    else:  # idxmap predictor for bounded predictors
                        p = IdxmapPredictor(src_path, trg_path, p, 1.0)
                elif wrapper == "altsrc":
                    src_test = _get_override_args("altsrc_test")
                    if isinstance(p, UnboundedVocabularyPredictor):
                        p = UnboundedAltsrcPredictor(src_test, p)
                    else:  # altsrc predictor for bounded predictors
                        p = AltsrcPredictor(src_test, p)
                elif wrapper == "word2char":
                    map_path = _get_override_args("word2char_map")
                    # word2char always wraps unbounded predictors
                    p = Word2charPredictor(map_path, p)
                elif wrapper == "fsttok":
                    fsttok_path = _get_override_args("fsttok_path")
                    # word2char always wraps unbounded predictors
                    p = FSTTokPredictor(fsttok_path, p)
                elif wrapper == "unkvocab":
                    # unkvocab always wraps bounded predictors
                    p = UnkvocabPredictor(args.trg_vocab_size, p)
                else:
                    logging.fatal("Predictor wrapper '%s' not available. "
                                  "Please double-check --predictors for "
                                  "spelling errors." % wrapper)
                    decoder.remove_predictors()
                    return
            decoder.add_predictor(pred, p, pred_weight)
            logging.info("Added predictor {} with weight {}".format(
                pred, pred_weight))
    except IOError as e:
        logging.fatal("One of the files required for setting up the "
                      "predictors could not be read: %s" % e)
        decoder.remove_predictors()
    except NameError as e:
        logging.fatal("Could not find external library: %s. Please make sure "
                      "that your PYTHONPATH and LD_LIBRARY_PATH contains all "
                      "paths required for the predictors." % e)
        decoder.remove_predictors()
    except ValueError as e:
        logging.fatal("A number format error occurred while configuring the "
                      "predictors: %s. Please double-check all integer- or "
                      "float-valued parameters such as --predictor_weights and"
                      " try again. Stack trace: %s" %
                      (e, traceback.format_exc()))
        decoder.remove_predictors()
    except Exception as e:
        logging.fatal("An unexpected %s has occurred while setting up the pre"
                      "dictors: %s Stack trace: %s" %
                      (sys.exc_info()[0], e, traceback.format_exc()))
        decoder.remove_predictors()
Exemplo n.º 2
0
def add_predictors(decoder):
    """Adds all enabled predictors to the ``decoder``. This function 
    makes heavy use of the global ``args`` which contains the
    SGNMT configuration. Particularly, it reads out ``args.predictors``
    and adds appropriate instances to ``decoder``.
    TODO: Refactor this method as it is waaaay tooooo looong
    
    Args:
        decoder (Decoder):  Decoding strategy, see ``create_decoder()``.
            This method will add predictors to this instance with
            ``add_predictor()``
    """
    preds = utils.split_comma(args.predictors)
    if not preds:
        logging.fatal("Require at least one predictor! See the --predictors "
                      "argument for more information.")
    weights = None
    if args.predictor_weights:
        weights = utils.split_comma(args.predictor_weights)
        if len(preds) != len(weights):
            logging.fatal("Specified %d predictors, but %d weights. Please "
                      "revise the --predictors and --predictor_weights "
                      "arguments" % (len(preds), len(weights)))
            return
    
    pred_weight = 1.0
    try:
        for idx, pred in enumerate(preds): # Add predictors one by one
            wrappers = []
            if '_' in pred: 
                # Handle weights when we have wrapper predictors
                wrappers = pred.split('_')
                pred = wrappers[-1]
                wrappers = wrappers[-2::-1]
                if weights:
                    wrapper_weights = [float(w) for w in weights[idx].split('_')]
                    pred_weight = wrapper_weights[-1]
                    wrapper_weights = wrapper_weights[-2::-1]
                else:
                    wrapper_weights = [1.0] * len(wrappers)
            elif weights:
                pred_weight = float(weights[idx])

            # Create predictor instances for the string argument ``pred``
            if pred == "nmt":
                nmt_engine = _get_override_args("nmt_engine")
                if nmt_engine == 'blocks':
                    nmt_config = _parse_config_param(
                        "nmt_config", blocks_get_default_nmt_config())
                    p = blocks_get_nmt_predictor(
                        args, _get_override_args("nmt_path"), nmt_config)
                elif nmt_engine == 'tensorflow':
                    nmt_config = _parse_config_param(
                        "nmt_config", tf_get_default_nmt_config())
                    p = tf_get_nmt_predictor(
                        args, _get_override_args("nmt_path"), nmt_config)
                elif nmt_engine != 'none':
                    logging.fatal("NMT engine %s is not supported (yet)!" % nmt_engine)
            elif pred == "nizza":
                p = NizzaPredictor(_get_override_args("pred_src_vocab_size"),
                                   _get_override_args("pred_trg_vocab_size"),
                                   _get_override_args("nizza_model"),
                                   _get_override_args("nizza_hparams_set"),
                                   _get_override_args("nizza_checkpoint_dir"),
                                   single_cpu_thread=args.single_cpu_thread)
            elif pred == "lexnizza":
                p = LexNizzaPredictor(_get_override_args("pred_src_vocab_size"),
                                      _get_override_args("pred_trg_vocab_size"),
                                      _get_override_args("nizza_model"),
                                      _get_override_args("nizza_hparams_set"),
                                      _get_override_args("nizza_checkpoint_dir"),
                                      single_cpu_thread=args.single_cpu_thread,
                                      alpha=args.lexnizza_alpha,
                                      beta=args.lexnizza_beta,
                                      trg2src_model_name=
                                          args.lexnizza_trg2src_model, 
                                      trg2src_hparams_set_name=
                                          args.lexnizza_trg2src_hparams_set,
                                      trg2src_checkpoint_dir=
                                          args.lexnizza_trg2src_checkpoint_dir,
                                      shortlist_strategies=
                                          args.lexnizza_shortlist_strategies,
                                      max_shortlist_length=
                                          args.lexnizza_max_shortlist_length,
                                      min_id=args.lexnizza_min_id)
            elif pred == "t2t":
                p = T2TPredictor(_get_override_args("pred_src_vocab_size"),
                                 _get_override_args("pred_trg_vocab_size"),
                                 _get_override_args("t2t_model"),
                                 _get_override_args("t2t_problem"),
                                 _get_override_args("t2t_hparams_set"),
                                 args.t2t_usr_dir,
                                 _get_override_args("t2t_checkpoint_dir"),
                                 t2t_unk_id=_get_override_args("t2t_unk_id"),
                                 single_cpu_thread=args.single_cpu_thread,
                                 max_terminal_id=args.syntax_max_terminal_id,
                                 pop_id=args.syntax_pop_id)
            elif pred == "fertt2t":
                p = FertilityT2TPredictor(
                                 _get_override_args("pred_src_vocab_size"),
                                 _get_override_args("pred_trg_vocab_size"),
                                 _get_override_args("t2t_model"),
                                 _get_override_args("t2t_problem"),
                                 _get_override_args("t2t_hparams_set"),
                                 args.t2t_usr_dir,
                                 _get_override_args("t2t_checkpoint_dir"),
                                 single_cpu_thread=args.single_cpu_thread,
                                 max_terminal_id=args.syntax_max_terminal_id,
                                 pop_id=args.syntax_pop_id)
            elif pred == "bracket":
                p = BracketPredictor(args.syntax_max_terminal_id,
                                     args.syntax_pop_id,
                                     max_depth=args.syntax_max_depth,
                                     extlength_path=args.extlength_path)
            elif pred == "osm":
                p = OSMPredictor(args.osm_type)
            elif pred == "forcedosm":
                p = ForcedOSMPredictor(args.trg_test)
            elif pred == "fst":
                p = FstPredictor(_get_override_args("fst_path"),
                                 args.use_fst_weights,
                                 args.normalize_fst_weights,
                                 skip_bos_weight=args.fst_skip_bos_weight,
                                 to_log=args.fst_to_log)
            elif pred == "nfst":
                p = NondeterministicFstPredictor(_get_override_args("fst_path"),
                                                 args.use_fst_weights,
                                                 args.normalize_fst_weights,
                                                 args.fst_skip_bos_weight,
                                                 to_log=args.fst_to_log)
            elif pred == "forced":
                p = ForcedPredictor(args.trg_test)
            elif pred == "bow":
                p = BagOfWordsPredictor(
                                args.trg_test,
                                args.bow_accept_subsets,
                                args.bow_accept_duplicates,
                                args.heuristic_scores_file,
                                args.collect_statistics,
                                "consumed" in args.bow_heuristic_strategies,
                                "remaining" in args.bow_heuristic_strategies,
                                args.bow_diversity_heuristic_factor,
                                _get_override_args("pred_trg_vocab_size"))
            elif pred == "bowsearch":
                p = BagOfWordsSearchPredictor(
                                decoder,
                                args.hypo_recombination,
                                args.trg_test,
                                args.bow_accept_subsets,
                                args.bow_accept_duplicates,
                                args.heuristic_scores_file,
                                args.collect_statistics,
                                "consumed" in args.bow_heuristic_strategies,
                                "remaining" in args.bow_heuristic_strategies,
                                args.bow_diversity_heuristic_factor,
                                _get_override_args("pred_trg_vocab_size"))
            elif pred == "forcedlst":
                feat_name = _get_override_args("forcedlst_sparse_feat")
                p = ForcedLstPredictor(args.trg_test,
                                       args.use_nbest_weights,
                                       args.forcedlst_match_unk,
                                       feat_name if feat_name else None)
            elif pred == "rtn":
                p = RtnPredictor(args.rtn_path,
                                 args.use_rtn_weights,
                                 args.normalize_rtn_weights,
                                 to_log=args.fst_to_log,
                                 minimize_rtns=args.minimize_rtns,
                                 rmeps=args.remove_epsilon_in_rtns)
            elif pred == "srilm":
                p = SRILMPredictor(args.lm_path, 
                                   _get_override_args("ngramc_order"),
                                   args.srilm_convert_to_ln)
            elif pred == "kenlm":
                p = KenLMPredictor(args.lm_path)
            elif pred == "nplm":
                p = NPLMPredictor(args.nplm_path, args.normalize_nplm_probs)
            elif pred == "rnnlm":
                p = tf_get_rnnlm_predictor(_get_override_args("rnnlm_path"),
                                           _get_override_args("rnnlm_config"),
                                           tf_get_rnnlm_prefix())
            elif pred == "wc":
                p = WordCountPredictor(args.wc_word,
                                       args.wc_nonterminal_penalty,
                                       args.syntax_nonterminal_ids,
                                       args.syntax_min_terminal_id,
                                       args.syntax_max_terminal_id,
                                       _get_override_args("pred_trg_vocab_size"))
            elif pred == "ngramc":
                p = NgramCountPredictor(_get_override_args("ngramc_path"),
                                        _get_override_args("ngramc_order"),
                                        args.ngramc_discount_factor)
            elif pred == "unkc":
                p = UnkCountPredictor(
                     _get_override_args("pred_src_vocab_size"), 
                     utils.split_comma(args.unk_count_lambdas, float))
            elif pred == "length":
                length_model_weights = utils.split_comma(
                    args.length_model_weights, float)
                p = NBLengthPredictor(args.src_test_raw, 
                                      length_model_weights, 
                                      args.use_length_point_probs,
                                      args.length_model_offset)
            elif pred == "extlength":
                p = ExternalLengthPredictor(args.extlength_path)
            elif pred == "lrhiero":
                fw = None
                if args.grammar_feature_weights:
                    fw = utils.split_comma(args.grammar_feature_weights, float)
                p = RuleXtractPredictor(args.rules_path,
                                        args.use_grammar_weights,
                                        fw)
            elif pred == "vanilla":
                continue
            else:
                logging.fatal("Predictor '%s' not available. Please check "
                              "--predictors for spelling errors." % pred)
                decoder.remove_predictors()
                return
            for _,wrapper in enumerate(wrappers):
                # Embed predictor ``p`` into wrapper predictors if necessary
                # TODO: Use wrapper_weights
                if wrapper == "idxmap":
                    src_path = _get_override_args("src_idxmap")
                    trg_path = _get_override_args("trg_idxmap")
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedIdxmapPredictor(src_path, trg_path, p, 1.0) 
                    else: # idxmap predictor for bounded predictors
                        p = IdxmapPredictor(src_path, trg_path, p, 1.0)
                elif wrapper == "maskvocab":
                    words = utils.split_comma(args.maskvocab_words, int)
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedMaskvocabPredictor(words, p) 
                    else: # idxmap predictor for bounded predictors
                        p = MaskvocabPredictor(words, p)
                elif wrapper == "weightnt":
                    p = WeightNonTerminalPredictor(
                        p, 
                        args.syntax_nonterminal_factor,
                        args.syntax_nonterminal_ids,
                        args.syntax_min_terminal_id,
                        args.syntax_max_terminal_id,
                        _get_override_args("pred_trg_vocab_size"))
                elif wrapper == "parse":
                    if args.parse_tok_grammar:
                        if args.parse_bpe_path:
                            p = BpeParsePredictor(
                                args.syntax_path,
                                args.syntax_bpe_path,
                                p,
                                args.syntax_word_out,
                                args.normalize_fst_weights,
                                norm_alpha=args.syntax_norm_alpha,
                                beam_size=args.syntax_internal_beam,
                                max_internal_len=args.syntax_max_internal_len,
                                allow_early_eos=args.syntax_allow_early_eos,
                                consume_out_of_class=args.syntax_consume_ooc,
                                terminal_restrict=args.syntax_terminal_restrict,
                                internal_only_restrict=args.syntax_internal_only,
                                eow_ids=args.syntax_eow_ids,
                                terminal_ids=args.syntax_terminal_ids)
                        else:
                            p = TokParsePredictor(
                                args.syntax_path,
                                p,
                                args.syntax_word_out,
                                args.normalize_fst_weights,
                                norm_alpha=args.syntax_norm_alpha,
                                beam_size=args.syntax_internal_beam,
                                max_internal_len=args.syntax_max_internal_len,
                                allow_early_eos=args.syntax_allow_early_eos,
                                consume_out_of_class=args.syntax_consume_ooc)
                    else:
                        p = ParsePredictor(
                            p,
                            args.normalize_fst_weights,
                            beam_size=args.syntax_internal_beam,
                            max_internal_len=args.syntax_max_internal_len,
                            nonterminal_ids=args.syntax_nonterminal_ids)
                elif wrapper == "altsrc":
                    src_test = _get_override_args("altsrc_test")
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedAltsrcPredictor(src_test, p)
                    else: # altsrc predictor for bounded predictors
                        p = AltsrcPredictor(src_test, p)
                elif wrapper == "word2char":
                    map_path = _get_override_args("word2char_map")
                    # word2char always wraps unbounded predictors
                    p = Word2charPredictor(map_path, p)
                elif wrapper == "skipvocab":
                    # skipvocab always wraps unbounded predictors
                    p = SkipvocabPredictor(args.skipvocab_max_id, 
                                           args.skipvocab_stop_size, 
                                           args.beam, 
                                           p)
                elif wrapper == "fsttok":
                    fsttok_path = _get_override_args("fsttok_path")
                    # fsttok always wraps unbounded predictors
                    p = FSTTokPredictor(fsttok_path,
                                        args.fst_unk_id,
                                        args.fsttok_max_pending_score,
                                        p)
                elif wrapper == "ngramize":
                    # ngramize always wraps bounded predictors
                    p = NgramizePredictor(args.min_ngram_order, 
                                          args.max_ngram_order,
                                          args.max_len_factor, p)
                elif wrapper == "unkvocab":
                    # unkvocab always wraps bounded predictors
                    p = UnkvocabPredictor(args.trg_vocab_size, p)
                else:
                    logging.fatal("Predictor wrapper '%s' not available. "
                                  "Please double-check --predictors for "
                                  "spelling errors." % wrapper)
                    decoder.remove_predictors()
                    return
            decoder.add_predictor(pred, p, pred_weight)
            logging.info("Initialized predictor {} (weight: {})".format(
                             pred, pred_weight))
    except IOError as e:
        logging.fatal("One of the files required for setting up the "
                      "predictors could not be read: %s" % e)
        decoder.remove_predictors()
    except AttributeError as e:
        logging.fatal("Invalid argument for one of the predictors: %s."
                       "Stack trace: %s" % (e, traceback.format_exc()))
        decoder.remove_predictors()
    except NameError as e:
        logging.fatal("Could not find external library: %s. Please make sure "
                      "that your PYTHONPATH and LD_LIBRARY_PATH contains all "
                      "paths required for the predictors. Stack trace: %s" % 
                      (e, traceback.format_exc()))
        decoder.remove_predictors()
    except ValueError as e:
        logging.fatal("A number format error occurred while configuring the "
                      "predictors: %s. Please double-check all integer- or "
                      "float-valued parameters such as --predictor_weights and"
                      " try again. Stack trace: %s" % (e, traceback.format_exc()))
        decoder.remove_predictors()
    except Exception as e:
        logging.fatal("An unexpected %s has occurred while setting up the pre"
                      "dictors: %s Stack trace: %s" % (sys.exc_info()[0],
                                                       e,
                                                       traceback.format_exc()))
        decoder.remove_predictors()