コード例 #1
0
ファイル: files_tasks.py プロジェクト: leehaausing/jiant
def download_squad_data_and_write_config(
    task_name: str, task_data_path: str, task_config_path: str
):
    if task_name == "squad_v1":
        train_file = "train-v1.1.json"
        dev_file = "dev-v1.1.json"
        version_2_with_negative = False
    elif task_name == "squad_v2":
        train_file = "train-v2.0.json"
        dev_file = "dev-v2.0.json"
        version_2_with_negative = True
    else:
        raise KeyError(task_name)

    os.makedirs(task_data_path, exist_ok=True)
    train_path = os.path.join(task_data_path, train_file)
    val_path = os.path.join(task_data_path, dev_file)
    download_utils.download_file(
        url=f"https://rajpurkar.github.io/SQuAD-explorer/dataset/{train_file}",
        file_path=train_path,
    )
    download_utils.download_file(
        url=f"https://rajpurkar.github.io/SQuAD-explorer/dataset/{dev_file}", file_path=val_path,
    )
    py_io.write_json(
        data={
            "task": "squad",
            "paths": {"train": train_path, "val": val_path},
            "version_2_with_negative": version_2_with_negative,
            "name": task_name,
        },
        path=task_config_path,
    )
コード例 #2
0
ファイル: files_tasks.py プロジェクト: eritain/jiant
def download_arct_data_and_write_config(task_name: str, task_data_path: str, task_config_path: str):
    os.makedirs(task_data_path, exist_ok=True)
    file_name_list = [
        "train-doubled.tsv",
        "train-w-swap-doubled.tsv",
        "train-w-swap.tsv",
        "train.tsv",
        "dev.tsv",
        "test.tsv",
    ]
    for file_name in file_name_list:
        download_utils.download_file(
            f"https://raw.githubusercontent.com/UKPLab/argument-reasoning-comprehension-task/"
            + f"master/experiments/src/main/python/data/{file_name}",
            os.path.join(task_data_path, file_name),
        )
    py_io.write_json(
        data={
            "task": task_name,
            "paths": {
                "train": os.path.join(task_data_path, "train.tsv"),
                "val": os.path.join(task_data_path, "val.tsv"),
                "test": os.path.join(task_data_path, "test.tsv"),
                "train_doubled": os.path.join(task_data_path, "train-doubled.tsv"),
                "train_w_swap": os.path.join(task_data_path, "train-w-swap.tsv"),
                "train_w_swap_doubled": os.path.join(task_data_path, "train-w-swap-doubled.tsv"),
            },
            "name": task_name,
        },
        path=task_config_path,
    )
コード例 #3
0
ファイル: files_tasks.py プロジェクト: eritain/jiant
def download_mutual_plus_data_and_write_config(
    task_name: str, task_data_path: str, task_config_path: str
):
    os.makedirs(task_data_path, exist_ok=True)
    os.makedirs(task_data_path + "/train", exist_ok=True)
    os.makedirs(task_data_path + "/dev", exist_ok=True)
    os.makedirs(task_data_path + "/test", exist_ok=True)
    num_files = {"train": 7088, "dev": 886, "test": 886}
    for phase in num_files:
        examples = []
        for i in range(num_files[phase]):
            file_name = phase + "_" + str(i + 1) + ".txt"
            download_utils.download_file(
                f"https://raw.githubusercontent.com/Nealcly/MuTual/"
                + f"master/data/mutual_plus/{phase}/{file_name}",
                os.path.join(task_data_path, phase, file_name),
            )
            for line in py_io.read_file_lines(os.path.join(task_data_path, phase, file_name)):
                examples.append(line)
        py_io.write_jsonl(examples, os.path.join(task_data_path, phase + ".jsonl"))
        shutil.rmtree(os.path.join(task_data_path, phase))

    py_io.write_json(
        data={
            "task": task_name,
            "paths": {
                "train": os.path.join(task_data_path, "train.jsonl"),
                "val": os.path.join(task_data_path, "dev.jsonl"),
                "test": os.path.join(task_data_path, "test.jsonl"),
            },
            "name": task_name,
        },
        path=task_config_path,
    )
コード例 #4
0
ファイル: xtreme.py プロジェクト: HonoMi/jiant
def download_xquad_data_and_write_config(task_data_base_path: str,
                                         task_config_base_path: str):
    languages = "ar de el en es hi ru th tr vi zh".split()
    for lang in languages:
        task_name = f"xquad_{lang}"
        task_data_path = py_io.create_dir(task_data_base_path, task_name)
        path = os.path.join(task_data_path, "xquad.json")
        download_utils.download_file(
            url=
            f"https://raw.githubusercontent.com/deepmind/xquad/master/xquad.{lang}.json",
            file_path=path,
        )
        py_io.write_json(
            data={
                "task": "xquad",
                "paths": {
                    "val": path
                },
                "name": task_name,
                "kwargs": {
                    "language": lang
                },
            },
            path=os.path.join(task_config_base_path,
                              f"{task_name}_config.json"),
            skip_if_exists=True,
        )
