예제 #1
0
def download_mutual_plus_data_and_write_config(
    task_name: str, task_data_path: str, task_config_path: str
):
    os.makedirs(task_data_path, exist_ok=True)
    os.makedirs(task_data_path + "/train", exist_ok=True)
    os.makedirs(task_data_path + "/dev", exist_ok=True)
    os.makedirs(task_data_path + "/test", exist_ok=True)
    num_files = {"train": 7088, "dev": 886, "test": 886}
    for phase in num_files:
        examples = []
        for i in range(num_files[phase]):
            file_name = phase + "_" + str(i + 1) + ".txt"
            download_utils.download_file(
                f"https://raw.githubusercontent.com/Nealcly/MuTual/"
                + f"master/data/mutual_plus/{phase}/{file_name}",
                os.path.join(task_data_path, phase, file_name),
            )
            for line in py_io.read_file_lines(os.path.join(task_data_path, phase, file_name)):
                examples.append(line)
        py_io.write_jsonl(examples, os.path.join(task_data_path, phase + ".jsonl"))
        shutil.rmtree(os.path.join(task_data_path, phase))

    py_io.write_json(
        data={
            "task": task_name,
            "paths": {
                "train": os.path.join(task_data_path, "train.jsonl"),
                "val": os.path.join(task_data_path, "dev.jsonl"),
                "test": os.path.join(task_data_path, "test.jsonl"),
            },
            "name": task_name,
        },
        path=task_config_path,
    )
예제 #2
0
파일: xtreme.py 프로젝트: v-mipeng/jiant
def download_xnli_data_and_write_config(task_data_base_path: str, task_config_base_path: str):
    xnli_temp_path = py_io.create_dir(task_data_base_path, "xnli_temp")
    download_utils.download_and_unzip(
        "https://dl.fbaipublicfiles.com/XNLI/XNLI-1.0.zip", xnli_temp_path,
    )
    full_val_data = py_io.read_jsonl(os.path.join(xnli_temp_path, "XNLI-1.0", "xnli.dev.jsonl"))
    val_data = datastructures.group_by(full_val_data, key_func=lambda elem: elem["language"])
    full_test_data = py_io.read_jsonl(os.path.join(xnli_temp_path, "XNLI-1.0", "xnli.test.jsonl"))
    test_data = datastructures.group_by(full_test_data, lambda elem: elem["language"])
    languages = sorted(list(val_data))
    for lang in languages:
        task_name = f"xnli_{lang}"
        task_data_path = py_io.create_dir(task_data_base_path, task_name)
        val_path = os.path.join(task_data_path, "val.jsonl")
        test_path = os.path.join(task_data_path, "test.jsonl")
        py_io.write_jsonl(data=val_data[lang], path=val_path)
        py_io.write_jsonl(data=test_data[lang], path=test_path)
        py_io.write_json(
            data={
                "task": "xnli",
                "paths": {"val": val_path, "test": test_path},
                "name": task_name,
                "kwargs": {"language": lang},
            },
            path=os.path.join(task_config_base_path, f"{task_name}_config.json"),
        )
    shutil.rmtree(xnli_temp_path)
예제 #3
0
def download_fever_nli_data_and_write_config(
    task_name: str, task_data_path: str, task_config_path: str
):
    os.makedirs(task_data_path, exist_ok=True)
    download_utils.download_and_unzip(
        ("https://www.dropbox.com/s/hylbuaovqwo2zav/nli_fever.zip?dl=1"), task_data_path,
    )
    # Since the FEVER NLI dataset doesn't have labels for the dev set, we also download the original
    # FEVER dev set and match example CIDs to obtain labels.
    orig_dev_path = os.path.join(task_data_path, "fever-dev-temp.jsonl")
    download_utils.download_file(
        "https://s3-eu-west-1.amazonaws.com/fever.public/shared_task_dev.jsonl", orig_dev_path,
    )
    id_to_label = {}
    for line in py_io.read_jsonl(orig_dev_path):
        if "id" not in line:
            logging.warning("FEVER dev dataset is missing ID.")
            continue
        if "label" not in line:
            logging.warning("FEVER dev dataset is missing label.")
            continue
        id_to_label[line["id"]] = line["label"]
    os.remove(orig_dev_path)

    dev_path = os.path.join(task_data_path, "nli_fever", "dev_fitems.jsonl")
    dev_examples = []
    for line in py_io.read_jsonl(dev_path):
        if "cid" not in line:
            logging.warning("Data in {} is missing CID.".format(dev_path))
            continue
        if int(line["cid"]) not in id_to_label:
            logging.warning("Could not match CID {} to dev data.".format(line["cid"]))
            continue
        dev_example = line
        dev_example["label"] = id_to_label[int(line["cid"])]
        dev_examples.append(dev_example)
    py_io.write_jsonl(dev_examples, os.path.join(task_data_path, "val.jsonl"))
    os.remove(dev_path)

    for phase in ["train", "test"]:
        os.rename(
            os.path.join(task_data_path, "nli_fever", f"{phase}_fitems.jsonl"),
            os.path.join(task_data_path, f"{phase}.jsonl"),
        )
    shutil.rmtree(os.path.join(task_data_path, "nli_fever"))

    py_io.write_json(
        data={
            "task": task_name,
            "paths": {
                "train": os.path.join(task_data_path, "train.jsonl"),
                "val": os.path.join(task_data_path, "val.jsonl"),
                "test": os.path.join(task_data_path, "test.jsonl"),
            },
            "name": task_name,
        },
        path=task_config_path,
    )