コード例 #1
0
from cam.sgnmt.blocks.alignment.nam import align_with_nam
from cam.sgnmt.blocks.alignment.nmt import align_with_nmt
from cam.sgnmt.blocks.nmt import blocks_get_default_nmt_config
from cam.sgnmt.misc.sparse import FileBasedFeatMap
from cam.sgnmt.output import CSVAlignmentOutputHandler, \
    NPYAlignmentOutputHandler, TextAlignmentOutputHandler
from cam.sgnmt.ui import get_align_parser

logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s')
logging.getLogger().setLevel(logging.INFO)

parser = get_align_parser()
args = parser.parse_args()

# Get configuration
configuration = blocks_get_default_nmt_config()
for k in dir(args):
    if k in configuration:
        configuration[k] = getattr(args, k)
if configuration['src_sparse_feat_map']:
    configuration['src_sparse_feat_map'] = FileBasedFeatMap(
        configuration['enc_embed'], configuration['src_sparse_feat_map'])
if configuration['trg_sparse_feat_map']:
    configuration['trg_sparse_feat_map'] = FileBasedFeatMap(
        configuration['dec_embed'], configuration['trg_sparse_feat_map'])
logging.info("Model options:\n{}".format(pprint.pformat(configuration)))

# Align
if args.alignment_model == "nam":
    alignments = align_with_nam(configuration, args)
elif args.alignment_model == "nmt":
コード例 #2
0
def _get_dataset_with_mono(mono_data_integration='exp3s',
                           backtrans_nmt_config='',
                           backtrans_store=True,
                           add_mono_dummy_data=True,
                           min_parallel_data=0.2,
                           backtrans_reload_frequency=0,
                           backtrans_max_same_word=0.3,
                           src_data='',
                           trg_data='',
                           src_mono_data='',
                           trg_mono_data='',
                           src_vocab_size=30000,
                           trg_vocab_size=30000,
                           src_sparse_feat_map='',
                           trg_sparse_feat_map='',
                           saveto='',
                           **kwargs):
    """Creates a parallel data stream with monolingual data. This is
    based on the ``ParallelSource`` framework in ``stream``.
    
    The arguments to this method are given by the configuration dict.
    """
    src_sens = stream.load_sentences_from_file(src_data, src_vocab_size)
    trg_sens = stream.load_sentences_from_file(trg_data, trg_vocab_size)
    trg_mono_sens = stream.load_sentences_from_file(trg_mono_data,
                                                    trg_vocab_size)

    backtrans_config = blocks_get_default_nmt_config()
    if backtrans_nmt_config:
        for pair in backtrans_nmt_config.split(","):
            (k, v) = pair.split("=", 1)
            backtrans_config[k] = type(backtrans_config[k])(v)

    parallel_src = ShuffledParallelSource(src_sens, trg_sens)
    dummy_src = None
    if add_mono_dummy_data:
        dummy_src = DummyParallelSource(utils.GO_ID, trg_mono_sens)
    if backtrans_store:
        backtrans_file = "%s/backtrans.txt" % saveto
        old_backtrans_src = OldBacktranslatedParallelSource(backtrans_file)
        backtrans_src = BacktranslatedParallelSource(
            trg_mono_sens, backtrans_config, backtrans_file,
            backtrans_max_same_word, backtrans_reload_frequency,
            old_backtrans_src)
    else:
        backtrans_src = BacktranslatedParallelSource(
            trg_mono_sens, backtrans_config, None, backtrans_max_same_word,
            backtrans_reload_frequency)

    if min_parallel_data > 0.0:
        if add_mono_dummy_data:
            dummy_src = MergedParallelSource(parallel_src, dummy_src,
                                             min_parallel_data)
        backtrans_src = MergedParallelSource(parallel_src, backtrans_src,
                                             min_parallel_data)
        old_backtrans_src = MergedParallelSource(parallel_src,
                                                 old_backtrans_src,
                                                 min_parallel_data)
    sources = []
    sources.append(parallel_src)
    if add_mono_dummy_data:
        sources.append(dummy_src)
    sources.append(backtrans_src)
    if backtrans_store:
        sources.append(old_backtrans_src)

    return ParallelSourceSwitchDataset(sources,
                                       src_vocab_size,
                                       trg_vocab_size,
                                       src_sparse_feat_map=src_sparse_feat_map,
                                       trg_sparse_feat_map=trg_sparse_feat_map)
コード例 #3
0
ファイル: align.py プロジェクト: ucam-smt/sgnmt
from cam.sgnmt.blocks.alignment.nmt import align_with_nmt
from cam.sgnmt.blocks.nmt import blocks_get_default_nmt_config
from cam.sgnmt.misc.sparse import FileBasedFeatMap
from cam.sgnmt.output import CSVAlignmentOutputHandler, \
    NPYAlignmentOutputHandler, TextAlignmentOutputHandler