コード例 #5
0
ファイル: files_tasks.py プロジェクト: holman57/jiant
def download_senteval_data_and_write_config(task_name: str,
                                            task_data_path: str,
                                            task_config_path: str):
    name_map = {
        "senteval_bigram_shift": "bigram_shift",
        "senteval_coordination_inversion": "coordination_inversion",
        "senteval_obj_number": "obj_number",
        "senteval_odd_man_out": "odd_man_out",
        "senteval_past_present": "past_present",
        "senteval_sentence_length": "sentence_length",
        "senteval_subj_number": "subj_number",
        "senteval_top_constituents": "top_constituents",
        "senteval_tree_depth": "tree_depth",
        "senteval_word_content": "word_content",
    }
    dataset_name = name_map[task_name]
    os.makedirs(task_data_path, exist_ok=True)
    # data contains all train/val/test examples, first column indicates the split
    data_path = os.path.join(task_data_path, "data.tsv")
    download_utils.download_file(
        url=
        "https://raw.githubusercontent.com/facebookresearch/SentEval/master/data/probing/"
        f"{dataset_name}.txt",
        file_path=data_path,
    )
    py_io.write_json(
        data={
            "task": task_name,
            "paths": {
                "data": data_path
            },
            "name": task_name
        },
        path=task_config_path,
    )
コード例 #6
0
ファイル: files_tasks.py プロジェクト: leehaausing/jiant
def download_fever_nli_data_and_write_config(
    task_name: str, task_data_path: str, task_config_path: str
):
    os.makedirs(task_data_path, exist_ok=True)
    download_utils.download_and_unzip(
        ("https://www.dropbox.com/s/hylbuaovqwo2zav/nli_fever.zip?dl=1"), task_data_path,
    )
    # Since the FEVER NLI dataset doesn't have labels for the dev set, we also download the original
    # FEVER dev set and match example CIDs to obtain labels.
    orig_dev_path = os.path.join(task_data_path, "fever-dev-temp.jsonl")
    download_utils.download_file(
        "https://s3-eu-west-1.amazonaws.com/fever.public/shared_task_dev.jsonl", orig_dev_path,
    )
    id_to_label = {}
    for line in py_io.read_jsonl(orig_dev_path):
        if "id" not in line:
            logging.warning("FEVER dev dataset is missing ID.")
            continue
        if "label" not in line:
            logging.warning("FEVER dev dataset is missing label.")
            continue
        id_to_label[line["id"]] = line["label"]
    os.remove(orig_dev_path)

    dev_path = os.path.join(task_data_path, "nli_fever", "dev_fitems.jsonl")
    dev_examples = []
    for line in py_io.read_jsonl(dev_path):
        if "cid" not in line:
            logging.warning("Data in {} is missing CID.".format(dev_path))
            continue
        if int(line["cid"]) not in id_to_label:
            logging.warning("Could not match CID {} to dev data.".format(line["cid"]))
            continue
        dev_example = line
        dev_example["label"] = id_to_label[int(line["cid"])]
        dev_examples.append(dev_example)
    py_io.write_jsonl(dev_examples, os.path.join(task_data_path, "val.jsonl"))
    os.remove(dev_path)

    for phase in ["train", "test"]:
        os.rename(
            os.path.join(task_data_path, "nli_fever", f"{phase}_fitems.jsonl"),
            os.path.join(task_data_path, f"{phase}.jsonl"),
        )
    shutil.rmtree(os.path.join(task_data_path, "nli_fever"))

    py_io.write_json(
        data={
            "task": task_name,
            "paths": {
                "train": os.path.join(task_data_path, "train.jsonl"),
                "val": os.path.join(task_data_path, "val.jsonl"),
                "test": os.path.join(task_data_path, "test.jsonl"),
            },
            "name": task_name,
        },
        path=task_config_path,
    )
コード例 #7
0
ファイル: files_tasks.py プロジェクト: holman57/jiant
def download_mctaco_data_and_write_config(task_name: str, task_data_path: str,
                                          task_config_path: str):
    os.makedirs(task_data_path, exist_ok=True)
    file_name_list = ["dev_3783.tsv", "test_9442.tsv"]
    for file_name in file_name_list:
        download_utils.download_file(
            f"https://raw.githubusercontent.com/CogComp/MCTACO/master/dataset/{file_name}",
            os.path.join(task_data_path, file_name),
        )
    py_io.write_json(
        data={
            "task": task_name,
            "paths": {
                "val": os.path.join(task_data_path, "dev_3783.tsv"),
                "test": os.path.join(task_data_path, "test_9442.tsv"),
            },
            "name": task_name,
        },
        path=task_config_path,
    )
