Exemplo n.º 1
0
def add_heuristics(decoder):
    """Adds all enabled heuristics to the ``decoder``. This is relevant
    for heuristic based search strategies like A*. This method relies
    on the global ``args`` variable and reads out ``args.heuristics``.

    Args:
        decoder (Decoder):  Decoding strategy, see ``create_decoder()``.
            This method will add heuristics to this instance with
            ``add_heuristic()``
    """
    if args.heuristic_predictors == 'all':
        h_predictors = decoder.predictors
    else:
        h_predictors = [
            decoder.predictors[int(idx)]
            for idx in utils.split_comma(args.heuristic_predictors)
        ]
    decoder.set_heuristic_predictors(h_predictors)
    for name in utils.split_comma(args.heuristics):
        if name == 'greedy':
            decoder.add_heuristic(
                GreedyHeuristic(args, args.cache_heuristic_estimates))
        elif name == 'predictor':
            decoder.add_heuristic(PredictorHeuristic())
        elif name == 'stats':
            decoder.add_heuristic(
                StatsHeuristic(args.heuristic_scores_file,
                               args.collect_statistics))
        elif name == 'scoreperword':
            decoder.add_heuristic(ScorePerWordHeuristic())
        elif name == 'lasttoken':
            decoder.add_heuristic(LastTokenHeuristic())
        else:
            logging.fatal("Heuristic %s not available. Please double-check "
                          "the --heuristics parameter." % name)
Exemplo n.º 2
0
def add_heuristics(decoder):
    """Adds all enabled heuristics to the ``decoder``. This is relevant
    for heuristic based search strategies like A*. This method relies 
    on the global ``args`` variable and reads out ``args.heuristics``.
    
    Args:
        decoder (Decoder):  Decoding strategy, see ``create_decoder()``.
            This method will add heuristics to this instance with
            ``add_heuristic()``
    """
    if args.heuristic_predictors == 'all':
        h_predictors = decoder.predictors
    else:
        h_predictors = [decoder.predictors[int(idx)]
                       for idx in utils.split_comma(args.heuristic_predictors)]
    decoder.set_heuristic_predictors(h_predictors)
    for name in utils.split_comma(args.heuristics):
        if name == 'greedy':
            decoder.add_heuristic(GreedyHeuristic(args,
                                                  args.cache_heuristic_estimates))
        elif name == 'predictor':
            decoder.add_heuristic(PredictorHeuristic())
        elif name == 'stats':
            decoder.add_heuristic(StatsHeuristic(args.heuristic_scores_file,
                                                 args.collect_statistics))
        elif name == 'scoreperword':
            decoder.add_heuristic(ScorePerWordHeuristic())
        elif name == 'lasttoken':
            decoder.add_heuristic(LastTokenHeuristic())
        else:
            logging.fatal("Heuristic %s not available. Please double-check "
                          "the --heuristics parameter." % name)
Exemplo n.º 3
0
def create_output_handlers():
    """Creates the output handlers defined in the ``io`` module.
    These handlers create output files in different formats from the
    decoding results.

    Args:
        args: Global command line arguments.

    Returns:
        list. List of output handlers according --outputs
    """
    if not args.outputs:
        return []
    trg_map = {} if utils.trg_cmap else utils.trg_wmap
    outputs = []
    start_sen_id = 0
    if args.range:
        idx, _ = args.range.split(":")
        start_sen_id = int(idx) - 1  # -1 because --range indices start with 1
    for name in utils.split_comma(args.outputs):
        if '%s' in args.output_path:
            path = args.output_path % name
        else:
            path = args.output_path
        if name == "text":
            outputs.append(TextOutputHandler(path, trg_map))
        elif name == "delay":
            outputs.append(DelayOutputHandler(path))
        elif name == "nbest":
            outputs.append(
                NBestOutputHandler(path, utils.split_comma(args.predictors),
                                   start_sen_id, trg_map))
        elif name == "ngram":
            outputs.append(
                NgramOutputHandler(path, args.min_ngram_order,
                                   args.max_ngram_order, start_sen_id))
        elif name == "timecsv":
            outputs.append(
                TimeCSVOutputHandler(path, utils.split_comma(args.predictors),
                                     start_sen_id))
        elif name == "fst":
            outputs.append(
                FSTOutputHandler(path, start_sen_id, args.fst_unk_id))
        elif name == "sfst":
            outputs.append(
                StandardFSTOutputHandler(path, start_sen_id, args.fst_unk_id))
        else:
            logging.fatal("Output format %s not available. Please double-check"
                          " the --outputs parameter." % name)
    return outputs
Exemplo n.º 4
0
 def __init__(self,
              max_terminal_id,
              closing_bracket_id,
              max_depth=-1,
              extlength_path=""):
     """Creates a new bracket predictor.
     
     Args:
         max_terminal_id (int): All IDs greater than this are 
             brackets
         closing_bracket_id (string): All brackets except these ones are 
             opening. Comma-separated list of integers.
         max_depth (int): If positive, restrict the maximum depth
         extlength_path (string): If this is set, restrict the 
             number of terminals to the distribution specified in
             the referenced file. Terminals can be implicit: We
             count a single terminal between each adjacent opening
             and closing bracket.
     """
     super(BracketPredictor, self).__init__()
     self.max_terminal_id = max_terminal_id
     try:
         self.closing_bracket_ids = utils.split_comma(
             closing_bracket_id, int)
     except:
         self.closing_bracket_ids = [int(closing_bracket_id)]
     self.max_depth = max_depth if max_depth >= 0 else 1000000
     if extlength_path:
         self.length_scores = load_external_lengths(extlength_path)
     else:
         self.length_scores = None
         self.max_length = 1000000
Exemplo n.º 5
0
def create_output_handlers():
    """Creates the output handlers defined in the ``io`` module. 
    These handlers create output files in different formats from the
    decoding results.
    
    Args:
        args: Global command line arguments.
    
    Returns:
        list. List of output handlers according --outputs
    """
    if not args.outputs:
        return []
    trg_map = {} if utils.trg_cmap else utils.trg_wmap
    outputs = []
    for name in utils.split_comma(args.outputs):
        if '%s' in args.output_path:
            path = args.output_path % name
        else:
            path = args.output_path
        if name == "text":
            outputs.append(TextOutputHandler(path, trg_map))
        elif name == "nbest":
            outputs.append(NBestOutputHandler(path, 
                                              utils.split_comma(args.predictors),
                                              trg_map))
        elif name == "ngram":
            outputs.append(NgramOutputHandler(path,
                                              args.min_ngram_order,
                                              args.max_ngram_order))
        elif name == "timecsv":
            outputs.append(TimeCSVOutputHandler(path, 
                                                utils.split_comma(args.predictors)))
        elif name == "fst":
            outputs.append(FSTOutputHandler(path,
                                            args.fst_unk_id))
        elif name == "sfst":
            outputs.append(StandardFSTOutputHandler(path,
                                                    args.fst_unk_id))
        else:
            logging.fatal("Output format %s not available. Please double-check"
                          " the --outputs parameter." % name)
    return outputs
Exemplo n.º 6
0
def create_output_handlers():
    """Creates the output handlers defined in the ``io`` module. 
    These handlers create output files in different formats from the
    decoding results.
    
    Args:
        args: Global command line arguments.
    
    Returns:
        list. List of output handlers according --outputs
    """
    if not args.outputs:
        return []
    outputs = []
    for name in utils.split_comma(args.outputs):
        if '%s' in args.output_path:
            path = args.output_path % name
        else:
            path = args.output_path
        if name == "text":
            outputs.append(TextOutputHandler(path))
        elif name == "nbest":
            outputs.append(NBestOutputHandler(path, 
                                              utils.split_comma(args.predictors)))
        elif name == "ngram":
            outputs.append(NgramOutputHandler(path,
                                              args.min_ngram_order,
                                              args.max_ngram_order))
        elif name == "timecsv":
            outputs.append(TimeCSVOutputHandler(path, 
                                                utils.split_comma(args.predictors)))
        elif name == "fst":
            outputs.append(FSTOutputHandler(path,
                                            args.fst_unk_id))
        elif name == "sfst":
            outputs.append(StandardFSTOutputHandler(path,
                                                    args.fst_unk_id))
        else:
            logging.fatal("Output format %s not available. Please double-check"
                          " the --outputs parameter." % name)
    return outputs