from cam.sgnmt.blocks.nmt import get_blocks_align_parser


logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s')
logging.getLogger().setLevel(logging.INFO)

parser = get_blocks_align_parser()
args = parser.parse_args()

# Get configuration
configuration = blocks_get_default_nmt_config()
for k in dir(args):
    if k in configuration:
        configuration[k] = getattr(args, k)
if configuration['src_sparse_feat_map']:
    configuration['src_sparse_feat_map'] = FileBasedFeatMap(
                                        configuration['enc_embed'],
                                        configuration['src_sparse_feat_map'])
if configuration['trg_sparse_feat_map']:
    configuration['trg_sparse_feat_map'] = FileBasedFeatMap(
                                        configuration['dec_embed'],
                                        configuration['trg_sparse_feat_map'])
logging.info("Model options:\n{}".format(pprint.pformat(configuration)))
    
# Align
if args.alignment_model == "nam":
コード例 #4
0
ファイル: decode_utils.py プロジェクト: Jack44Wang/sgnmt
def add_predictors(decoder):
    """Adds all enabled predictors to the ``decoder``. This function
    makes heavy use of the global ``args`` which contains the
    SGNMT configuration. Particularly, it reads out ``args.predictors``
    and adds appropriate instances to ``decoder``.
    TODO: Refactor this method as it is waaaay tooooo looong

    Args:
        decoder (Decoder):  Decoding strategy, see ``create_decoder()``.
            This method will add predictors to this instance with
            ``add_predictor()``
    """
    preds = utils.split_comma(args.predictors)
    if not preds:
        logging.fatal("Require at least one predictor! See the --predictors "
                      "argument for more information.")
    weights = None
    if args.predictor_weights:
        weights = utils.split_comma(args.predictor_weights)
        if len(preds) != len(weights):
            logging.fatal("Specified %d predictors, but %d weights. Please "
                          "revise the --predictors and --predictor_weights "
                          "arguments" % (len(preds), len(weights)))
            return

    pred_weight = 1.0
    try:
        for idx, pred in enumerate(preds):  # Add predictors one by one
            wrappers = []
            if '_' in pred:
                # Handle weights when we have wrapper predictors
                wrappers = pred.split('_')
                pred = wrappers[-1]
                wrappers = wrappers[-2::-1]
                if weights:
                    wrapper_weights = [
                        float(w) for w in weights[idx].split('_')
                    ]
                    pred_weight = wrapper_weights[-1]
                    wrapper_weights = wrapper_weights[-2::-1]
                else:
                    wrapper_weights = [1.0] * len(wrappers)
            elif weights:
                pred_weight = float(weights[idx])

            # Create predictor instances for the string argument ``pred``
            if pred == "nmt":
                nmt_engine = _get_override_args("nmt_engine")
                if nmt_engine == 'blocks':
                    nmt_config = _parse_config_param(
                        "nmt_config", blocks_get_default_nmt_config())
                    p = blocks_get_nmt_predictor(
                        args, _get_override_args("nmt_path"), nmt_config)
                elif nmt_engine == 'tensorflow':
                    nmt_config = _parse_config_param(
                        "nmt_config", tf_get_default_nmt_config())
                    p = tf_get_nmt_predictor(args,
                                             _get_override_args("nmt_path"),
                                             nmt_config)
                elif nmt_engine != 'none':
                    logging.fatal("NMT engine %s is not supported (yet)!" %
                                  nmt_engine)
            elif pred == "nizza":
                p = NizzaPredictor(_get_override_args("pred_src_vocab_size"),
                                   _get_override_args("pred_trg_vocab_size"),
                                   _get_override_args("nizza_model"),
                                   _get_override_args("nizza_hparams_set"),
                                   _get_override_args("nizza_checkpoint_dir"),
                                   single_cpu_thread=args.single_cpu_thread)
            elif pred == "lexnizza":
                p = LexNizzaPredictor(
                    _get_override_args("pred_src_vocab_size"),
                    _get_override_args("pred_trg_vocab_size"),
                    _get_override_args("nizza_model"),
                    _get_override_args("nizza_hparams_set"),
                    _get_override_args("nizza_checkpoint_dir"),
                    single_cpu_thread=args.single_cpu_thread,
                    alpha=args.lexnizza_alpha,
                    beta=args.lexnizza_beta,
                    trg2src_model_name=args.lexnizza_trg2src_model,
                    trg2src_hparams_set_name=args.lexnizza_trg2src_hparams_set,
                    trg2src_checkpoint_dir=args.
                    lexnizza_trg2src_checkpoint_dir,
                    shortlist_strategies=args.lexnizza_shortlist_strategies,
                    max_shortlist_length=args.lexnizza_max_shortlist_length,
                    min_id=args.lexnizza_min_id)
            elif pred == "t2t":
                p = T2TPredictor(_get_override_args("pred_src_vocab_size"),
                                 _get_override_args("pred_trg_vocab_size"),
                                 _get_override_args("t2t_model"),
                                 _get_override_args("t2t_problem"),
                                 _get_override_args("t2t_hparams_set"),
                                 args.t2t_usr_dir,
                                 _get_override_args("t2t_checkpoint_dir"),
                                 single_cpu_thread=args.single_cpu_thread,
                                 max_terminal_id=args.syntax_max_terminal_id,
                                 pop_id=args.syntax_pop_id)
            elif pred == "simt2t":
                p = SimT2TPredictor(
                    _get_override_args("pred_src_vocab_size"),
                    _get_override_args("pred_trg_vocab_size"),
                    _get_override_args("t2t_model"),
                    _get_override_args("t2t_problem"),
                    _get_override_args("t2t_hparams_set"),
                    args.t2t_usr_dir,
                    _get_override_args("t2t_checkpoint_dir"),
                    single_cpu_thread=args.single_cpu_thread,
                    max_terminal_id=args.syntax_max_terminal_id,
                    pop_id=args.syntax_pop_id)
            elif pred == "simt2tv2":
                p = SimT2TPredictor_v2(
                    _get_override_args("pred_src_vocab_size"),
                    _get_override_args("pred_trg_vocab_size"),
                    _get_override_args("t2t_model"),
                    _get_override_args("t2t_problem"),
                    _get_override_args("t2t_hparams_set"),
                    args.t2t_usr_dir,
                    _get_override_args("t2t_checkpoint_dir"),
                    single_cpu_thread=args.single_cpu_thread,
                    max_terminal_id=args.syntax_max_terminal_id,
                    pop_id=args.syntax_pop_id)
            elif pred == "bracket":
                p = BracketPredictor(args.syntax_max_terminal_id,
                                     args.syntax_pop_id,
                                     max_depth=args.syntax_max_depth,
                                     extlength_path=args.extlength_path)
            elif pred == "osm":
                p = OSMPredictor()
            elif pred == "forcedosm":
                p = ForcedOSMPredictor(args.trg_test)
            elif pred == "fst":
                p = FstPredictor(_get_override_args("fst_path"),
                                 args.use_fst_weights,
                                 args.normalize_fst_weights,
                                 skip_bos_weight=args.fst_skip_bos_weight,
                                 to_log=args.fst_to_log)
            elif pred == "nfst":
                p = NondeterministicFstPredictor(
                    _get_override_args("fst_path"),
                    args.use_fst_weights,
                    args.normalize_fst_weights,
                    args.fst_skip_bos_weight,
                    to_log=args.fst_to_log)
            elif pred == "forced":
                p = ForcedPredictor(args.trg_test)
            elif pred == "bow":
                p = BagOfWordsPredictor(
                    args.trg_test, args.bow_accept_subsets,
                    args.bow_accept_duplicates, args.heuristic_scores_file,
                    args.collect_statistics, "consumed"
                    in args.bow_heuristic_strategies, "remaining"
                    in args.bow_heuristic_strategies,
                    args.bow_diversity_heuristic_factor,
                    _get_override_args("pred_trg_vocab_size"))
            elif pred == "bowsearch":
                p = BagOfWordsSearchPredictor(
                    decoder, args.hypo_recombination, args.trg_test,
                    args.bow_accept_subsets, args.bow_accept_duplicates,
                    args.heuristic_scores_file, args.collect_statistics,
                    "consumed" in args.bow_heuristic_strategies, "remaining"
                    in args.bow_heuristic_strategies,
                    args.bow_diversity_heuristic_factor,
                    _get_override_args("pred_trg_vocab_size"))
            elif pred == "forcedlst":
                feat_name = _get_override_args("forcedlst_sparse_feat")
                p = ForcedLstPredictor(args.trg_test, args.use_nbest_weights,
                                       feat_name if feat_name else None)
            elif pred == "rtn":
                p = RtnPredictor(args.rtn_path,
                                 args.use_rtn_weights,
                                 args.normalize_rtn_weights,
                                 to_log=args.fst_to_log,
                                 minimize_rtns=args.minimize_rtns,
                                 rmeps=args.remove_epsilon_in_rtns)
            elif pred == "srilm":
                p = SRILMPredictor(args.srilm_path, args.srilm_order,
                                   args.srilm_convert_to_ln)
            elif pred == "nplm":
                p = NPLMPredictor(args.nplm_path, args.normalize_nplm_probs)
            elif pred == "rnnlm":
                p = tf_get_rnnlm_predictor(_get_override_args("rnnlm_path"),
                                           _get_override_args("rnnlm_config"),
                                           tf_get_rnnlm_prefix())
            elif pred == "wc":
                p = WordCountPredictor(args.wc_word)
            elif pred == "ngramc":
                p = NgramCountPredictor(_get_override_args("ngramc_path"),
                                        _get_override_args("ngramc_order"),
                                        args.ngramc_discount_factor)
            elif pred == "unkc":
                p = UnkCountPredictor(
                    _get_override_args("pred_src_vocab_size"),
                    [float(l) for l in args.unk_count_lambdas.split(',')])
            elif pred == "length":
                length_model_weights = [
                    float(w) for w in args.length_model_weights.split(',')
                ]
                p = NBLengthPredictor(args.src_test_raw, length_model_weights,
                                      args.use_length_point_probs,
                                      args.length_model_offset)
            elif pred == "extlength":
                p = ExternalLengthPredictor(args.extlength_path)
            elif pred == "lrhiero":
                fw = None
                if args.grammar_feature_weights:
                    fw = [
                        float(w)
                        for w in args.grammar_feature_weights.split(',')
                    ]
                p = RuleXtractPredictor(args.rules_path,
                                        args.use_grammar_weights, fw)
            elif pred == "vanilla":
                continue
            else:
                logging.fatal("Predictor '%s' not available. Please check "
                              "--predictors for spelling errors." % pred)
                decoder.remove_predictors()
                return
            for _, wrapper in enumerate(wrappers):
                # Embed predictor ``p`` into wrapper predictors if necessary
                # TODO: Use wrapper_weights
                if wrapper == "idxmap":
                    src_path = _get_override_args("src_idxmap")
                    trg_path = _get_override_args("trg_idxmap")
                    if isinstance(p, UnboundedVocabularyPredictor):
                        p = UnboundedIdxmapPredictor(src_path, trg_path, p,
                                                     1.0)
                    else:  # idxmap predictor for bounded predictors
                        p = IdxmapPredictor(src_path, trg_path, p, 1.0)
                elif wrapper == "altsrc":
                    src_test = _get_override_args("altsrc_test")
                    if isinstance(p, UnboundedVocabularyPredictor):
                        p = UnboundedAltsrcPredictor(src_test, p)
                    else:  # altsrc predictor for bounded predictors
                        p = AltsrcPredictor(src_test, p)
                elif wrapper == "word2char":
                    map_path = _get_override_args("word2char_map")
                    # word2char always wraps unbounded predictors
                    p = Word2charPredictor(map_path, p)
                elif wrapper == "skipvocab":
                    # skipvocab always wraps unbounded predictors
                    p = SkipvocabPredictor(args.skipvocab_max_id,
                                           args.skipvocab_stop_size, args.beam,
                                           p)
                elif wrapper == "fsttok":
                    fsttok_path = _get_override_args("fsttok_path")
                    # fsttok always wraps unbounded predictors
                    p = FSTTokPredictor(fsttok_path, args.fst_unk_id,
                                        args.fsttok_max_pending_score, p)
                elif wrapper == "ngramize":
                    # ngramize always wraps bounded predictors
                    p = NgramizePredictor(args.min_ngram_order,
                                          args.max_ngram_order,
                                          args.max_len_factor, p)
                elif wrapper == "unkvocab":
                    # unkvocab always wraps bounded predictors
                    p = UnkvocabPredictor(args.trg_vocab_size, p)
                else:
                    logging.fatal("Predictor wrapper '%s' not available. "
                                  "Please double-check --predictors for "
                                  "spelling errors." % wrapper)
                    decoder.remove_predictors()
                    return
            decoder.add_predictor(pred, p, pred_weight)
            logging.info("Initialized predictor {} (weight: {})".format(
                pred, pred_weight))
    except IOError as e:
        logging.fatal("One of the files required for setting up the "
                      "predictors could not be read: %s" % e)
        decoder.remove_predictors()
    except AttributeError as e:
        logging.fatal("Invalid argument for one of the predictors: %s" % e)
        decoder.remove_predictors()
    except NameError as e:
        logging.fatal("Could not find external library: %s. Please make sure "
                      "that your PYTHONPATH and LD_LIBRARY_PATH contains all "
                      "paths required for the predictors. Stack trace: %s" %
                      (e, traceback.format_exc()))
        decoder.remove_predictors()
    except ValueError as e:
        logging.fatal("A number format error occurred while configuring the "
                      "predictors: %s. Please double-check all integer- or "
                      "float-valued parameters such as --predictor_weights and"
                      " try again. Stack trace: %s" %
                      (e, traceback.format_exc()))
        decoder.remove_predictors()
    except Exception as e:
        logging.fatal("An unexpected %s has occurred while setting up the pre"
                      "dictors: %s Stack trace: %s" %
                      (sys.exc_info()[0], e, traceback.format_exc()))
        decoder.remove_predictors()