コード例 #8
0
ファイル: files_tasks.py プロジェクト: holman57/jiant
def download_spatial_data_and_write_config(task_name: str, task_data_path: str,
                                           task_config_path: str):
    os.makedirs(task_data_path, exist_ok=True)
    data_path = os.path.join(task_data_path, "data.tsv")
    download_utils.download_file(
        url="http://lukeholman.net/remote/spatial_template_sentences.csv",
        file_path=data_path,
    )
    py_io.write_json(
        data={
            "task": "spatial",
            "paths": {
                "train": os.path.join(task_data_path, "train.jsonl"),
                "val": os.path.join(task_data_path, "valid.jsonl"),
                "test": os.path.join(task_data_path, "tests.jsonl"),
            },
            "name": "spatial",
        },
        path=task_config_path,
    )
コード例 #9
0
ファイル: files_tasks.py プロジェクト: holman57/jiant
def download_acceptability_judgments_data_and_write_config(
        task_name: str, task_data_path: str, task_config_path: str):
    dataset_name = {
        "acceptability_definiteness": "definiteness",
        "acceptability_coord": "coordinating-conjunctions",
        "acceptability_whwords": "whwords",
        "acceptability_eos": "eos",
    }[task_name]
    os.makedirs(task_data_path, exist_ok=True)
    # data contains all train/val/test examples
    # metadata contains the split indicators
    # (there are 10 CV folds, we use fold1 by default, see below)
    data_path = os.path.join(task_data_path, "data.json")
    metadata_path = os.path.join(task_data_path, "metadata.json")
    download_utils.download_file(
        url=
        "https://raw.githubusercontent.com/decompositional-semantics-initiative/DNC/master/"
        f"function_words/ACCEPTABILITY/acceptability-{dataset_name}_data.json",
        file_path=data_path,
    )
    download_utils.download_file(
        url=
        "https://raw.githubusercontent.com/decompositional-semantics-initiative/DNC/master/"
        f"function_words/ACCEPTABILITY/acceptability-{dataset_name}_metadata.json",
        file_path=metadata_path,
    )
    py_io.write_json(
        data={
            "task": task_name,
            "paths": {
                "data": data_path,
                "metadata": metadata_path
            },
            "name": task_name,
            "kwargs": {
                "fold": "fold1"
            },  # use fold1 (out of 10) by default
        },
        path=task_config_path,
    )
コード例 #10
0
ファイル: files_tasks.py プロジェクト: yzpang/jiant
def download_piqa_data_and_write_config(task_name: str, task_data_path: str,
                                        task_config_path: str):
    os.makedirs(task_data_path, exist_ok=True)
    download_utils.download_file(
        "https://yonatanbisk.com/piqa/data/train.jsonl",
        os.path.join(task_data_path, "train.jsonl"),
    )
    download_utils.download_file(
        "https://yonatanbisk.com/piqa/data/train-labels.lst",
        os.path.join(task_data_path, "train-labels.lst"),
    )
    download_utils.download_file(
        "https://yonatanbisk.com/piqa/data/valid.jsonl",
        os.path.join(task_data_path, "valid.jsonl"),
    )
    download_utils.download_file(
        "https://yonatanbisk.com/piqa/data/valid-labels.lst",
        os.path.join(task_data_path, "valid-labels.lst"),
    )
    download_utils.download_file(
        "https://yonatanbisk.com/piqa/data/tests.jsonl",
        os.path.join(task_data_path, "tests.jsonl"),
    )

    py_io.write_json(
        data={
            "task": task_name,
            "paths": {
                "train": os.path.join(task_data_path, "train.jsonl"),
                "train_labels": os.path.join(task_data_path,
                                             "train-labels.lst"),
                "val": os.path.join(task_data_path, "valid.jsonl"),
                "val_labels": os.path.join(task_data_path, "valid-labels.lst"),
                "test": os.path.join(task_data_path, "tests.jsonl"),
            },
            "name": task_name,
        },
        path=task_config_path,
    )
コード例 #11
0
ファイル: files_tasks.py プロジェクト: yzpang/jiant
def download_mrqa_natural_questions_data_and_write_config(
        task_name: str, task_data_path: str, task_config_path: str):
    os.makedirs(task_data_path, exist_ok=True)
    download_utils.download_file(
        "https://s3.us-east-2.amazonaws.com/mrqa/release/v2/train/NaturalQuestionsShort.jsonl.gz",
        os.path.join(task_data_path, "train.jsonl.gz"),
    )
    download_utils.download_file(
        "https://s3.us-east-2.amazonaws.com/mrqa/release/v2/dev/NaturalQuestionsShort.jsonl.gz",
        os.path.join(task_data_path, "val.jsonl.gz"),
    )
    py_io.write_json(
        data={
            "task": task_name,
            "paths": {
                "train": os.path.join(task_data_path, "train.jsonl.gz"),
                "val": os.path.join(task_data_path, "val.jsonl.gz"),
            },
            "name": task_name,
        },
        path=task_config_path,
    )
