Exemplo n.º 1
0
def download_panx_data_and_write_config(task_data_base_path: str, task_config_base_path: str):
    def _process_one_file(infile, outfile):
        lines = open(infile, "r").readlines()
        if lines[-1].strip() == "":
            lines = lines[:-1]
        with open(outfile, "w") as fout:
            for line in lines:
                items = line.strip().split("\t")
                if len(items) == 2:
                    label = items[1].strip()
                    idx = items[0].find(":")
                    if idx != -1:
                        token = items[0][idx + 1 :].strip()
                        fout.write(f"{token}\t{label}\n")
                else:
                    fout.write("\n")

    panx_temp_path = os.path.join(task_data_base_path, "panx_temp")
    zip_path = os.path.join(panx_temp_path, "AmazonPhotos.zip")
    assert os.path.exists(zip_path), (
        "Download AmazonPhotos.zip from"
        " https://www.amazon.com/clouddrive/share/d3KGCRCIYwhKJF0H3eWA26hjg2ZCRhjpEQtDL70FSBN"
        f" and save it to {zip_path}"
    )
    download_utils.unzip_file(zip_path=zip_path, extract_location=panx_temp_path)
    languages = (
        "af ar bg bn de el en es et eu fa fi fr he hi hu id it ja jv ka "
        "kk ko ml mr ms my nl pt ru sw ta te th tl tr ur vi yo zh"
    ).split()
    for lang in languages:
        task_name = f"panx_{lang}"
        untar_path = os.path.join(panx_temp_path, "panx_dataset", lang)
        os.makedirs(untar_path, exist_ok=True)
        download_utils.untar_file(
            tar_path=os.path.join(panx_temp_path, "panx_dataset", f"{lang}.tar.gz"),
            extract_location=untar_path,
            delete=True,
        )
        task_data_path = os.path.join(task_data_base_path, task_name)
        os.makedirs(task_data_path, exist_ok=True)
        filename_dict = {"train": "train", "val": "dev", "test": "test"}
        paths_dict = {}
        for phase, filename in filename_dict.items():
            in_path = os.path.join(untar_path, filename)
            out_path = os.path.join(task_data_path, f"{phase}.tsv")
            if not os.path.exists(in_path):
                continue
            _process_one_file(infile=in_path, outfile=out_path)
            paths_dict[phase] = out_path
        py_io.write_json(
            data={
                "task": "panx",
                "paths": paths_dict,
                "name": task_name,
                "kwargs": {"language": lang},
            },
            path=os.path.join(task_config_base_path, f"{task_name}_config.json"),
        )
    shutil.rmtree(os.path.join(panx_temp_path, "panx_dataset"))