Exemplo n.º 7
0
 def __init__(self, decoder_args):
     """Creates a new beam decoder with culmulative predictor score
     limits. In addition to the constructor of `BeamDecoder`, the 
     following values are fetched from `decoder_args`:
     
         pred_limits (string): Comma-separated list of predictor
                               score limits.
     """
     super(PredLimitBeamDecoder, self).__init__(decoder_args)
     self.pred_limits = []
     for l in utils.split_comma(decoder_args.pred_limits):
         try:
             self.pred_limits.append(float(l))
         except:
             self.pred_limits.append(utils.NEG_INF)
     logging.info("Cumulative predictor score limits: %s" %
                  self.pred_limits)
Exemplo n.º 8
0
 def get_domain_task_weights(w):
     """Get array of domain-task weights from string w
     Returns None if w is None or contains non-square number
             of weights (currently invalid)
             or 2D numpy float array of weights otherwise
     """
     if w is None:
         logging.critical(
             'Need bayesian_domain_task_weights for state-dependent BI')
     else:
         domain_weights = utils.split_comma(w, float)
         num_domains = int(len(domain_weights)**0.5)
         if len(domain_weights) == num_domains**2:
             weights_array = np.reshape(domain_weights,
                                        (num_domains, num_domains))
             logging.info('Using {} for Bayesian Interpolation'.format(
                 weights_array))
             return weights_array
         else:
             logging.critical(
                 'Need square number of domain-task weights, have {}'.
                 format(len(domain_weights)))
Exemplo n.º 9
0
 def get_domain_task_weights(w):
     """Get array of domain-task weights from string w
     Returns None if w is None or contains non-square number
             of weights (currently invalid)
             or 2D numpy float array of weights otherwise
     """
     if w is None:
         logging.critical(
             'Need bayesian_domain_task_weights for state-dependent BI')
     else:
         domain_weights = utils.split_comma(w, float)
         num_domains = int(len(domain_weights) ** 0.5)
         if len(domain_weights) == num_domains ** 2:
             weights_array = np.reshape(domain_weights,
                                        (num_domains, num_domains))
             logging.info('Using {} for Bayesian Interpolation'.format(
                 weights_array))
             return weights_array
         else:
             logging.critical(
                 'Need square number of domain-task weights, have {}'.format(
                     len(domain_weights)))
Exemplo n.º 10
0
    def __init__(self,
                 src_vocab_size,
                 trg_vocab_size,
                 model_name,
                 hparams_set_name,
                 checkpoint_dir,
                 single_cpu_thread,
                 alpha,
                 beta,
                 shortlist_strategies,
                 trg2src_model_name="",
                 trg2src_hparams_set_name="",
                 trg2src_checkpoint_dir="",
                 max_shortlist_length=0,
                 min_id=0,
                 nizza_unk_id=None):
        """Initializes a nizza predictor.

        Args:
            src_vocab_size (int): Source vocabulary size (called inputs_vocab_size
                in nizza)
            trg_vocab_size (int): Target vocabulary size (called targets_vocab_size
                in nizza)
            model_name (string): Name of the nizza model
            hparams_set_name (string): Name of the nizza hyper-parameter set
            checkpoint_dir (string): Path to the Nizza checkpoint directory. The 
                                     predictor will load the top most checkpoint in 
                                     the `checkpoints` file.
            single_cpu_thread (bool): If true, prevent tensorflow from
                                      doing multithreading.
            alpha (float): Score for each matching word
            beta (float): Penalty for each uncovered word at the end
            shortlist_strategies (string): Comma-separated list of shortlist
                strategies.
            trg2src_model_name (string): Name of the target2source nizza model
            trg2src_hparams_set_name (string): Name of the nizza hyper-parameter set
                                     for the target2source model
            trg2src_checkpoint_dir (string): Path to the Nizza checkpoint directory
                                     for the target2source model. The 
                                     predictor will load the top most checkpoint in 
                                     the `checkpoints` file.
            max_shortlist_length (int): If a shortlist exceeds this limit,
                initialize the initial coverage with 1 at this position. If
                zero, do not apply any limit
            min_id (int): Do not use IDs below this threshold (filters out most
                frequent words).
            nizza_unk_id (int): If set, use this as UNK id. Otherwise, the
                nizza is assumed to have no UNKs

        Raises:
            IOError if checkpoint file not found.
        """
        super(LexNizzaPredictor, self).__init__(src_vocab_size,
                                                trg_vocab_size,
                                                model_name,
                                                hparams_set_name,
                                                checkpoint_dir,
                                                single_cpu_thread,
                                                nizza_unk_id=nizza_unk_id)
        self.alpha = alpha
        self.alpha_is_zero = alpha == 0.0
        self.beta = beta
        self.shortlist_strategies = utils.split_comma(shortlist_strategies)
        self.max_shortlist_length = max_shortlist_length
        self.min_id = min_id
        if trg2src_checkpoint_dir:
            self.use_trg2src = True
            predictor_graph = tf.Graph()
            with predictor_graph.as_default() as g:
                hparams = registry.get_registered_hparams_set(
                    trg2src_hparams_set_name)
                hparams.add_hparam("inputs_vocab_size", trg_vocab_size)
                hparams.add_hparam("targets_vocab_size", src_vocab_size)
                run_config = tf.contrib.learn.RunConfig()
                run_config = run_config.replace(
                    model_dir=trg2src_checkpoint_dir)
                model = registry.get_registered_model(trg2src_model_name,
                                                      hparams, run_config)
                features = {
                    "inputs": tf.expand_dims(tf.range(trg_vocab_size), 0)
                }
                mode = tf.estimator.ModeKeys.PREDICT
                trg2src_lex_logits = model.precompute(features, mode, hparams)
                # Precompute trg2src partitions
                partitions = tf.reduce_logsumexp(trg2src_lex_logits, axis=-1)
                self._trg2src_src_words_var = tf.placeholder(
                    dtype=tf.int32,
                    shape=[None],
                    name="sgnmt_trg2src_src_words")
                # trg2src_lex_logits has shape [1, trg_vocab_size, src_vocab_size]
                self.trg2src_logits = tf.gather(
                    tf.transpose(trg2src_lex_logits[0, :, :]),
                    self._trg2src_src_words_var)
                # trg2src_logits has shape [len(src_words), trg_vocab_size]
                self.trg2src_mon_sess = self.create_session(
                    trg2src_checkpoint_dir)
                logging.debug("Precomputing lexnizza trg2src partitions...")
                self.trg2src_partitions = self.trg2src_mon_sess.run(partitions)
        else:
            self.use_trg2src = False
            logging.warn("No target-to-source model specified for lexnizza.")