コード例 #12
0
ファイル: files_tasks.py プロジェクト: leehaausing/jiant
def download_newsqa_data_and_write_config(
    task_name: str, task_data_path: str, task_config_path: str
):
    def get_consensus_answer(row_):
        answer_char_start, answer_char_end = None, None
        if row_.validated_answers:
            validated_answers_ = json.loads(row.validated_answers)
            answer_, max_count = max(validated_answers_.items(), key=itemgetter(1))
            total_count = sum(validated_answers_.values())
            if max_count >= total_count / 2.0:
                if answer_ != "none" and answer_ != "bad_question":
                    answer_char_start, answer_char_end = map(int, answer_.split(":"))
                else:
                    # No valid answer.
                    pass
        else:
            # Check row_.answer_char_ranges for most common answer.
            # No validation was done so there must be an answer with consensus.
            answers = Counter()
            for user_answer in row_.answer_char_ranges.split("|"):
                for ans in user_answer.split(","):
                    answers[ans] += 1
            top_answer = answers.most_common(1)
            if top_answer:
                top_answer, _ = top_answer[0]
                if ":" in top_answer:
                    answer_char_start, answer_char_end = map(int, top_answer.split(":"))

        return answer_char_start, answer_char_end

    def load_combined(path):
        result = pd.read_csv(
            path,
            encoding="utf-8",
            dtype=dict(is_answer_absent=float),
            na_values=dict(question=[], story_text=[], validated_answers=[]),
            keep_default_na=False,
        )

        if "story_text" in result.keys():
            for row_ in display.tqdm(
                result.itertuples(), total=len(result), desc="Adjusting story texts"
            ):
                story_text_ = row_.story_text.replace("\r\n", "\n")
                result.at[row_.Index, "story_text"] = story_text_

        return result

    def _map_answers(answers):
        result = []
        for a in answers.split("|"):
            user_answers = []
            result.append(dict(sourcerAnswers=user_answers))
            for r in a.split(","):
                if r == "None":
                    user_answers.append(dict(noAnswer=True))
                else:
                    start_, end_ = map(int, r.split(":"))
                    user_answers.append(dict(s=start_, e=end_))
        return result

    def strip_empty_strings(strings):
        while strings and strings[-1] == "":
            del strings[-1]
        return strings

    # Require: cnn_stories.tgz
    cnn_stories_path = os.path.join(task_data_path, "cnn_stories.tgz")
    assert os.path.exists(cnn_stories_path), (
        "Download CNN Stories from https://cs.nyu.edu/~kcho/DMQA/ and save to " + cnn_stories_path
    )
    # Require: newsqa-data-v1/newsqa-data-v1.csv
    dataset_path = os.path.join(task_data_path, "newsqa-data-v1", "newsqa-data-v1.csv")
    if os.path.exists(dataset_path):
        pass
    elif os.path.exists(os.path.join(task_data_path, "newsqa-data-v1.zip")):
        download_utils.unzip_file(
            zip_path=os.path.join(task_data_path, "newsqa-data-v1.zip"),
            extract_location=task_data_path,
            delete=False,
        )
    else:
        raise AssertionError(
            "Download https://www.microsoft.com/en-us/research/project/newsqa-dataset/#!download"
            " and save to " + os.path.join(task_data_path, "newsqa-data-v1.zip")
        )

    # Download auxiliary data
    os.makedirs(task_data_path, exist_ok=True)
    file_name_list = [
        "train_story_ids.csv",
        "dev_story_ids.csv",
        "test_story_ids.csv",
        "stories_requiring_extra_newline.csv",
        "stories_requiring_two_extra_newlines.csv",
        "stories_to_decode_specially.csv",
    ]
    for file_name in file_name_list:
        download_utils.download_file(
            f"https://raw.githubusercontent.com/Maluuba/newsqa/master/maluuba/newsqa/{file_name}",
            os.path.join(task_data_path, file_name),
        )

    dataset = load_combined(dataset_path)
    remaining_story_ids = set(dataset["story_id"])
    with open(
        os.path.join(task_data_path, "stories_requiring_extra_newline.csv"), "r", encoding="utf-8"
    ) as f:
        stories_requiring_extra_newline = set(f.read().split("\n"))

    with open(
        os.path.join(task_data_path, "stories_requiring_two_extra_newlines.csv"),
        "r",
        encoding="utf-8",
    ) as f:
        stories_requiring_two_extra_newlines = set(f.read().split("\n"))

    with open(
        os.path.join(task_data_path, "stories_to_decode_specially.csv"), "r", encoding="utf-8"
    ) as f:
        stories_to_decode_specially = set(f.read().split("\n"))

    # Start combining data files
    story_id_to_text = {}
    with tarfile.open(cnn_stories_path, mode="r:gz", encoding="utf-8") as t:
        highlight_indicator = "@highlight"

        copyright_line_pattern = re.compile(
            "^(Copyright|Entire contents of this article copyright, )"
        )
        with display.tqdm(total=len(remaining_story_ids), desc="Getting story texts") as pbar:
            for member in t.getmembers():
                story_id = member.name
                if story_id in remaining_story_ids:
                    remaining_story_ids.remove(story_id)
                    story_file = t.extractfile(member)

                    # Correct discrepancies in stories.
                    # Problems are caused by using several programming languages and libraries.
                    # When ingesting the stories, we started with Python 2.
                    # After dealing with unicode issues, we tried switching to Python 3.
                    # That caused inconsistency problems so we switched back to Python 2.
                    # Furthermore, when crowdsourcing, JavaScript and HTML templating perturbed
                    # the stories.
                    # So here we map the text to be compatible with the indices.
                    lines = map(lambda s_: s_.strip().decode("utf-8"), story_file.readlines())

                    story_file.close()
                    lines = list(lines)
                    highlights_start = lines.index(highlight_indicator)
                    story_lines = lines[:highlights_start]
                    story_lines = strip_empty_strings(story_lines)
                    while len(story_lines) > 1 and copyright_line_pattern.search(story_lines[-1]):
                        story_lines = strip_empty_strings(story_lines[:-2])
                    if story_id in stories_requiring_two_extra_newlines:
                        story_text = "\n\n\n".join(story_lines)
                    elif story_id in stories_requiring_extra_newline:
                        story_text = "\n\n".join(story_lines)
                    else:
                        story_text = "\n".join(story_lines)

                    story_text = story_text.replace("\xe2\x80\xa2", "\xe2\u20ac\xa2")
                    story_text = story_text.replace("\xe2\x82\xac", "\xe2\u201a\xac")
                    story_text = story_text.replace("\r", "\n")
                    if story_id in stories_to_decode_specially:
                        story_text = story_text.replace("\xe9", "\xc3\xa9")
                    story_id_to_text[story_id] = story_text

                    pbar.update()

                    if len(remaining_story_ids) == 0:
                        break

    for row in display.tqdm(dataset.itertuples(), total=len(dataset), desc="Setting story texts"):
        # Set story_text since we cannot include it in the dataset.
        story_text = story_id_to_text[row.story_id]
        dataset.at[row.Index, "story_text"] = story_text

        # Handle endings that are too large.
        answer_char_ranges = row.answer_char_ranges.split("|")
        updated_answer_char_ranges = []
        ranges_updated = False
        for user_answer_char_ranges in answer_char_ranges:
            updated_user_answer_char_ranges = []
            for char_range in user_answer_char_ranges.split(","):
                if char_range != "None":
                    start, end = map(int, char_range.split(":"))
                    if end > len(story_text):
                        ranges_updated = True
                        end = len(story_text)
                    if start < end:
                        updated_user_answer_char_ranges.append("%d:%d" % (start, end))
                    else:
                        # It's unclear why but sometimes the end is after the start.
                        # We'll filter these out.
                        ranges_updated = True
                else:
                    updated_user_answer_char_ranges.append(char_range)
            if updated_user_answer_char_ranges:
                updated_user_answer_char_ranges = ",".join(updated_user_answer_char_ranges)
                updated_answer_char_ranges.append(updated_user_answer_char_ranges)
        if ranges_updated:
            updated_answer_char_ranges = "|".join(updated_answer_char_ranges)
            dataset.at[row.Index, "answer_char_ranges"] = updated_answer_char_ranges

        if row.validated_answers and not pd.isnull(row.validated_answers):
            updated_validated_answers = {}
            validated_answers = json.loads(row.validated_answers)
            for char_range, count in validated_answers.items():
                if ":" in char_range:
                    start, end = map(int, char_range.split(":"))
                    if end > len(story_text):
                        ranges_updated = True
                        end = len(story_text)
                    if start < end:
                        char_range = "{}:{}".format(start, end)
                        updated_validated_answers[char_range] = count
                    else:
                        # It's unclear why but sometimes the end is after the start.
                        # We'll filter these out.
                        ranges_updated = True
                else:
                    updated_validated_answers[char_range] = count
            if ranges_updated:
                updated_validated_answers = json.dumps(
                    updated_validated_answers, ensure_ascii=False, separators=(",", ":")
                )
                dataset.at[row.Index, "validated_answers"] = updated_validated_answers

    # Process Splits
    data = []
    cache = dict()

    train_story_ids = set(
        pd.read_csv(os.path.join(task_data_path, "train_story_ids.csv"))["story_id"].values
    )
    dev_story_ids = set(
        pd.read_csv(os.path.join(task_data_path, "dev_story_ids.csv"))["story_id"].values
    )
    test_story_ids = set(
        pd.read_csv(os.path.join(task_data_path, "test_story_ids.csv"))["story_id"].values
    )

    def _get_data_type(story_id_):
        if story_id_ in train_story_ids:
            return "train"
        elif story_id_ in dev_story_ids:
            return "dev"
        elif story_id_ in test_story_ids:
            return "test"
        else:
            return ValueError("{} not found in any story ID set.".format(story_id))

    for row in display.tqdm(dataset.itertuples(), total=len(dataset), desc="Building json"):
        questions = cache.get(row.story_id)
        if questions is None:
            questions = []
            datum = dict(
                storyId=row.story_id,
                type=_get_data_type(row.story_id),
                text=row.story_text,
                questions=questions,
            )
            cache[row.story_id] = questions
            data.append(datum)
        q = dict(
            q=row.question,
            answers=_map_answers(row.answer_char_ranges),
            isAnswerAbsent=row.is_answer_absent,
        )
        if row.is_question_bad != "?":
            q["isQuestionBad"] = float(row.is_question_bad)
        if row.validated_answers and not pd.isnull(row.validated_answers):
            validated_answers = json.loads(row.validated_answers)
            q["validatedAnswers"] = []
            for answer, count in validated_answers.items():
                answer_item = dict(count=count)
                if answer == "none":
                    answer_item["noAnswer"] = True
                elif answer == "bad_question":
                    answer_item["badQuestion"] = True
                else:
                    s, e = map(int, answer.split(":"))
                    answer_item["s"] = s
                    answer_item["e"] = e
                q["validatedAnswers"].append(answer_item)
        consensus_start, consensus_end = get_consensus_answer(row)
        if consensus_start is None and consensus_end is None:
            if q.get("isQuestionBad", 0) >= 0.5:
                q["consensus"] = dict(badQuestion=True)
            else:
                q["consensus"] = dict(noAnswer=True)
        else:
            q["consensus"] = dict(s=consensus_start, e=consensus_end)
        questions.append(q)

    phase_dict = {
        "train": [],
        "val": [],
        "test": [],
    }
    phase_map = {"train": "train", "dev": "val", "test": "test"}
    for entry in data:
        phase = phase_map[entry["type"]]
        output_entry = {"text": entry["text"], "storyId": entry["storyId"], "qas": []}
        for qn in entry["questions"]:
            if "badQuestion" in qn["consensus"] or "noAnswer" in qn["consensus"]:
                continue
            output_entry["qas"].append({"question": qn["q"], "answer": qn["consensus"]})
        phase_dict[phase].append(output_entry)
    for phase, phase_data in phase_dict.items():
        py_io.write_jsonl(phase_data, os.path.join(task_data_path, f"{phase}.jsonl"))
    py_io.write_json(
        data={
            "task": task_name,
            "paths": {
                "train": os.path.join(task_data_path, "train.jsonl"),
                "val": os.path.join(task_data_path, "val.jsonl"),
                "test": os.path.join(task_data_path, "val.jsonl"),
            },
            "name": task_name,
        },
        path=task_config_path,
    )
    for file_name in file_name_list:
        os.remove(os.path.join(task_data_path, file_name))
