示例#1
0
import urllib.request
import os

if not os.path.exists(TEMP_DIRECTORY):
    os.makedirs(TEMP_DIRECTORY)

if GOOGLE_DRIVE:
    download_from_google_drive(DRIVE_FILE_ID, MODEL_NAME)

urllib.request.urlretrieve(
    "https://www.quest.dcs.shef.ac.uk/wmt20_files_qe/training_ro-en.tar.gz",
    "training_ro-en.tar.gz")

model = QuestModel(MODEL_TYPE,
                   MODEL_NAME,
                   num_labels=1,
                   use_cuda=torch.cuda.is_available(),
                   args=transformer_nmt_config)

tar = tarfile.open("training_ro-en.tar.gz", "r:gz")
tar.extractall()
tar.close()

with open('train.roen.ro') as f:
    romanian_lines = f.read().splitlines()

with open('train.roen.en') as f:
    english_lines = f.read().splitlines()

nmt_sentence_pairs = list(
    map(list, zip(romanian_lines[0:1000000], english_lines[0:1000000])))
示例#2
0
assert (len(index) == 1000)
if transformer_config["evaluate_during_training"]:
    if transformer_config["n_fold"] > 1:
        dev_preds = np.zeros((len(dev), transformer_config["n_fold"]))
        test_preds = np.zeros((len(test), transformer_config["n_fold"]))
        for i in range(transformer_config["n_fold"]):

            if os.path.exists(
                    transformer_config['output_dir']) and os.path.isdir(
                        transformer_config['output_dir']):
                shutil.rmtree(transformer_config['output_dir'])

            model = QuestModel(MODEL_TYPE,
                               MODEL_NAME,
                               num_labels=1,
                               use_cuda=torch.cuda.is_available(),
                               args=transformer_config)
            train_df, eval_df = train_test_split(train,
                                                 test_size=0.1,
                                                 random_state=SEED * i)
            model.train_model(train_df,
                              eval_df=eval_df,
                              pearson_corr=pearson_corr,
                              spearman_corr=spearman_corr,
                              mae=mean_absolute_error)
            model = QuestModel(MODEL_TYPE,
                               transformer_config["best_model_dir"],
                               num_labels=1,
                               use_cuda=torch.cuda.is_available(),
                               args=transformer_config)