Exemplo n.º 11
0
def add_predictors(decoder):
    """Adds all enabled predictors to the ``decoder``. This function
    makes heavy use of the global ``args`` which contains the
    SGNMT configuration. Particularly, it reads out ``args.predictors``
    and adds appropriate instances to ``decoder``.
    TODO: Refactor this method as it is waaaay tooooo looong

    Args:
        decoder (Decoder):  Decoding strategy, see ``create_decoder()``.
            This method will add predictors to this instance with
            ``add_predictor()``
    """
    preds = utils.split_comma(args.predictors)
    if not preds:
        logging.fatal("Require at least one predictor! See the --predictors "
                      "argument for more information.")
    weights = None
    if args.predictor_weights:
        weights = utils.split_comma(args.predictor_weights)
        if len(preds) != len(weights):
            logging.fatal("Specified %d predictors, but %d weights. Please "
                          "revise the --predictors and --predictor_weights "
                          "arguments" % (len(preds), len(weights)))
            return

    pred_weight = 1.0
    try:
        for idx, pred in enumerate(preds):  # Add predictors one by one
            wrappers = []
            if '_' in pred:
                # Handle weights when we have wrapper predictors
                wrappers = pred.split('_')
                pred = wrappers[-1]
                wrappers = wrappers[-2::-1]
                if weights:
                    wrapper_weights = [
                        float(w) for w in weights[idx].split('_')
                    ]
                    pred_weight = wrapper_weights[-1]
                    wrapper_weights = wrapper_weights[-2::-1]
                else:
                    wrapper_weights = [1.0] * len(wrappers)
            elif weights:
                pred_weight = float(weights[idx])

            # Create predictor instances for the string argument ``pred``
            if pred == "nmt":
                nmt_engine = _get_override_args("nmt_engine")
                if nmt_engine == 'blocks':
                    nmt_config = _parse_config_param(
                        "nmt_config", blocks_get_default_nmt_config())
                    p = blocks_get_nmt_predictor(
                        args, _get_override_args("nmt_path"), nmt_config)
                elif nmt_engine == 'tensorflow':
                    nmt_config = _parse_config_param(
                        "nmt_config", tf_get_default_nmt_config())
                    p = tf_get_nmt_predictor(args,
                                             _get_override_args("nmt_path"),
                                             nmt_config)
                elif nmt_engine != 'none':
                    logging.fatal("NMT engine %s is not supported (yet)!" %
                                  nmt_engine)
            elif pred == "nizza":
                p = NizzaPredictor(_get_override_args("pred_src_vocab_size"),
                                   _get_override_args("pred_trg_vocab_size"),
                                   _get_override_args("nizza_model"),
                                   _get_override_args("nizza_hparams_set"),
                                   _get_override_args("nizza_checkpoint_dir"),
                                   single_cpu_thread=args.single_cpu_thread)
            elif pred == "lexnizza":
                p = LexNizzaPredictor(
                    _get_override_args("pred_src_vocab_size"),
                    _get_override_args("pred_trg_vocab_size"),
                    _get_override_args("nizza_model"),
                    _get_override_args("nizza_hparams_set"),
                    _get_override_args("nizza_checkpoint_dir"),
                    single_cpu_thread=args.single_cpu_thread,
                    alpha=args.lexnizza_alpha,
                    beta=args.lexnizza_beta,
                    trg2src_model_name=args.lexnizza_trg2src_model,
                    trg2src_hparams_set_name=args.lexnizza_trg2src_hparams_set,
                    trg2src_checkpoint_dir=args.
                    lexnizza_trg2src_checkpoint_dir,
                    shortlist_strategies=args.lexnizza_shortlist_strategies,
                    max_shortlist_length=args.lexnizza_max_shortlist_length,
                    min_id=args.lexnizza_min_id)
            elif pred == "t2t":
                p = T2TPredictor(_get_override_args("pred_src_vocab_size"),
                                 _get_override_args("pred_trg_vocab_size"),
                                 _get_override_args("t2t_model"),
                                 _get_override_args("t2t_problem"),
                                 _get_override_args("t2t_hparams_set"),
                                 args.t2t_usr_dir,
                                 _get_override_args("t2t_checkpoint_dir"),
                                 single_cpu_thread=args.single_cpu_thread,
                                 max_terminal_id=args.syntax_max_terminal_id,
                                 pop_id=args.syntax_pop_id)
            elif pred == "simt2t":
                p = SimT2TPredictor(
                    _get_override_args("pred_src_vocab_size"),
                    _get_override_args("pred_trg_vocab_size"),
                    _get_override_args("t2t_model"),
                    _get_override_args("t2t_problem"),
                    _get_override_args("t2t_hparams_set"),
                    args.t2t_usr_dir,
                    _get_override_args("t2t_checkpoint_dir"),
                    single_cpu_thread=args.single_cpu_thread,
                    max_terminal_id=args.syntax_max_terminal_id,
                    pop_id=args.syntax_pop_id)
            elif pred == "simt2tv2":
                p = SimT2TPredictor_v2(
                    _get_override_args("pred_src_vocab_size"),
                    _get_override_args("pred_trg_vocab_size"),
                    _get_override_args("t2t_model"),
                    _get_override_args("t2t_problem"),
                    _get_override_args("t2t_hparams_set"),
                    args.t2t_usr_dir,
                    _get_override_args("t2t_checkpoint_dir"),
                    single_cpu_thread=args.single_cpu_thread,
                    max_terminal_id=args.syntax_max_terminal_id,
                    pop_id=args.syntax_pop_id)
            elif pred == "bracket":
                p = BracketPredictor(args.syntax_max_terminal_id,
                                     args.syntax_pop_id,
                                     max_depth=args.syntax_max_depth,
                                     extlength_path=args.extlength_path)
            elif pred == "osm":
                p = OSMPredictor()
            elif pred == "forcedosm":
                p = ForcedOSMPredictor(args.trg_test)
            elif pred == "fst":
                p = FstPredictor(_get_override_args("fst_path"),
                                 args.use_fst_weights,
                                 args.normalize_fst_weights,
                                 skip_bos_weight=args.fst_skip_bos_weight,
                                 to_log=args.fst_to_log)
            elif pred == "nfst":
                p = NondeterministicFstPredictor(
                    _get_override_args("fst_path"),
                    args.use_fst_weights,
                    args.normalize_fst_weights,
                    args.fst_skip_bos_weight,
                    to_log=args.fst_to_log)
            elif pred == "forced":
                p = ForcedPredictor(args.trg_test)
            elif pred == "bow":
                p = BagOfWordsPredictor(
                    args.trg_test, args.bow_accept_subsets,
                    args.bow_accept_duplicates, args.heuristic_scores_file,
                    args.collect_statistics, "consumed"
                    in args.bow_heuristic_strategies, "remaining"
                    in args.bow_heuristic_strategies,
                    args.bow_diversity_heuristic_factor,
                    _get_override_args("pred_trg_vocab_size"))
            elif pred == "bowsearch":
                p = BagOfWordsSearchPredictor(
                    decoder, args.hypo_recombination, args.trg_test,
                    args.bow_accept_subsets, args.bow_accept_duplicates,
                    args.heuristic_scores_file, args.collect_statistics,
                    "consumed" in args.bow_heuristic_strategies, "remaining"
                    in args.bow_heuristic_strategies,
                    args.bow_diversity_heuristic_factor,
                    _get_override_args("pred_trg_vocab_size"))
            elif pred == "forcedlst":
                feat_name = _get_override_args("forcedlst_sparse_feat")
                p = ForcedLstPredictor(args.trg_test, args.use_nbest_weights,
                                       feat_name if feat_name else None)
            elif pred == "rtn":
                p = RtnPredictor(args.rtn_path,
                                 args.use_rtn_weights,
                                 args.normalize_rtn_weights,
                                 to_log=args.fst_to_log,
                                 minimize_rtns=args.minimize_rtns,
                                 rmeps=args.remove_epsilon_in_rtns)
            elif pred == "srilm":
                p = SRILMPredictor(args.srilm_path, args.srilm_order,
                                   args.srilm_convert_to_ln)
            elif pred == "nplm":
                p = NPLMPredictor(args.nplm_path, args.normalize_nplm_probs)
            elif pred == "rnnlm":
                p = tf_get_rnnlm_predictor(_get_override_args("rnnlm_path"),
                                           _get_override_args("rnnlm_config"),
                                           tf_get_rnnlm_prefix())
            elif pred == "wc":
                p = WordCountPredictor(args.wc_word)
            elif pred == "ngramc":
                p = NgramCountPredictor(_get_override_args("ngramc_path"),
                                        _get_override_args("ngramc_order"),
                                        args.ngramc_discount_factor)
            elif pred == "unkc":
                p = UnkCountPredictor(
                    _get_override_args("pred_src_vocab_size"),
                    [float(l) for l in args.unk_count_lambdas.split(',')])
            elif pred == "length":
                length_model_weights = [
                    float(w) for w in args.length_model_weights.split(',')
                ]
                p = NBLengthPredictor(args.src_test_raw, length_model_weights,
                                      args.use_length_point_probs,
                                      args.length_model_offset)
            elif pred == "extlength":
                p = ExternalLengthPredictor(args.extlength_path)
            elif pred == "lrhiero":
                fw = None
                if args.grammar_feature_weights:
                    fw = [
                        float(w)
                        for w in args.grammar_feature_weights.split(',')
                    ]
                p = RuleXtractPredictor(args.rules_path,
                                        args.use_grammar_weights, fw)
            elif pred == "vanilla":
                continue
            else:
                logging.fatal("Predictor '%s' not available. Please check "
                              "--predictors for spelling errors." % pred)
                decoder.remove_predictors()
                return
            for _, wrapper in enumerate(wrappers):
                # Embed predictor ``p`` into wrapper predictors if necessary
                # TODO: Use wrapper_weights
                if wrapper == "idxmap":
                    src_path = _get_override_args("src_idxmap")
                    trg_path = _get_override_args("trg_idxmap")
                    if isinstance(p, UnboundedVocabularyPredictor):
                        p = UnboundedIdxmapPredictor(src_path, trg_path, p,
                                                     1.0)
                    else:  # idxmap predictor for bounded predictors
                        p = IdxmapPredictor(src_path, trg_path, p, 1.0)
                elif wrapper == "altsrc":
                    src_test = _get_override_args("altsrc_test")
                    if isinstance(p, UnboundedVocabularyPredictor):
                        p = UnboundedAltsrcPredictor(src_test, p)
                    else:  # altsrc predictor for bounded predictors
                        p = AltsrcPredictor(src_test, p)
                elif wrapper == "word2char":
                    map_path = _get_override_args("word2char_map")
                    # word2char always wraps unbounded predictors
                    p = Word2charPredictor(map_path, p)
                elif wrapper == "skipvocab":
                    # skipvocab always wraps unbounded predictors
                    p = SkipvocabPredictor(args.skipvocab_max_id,
                                           args.skipvocab_stop_size, args.beam,
                                           p)
                elif wrapper == "fsttok":
                    fsttok_path = _get_override_args("fsttok_path")
                    # fsttok always wraps unbounded predictors
                    p = FSTTokPredictor(fsttok_path, args.fst_unk_id,
                                        args.fsttok_max_pending_score, p)
                elif wrapper == "ngramize":
                    # ngramize always wraps bounded predictors
                    p = NgramizePredictor(args.min_ngram_order,
                                          args.max_ngram_order,
                                          args.max_len_factor, p)
                elif wrapper == "unkvocab":
                    # unkvocab always wraps bounded predictors
                    p = UnkvocabPredictor(args.trg_vocab_size, p)
                else:
                    logging.fatal("Predictor wrapper '%s' not available. "
                                  "Please double-check --predictors for "
                                  "spelling errors." % wrapper)
                    decoder.remove_predictors()
                    return
            decoder.add_predictor(pred, p, pred_weight)
            logging.info("Initialized predictor {} (weight: {})".format(
                pred, pred_weight))
    except IOError as e:
        logging.fatal("One of the files required for setting up the "
                      "predictors could not be read: %s" % e)
        decoder.remove_predictors()
    except AttributeError as e:
        logging.fatal("Invalid argument for one of the predictors: %s" % e)
        decoder.remove_predictors()
    except NameError as e:
        logging.fatal("Could not find external library: %s. Please make sure "
                      "that your PYTHONPATH and LD_LIBRARY_PATH contains all "
                      "paths required for the predictors. Stack trace: %s" %
                      (e, traceback.format_exc()))
        decoder.remove_predictors()
    except ValueError as e:
        logging.fatal("A number format error occurred while configuring the "
                      "predictors: %s. Please double-check all integer- or "
                      "float-valued parameters such as --predictor_weights and"
                      " try again. Stack trace: %s" %
                      (e, traceback.format_exc()))
        decoder.remove_predictors()
    except Exception as e:
        logging.fatal("An unexpected %s has occurred while setting up the pre"
                      "dictors: %s Stack trace: %s" %
                      (sys.exc_info()[0], e, traceback.format_exc()))
        decoder.remove_predictors()