コード例 #13
0
ファイル: xtreme.py プロジェクト: HonoMi/jiant
def download_tydiqa_data_and_write_config(task_data_base_path: str,
                                          task_config_base_path: str):
    tydiqa_temp_path = py_io.create_dir(task_data_base_path, "tydiqa_temp")
    full_train_path = os.path.join(tydiqa_temp_path,
                                   "tydiqa-goldp-v1.1-train.json")
    download_utils.download_file(
        "https://storage.googleapis.com/tydiqa/v1.1/tydiqa-goldp-v1.1-train.json",
        full_train_path,
    )
    download_utils.download_and_untar(
        "https://storage.googleapis.com/tydiqa/v1.1/tydiqa-goldp-v1.1-dev.tgz",
        tydiqa_temp_path,
    )
    languages_dict = {
        "arabic": "ar",
        "bengali": "bn",
        "english": "en",
        "finnish": "fi",
        "indonesian": "id",
        "korean": "ko",
        "russian": "ru",
        "swahili": "sw",
        "telugu": "te",
    }

    # Split train data
    data = py_io.read_json(full_train_path)
    lang2data = {lang: [] for lang in languages_dict.values()}
    for doc in data["data"]:
        for par in doc["paragraphs"]:
            context = par["context"]
            for qa in par["qas"]:
                question = qa["question"]
                question_id = qa["id"]
                example_lang = languages_dict[question_id.split("-")[0]]
                q_id = question_id.split("-")[-1]
                for answer in qa["answers"]:
                    a_start, a_text = answer["answer_start"], answer["text"]
                    a_end = a_start + len(a_text)
                    assert context[a_start:a_end] == a_text
                lang2data[example_lang].append({
                    "paragraphs": [{
                        "context":
                        context,
                        "qas": [{
                            "answers": qa["answers"],
                            "question": question,
                            "id": q_id
                        }],
                    }]
                })

    for full_lang, lang in languages_dict.items():
        task_name = f"tydiqa_{lang}"
        task_data_path = py_io.create_dir(task_data_base_path, task_name)
        train_path = os.path.join(task_data_path, f"tydiqa.{lang}.train.json")
        py_io.write_json(
            data=data,
            path=train_path,
            skip_if_exists=True,
        )
        val_path = os.path.join(task_data_path, f"tydiqa.{lang}.dev.json")
        os.rename(
            src=os.path.join(tydiqa_temp_path, "tydiqa-goldp-v1.1-dev",
                             f"tydiqa-goldp-dev-{full_lang}.json"),
            dst=val_path,
        )
        py_io.write_json(
            data={
                "task": "tydiqa",
                "paths": {
                    "train": train_path,
                    "val": val_path
                },
                "kwargs": {
                    "language": lang
                },
                "name": task_name,
            },
            path=os.path.join(task_config_base_path,
                              f"{task_name}_config.json"),
            skip_if_exists=True,
        )
    shutil.rmtree(tydiqa_temp_path)