コード例 #5
0
ファイル: decode_utils.py プロジェクト: ucam-smt/sgnmt
def add_predictors(decoder):
    """Adds all enabled predictors to the ``decoder``. This function 
    makes heavy use of the global ``args`` which contains the
    SGNMT configuration. Particularly, it reads out ``args.predictors``
    and adds appropriate instances to ``decoder``.
    TODO: Refactor this method as it is waaaay tooooo looong
    
    Args:
        decoder (Decoder):  Decoding strategy, see ``create_decoder()``.
            This method will add predictors to this instance with
            ``add_predictor()``
    """
    preds = utils.split_comma(args.predictors)
    if not preds:
        logging.fatal("Require at least one predictor! See the --predictors "
                      "argument for more information.")
    weights = None
    if args.predictor_weights:
        weights = utils.split_comma(args.predictor_weights)
        if len(preds) != len(weights):
            logging.fatal("Specified %d predictors, but %d weights. Please "
                      "revise the --predictors and --predictor_weights "
                      "arguments" % (len(preds), len(weights)))
            return
    
    pred_weight = 1.0
    try:
        for idx, pred in enumerate(preds): # Add predictors one by one
            wrappers = []
            if '_' in pred: 
                # Handle weights when we have wrapper predictors
                wrappers = pred.split('_')
                pred = wrappers[-1]
                wrappers = wrappers[-2::-1]
                if weights:
                    wrapper_weights = [float(w) for w in weights[idx].split('_')]
                    pred_weight = wrapper_weights[-1]
                    wrapper_weights = wrapper_weights[-2::-1]
                else:
                    wrapper_weights = [1.0] * len(wrappers)
            elif weights:
                pred_weight = float(weights[idx])

            # Create predictor instances for the string argument ``pred``
            if pred == "nmt":
                # TODO: Clean this up and make a blocks and tfnmt predictor
                nmt_engine = _get_override_args("nmt_engine")
                if nmt_engine == 'blocks':
                    nmt_config = _parse_config_param(
                        "nmt_config", blocks_get_default_nmt_config())
                    p = blocks_get_nmt_predictor(
                        args, _get_override_args("nmt_path"), nmt_config)
                elif nmt_engine == 'tensorflow':
                    nmt_config = _parse_config_param(
                        "nmt_config", tf_get_default_nmt_config())
                    p = tf_get_nmt_predictor(
                        args, _get_override_args("nmt_path"), nmt_config)
                elif nmt_engine != 'none':
                    logging.fatal("NMT engine %s is not supported (yet)!" % nmt_engine)
            elif pred == "nizza":
                p = NizzaPredictor(_get_override_args("pred_src_vocab_size"),
                                   _get_override_args("pred_trg_vocab_size"),
                                   _get_override_args("nizza_model"),
                                   _get_override_args("nizza_hparams_set"),
                                   _get_override_args("nizza_checkpoint_dir"),
                                   n_cpu_threads=args.n_cpu_threads)
            elif pred == "lexnizza":
                p = LexNizzaPredictor(_get_override_args("pred_src_vocab_size"),
                                      _get_override_args("pred_trg_vocab_size"),
                                      _get_override_args("nizza_model"),
                                      _get_override_args("nizza_hparams_set"),
                                      _get_override_args("nizza_checkpoint_dir"),
                                      n_cpu_threads=args.n_cpu_threads,
                                      alpha=args.lexnizza_alpha,
                                      beta=args.lexnizza_beta,
                                      trg2src_model_name=
                                          args.lexnizza_trg2src_model, 
                                      trg2src_hparams_set_name=
                                          args.lexnizza_trg2src_hparams_set,
                                      trg2src_checkpoint_dir=
                                          args.lexnizza_trg2src_checkpoint_dir,
                                      shortlist_strategies=
                                          args.lexnizza_shortlist_strategies,
                                      max_shortlist_length=
                                          args.lexnizza_max_shortlist_length,
                                      min_id=args.lexnizza_min_id)
            elif pred == "t2t":
                p = T2TPredictor(_get_override_args("pred_src_vocab_size"),
                                 _get_override_args("pred_trg_vocab_size"),
                                 _get_override_args("t2t_model"),
                                 _get_override_args("t2t_problem"),
                                 _get_override_args("t2t_hparams_set"),
                                 args.t2t_usr_dir,
                                 _get_override_args("t2t_checkpoint_dir"),
                                 t2t_unk_id=_get_override_args("t2t_unk_id"),
                                 n_cpu_threads=args.n_cpu_threads,
                                 max_terminal_id=args.syntax_max_terminal_id,
                                 pop_id=args.syntax_pop_id)
            elif pred == "segt2t":
                p = SegT2TPredictor(_get_override_args("pred_src_vocab_size"),
                                    _get_override_args("pred_trg_vocab_size"),
                                    _get_override_args("t2t_model"),
                                    _get_override_args("t2t_problem"),
                                    _get_override_args("t2t_hparams_set"),
                                    args.t2t_usr_dir,
                                    _get_override_args("t2t_checkpoint_dir"),
                                    t2t_unk_id=_get_override_args("t2t_unk_id"),
                                    n_cpu_threads=args.n_cpu_threads,
                                    max_terminal_id=args.syntax_max_terminal_id,
                                    pop_id=args.syntax_pop_id)
            elif pred == "editt2t":
                p = EditT2TPredictor(_get_override_args("pred_src_vocab_size"),
                                     _get_override_args("pred_trg_vocab_size"),
                                     _get_override_args("t2t_model"),
                                     _get_override_args("t2t_problem"),
                                     _get_override_args("t2t_hparams_set"),
                                     args.trg_test,
                                     args.beam,
                                     args.t2t_usr_dir,
                                     _get_override_args("t2t_checkpoint_dir"),
                                     t2t_unk_id=_get_override_args("t2t_unk_id"),
                                     n_cpu_threads=args.n_cpu_threads,
                                     max_terminal_id=args.syntax_max_terminal_id,
                                     pop_id=args.syntax_pop_id)
            elif pred == "fertt2t":
                p = FertilityT2TPredictor(
                                 _get_override_args("pred_src_vocab_size"),
                                 _get_override_args("pred_trg_vocab_size"),
                                 _get_override_args("t2t_model"),
                                 _get_override_args("t2t_problem"),
                                 _get_override_args("t2t_hparams_set"),
                                 args.t2t_usr_dir,
                                 _get_override_args("t2t_checkpoint_dir"),
                                 n_cpu_threasd=args.n_cpu_threads,
                                 max_terminal_id=args.syntax_max_terminal_id,
                                 pop_id=args.syntax_pop_id)
            elif pred == "bracket":
                p = BracketPredictor(args.syntax_max_terminal_id,
                                     args.syntax_pop_id,
                                     max_depth=args.syntax_max_depth,
                                     extlength_path=args.extlength_path)
            elif pred == "osm":
                p = OSMPredictor(args.osm_type)
            elif pred == "forcedosm":
                p = ForcedOSMPredictor(args.trg_test)
            elif pred == "fst":
                p = FstPredictor(_get_override_args("fst_path"),
                                 args.use_fst_weights,
                                 args.normalize_fst_weights,
                                 skip_bos_weight=args.fst_skip_bos_weight,
                                 to_log=args.fst_to_log)
            elif pred == "nfst":
                p = NondeterministicFstPredictor(_get_override_args("fst_path"),
                                                 args.use_fst_weights,
                                                 args.normalize_fst_weights,
                                                 args.fst_skip_bos_weight,
                                                 to_log=args.fst_to_log)
            elif pred == "forced":
                p = ForcedPredictor(args.trg_test)
            elif pred == "bow":
                p = BagOfWordsPredictor(
                                args.trg_test,
                                args.bow_accept_subsets,
                                args.bow_accept_duplicates,
                                args.heuristic_scores_file,
                                args.collect_statistics,
                                "consumed" in args.bow_heuristic_strategies,
                                "remaining" in args.bow_heuristic_strategies,
                                args.bow_diversity_heuristic_factor,
                                _get_override_args("pred_trg_vocab_size"))
            elif pred == "bowsearch":
                p = BagOfWordsSearchPredictor(
                                decoder,
                                args.hypo_recombination,
                                args.trg_test,
                                args.bow_accept_subsets,
                                args.bow_accept_duplicates,
                                args.heuristic_scores_file,
                                args.collect_statistics,
                                "consumed" in args.bow_heuristic_strategies,
                                "remaining" in args.bow_heuristic_strategies,
                                args.bow_diversity_heuristic_factor,
                                _get_override_args("pred_trg_vocab_size"))
            elif pred == "forcedlst":
                feat_name = _get_override_args("forcedlst_sparse_feat")
                p = ForcedLstPredictor(args.trg_test,
                                       args.use_nbest_weights,
                                       args.forcedlst_match_unk,
                                       feat_name if feat_name else None)
            elif pred == "rtn":
                p = RtnPredictor(args.rtn_path,
                                 args.use_rtn_weights,
                                 args.normalize_rtn_weights,
                                 to_log=args.fst_to_log,
                                 minimize_rtns=args.minimize_rtns,
                                 rmeps=args.remove_epsilon_in_rtns)
            elif pred == "srilm":
                p = SRILMPredictor(args.lm_path, 
                                   _get_override_args("ngramc_order"),
                                   args.srilm_convert_to_ln)
            elif pred == "kenlm":
                p = KenLMPredictor(args.lm_path)
            elif pred == "nplm":
                p = NPLMPredictor(args.nplm_path, args.normalize_nplm_probs)
            elif pred == "rnnlm":
                p = tf_get_rnnlm_predictor(_get_override_args("rnnlm_path"),
                                           _get_override_args("rnnlm_config"),
                                           tf_get_rnnlm_prefix())
            elif pred == "wc":
                p = WordCountPredictor(args.wc_word,
                                       args.wc_nonterminal_penalty,
                                       args.syntax_nonterminal_ids,
                                       args.syntax_min_terminal_id,
                                       args.syntax_max_terminal_id,
                                       args.negative_wc,
                                       _get_override_args("pred_trg_vocab_size"))
            elif pred == "ngramc":
                p = NgramCountPredictor(_get_override_args("ngramc_path"),
                                        _get_override_args("ngramc_order"),
                                        args.ngramc_discount_factor)
            elif pred == "unkc":
                p = UnkCountPredictor(
                     _get_override_args("pred_src_vocab_size"), 
                     utils.split_comma(args.unk_count_lambdas, float))
            elif pred == "length":
                length_model_weights = utils.split_comma(
                    args.length_model_weights, float)
                p = NBLengthPredictor(args.src_test_raw, 
                                      length_model_weights, 
                                      args.use_length_point_probs,
                                      args.length_model_offset)
            elif pred == "extlength":
                p = ExternalLengthPredictor(args.extlength_path)
            elif pred == "lrhiero":
                fw = None
                if args.grammar_feature_weights:
                    fw = utils.split_comma(args.grammar_feature_weights, float)
                p = RuleXtractPredictor(args.rules_path,
                                        args.use_grammar_weights,
                                        fw)
            elif pred == "vanilla":
                continue
            else:
                logging.fatal("Predictor '%s' not available. Please check "
                              "--predictors for spelling errors." % pred)
                decoder.remove_predictors()
                return
            for _,wrapper in enumerate(wrappers):
                # Embed predictor ``p`` into wrapper predictors if necessary
                # TODO: Use wrapper_weights
                if wrapper == "idxmap":
                    src_path = _get_override_args("src_idxmap")
                    trg_path = _get_override_args("trg_idxmap")
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedIdxmapPredictor(src_path, trg_path, p, 1.0) 
                    else: # idxmap predictor for bounded predictors
                        p = IdxmapPredictor(src_path, trg_path, p, 1.0)
                elif wrapper == "maskvocab":
                    words = utils.split_comma(args.maskvocab_words, int)
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedMaskvocabPredictor(words, p) 
                    else: # idxmap predictor for bounded predictors
                        p = MaskvocabPredictor(words, p)
                elif wrapper == "weightnt":
                    p = WeightNonTerminalPredictor(
                        p, 
                        args.syntax_nonterminal_factor,
                        args.syntax_nonterminal_ids,
                        args.syntax_min_terminal_id,
                        args.syntax_max_terminal_id,
                        _get_override_args("pred_trg_vocab_size"))
                elif wrapper == "parse":
                    if args.parse_tok_grammar:
                        if args.parse_bpe_path:
                            p = BpeParsePredictor(
                                args.syntax_path,
                                args.syntax_bpe_path,
                                p,
                                args.syntax_word_out,
                                args.normalize_fst_weights,
                                norm_alpha=args.syntax_norm_alpha,
                                beam_size=args.syntax_internal_beam,
                                max_internal_len=args.syntax_max_internal_len,
                                allow_early_eos=args.syntax_allow_early_eos,
                                consume_out_of_class=args.syntax_consume_ooc,
                                terminal_restrict=args.syntax_terminal_restrict,
                                internal_only_restrict=args.syntax_internal_only,
                                eow_ids=args.syntax_eow_ids,
                                terminal_ids=args.syntax_terminal_ids)
                        else:
                            p = TokParsePredictor(
                                args.syntax_path,
                                p,
                                args.syntax_word_out,
                                args.normalize_fst_weights,
                                norm_alpha=args.syntax_norm_alpha,
                                beam_size=args.syntax_internal_beam,
                                max_internal_len=args.syntax_max_internal_len,
                                allow_early_eos=args.syntax_allow_early_eos,
                                consume_out_of_class=args.syntax_consume_ooc)
                    else:
                        p = ParsePredictor(
                            p,
                            args.normalize_fst_weights,
                            beam_size=args.syntax_internal_beam,
                            max_internal_len=args.syntax_max_internal_len,
                            nonterminal_ids=args.syntax_nonterminal_ids)
                elif wrapper == "altsrc":
                    src_test = _get_override_args("altsrc_test")
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedAltsrcPredictor(src_test, p)
                    else: # altsrc predictor for bounded predictors
                        p = AltsrcPredictor(src_test, p)
                elif wrapper == "rank":
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedRankPredictor(p)
                    else: # rank predictor for bounded predictors
                        p = RankPredictor(p)
                elif wrapper == "glue":
                    if isinstance(p, UnboundedVocabularyPredictor): 
                        p = UnboundedGluePredictor(args.max_len_factor, p)
                    else: # glue predictor for bounded predictors
                        p = GluePredictor(args.max_len_factor, p)
                elif wrapper == "word2char":
                    map_path = _get_override_args("word2char_map")
                    # word2char always wraps unbounded predictors
                    p = Word2charPredictor(map_path, p)
                elif wrapper == "skipvocab":
                    # skipvocab always wraps unbounded predictors
                    p = SkipvocabPredictor(args.skipvocab_max_id, 
                                           args.skipvocab_stop_size, 
                                           args.beam, 
                                           p)
                elif wrapper == "fsttok":
                    fsttok_path = _get_override_args("fsttok_path")
                    # fsttok always wraps unbounded predictors
                    p = FSTTokPredictor(fsttok_path,
                                        args.fst_unk_id,
                                        args.fsttok_max_pending_score,
                                        p)
                elif wrapper == "ngramize":
                    # ngramize always wraps bounded predictors
                    p = NgramizePredictor(args.min_ngram_order, 
                                          args.max_ngram_order,
                                          args.max_len_factor, p)
                elif wrapper == "unkvocab":
                    # unkvocab always wraps bounded predictors
                    p = UnkvocabPredictor(args.trg_vocab_size, p)
                else:
                    logging.fatal("Predictor wrapper '%s' not available. "
                                  "Please double-check --predictors for "
                                  "spelling errors." % wrapper)
                    decoder.remove_predictors()
                    return
            decoder.add_predictor(pred, p, pred_weight)
            logging.info("Initialized predictor {} (weight: {})".format(
                             pred, pred_weight))
    except IOError as e:
        logging.fatal("One of the files required for setting up the "
                      "predictors could not be read: %s" % e)
        decoder.remove_predictors()
    except AttributeError as e:
        logging.fatal("Invalid argument for one of the predictors: %s."
                       "Stack trace: %s" % (e, traceback.format_exc()))
        decoder.remove_predictors()
    except NameError as e:
        logging.fatal("Could not find external library: %s. Please make sure "
                      "that your PYTHONPATH and LD_LIBRARY_PATH contains all "
                      "paths required for the predictors. Stack trace: %s" % 
                      (e, traceback.format_exc()))
        decoder.remove_predictors()
    except ValueError as e:
        logging.fatal("A number format error occurred while configuring the "
                      "predictors: %s. Please double-check all integer- or "
                      "float-valued parameters such as --predictor_weights and"
                      " try again. Stack trace: %s" % (e, traceback.format_exc()))
        decoder.remove_predictors()
    except Exception as e:
        logging.fatal("An unexpected %s has occurred while setting up the pre"
                      "dictors: %s Stack trace: %s" % (sys.exc_info()[0],
                                                       e,
                                                       traceback.format_exc()))
        decoder.remove_predictors()
コード例 #6
0
ファイル: batch_decode.py プロジェクト: Jack44Wang/sgnmt
    # Print out result
    for task in sorted(finished_tasks, key=lambda t: t.sen_id): 
        logging.info("Decoded (ID: %d): %s" % (
                        task.sen_id,
                        ' '.join([str(w) for w in task.get_best_translation()])))
        logging.info("Stats (ID: %d): %s" % (task.sen_id,
                                             task.get_stats_string()))
    logging.info("Decoding finished. Time: %.6f" % (stop_time - start_time))
    os.system('kill %d' % os.getpid())
        

# MAIN ENTRY POINT

# Get configuration
config = blocks_get_default_nmt_config()
for k in dir(args):
    if k in config:
        config[k] = getattr(args, k)
logging.info("Model options:\n{}".format(pprint.pformat(config)))
np.show_config()

nmt_model = NMTModel(config)
nmt_model.set_up()

loader = LoadNMTUtils(get_nmt_model_path_best_bleu(config),
                      config['saveto'],
                      nmt_model.search_model)
loader.load_weights()

src_sentences = load_sentences(args.src_test,