dev_fold_sources_tags = [] dev_fold_targets_tags = [] for i in range(microtransquest_config["n_fold"]): if os.path.exists(microtransquest_config['output_dir']) and os.path.isdir( microtransquest_config['output_dir']): shutil.rmtree(microtransquest_config['output_dir']) if microtransquest_config["evaluate_during_training"]: raw_train, raw_eval = train_test_split(raw_train_df, test_size=0.1, random_state=SEED * i) model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) model.train_model(raw_train, eval_data=raw_eval) model = MicroTransQuestModel(MODEL_TYPE, microtransquest_config["best_model_dir"], labels=["OK", "BAD"], args=microtransquest_config) else: model = MicroTransQuestModel(MODEL_TYPE, MODEL_NAME, labels=["OK", "BAD"], args=microtransquest_config) model.train_model(raw_train_df) sources_tags, targets_tags = model.predict(test_sentences,
from transquest.algo.word_level.microtransquest.run_model import MicroTransQuestModel from transquest_ui.app.args import microtransquest_config, monotransquest_config class PredictedToken: def __init__(self, text, quality): self.text = text self.quality = quality model_args = {"use_multiprocessing": False} en_de_word = MicroTransQuestModel("xlmroberta", "TransQuest/microtransquest-en_de-wiki", args=microtransquest_config, labels=["OK", "BAD"], use_cuda=False) en_de_da = MonoTransQuestModel("xlmroberta", "TransQuest/monotransquest-da-en_de-wiki", args=monotransquest_config, num_labels=1, use_cuda=False) # en_zh_word = MicroTransQuestModel("xlmroberta", "TransQuest/microtransquest-en_zh-wiki", args=model_args, labels=["OK", "BAD"], use_cuda=False) # en_zh_da = MonoTransQuestModel("xlmroberta", "TransQuest/monotransquest-da-en_zh-wiki", args=model_args, num_labels=1, use_cuda=False) # # multilingual = MicroTransQuestModel("xlmroberta", "TransQuest/microtransquest-en_zh-wiki", args=model_args, labels=["OK", "BAD"], use_cuda=False) # multilingual_da = MonoTransQuestModel("xlmroberta", "TransQuest/monotransquest-da-multilingual", args=model_args, num_labels=1, use_cuda=False) logging.info("Finished loading models")