コード例 #14
0
ファイル: xtreme.py プロジェクト: HonoMi/jiant
def download_udpos_data_and_write_config(task_data_base_path: str,
                                         task_config_base_path: str):
    # UDPOS requires networkx==1.11

    def _read_one_file(file):
        # Adapted from https://github.com/JunjieHu/xtreme/blob/
        #              9fe0b142d0ee3eb7dd047ab86f12a76702e79bb4/utils_preprocess.py
        data = []
        sent, tag, lines = [], [], []
        for line in open(file, "r"):
            items = line.strip().split("\t")
            if len(items) != 10:
                num_empty = sum([int(w == "_") for w in sent])
                if num_empty == 0 or num_empty < len(sent) - 1:
                    data.append((sent, tag, lines))
                sent, tag, lines = [], [], []
            else:
                sent.append(items[1].strip())
                tag.append(items[3].strip())
                lines.append(line.strip())
                assert len(sent) == int(
                    items[0]), "line={}, sent={}, tag={}".format(
                        line, sent, tag)
        return data

    def _remove_empty_space(data):
        # Adapted from https://github.com/google-research/xtreme/blob/
        #              522434d1aece34131d997a97ce7e9242a51a688a/utils_preprocess.py#L212
        new_data = {}
        for split in data:
            new_data[split] = []
            for sent, tag, lines in data[split]:
                new_sent = [
                    "".join(w.replace("\u200c", "").split(" ")) for w in sent
                ]
                lines = [line.replace("\u200c", "") for line in lines]
                assert len(" ".join(new_sent).split(" ")) == len(tag)
                new_data[split].append((new_sent, tag, lines))
        return new_data

    def check_file(file):
        # Adapted from https://github.com/google-research/xtreme/blob/
        #              522434d1aece34131d997a97ce7e9242a51a688a/utils_preprocess.py#L223
        for i, l in enumerate(open(file)):
            items = l.strip().split("\t")
            assert len(items[0].split(" ")) == len(
                items[1].split(" ")), "idx={}, line={}".format(i, l)

    def _write_files(data, output_dir, lang_, suffix):
        # Adapted from https://github.com/google-research/xtreme/blob/
        #              522434d1aece34131d997a97ce7e9242a51a688a/utils_preprocess.py#L228
        for split in data:
            if len(data[split]) > 0:
                prefix = os.path.join(output_dir, f"{split}-{lang_}")
                if suffix == "mt":
                    path = prefix + ".mt.tsv"
                    if os.path.exists(path):
                        logger.info(
                            'Skip writing to %s since it already exists.',
                            path)
                    else:
                        with py_io.get_lock(path):
                            with open(path, "w") as fout:
                                for idx, (sent, tag,
                                          _) in enumerate(data[split]):
                                    newline = "\n" if idx != len(
                                        data[split]) - 1 else ""
                                    fout.write("{}\t{}{}".format(
                                        " ".join(sent), " ".join(tag),
                                        newline))
                            check_file(prefix + ".mt.tsv")
                            logger.info("    - finish checking " + prefix +
                                        ".mt.tsv")
                elif suffix == "tsv":
                    path = prefix + ".tsv"
                    if os.path.exists(path):
                        logger.info(
                            'Skip writing to %s since it already exists.',
                            path)
                    else:
                        with py_io.get_lock(path):
                            with open(path, "w") as fout:
                                for sidx, (sent, tag,
                                           _) in enumerate(data[split]):
                                    for widx, (w,
                                               t) in enumerate(zip(sent, tag)):
                                        newline = (
                                            "" if
                                            (sidx == len(data[split]) - 1) and
                                            (widx == len(sent) - 1) else "\n")
                                        fout.write("{}\t{}{}".format(
                                            w, t, newline))
                                    fout.write("\n")
                elif suffix == "conll":
                    path = prefix + ".conll"
                    if os.path.exists(path):
                        logger.info(
                            'Skip writing to %s since it already exists.',
                            path)
                    else:
                        with open(path, "w") as fout:
                            for _, _, lines in data[split]:
                                for line in lines:
                                    fout.write(line.strip() + "\n")
                                fout.write("\n")
                logger.info(f"finish writing file to {prefix}.{suffix}")

    languages = ("af ar bg de el en es et eu fa fi fr he hi hu id it ja "
                 "kk ko mr nl pt ru ta te th tl tr ur vi yo zh").split()
    udpos_temp_path = py_io.create_dir(task_data_base_path, "udpos_temp")
    download_utils.download_and_untar(
        "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3105/"
        "ud-treebanks-v2.5.tgz",
        udpos_temp_path,
    )
    download_utils.download_file(
        "https://raw.githubusercontent.com/google-research/xtreme/master/third_party/"
        "ud-conversion-tools/lib/conll.py",
        os.path.join(udpos_temp_path, "conll.py"),
    )
    conll = filesystem.import_from_path(
        os.path.join(udpos_temp_path, "conll.py"))
    conllu_path_ls = sorted(
        glob.glob(os.path.join(udpos_temp_path, "*", "*", "*.conllu")))
    conll_path = os.path.join(udpos_temp_path, "conll")

    # === Convert conllu files to conll === #
    for input_path in display.tqdm(
            conllu_path_ls, desc="Convert conllu files to conll format"):
        input_path_fol, input_path_file = os.path.split(input_path)
        lang = input_path_file.split("_")[0]
        os.makedirs(os.path.join(conll_path, lang), exist_ok=True)
        output_path = os.path.join(
            conll_path, lang,
            strings.replace_suffix(input_path_file, "conllu", "conll"))
        pos_rank_precedence_dict = {
            "default":
            ("VERB NOUN PROPN PRON ADJ NUM ADV INTJ AUX ADP DET PART CCONJ SCONJ X PUNCT"
             ).split(" "),
            "es":
            "VERB AUX PRON ADP DET".split(" "),
            "fr":
            "VERB AUX PRON NOUN ADJ ADV ADP DET PART SCONJ CONJ".split(" "),
            "it":
            "VERB AUX ADV PRON ADP DET INTJ".split(" "),
        }

        if lang in pos_rank_precedence_dict:
            current_pos_precedence_list = pos_rank_precedence_dict[lang]
        else:
            current_pos_precedence_list = pos_rank_precedence_dict["default"]

        cio = conll.CoNLLReader()
        orig_treebank = cio.read_conll_u(input_path)
        modif_treebank = copy.copy(orig_treebank)

        for s in modif_treebank:
            s.filter_sentence_content(
                replace_subtokens_with_fused_forms=True,
                posPreferenceDict=current_pos_precedence_list,
                node_properties_to_remove=False,
                remove_deprel_suffixes=False,
                remove_arabic_diacritics=False,
            )

        cio.write_conll(
            list_of_graphs=modif_treebank,
            conll_path=Path(output_path),
            conllformat="conll2006",
            print_fused_forms=True,
            print_comments=False,
        )

    # === Convert conll to final format === #
    for lang in display.tqdm(languages, desc="Convert conll to final format"):
        task_name = f"udpos_{lang}"
        task_data_path = os.path.join(task_data_base_path, task_name)
        os.makedirs(task_data_path, exist_ok=True)
        all_examples = {k: [] for k in ["train", "val", "test"]}
        for path in sorted(glob.glob(os.path.join(conll_path, lang,
                                                  "*.conll"))):
            examples = _read_one_file(path)
            if "train" in path:
                all_examples["train"] += examples
            elif "dev" in path:
                all_examples["val"] += examples
            elif "test" in path:
                all_examples["test"] += examples
            else:
                raise KeyError()
        all_examples = _remove_empty_space(all_examples)
        _write_files(
            data=all_examples,
            output_dir=task_data_path,
            lang_=lang,
            suffix="tsv",
        )
        paths_dict = {
            phase: os.path.join(task_data_path, f"{phase}-{lang}.tsv")
            for phase, phase_data in all_examples.items()
            if len(phase_data) > 0
        }
        py_io.write_json(
            data={
                "task": "udpos",
                "paths": paths_dict,
                "name": task_name,
                "kwargs": {
                    "language": lang
                },
            },
            path=os.path.join(task_config_base_path,
                              f"{task_name}_config.json"),
            skip_if_exists=True,
        )
    shutil.rmtree(udpos_temp_path)