Exemplo n.º 12
0
def add_predictors(decoder):
    """Adds all enabled predictors to the ``decoder``. This function 
    makes heavy use of the global ``args`` which contains the
    SGNMT configuration. Particularly, it reads out ``args.predictors``
    and adds appropriate instances to ``decoder``.
    TODO: Refactor this method as it is waaaay tooooo looong
    
    Args:
        decoder (Decoder):  Decoding strategy, see ``create_decoder()``.
            This method will add predictors to this instance with
            ``add_predictor()``
    """
    preds = utils.split_comma(args.predictors)
    if not preds:
        logging.fatal("Require at least one predictor! See the --predictors "
                      "argument for more information.")
    weights = None
    if args.predictor_weights:
        weights = utils.split_comma(args.predictor_weights)
        if len(preds) != len(weights):
            logging.fatal("Specified %d predictors, but %d weights. Please "
                      "revise the --predictors and --predictor_weights "
                      "arguments" % (len(preds), len(weights)))
            return
    
    pred_weight = 1.0
    try:
        for idx, pred in enumerate(preds): # Add predictors one by one
            wrappers = []
            if '_' in pred: 
                # Handle weights when we have wrapper predictors
                wrappers = pred.split('_')
                pred = wrappers[-1]
                wrappers = wrappers[-2::-1]
                if weights:
                    wrapper_weights = [float(w) for w in weights[idx].split('_')]
                    pred_weight = wrapper_weights[-1]
                    wrapper_weights = wrapper_weights[-2::-1]
                else:
                    wrapper_weights = [1.0] * len(wrappers)
            elif weights:
                pred_weight = float(weights[idx])

            # Create predictor instances for the string argument ``pred``
            if pred == "nmt":
                # TODO: Clean this up and make a blocks and tfnmt predictor
                nmt_engine = _get_override_args("nmt_engine")
                if nmt_engine == 'blocks':
                    nmt_config = _parse_config_param(
                        "nmt_config", blocks_get_default_nmt_config())
                    p = blocks_get_nmt_predictor(
                        args, _get_override_args("nmt_path"), nmt_config)
                elif nmt_engine == 'tensorflow':
                    nmt_config = _parse_config_param(
                        "nmt_config", tf_get_default_nmt_config())
                    p = tf_get_nmt_predictor(
                        args, _get_override_args("nmt_path"), nmt_config)
                elif nmt_engine != 'none':
                    logging.fatal("NMT engine %s is not supported (yet)!" % nmt_engine)
            elif pred == "nizza":
                p = NizzaPredictor(_get_override_args("pred_src_vocab_size"),
                                   _get_override_args("pred_trg_vocab_size"),
                                   _get_override_args("nizza_model"),
                                   _get_override_args("nizza_hparams_set"),
                                   _get_override_args("nizza_checkpoint_dir"),
                                   n_cpu_threads=args.n_cpu_threads)
            elif pred == "lexnizza":
                p = LexNizzaPredictor(_get_override_args("pred_src_vocab_size"),
                                      _get_override_args("pred_trg_vocab_size"),
                                      _get_override_args("nizza_model"),
                                      _get_override_args("nizza_hparams_set"),
                                      _get_override_args("nizza_checkpoint_dir"),
                                      n_cpu_threads=args.n_cpu_threads,
                                      alpha=args.lexnizza_alpha,
                                      beta=args.lexnizza_beta,
                                      trg2src_model_name=
                                          args.lexnizza_trg2src_model, 
                                      trg2src_hparams_set_name=
                                          args.lexnizza_trg2src_hparams_set,
                                      trg2src_checkpoint_dir=
                                          args.lexnizza_trg2src_checkpoint_dir,
                                      shortlist_strategies=
                                          args.lexnizza_shortlist_strategies,
                                      max_shortlist_length=
                                          args.lexnizza_max_shortlist_length,
                                      min_id=args.lexnizza_min_id)
            elif pred == "t2t":
                p = T2TPredictor(_get_override_args("pred_src_vocab_size"),
                                 _get_override_args("pred_trg_vocab_size"),
                                 _get_override_args("t2t_model"),
                                 _get_override_args("t2t_problem"),
                                 _get_override_args("t2t_hparams_set"),
                                 args.t2t_usr_dir,
                                 _get_override_args("t2t_checkpoint_dir"),
                                 t2t_unk_id=_get_override_args("t2t_unk_id"),
                                 n_cpu_threads=args.n_cpu_threads,
                                 max_terminal_id=args.syntax_max_terminal_id,
                                 pop_id=args.syntax_pop_id)
            elif pred == "segt2t":
                p = SegT2TPredictor(_get_override_args("pred_src_vocab_size"),
                                    _get_override_args("pred_trg_vocab_size"),
                                    _get_override_args("t2t_model"),
                                    _get_override_args("t2t_problem"),
                                    _get_override_args("t2t_hparams_set"),
                                    args.t2t_usr_dir,
                                    _get_override_args("t2t_checkpoint_dir"),
                                    t2t_unk_id=_get_override_args("t2t_unk_id"),
                                    n_cpu_threads=args.n_cpu_threads,
                                    max_terminal_id=args.syntax_max_terminal_id,
                                    pop_id=args.syntax_pop_id)
            elif pred == "editt2t":
                p = EditT2TPredictor(_get_override_args("pred_src_vocab_size"),
                                     _get_override_args("pred_trg_vocab_size"),
                                     _get_override_args("t2t_model"),
                                     _get_override_args("t2t_problem"),
                                     _get_override_args("t2t_hparams_set"),
                                     args.trg_test,
                                     args.beam,
                                     args.t2t_usr_dir,
                                     _get_override_args("t2t_checkpoint_dir"),
                                     t2t_unk_id=_get_override_args("t2t_unk_id"),
                                     n_cpu_threads=args.n_cpu_threads,
                                     max_terminal_id=args.syntax_max_terminal_id,
                                     pop_id=args.syntax_pop_id)
            elif pred == "fertt2t":
                p = FertilityT2TPredictor(
                                 _get_override_args("pred_src_vocab_size"),
                                 _get_override_args("pred_trg_vocab_size"),
                                 _get_override_args("t2t_model"),
                                 _get_override_args("t2t_problem"),
                                 _get_override_args("t2t_hparams_set"),
                                 args.t2t_usr_dir,
                                 _get_override_args("t2t_checkpoint_dir"),
                                 n_cpu_threasd=args.n_cpu_threads,
                                 max_terminal_id=args.syntax_max_terminal_id,
                                 pop_id=args.syntax_pop_id)
            elif pred == "bracket":
                p = BracketPredictor(args.syntax_max_terminal_id,
                                     args.syntax_pop_id,
                                     max_depth=args.syntax_max_depth,
                                     extlength_path=args.extlength_path)
            elif pred == "osm":
                p = OSMPredictor(args.osm_type)
            elif pred == "forcedosm":
                p = ForcedOSMPredictor(args.trg_test)
            elif pred == "fst":
                p = FstPredictor(_get_override_args("fst_path"),
                                 args.use_fst_weights,
                                 args.normalize_fst_weights,
                                 skip_bos_weight=args.fst_skip_bos_weight,
                                 to_log=args.fst_to_log)
            elif pred == "nfst":
                p = NondeterministicFstPredictor(_get_override_args("fst_path"),
                                                 args.use_fst_weights,
                                                 args.normalize_fst_weights,
                                                 args.fst_skip_bos_weight,
                                                 to_log=args.fst_to_log)
            elif pred == "forced":
                p = ForcedPredictor(args.trg_test)
            elif pred == "bow":
                p = BagOfWordsPredictor(
                                args.trg_test,
                                args.bow_accept_subsets,
                                args.bow_accept_duplicates,
                                args.heuristic_scores_file,
                                args.collect_statistics,
                                "consumed" in args.bow_heuristic_strategies,
                                "remaining" in args.bow_heuristic_strategies,
                                args.bow_diversity_heuristic_factor,
                                _get_override_args("pred_trg_vocab_size"))
            elif pred == "bowsearch":
                p = BagOfWordsSearchPredictor(
                                decoder,
                                args.hypo_recombination,
                                args.trg_test,
                                args.bow_accept_subsets,
                                args.bow_accept_duplicates,
                                args.heuristic_scores_file,
                                args.collect_statistics,
                                "consumed" in args.bow_heuristic_strategies,
                                "remaining" in args.bow_heuristic_strategies,
                                args.bow_diversity_heuristic_factor,
                                _get_override_args("pred_trg_vocab_size"))
            elif pred == "forcedlst":
                feat_name = _get_override_args("forcedlst_sparse_feat")
                p = ForcedLstPredictor(args.trg_test,
                                       args.use_nbest_weights,
                                       args.forcedlst_match_unk,
                                       feat_name if feat_name else None)
            elif pred == "rtn":
                p = RtnPredictor(args.rtn_path,
                                 args.use_rtn_weights,
                                 args.normalize_rtn_weights,
                                 to_log=args.fst_to_log,
                                 minimize_rtns=args.minimize_rtns,
                                 rmeps=args.remove_epsilon_in_rtns)
            elif pred == "srilm":
                p = SRILMPredictor(args.lm_path, 
                                   _get_override_args("ngramc_order"),
                                   args.srilm_convert_to_ln)
            elif pred == "kenlm":
                p = KenLMPredictor(args.lm_path)
            elif pred == "nplm":
                p = NPLMPredictor(args.nplm_path, args.normalize_nplm_probs)
            elif pred == "rnnlm":
                p = tf_get_rnnlm_predictor(_get_override_args("rnnlm_path"),
                                           _get_override_args("rnnlm_config"),
                                           tf_get_rnnlm_prefix())
            elif pred == "wc":
                p = WordCountPredictor(args.wc_word,
                                       args.wc_nonterminal_penalty,
                                       args.syntax_nonterminal_ids,
                                       args.syntax_min_terminal_id,
                                       args.syntax_max_terminal_id,
                                       args.negative_wc,
                                       _get_override_args("pred_trg_vocab_size"))
            elif pred == "ngramc":
                p = NgramCountPredictor(_get_override_args("ngramc_path"),
                                        _get_override_args("ngramc_order"),
                                        args.ngramc_discount_factor)
            elif pred == "unkc":
                p = UnkCountPredictor(
                     _get_override_args("pred_src_vocab_size"), 
                     utils.split_comma(args.unk_count_lambdas, float))
            elif pred == "length":
                length_model_weights = utils.split_comma(
                    args.length_model_weights, float)
                p = NBLengthPredictor(args.src_test_raw, 
                                      length_model_weights, 
                                      args.use_length_point_probs,
                                      args.length_model_offset)
            elif pred == "extlength":
                p = ExternalLengthPredictor(args.extlength_path)
            elif pred == "lrhiero":
                fw = None
                if args.grammar_feature_weights:
                    fw = utils.split_comma(args.grammar_feature_weights, float)
                p = RuleXtractPredictor(args.rules_path,
                                        args.use_grammar_weights,
                                        fw)
            elif pred == "vanilla":
                continue
            else:
                logging.fatal("Predictor '%s' not available. Please check "
                              "--predictors for spelling errors." % pred)
                decoder.remove_predictors()
                return
            for _,wrapper in enumerate(wrappers):
                # Embed predictor ``p`` into wrapper predictors if necessary
                # TODO: Use wrapper_weights
                if wrapper == "idxmap":
                    src_path = _get_override_args("src_idxmap")
                    trg_path = _get_override_args("trg_idxmap")
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedIdxmapPredictor(src_path, trg_path, p, 1.0) 
                    else: # idxmap predictor for bounded predictors
                        p = IdxmapPredictor(src_path, trg_path, p, 1.0)
                elif wrapper == "maskvocab":
                    words = utils.split_comma(args.maskvocab_words, int)
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedMaskvocabPredictor(words, p) 
                    else: # idxmap predictor for bounded predictors
                        p = MaskvocabPredictor(words, p)
                elif wrapper == "weightnt":
                    p = WeightNonTerminalPredictor(
                        p, 
                        args.syntax_nonterminal_factor,
                        args.syntax_nonterminal_ids,
                        args.syntax_min_terminal_id,
                        args.syntax_max_terminal_id,
                        _get_override_args("pred_trg_vocab_size"))
                elif wrapper == "parse":
                    if args.parse_tok_grammar:
                        if args.parse_bpe_path:
                            p = BpeParsePredictor(
                                args.syntax_path,
                                args.syntax_bpe_path,
                                p,
                                args.syntax_word_out,
                                args.normalize_fst_weights,
                                norm_alpha=args.syntax_norm_alpha,
                                beam_size=args.syntax_internal_beam,
                                max_internal_len=args.syntax_max_internal_len,
                                allow_early_eos=args.syntax_allow_early_eos,
                                consume_out_of_class=args.syntax_consume_ooc,
                                terminal_restrict=args.syntax_terminal_restrict,
                                internal_only_restrict=args.syntax_internal_only,
                                eow_ids=args.syntax_eow_ids,
                                terminal_ids=args.syntax_terminal_ids)
                        else:
                            p = TokParsePredictor(
                                args.syntax_path,
                                p,
                                args.syntax_word_out,
                                args.normalize_fst_weights,
                                norm_alpha=args.syntax_norm_alpha,
                                beam_size=args.syntax_internal_beam,
                                max_internal_len=args.syntax_max_internal_len,
                                allow_early_eos=args.syntax_allow_early_eos,
                                consume_out_of_class=args.syntax_consume_ooc)
                    else:
                        p = ParsePredictor(
                            p,
                            args.normalize_fst_weights,
                            beam_size=args.syntax_internal_beam,
                            max_internal_len=args.syntax_max_internal_len,
                            nonterminal_ids=args.syntax_nonterminal_ids)
                elif wrapper == "altsrc":
                    src_test = _get_override_args("altsrc_test")
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedAltsrcPredictor(src_test, p)
                    else: # altsrc predictor for bounded predictors
                        p = AltsrcPredictor(src_test, p)
                elif wrapper == "rank":
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedRankPredictor(p)
                    else: # rank predictor for bounded predictors
                        p = RankPredictor(p)
                elif wrapper == "glue":
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedGluePredictor(args.max_len_factor, p)
                    else: # glue predictor for bounded predictors
                        p = GluePredictor(args.max_len_factor, p)
                elif wrapper == "word2char":
                    map_path = _get_override_args("word2char_map")
                    # word2char always wraps unbounded predictors
                    p = Word2charPredictor(map_path, p)
                elif wrapper == "skipvocab":
                    # skipvocab always wraps unbounded predictors
                    p = SkipvocabPredictor(args.skipvocab_max_id, 
                                           args.skipvocab_stop_size, 
                                           args.beam, 
                                           p)
                elif wrapper == "fsttok":
                    fsttok_path = _get_override_args("fsttok_path")
                    # fsttok always wraps unbounded predictors
                    p = FSTTokPredictor(fsttok_path,
                                        args.fst_unk_id,
                                        args.fsttok_max_pending_score,
                                        p)
                elif wrapper == "ngramize":
                    # ngramize always wraps bounded predictors
                    p = NgramizePredictor(args.min_ngram_order, 
                                          args.max_ngram_order,
                                          args.max_len_factor, p)
                elif wrapper == "unkvocab":
                    # unkvocab always wraps bounded predictors
                    p = UnkvocabPredictor(args.trg_vocab_size, p)
                else:
                    logging.fatal("Predictor wrapper '%s' not available. "
                                  "Please double-check --predictors for "
                                  "spelling errors." % wrapper)
                    decoder.remove_predictors()
                    return
            decoder.add_predictor(pred, p, pred_weight)
            logging.info("Initialized predictor {} (weight: {})".format(
                             pred, pred_weight))
    except IOError as e:
        logging.fatal("One of the files required for setting up the "
                      "predictors could not be read: %s" % e)
        decoder.remove_predictors()
    except AttributeError as e:
        logging.fatal("Invalid argument for one of the predictors: %s."
                       "Stack trace: %s" % (e, traceback.format_exc()))
        decoder.remove_predictors()
    except NameError as e:
        logging.fatal("Could not find external library: %s. Please make sure "
                      "that your PYTHONPATH and LD_LIBRARY_PATH contains all "
                      "paths required for the predictors. Stack trace: %s" % 
                      (e, traceback.format_exc()))
        decoder.remove_predictors()
    except ValueError as e:
        logging.fatal("A number format error occurred while configuring the "
                      "predictors: %s. Please double-check all integer- or "
                      "float-valued parameters such as --predictor_weights and"
                      " try again. Stack trace: %s" % (e, traceback.format_exc()))
        decoder.remove_predictors()
    except Exception as e:
        logging.fatal("An unexpected %s has occurred while setting up the pre"
                      "dictors: %s Stack trace: %s" % (sys.exc_info()[0],
                                                       e,
                                                       traceback.format_exc()))
        decoder.remove_predictors()