Exemplo n.º 2
0
def download_newsqa_data_and_write_config(
    task_name: str, task_data_path: str, task_config_path: str
):
    def get_consensus_answer(row_):
        answer_char_start, answer_char_end = None, None
        if row_.validated_answers:
            validated_answers_ = json.loads(row.validated_answers)
            answer_, max_count = max(validated_answers_.items(), key=itemgetter(1))
            total_count = sum(validated_answers_.values())
            if max_count >= total_count / 2.0:
                if answer_ != "none" and answer_ != "bad_question":
                    answer_char_start, answer_char_end = map(int, answer_.split(":"))
                else:
                    # No valid answer.
                    pass
        else:
            # Check row_.answer_char_ranges for most common answer.
            # No validation was done so there must be an answer with consensus.
            answers = Counter()
            for user_answer in row_.answer_char_ranges.split("|"):
                for ans in user_answer.split(","):
                    answers[ans] += 1
            top_answer = answers.most_common(1)
            if top_answer:
                top_answer, _ = top_answer[0]
                if ":" in top_answer:
                    answer_char_start, answer_char_end = map(int, top_answer.split(":"))

        return answer_char_start, answer_char_end

    def load_combined(path):
        result = pd.read_csv(
            path,
            encoding="utf-8",
            dtype=dict(is_answer_absent=float),
            na_values=dict(question=[], story_text=[], validated_answers=[]),
            keep_default_na=False,
        )

        if "story_text" in result.keys():
            for row_ in display.tqdm(
                result.itertuples(), total=len(result), desc="Adjusting story texts"
            ):
                story_text_ = row_.story_text.replace("\r\n", "\n")
                result.at[row_.Index, "story_text"] = story_text_

        return result

    def _map_answers(answers):
        result = []
        for a in answers.split("|"):
            user_answers = []
            result.append(dict(sourcerAnswers=user_answers))
            for r in a.split(","):
                if r == "None":
                    user_answers.append(dict(noAnswer=True))
                else:
                    start_, end_ = map(int, r.split(":"))
                    user_answers.append(dict(s=start_, e=end_))
        return result

    def strip_empty_strings(strings):
        while strings and strings[-1] == "":
            del strings[-1]
        return strings

    # Require: cnn_stories.tgz
    cnn_stories_path = os.path.join(task_data_path, "cnn_stories.tgz")
    assert os.path.exists(cnn_stories_path), (
        "Download CNN Stories from https://cs.nyu.edu/~kcho/DMQA/ and save to " + cnn_stories_path
    )
    # Require: newsqa-data-v1/newsqa-data-v1.csv
    dataset_path = os.path.join(task_data_path, "newsqa-data-v1", "newsqa-data-v1.csv")
    if os.path.exists(dataset_path):
        pass
    elif os.path.exists(os.path.join(task_data_path, "newsqa-data-v1.zip")):
        download_utils.unzip_file(
            zip_path=os.path.join(task_data_path, "newsqa-data-v1.zip"),
            extract_location=task_data_path,
            delete=False,
        )
    else:
        raise AssertionError(
            "Download https://www.microsoft.com/en-us/research/project/newsqa-dataset/#!download"
            " and save to " + os.path.join(task_data_path, "newsqa-data-v1.zip")
        )

    # Download auxiliary data
    os.makedirs(task_data_path, exist_ok=True)
    file_name_list = [
        "train_story_ids.csv",
        "dev_story_ids.csv",
        "test_story_ids.csv",
        "stories_requiring_extra_newline.csv",
        "stories_requiring_two_extra_newlines.csv",
        "stories_to_decode_specially.csv",
    ]
    for file_name in file_name_list:
        download_utils.download_file(
            f"https://raw.githubusercontent.com/Maluuba/newsqa/master/maluuba/newsqa/{file_name}",
            os.path.join(task_data_path, file_name),
        )

    dataset = load_combined(dataset_path)
    remaining_story_ids = set(dataset["story_id"])
    with open(
        os.path.join(task_data_path, "stories_requiring_extra_newline.csv"), "r", encoding="utf-8"
    ) as f:
        stories_requiring_extra_newline = set(f.read().split("\n"))

    with open(
        os.path.join(task_data_path, "stories_requiring_two_extra_newlines.csv"),
        "r",
        encoding="utf-8",
    ) as f:
        stories_requiring_two_extra_newlines = set(f.read().split("\n"))

    with open(
        os.path.join(task_data_path, "stories_to_decode_specially.csv"), "r", encoding="utf-8"
    ) as f:
        stories_to_decode_specially = set(f.read().split("\n"))

    # Start combining data files
    story_id_to_text = {}
    with tarfile.open(cnn_stories_path, mode="r:gz", encoding="utf-8") as t:
        highlight_indicator = "@highlight"

        copyright_line_pattern = re.compile(
            "^(Copyright|Entire contents of this article copyright, )"
        )
        with display.tqdm(total=len(remaining_story_ids), desc="Getting story texts") as pbar:
            for member in t.getmembers():
                story_id = member.name
                if story_id in remaining_story_ids:
                    remaining_story_ids.remove(story_id)
                    story_file = t.extractfile(member)

                    # Correct discrepancies in stories.
                    # Problems are caused by using several programming languages and libraries.
                    # When ingesting the stories, we started with Python 2.
                    # After dealing with unicode issues, we tried switching to Python 3.
                    # That caused inconsistency problems so we switched back to Python 2.
                    # Furthermore, when crowdsourcing, JavaScript and HTML templating perturbed
                    # the stories.
                    # So here we map the text to be compatible with the indices.
                    lines = map(lambda s_: s_.strip().decode("utf-8"), story_file.readlines())

                    story_file.close()
                    lines = list(lines)
                    highlights_start = lines.index(highlight_indicator)
                    story_lines = lines[:highlights_start]
                    story_lines = strip_empty_strings(story_lines)
                    while len(story_lines) > 1 and copyright_line_pattern.search(story_lines[-1]):
                        story_lines = strip_empty_strings(story_lines[:-2])
                    if story_id in stories_requiring_two_extra_newlines:
                        story_text = "\n\n\n".join(story_lines)
                    elif story_id in stories_requiring_extra_newline:
                        story_text = "\n\n".join(story_lines)
                    else:
                        story_text = "\n".join(story_lines)

                    story_text = story_text.replace("\xe2\x80\xa2", "\xe2\u20ac\xa2")
                    story_text = story_text.replace("\xe2\x82\xac", "\xe2\u201a\xac")
                    story_text = story_text.replace("\r", "\n")
                    if story_id in stories_to_decode_specially:
                        story_text = story_text.replace("\xe9", "\xc3\xa9")
                    story_id_to_text[story_id] = story_text

                    pbar.update()

                    if len(remaining_story_ids) == 0:
                        break

    for row in display.tqdm(dataset.itertuples(), total=len(dataset), desc="Setting story texts"):
        # Set story_text since we cannot include it in the dataset.
        story_text = story_id_to_text[row.story_id]
        dataset.at[row.Index, "story_text"] = story_text

        # Handle endings that are too large.
        answer_char_ranges = row.answer_char_ranges.split("|")
        updated_answer_char_ranges = []
        ranges_updated = False
        for user_answer_char_ranges in answer_char_ranges:
            updated_user_answer_char_ranges = []
            for char_range in user_answer_char_ranges.split(","):
                if char_range != "None":
                    start, end = map(int, char_range.split(":"))
                    if end > len(story_text):
                        ranges_updated = True
                        end = len(story_text)
                    if start < end:
                        updated_user_answer_char_ranges.append("%d:%d" % (start, end))
                    else:
                        # It's unclear why but sometimes the end is after the start.
                        # We'll filter these out.
                        ranges_updated = True
                else:
                    updated_user_answer_char_ranges.append(char_range)
            if updated_user_answer_char_ranges:
                updated_user_answer_char_ranges = ",".join(updated_user_answer_char_ranges)
                updated_answer_char_ranges.append(updated_user_answer_char_ranges)
        if ranges_updated:
            updated_answer_char_ranges = "|".join(updated_answer_char_ranges)
            dataset.at[row.Index, "answer_char_ranges"] = updated_answer_char_ranges

        if row.validated_answers and not pd.isnull(row.validated_answers):
            updated_validated_answers = {}
            validated_answers = json.loads(row.validated_answers)
            for char_range, count in validated_answers.items():
                if ":" in char_range:
                    start, end = map(int, char_range.split(":"))
                    if end > len(story_text):
                        ranges_updated = True
                        end = len(story_text)
                    if start < end:
                        char_range = "{}:{}".format(start, end)
                        updated_validated_answers[char_range] = count
                    else:
                        # It's unclear why but sometimes the end is after the start.
                        # We'll filter these out.
                        ranges_updated = True
                else:
                    updated_validated_answers[char_range] = count
            if ranges_updated:
                updated_validated_answers = json.dumps(
                    updated_validated_answers, ensure_ascii=False, separators=(",", ":")
                )
                dataset.at[row.Index, "validated_answers"] = updated_validated_answers

    # Process Splits
    data = []
    cache = dict()

    train_story_ids = set(
        pd.read_csv(os.path.join(task_data_path, "train_story_ids.csv"))["story_id"].values
    )
    dev_story_ids = set(
        pd.read_csv(os.path.join(task_data_path, "dev_story_ids.csv"))["story_id"].values
    )
    test_story_ids = set(
        pd.read_csv(os.path.join(task_data_path, "test_story_ids.csv"))["story_id"].values
    )

    def _get_data_type(story_id_):
        if story_id_ in train_story_ids:
            return "train"
        elif story_id_ in dev_story_ids:
            return "dev"
        elif story_id_ in test_story_ids:
            return "test"
        else:
            return ValueError("{} not found in any story ID set.".format(story_id))

    for row in display.tqdm(dataset.itertuples(), total=len(dataset), desc="Building json"):
        questions = cache.get(row.story_id)
        if questions is None:
            questions = []
            datum = dict(
                storyId=row.story_id,
                type=_get_data_type(row.story_id),
                text=row.story_text,
                questions=questions,
            )
            cache[row.story_id] = questions
            data.append(datum)
        q = dict(
            q=row.question,
            answers=_map_answers(row.answer_char_ranges),
            isAnswerAbsent=row.is_answer_absent,
        )
        if row.is_question_bad != "?":
            q["isQuestionBad"] = float(row.is_question_bad)
        if row.validated_answers and not pd.isnull(row.validated_answers):
            validated_answers = json.loads(row.validated_answers)
            q["validatedAnswers"] = []
            for answer, count in validated_answers.items():
                answer_item = dict(count=count)
                if answer == "none":
                    answer_item["noAnswer"] = True
                elif answer == "bad_question":
                    answer_item["badQuestion"] = True
                else:
                    s, e = map(int, answer.split(":"))
                    answer_item["s"] = s
                    answer_item["e"] = e
                q["validatedAnswers"].append(answer_item)
        consensus_start, consensus_end = get_consensus_answer(row)
        if consensus_start is None and consensus_end is None:
            if q.get("isQuestionBad", 0) >= 0.5:
                q["consensus"] = dict(badQuestion=True)
            else:
                q["consensus"] = dict(noAnswer=True)
        else:
            q["consensus"] = dict(s=consensus_start, e=consensus_end)
        questions.append(q)

    phase_dict = {
        "train": [],
        "val": [],
        "test": [],
    }
    phase_map = {"train": "train", "dev": "val", "test": "test"}
    for entry in data:
        phase = phase_map[entry["type"]]
        output_entry = {"text": entry["text"], "storyId": entry["storyId"], "qas": []}
        for qn in entry["questions"]:
            if "badQuestion" in qn["consensus"] or "noAnswer" in qn["consensus"]:
                continue
            output_entry["qas"].append({"question": qn["q"], "answer": qn["consensus"]})
        phase_dict[phase].append(output_entry)
    for phase, phase_data in phase_dict.items():
        py_io.write_jsonl(phase_data, os.path.join(task_data_path, f"{phase}.jsonl"))
    py_io.write_json(
        data={
            "task": task_name,
            "paths": {
                "train": os.path.join(task_data_path, "train.jsonl"),
                "val": os.path.join(task_data_path, "val.jsonl"),
                "test": os.path.join(task_data_path, "val.jsonl"),
            },
            "name": task_name,
        },
        path=task_config_path,
    )
    for file_name in file_name_list:
        os.remove(os.path.join(task_data_path, file_name))