コード例 #1
0
dev_fold_sources_tags = []
dev_fold_targets_tags = []

for i in range(microtransquest_config["n_fold"]):

    if os.path.exists(microtransquest_config['output_dir']) and os.path.isdir(
            microtransquest_config['output_dir']):
        shutil.rmtree(microtransquest_config['output_dir'])

    if microtransquest_config["evaluate_during_training"]:
        raw_train, raw_eval = train_test_split(raw_train_df,
                                               test_size=0.1,
                                               random_state=SEED * i)
        model = MicroTransQuestModel(MODEL_TYPE,
                                     MODEL_NAME,
                                     labels=["OK", "BAD"],
                                     args=microtransquest_config)
        model.train_model(raw_train, eval_data=raw_eval)
        model = MicroTransQuestModel(MODEL_TYPE,
                                     microtransquest_config["best_model_dir"],
                                     labels=["OK", "BAD"],
                                     args=microtransquest_config)

    else:
        model = MicroTransQuestModel(MODEL_TYPE,
                                     MODEL_NAME,
                                     labels=["OK", "BAD"],
                                     args=microtransquest_config)
        model.train_model(raw_train_df)

    sources_tags, targets_tags = model.predict(test_sentences,
コード例 #2
0
ファイル: app.py プロジェクト: TharinduDR/TransQuest-UI
from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel

from transquest_ui.app.args import microtransquest_config, monotransquest_config


class PredictedToken:
    def __init__(self, text, quality):
        self.text = text
        self.quality = quality


model_args = {"use_multiprocessing": False}

en_de_word = MicroTransQuestModel("xlmroberta",
                                  "TransQuest/microtransquest-en_de-wiki",
                                  args=microtransquest_config,
                                  labels=["OK", "BAD"],
                                  use_cuda=False)
en_de_da = MonoTransQuestModel("xlmroberta",
                               "TransQuest/monotransquest-da-en_de-wiki",
                               args=monotransquest_config,
                               num_labels=1,
                               use_cuda=False)

# en_zh_word = MicroTransQuestModel("xlmroberta", "TransQuest/microtransquest-en_zh-wiki", args=model_args, labels=["OK", "BAD"], use_cuda=False)
# en_zh_da = MonoTransQuestModel("xlmroberta", "TransQuest/monotransquest-da-en_zh-wiki", args=model_args, num_labels=1, use_cuda=False)
#
# multilingual = MicroTransQuestModel("xlmroberta", "TransQuest/microtransquest-en_zh-wiki", args=model_args, labels=["OK", "BAD"], use_cuda=False)
# multilingual_da = MonoTransQuestModel("xlmroberta", "TransQuest/monotransquest-da-multilingual", args=model_args, num_labels=1, use_cuda=False)

logging.info("Finished loading models")