Exemplo n.º 13
0
def add_predictors(decoder):
    """Adds all enabled predictors to the ``decoder``. This function 
    makes heavy use of the global ``args`` which contains the
    SGNMT configuration. Particularly, it reads out ``args.predictors``
    and adds appropriate instances to ``decoder``.
    TODO: Refactor this method as it is waaaay tooooo looong
    
    Args:
        decoder (Decoder):  Decoding strategy, see ``create_decoder()``.
            This method will add predictors to this instance with
            ``add_predictor()``
    """
    preds = utils.split_comma(args.predictors)
    if not preds:
        logging.fatal("Require at least one predictor! See the --predictors "
                      "argument for more information.")
    weights = None
    if args.predictor_weights:
        weights = utils.split_comma(args.predictor_weights)
        if len(preds) != len(weights):
            logging.fatal("Specified %d predictors, but %d weights. Please "
                      "revise the --predictors and --predictor_weights "
                      "arguments" % (len(preds), len(weights)))
            return
    
    pred_weight = 1.0
    try:
        for idx, pred in enumerate(preds): # Add predictors one by one
            wrappers = []
            if '_' in pred: 
                # Handle weights when we have wrapper predictors
                wrappers = pred.split('_')
                pred = wrappers[-1]
                wrappers = wrappers[-2::-1]
                if weights:
                    wrapper_weights = [float(w) for w in weights[idx].split('_')]
                    pred_weight = wrapper_weights[-1]
                    wrapper_weights = wrapper_weights[-2::-1]
                else:
                    wrapper_weights = [1.0] * len(wrappers)
            elif weights:
                pred_weight = float(weights[idx])

            # Create predictor instances for the string argument ``pred``
            if pred == "nizza":
                p = NizzaPredictor(_get_override_args("pred_src_vocab_size"),
                                   _get_override_args("pred_trg_vocab_size"),
                                   _get_override_args("nizza_model"),
                                   _get_override_args("nizza_hparams_set"),
                                   _get_override_args("nizza_checkpoint_dir"),
                                   n_cpu_threads=args.n_cpu_threads)
            elif pred == "lexnizza":
                p = LexNizzaPredictor(_get_override_args("pred_src_vocab_size"),
                                      _get_override_args("pred_trg_vocab_size"),
                                      _get_override_args("nizza_model"),
                                      _get_override_args("nizza_hparams_set"),
                                      _get_override_args("nizza_checkpoint_dir"),
                                      n_cpu_threads=args.n_cpu_threads,
                                      alpha=args.lexnizza_alpha,
                                      beta=args.lexnizza_beta,
                                      trg2src_model_name=
                                          args.lexnizza_trg2src_model, 
                                      trg2src_hparams_set_name=
                                          args.lexnizza_trg2src_hparams_set,
                                      trg2src_checkpoint_dir=
                                          args.lexnizza_trg2src_checkpoint_dir,
                                      shortlist_strategies=
                                          args.lexnizza_shortlist_strategies,
                                      max_shortlist_length=
                                          args.lexnizza_max_shortlist_length,
                                      min_id=args.lexnizza_min_id)
            elif pred == "t2t":
                p = T2TPredictor(_get_override_args("pred_src_vocab_size"),
                                 _get_override_args("pred_trg_vocab_size"),
                                 _get_override_args("t2t_model"),
                                 _get_override_args("t2t_problem"),
                                 _get_override_args("t2t_hparams_set"),
                                 args.t2t_usr_dir,
                                 _get_override_args("t2t_checkpoint_dir"),
                                 t2t_unk_id=_get_override_args("t2t_unk_id"),
                                 n_cpu_threads=args.n_cpu_threads,
                                 max_terminal_id=args.syntax_max_terminal_id,
                                 pop_id=args.syntax_pop_id)
            elif pred == "segt2t":
                p = SegT2TPredictor(_get_override_args("pred_src_vocab_size"),
                                    _get_override_args("pred_trg_vocab_size"),
                                    _get_override_args("t2t_model"),
                                    _get_override_args("t2t_problem"),
                                    _get_override_args("t2t_hparams_set"),
                                    args.t2t_usr_dir,
                                    _get_override_args("t2t_checkpoint_dir"),
                                    t2t_unk_id=_get_override_args("t2t_unk_id"),
                                    n_cpu_threads=args.n_cpu_threads,
                                    max_terminal_id=args.syntax_max_terminal_id,
                                    pop_id=args.syntax_pop_id)
            elif pred == "editt2t":
                p = EditT2TPredictor(_get_override_args("pred_src_vocab_size"),
                                     _get_override_args("pred_trg_vocab_size"),
                                     _get_override_args("t2t_model"),
                                     _get_override_args("t2t_problem"),
                                     _get_override_args("t2t_hparams_set"),
                                     args.trg_test,
                                     args.beam,
                                     args.t2t_usr_dir,
                                     _get_override_args("t2t_checkpoint_dir"),
                                     t2t_unk_id=_get_override_args("t2t_unk_id"),
                                     n_cpu_threads=args.n_cpu_threads,
                                     max_terminal_id=args.syntax_max_terminal_id,
                                     pop_id=args.syntax_pop_id)
            elif pred == "fertt2t":
                p = FertilityT2TPredictor(
                                 _get_override_args("pred_src_vocab_size"),
                                 _get_override_args("pred_trg_vocab_size"),
                                 _get_override_args("t2t_model"),
                                 _get_override_args("t2t_problem"),
                                 _get_override_args("t2t_hparams_set"),
                                 args.t2t_usr_dir,
                                 _get_override_args("t2t_checkpoint_dir"),
                                 n_cpu_threasd=args.n_cpu_threads,
                                 max_terminal_id=args.syntax_max_terminal_id,
                                 pop_id=args.syntax_pop_id)
            elif pred == "fairseq":
                p = FairseqPredictor(_get_override_args("fairseq_path"),
                                     args.fairseq_user_dir,
                                     args.fairseq_lang_pair,
                                     args.n_cpu_threads,
                                     args.subtract_uni,
                                     args.subtract_marg,
                                     _get_override_args("marg_path"),
                                     args.lmbda,
                                     args.ppmi,
                                     args.epsilon)
            elif pred == "bracket":
                p = BracketPredictor(args.syntax_max_terminal_id,
                                     args.syntax_pop_id,
                                     max_depth=args.syntax_max_depth,
                                     extlength_path=args.extlength_path)
            elif pred == "osm":
                p = OSMPredictor(args.src_wmap,
                                 args.trg_wmap,
                                 use_jumps=args.osm_use_jumps,
                                 use_auto_pop=args.osm_use_auto_pop,
                                 use_unpop=args.osm_use_unpop,
                                 use_pop2=args.osm_use_pop2,
                                 use_src_eop=args.osm_use_src_eop,
                                 use_copy=args.osm_use_copy)
            elif pred == "forcedosm":
                p = ForcedOSMPredictor(args.src_wmap, 
                                       args.trg_wmap, 
                                       args.trg_test)
            elif pred == "fst":
                p = FstPredictor(_get_override_args("fst_path"),
                                 args.use_fst_weights,
                                 args.normalize_fst_weights,
                                 skip_bos_weight=args.fst_skip_bos_weight,
                                 to_log=args.fst_to_log)
            elif pred == "nfst":
                p = NondeterministicFstPredictor(_get_override_args("fst_path"),
                                                 args.use_fst_weights,
                                                 args.normalize_fst_weights,
                                                 args.fst_skip_bos_weight,
                                                 to_log=args.fst_to_log)
            elif pred == "forced":
                p = ForcedPredictor(
                                args.trg_test, 
                                utils.split_comma(args.forced_spurious, int))
            elif pred == "bow":
                p = BagOfWordsPredictor(
                                args.trg_test,
                                args.bow_accept_subsets,
                                args.bow_accept_duplicates,
                                args.heuristic_scores_file,
                                args.collect_statistics,
                                "consumed" in args.bow_heuristic_strategies,
                                "remaining" in args.bow_heuristic_strategies,
                                args.bow_diversity_heuristic_factor,
                                _get_override_args("pred_trg_vocab_size"))
            elif pred == "bowsearch":
                p = BagOfWordsSearchPredictor(
                                decoder,
                                args.hypo_recombination,
                                args.trg_test,
                                args.bow_accept_subsets,
                                args.bow_accept_duplicates,
                                args.heuristic_scores_file,
                                args.collect_statistics,
                                "consumed" in args.bow_heuristic_strategies,
                                "remaining" in args.bow_heuristic_strategies,
                                args.bow_diversity_heuristic_factor,
                                _get_override_args("pred_trg_vocab_size"))
            elif pred == "forcedlst":
                feat_name = _get_override_args("forcedlst_sparse_feat")
                p = ForcedLstPredictor(args.trg_test,
                                       args.use_nbest_weights,
                                       args.forcedlst_match_unk,
                                       feat_name if feat_name else None)
            elif pred == "rtn":
                p = RtnPredictor(args.rtn_path,
                                 args.use_rtn_weights,
                                 args.normalize_rtn_weights,
                                 to_log=args.fst_to_log,
                                 minimize_rtns=args.minimize_rtns,
                                 rmeps=args.remove_epsilon_in_rtns)
            elif pred == "kenlm":
                p = KenLMPredictor(args.lm_path)
            elif pred == "wc":
                p = WordCountPredictor(args.wc_word,
                                       args.wc_nonterminal_penalty,
                                       args.syntax_nonterminal_ids,
                                       args.syntax_min_terminal_id,
                                       args.syntax_max_terminal_id,
                                       args.negative_wc,
                                       _get_override_args("pred_trg_vocab_size"))
            elif pred == "ngramc":
                p = NgramCountPredictor(_get_override_args("ngramc_path"),
                                        _get_override_args("ngramc_order"),
                                        args.ngramc_discount_factor)
            elif pred == "unkc":
                p = UnkCountPredictor(
                     _get_override_args("pred_src_vocab_size"), 
                     utils.split_comma(args.unk_count_lambdas, float))
            elif pred == "length":
                length_model_weights = utils.split_comma(
                    args.length_model_weights, float)
                p = NBLengthPredictor(args.src_test_raw, 
                                      length_model_weights, 
                                      args.use_length_point_probs,
                                      args.length_model_offset)
            elif pred == "extlength":
                p = ExternalLengthPredictor(args.extlength_path)
            elif pred == "lrhiero":
                fw = None
                if args.grammar_feature_weights:
                    fw = utils.split_comma(args.grammar_feature_weights, float)
                p = RuleXtractPredictor(args.rules_path,
                                        args.use_grammar_weights,
                                        fw)
            else:
                logging.fatal("Predictor '%s' not available. Please check "
                              "--predictors for spelling errors." % pred)
                decoder.remove_predictors()
                return
            for _,wrapper in enumerate(wrappers):
                # Embed predictor ``p`` into wrapper predictors if necessary
                # TODO: Use wrapper_weights
                if wrapper == "idxmap":
                    src_path = _get_override_args("src_idxmap")
                    trg_path = _get_override_args("trg_idxmap")
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedIdxmapPredictor(src_path, trg_path, p, 1.0) 
                    else: # idxmap predictor for bounded predictors
                        p = IdxmapPredictor(src_path, trg_path, p, 1.0)
                elif wrapper == "maskvocab":
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedMaskvocabPredictor(args.maskvocab_vocab, p)
                    else: # idxmap predictor for bounded predictors
                        p = MaskvocabPredictor(args.maskvocab_vocab, p)
                elif wrapper == "weightnt":
                    p = WeightNonTerminalPredictor(
                        p, 
                        args.syntax_nonterminal_factor,
                        args.syntax_nonterminal_ids,
                        args.syntax_min_terminal_id,
                        args.syntax_max_terminal_id,
                        _get_override_args("pred_trg_vocab_size"))
                elif wrapper == "parse":
                    if args.parse_tok_grammar:
                        if args.parse_bpe_path:
                            p = BpeParsePredictor(
                                args.syntax_path,
                                args.syntax_bpe_path,
                                p,
                                args.syntax_word_out,
                                args.normalize_fst_weights,
                                norm_alpha=args.syntax_norm_alpha,
                                beam_size=args.syntax_internal_beam,
                                max_internal_len=args.syntax_max_internal_len,
                                allow_early_eos=args.syntax_allow_early_eos,
                                consume_out_of_class=args.syntax_consume_ooc,
                                terminal_restrict=args.syntax_terminal_restrict,
                                internal_only_restrict=args.syntax_internal_only,
                                eow_ids=args.syntax_eow_ids,
                                terminal_ids=args.syntax_terminal_ids)
                        else:
                            p = TokParsePredictor(
                                args.syntax_path,
                                p,
                                args.syntax_word_out,
                                args.normalize_fst_weights,
                                norm_alpha=args.syntax_norm_alpha,
                                beam_size=args.syntax_internal_beam,
                                max_internal_len=args.syntax_max_internal_len,
                                allow_early_eos=args.syntax_allow_early_eos,
                                consume_out_of_class=args.syntax_consume_ooc)
                    else:
                        p = ParsePredictor(
                            p,
                            args.normalize_fst_weights,
                            beam_size=args.syntax_internal_beam,
                            max_internal_len=args.syntax_max_internal_len,
                            nonterminal_ids=args.syntax_nonterminal_ids)
                elif wrapper == "altsrc":
                    src_test = _get_override_args("altsrc_test")
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedAltsrcPredictor(src_test, p)
                    else: # altsrc predictor for bounded predictors
                        p = AltsrcPredictor(src_test, p)
                elif wrapper == "rank":
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedRankPredictor(p)
                    else: # rank predictor for bounded predictors
                        p = RankPredictor(p)
                elif wrapper == "glue":
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedGluePredictor(args.max_len_factor, p)
                    else: # glue predictor for bounded predictors
                        p = GluePredictor(args.max_len_factor, p)
                elif wrapper == "word2char":
                    map_path = _get_override_args("word2char_map")
                    # word2char always wraps unbounded predictors
                    p = Word2charPredictor(map_path, p)
                elif wrapper == "skipvocab":
                    # skipvocab always wraps unbounded predictors
                    p = SkipvocabPredictor(args.skipvocab_vocab, 
                                           args.skipvocab_stop_size, 
                                           args.beam, 
                                           p)
                elif wrapper == "fsttok":
                    fsttok_path = _get_override_args("fsttok_path")
                    # fsttok always wraps unbounded predictors
                    p = FSTTokPredictor(fsttok_path,
                                        args.fst_unk_id,
                                        args.fsttok_max_pending_score,
                                        p)
                elif wrapper == "ngramize":
                    # ngramize always wraps bounded predictors
                    p = NgramizePredictor(args.min_ngram_order, 
                                          args.max_ngram_order,
                                          args.max_len_factor, p)
                elif wrapper == "unkvocab":
                    # unkvocab always wraps bounded predictors
                    p = UnkvocabPredictor(args.trg_vocab_size, p)
                else:
                    logging.fatal("Predictor wrapper '%s' not available. "
                                  "Please double-check --predictors for "
                                  "spelling errors." % wrapper)
                    decoder.remove_predictors()
                    return
            decoder.add_predictor(pred, p, pred_weight)
            logging.info("Initialized predictor {} (weight: {})".format(
                             pred, pred_weight))
    except IOError as e:
        logging.fatal("One of the files required for setting up the "
                      "predictors could not be read: %s" % e)
        decoder.remove_predictors()
    except AttributeError as e:
        logging.fatal("Invalid argument for one of the predictors: %s."
                       "Stack trace: %s" % (e, traceback.format_exc()))
        decoder.remove_predictors()
    except NameError as e:
        logging.fatal("Could not find external library: %s. Please make sure "
                      "that your PYTHONPATH and LD_LIBRARY_PATH contains all "
                      "paths required for the predictors. Stack trace: %s" % 
                      (e, traceback.format_exc()))
        decoder.remove_predictors()
    except ValueError as e:
        logging.fatal("A number format error occurred while configuring the "
                      "predictors: %s. Please double-check all integer- or "
                      "float-valued parameters such as --predictor_weights and"
                      " try again. Stack trace: %s" % (e, traceback.format_exc()))
        decoder.remove_predictors()
    except Exception as e:
        logging.fatal("An unexpected %s has occurred while setting up the pre"
                      "dictors: %s Stack trace: %s" % (sys.exc_info()[0],
                                                       e,
                                                       traceback.format_exc()))
        decoder.remove_predictors()
Exemplo n.º 14
0
    def __init__(self, src_vocab_size, trg_vocab_size, model_name, 
                 hparams_set_name, checkpoint_dir, single_cpu_thread,
                 alpha, beta, shortlist_strategies,
                 trg2src_model_name="", trg2src_hparams_set_name="",
                 trg2src_checkpoint_dir="",
                 max_shortlist_length=0,
                 min_id=0,
                 nizza_unk_id=None):
        """Initializes a nizza predictor.

        Args:
            src_vocab_size (int): Source vocabulary size (called inputs_vocab_size
                in nizza)
            trg_vocab_size (int): Target vocabulary size (called targets_vocab_size
                in nizza)
            model_name (string): Name of the nizza model
            hparams_set_name (string): Name of the nizza hyper-parameter set
            checkpoint_dir (string): Path to the Nizza checkpoint directory. The 
                                     predictor will load the top most checkpoint in 
                                     the `checkpoints` file.
            single_cpu_thread (bool): If true, prevent tensorflow from
                                      doing multithreading.
            alpha (float): Score for each matching word
            beta (float): Penalty for each uncovered word at the end
            shortlist_strategies (string): Comma-separated list of shortlist
                strategies.
            trg2src_model_name (string): Name of the target2source nizza model
            trg2src_hparams_set_name (string): Name of the nizza hyper-parameter set
                                     for the target2source model
            trg2src_checkpoint_dir (string): Path to the Nizza checkpoint directory
                                     for the target2source model. The 
                                     predictor will load the top most checkpoint in 
                                     the `checkpoints` file.
            max_shortlist_length (int): If a shortlist exceeds this limit,
                initialize the initial coverage with 1 at this position. If
                zero, do not apply any limit
            min_id (int): Do not use IDs below this threshold (filters out most
                frequent words).
            nizza_unk_id (int): If set, use this as UNK id. Otherwise, the
                nizza is assumed to have no UNKs

        Raises:
            IOError if checkpoint file not found.
        """
        super(LexNizzaPredictor, self).__init__(
                src_vocab_size, trg_vocab_size, model_name, hparams_set_name, 
                checkpoint_dir, single_cpu_thread, nizza_unk_id=nizza_unk_id)
        self.alpha = alpha
        self.alpha_is_zero = alpha == 0.0
        self.beta = beta
        self.shortlist_strategies = utils.split_comma(shortlist_strategies)
        self.max_shortlist_length = max_shortlist_length
        self.min_id = min_id
        if trg2src_checkpoint_dir:
            self.use_trg2src = True
            predictor_graph = tf.Graph()
            with predictor_graph.as_default() as g:
                hparams = registry.get_registered_hparams_set(trg2src_hparams_set_name)
                hparams.add_hparam("inputs_vocab_size", trg_vocab_size)
                hparams.add_hparam("targets_vocab_size", src_vocab_size)
                run_config = tf.contrib.learn.RunConfig()
                run_config = run_config.replace(model_dir=trg2src_checkpoint_dir)
                model = registry.get_registered_model(trg2src_model_name, hparams, run_config)
                features = {"inputs": tf.expand_dims(tf.range(trg_vocab_size), 0)}
                mode = tf.estimator.ModeKeys.PREDICT
                trg2src_lex_logits = model.precompute(features, mode, hparams)
                # Precompute trg2src partitions
                partitions = tf.reduce_logsumexp(trg2src_lex_logits, axis=-1)
                self._trg2src_src_words_var = tf.placeholder(dtype=tf.int32, shape=[None],
                                                  name="sgnmt_trg2src_src_words")
                # trg2src_lex_logits has shape [1, trg_vocab_size, src_vocab_size]
                self.trg2src_logits = tf.gather(tf.transpose(trg2src_lex_logits[0, :, :]), self._trg2src_src_words_var)
                # trg2src_logits has shape [len(src_words), trg_vocab_size]
                self.trg2src_mon_sess = self.create_session(trg2src_checkpoint_dir)
                logging.debug("Precomputing lexnizza trg2src partitions...")
                self.trg2src_partitions = self.trg2src_mon_sess.run(partitions)
        else:
            self.use_trg2src = False
            logging.warn("No target-to-source model specified for lexnizza.")