def download_squad_data_and_write_config( task_name: str, task_data_path: str, task_config_path: str ): if task_name == "squad_v1": train_file = "train-v1.1.json" dev_file = "dev-v1.1.json" version_2_with_negative = False elif task_name == "squad_v2": train_file = "train-v2.0.json" dev_file = "dev-v2.0.json" version_2_with_negative = True else: raise KeyError(task_name) os.makedirs(task_data_path, exist_ok=True) train_path = os.path.join(task_data_path, train_file) val_path = os.path.join(task_data_path, dev_file) download_utils.download_file( url=f"https://rajpurkar.github.io/SQuAD-explorer/dataset/{train_file}", file_path=train_path, ) download_utils.download_file( url=f"https://rajpurkar.github.io/SQuAD-explorer/dataset/{dev_file}", file_path=val_path, ) py_io.write_json( data={ "task": "squad", "paths": {"train": train_path, "val": val_path}, "version_2_with_negative": version_2_with_negative, "name": task_name, }, path=task_config_path, )
def download_arct_data_and_write_config(task_name: str, task_data_path: str, task_config_path: str): os.makedirs(task_data_path, exist_ok=True) file_name_list = [ "train-doubled.tsv", "train-w-swap-doubled.tsv", "train-w-swap.tsv", "train.tsv", "dev.tsv", "test.tsv", ] for file_name in file_name_list: download_utils.download_file( f"https://raw.githubusercontent.com/UKPLab/argument-reasoning-comprehension-task/" + f"master/experiments/src/main/python/data/{file_name}", os.path.join(task_data_path, file_name), ) py_io.write_json( data={ "task": task_name, "paths": { "train": os.path.join(task_data_path, "train.tsv"), "val": os.path.join(task_data_path, "val.tsv"), "test": os.path.join(task_data_path, "test.tsv"), "train_doubled": os.path.join(task_data_path, "train-doubled.tsv"), "train_w_swap": os.path.join(task_data_path, "train-w-swap.tsv"), "train_w_swap_doubled": os.path.join(task_data_path, "train-w-swap-doubled.tsv"), }, "name": task_name, }, path=task_config_path, )
def download_mutual_plus_data_and_write_config( task_name: str, task_data_path: str, task_config_path: str ): os.makedirs(task_data_path, exist_ok=True) os.makedirs(task_data_path + "/train", exist_ok=True) os.makedirs(task_data_path + "/dev", exist_ok=True) os.makedirs(task_data_path + "/test", exist_ok=True) num_files = {"train": 7088, "dev": 886, "test": 886} for phase in num_files: examples = [] for i in range(num_files[phase]): file_name = phase + "_" + str(i + 1) + ".txt" download_utils.download_file( f"https://raw.githubusercontent.com/Nealcly/MuTual/" + f"master/data/mutual_plus/{phase}/{file_name}", os.path.join(task_data_path, phase, file_name), ) for line in py_io.read_file_lines(os.path.join(task_data_path, phase, file_name)): examples.append(line) py_io.write_jsonl(examples, os.path.join(task_data_path, phase + ".jsonl")) shutil.rmtree(os.path.join(task_data_path, phase)) py_io.write_json( data={ "task": task_name, "paths": { "train": os.path.join(task_data_path, "train.jsonl"), "val": os.path.join(task_data_path, "dev.jsonl"), "test": os.path.join(task_data_path, "test.jsonl"), }, "name": task_name, }, path=task_config_path, )
def download_xquad_data_and_write_config(task_data_base_path: str, task_config_base_path: str): languages = "ar de el en es hi ru th tr vi zh".split() for lang in languages: task_name = f"xquad_{lang}" task_data_path = py_io.create_dir(task_data_base_path, task_name) path = os.path.join(task_data_path, "xquad.json") download_utils.download_file( url= f"https://raw.githubusercontent.com/deepmind/xquad/master/xquad.{lang}.json", file_path=path, ) py_io.write_json( data={ "task": "xquad", "paths": { "val": path }, "name": task_name, "kwargs": { "language": lang }, }, path=os.path.join(task_config_base_path, f"{task_name}_config.json"), skip_if_exists=True, )
def download_senteval_data_and_write_config(task_name: str, task_data_path: str, task_config_path: str): name_map = { "senteval_bigram_shift": "bigram_shift", "senteval_coordination_inversion": "coordination_inversion", "senteval_obj_number": "obj_number", "senteval_odd_man_out": "odd_man_out", "senteval_past_present": "past_present", "senteval_sentence_length": "sentence_length", "senteval_subj_number": "subj_number", "senteval_top_constituents": "top_constituents", "senteval_tree_depth": "tree_depth", "senteval_word_content": "word_content", } dataset_name = name_map[task_name] os.makedirs(task_data_path, exist_ok=True) # data contains all train/val/test examples, first column indicates the split data_path = os.path.join(task_data_path, "data.tsv") download_utils.download_file( url= "https://raw.githubusercontent.com/facebookresearch/SentEval/master/data/probing/" f"{dataset_name}.txt", file_path=data_path, ) py_io.write_json( data={ "task": task_name, "paths": { "data": data_path }, "name": task_name }, path=task_config_path, )
def download_fever_nli_data_and_write_config( task_name: str, task_data_path: str, task_config_path: str ): os.makedirs(task_data_path, exist_ok=True) download_utils.download_and_unzip( ("https://www.dropbox.com/s/hylbuaovqwo2zav/nli_fever.zip?dl=1"), task_data_path, ) # Since the FEVER NLI dataset doesn't have labels for the dev set, we also download the original # FEVER dev set and match example CIDs to obtain labels. orig_dev_path = os.path.join(task_data_path, "fever-dev-temp.jsonl") download_utils.download_file( "https://s3-eu-west-1.amazonaws.com/fever.public/shared_task_dev.jsonl", orig_dev_path, ) id_to_label = {} for line in py_io.read_jsonl(orig_dev_path): if "id" not in line: logging.warning("FEVER dev dataset is missing ID.") continue if "label" not in line: logging.warning("FEVER dev dataset is missing label.") continue id_to_label[line["id"]] = line["label"] os.remove(orig_dev_path) dev_path = os.path.join(task_data_path, "nli_fever", "dev_fitems.jsonl") dev_examples = [] for line in py_io.read_jsonl(dev_path): if "cid" not in line: logging.warning("Data in {} is missing CID.".format(dev_path)) continue if int(line["cid"]) not in id_to_label: logging.warning("Could not match CID {} to dev data.".format(line["cid"])) continue dev_example = line dev_example["label"] = id_to_label[int(line["cid"])] dev_examples.append(dev_example) py_io.write_jsonl(dev_examples, os.path.join(task_data_path, "val.jsonl")) os.remove(dev_path) for phase in ["train", "test"]: os.rename( os.path.join(task_data_path, "nli_fever", f"{phase}_fitems.jsonl"), os.path.join(task_data_path, f"{phase}.jsonl"), ) shutil.rmtree(os.path.join(task_data_path, "nli_fever")) py_io.write_json( data={ "task": task_name, "paths": { "train": os.path.join(task_data_path, "train.jsonl"), "val": os.path.join(task_data_path, "val.jsonl"), "test": os.path.join(task_data_path, "test.jsonl"), }, "name": task_name, }, path=task_config_path, )
def download_mctaco_data_and_write_config(task_name: str, task_data_path: str, task_config_path: str): os.makedirs(task_data_path, exist_ok=True) file_name_list = ["dev_3783.tsv", "test_9442.tsv"] for file_name in file_name_list: download_utils.download_file( f"https://raw.githubusercontent.com/CogComp/MCTACO/master/dataset/{file_name}", os.path.join(task_data_path, file_name), ) py_io.write_json( data={ "task": task_name, "paths": { "val": os.path.join(task_data_path, "dev_3783.tsv"), "test": os.path.join(task_data_path, "test_9442.tsv"), }, "name": task_name, }, path=task_config_path, )
def download_spatial_data_and_write_config(task_name: str, task_data_path: str, task_config_path: str): os.makedirs(task_data_path, exist_ok=True) data_path = os.path.join(task_data_path, "data.tsv") download_utils.download_file( url="http://lukeholman.net/remote/spatial_template_sentences.csv", file_path=data_path, ) py_io.write_json( data={ "task": "spatial", "paths": { "train": os.path.join(task_data_path, "train.jsonl"), "val": os.path.join(task_data_path, "valid.jsonl"), "test": os.path.join(task_data_path, "tests.jsonl"), }, "name": "spatial", }, path=task_config_path, )
def download_acceptability_judgments_data_and_write_config( task_name: str, task_data_path: str, task_config_path: str): dataset_name = { "acceptability_definiteness": "definiteness", "acceptability_coord": "coordinating-conjunctions", "acceptability_whwords": "whwords", "acceptability_eos": "eos", }[task_name] os.makedirs(task_data_path, exist_ok=True) # data contains all train/val/test examples # metadata contains the split indicators # (there are 10 CV folds, we use fold1 by default, see below) data_path = os.path.join(task_data_path, "data.json") metadata_path = os.path.join(task_data_path, "metadata.json") download_utils.download_file( url= "https://raw.githubusercontent.com/decompositional-semantics-initiative/DNC/master/" f"function_words/ACCEPTABILITY/acceptability-{dataset_name}_data.json", file_path=data_path, ) download_utils.download_file( url= "https://raw.githubusercontent.com/decompositional-semantics-initiative/DNC/master/" f"function_words/ACCEPTABILITY/acceptability-{dataset_name}_metadata.json", file_path=metadata_path, ) py_io.write_json( data={ "task": task_name, "paths": { "data": data_path, "metadata": metadata_path }, "name": task_name, "kwargs": { "fold": "fold1" }, # use fold1 (out of 10) by default }, path=task_config_path, )
def download_piqa_data_and_write_config(task_name: str, task_data_path: str, task_config_path: str): os.makedirs(task_data_path, exist_ok=True) download_utils.download_file( "https://yonatanbisk.com/piqa/data/train.jsonl", os.path.join(task_data_path, "train.jsonl"), ) download_utils.download_file( "https://yonatanbisk.com/piqa/data/train-labels.lst", os.path.join(task_data_path, "train-labels.lst"), ) download_utils.download_file( "https://yonatanbisk.com/piqa/data/valid.jsonl", os.path.join(task_data_path, "valid.jsonl"), ) download_utils.download_file( "https://yonatanbisk.com/piqa/data/valid-labels.lst", os.path.join(task_data_path, "valid-labels.lst"), ) download_utils.download_file( "https://yonatanbisk.com/piqa/data/tests.jsonl", os.path.join(task_data_path, "tests.jsonl"), ) py_io.write_json( data={ "task": task_name, "paths": { "train": os.path.join(task_data_path, "train.jsonl"), "train_labels": os.path.join(task_data_path, "train-labels.lst"), "val": os.path.join(task_data_path, "valid.jsonl"), "val_labels": os.path.join(task_data_path, "valid-labels.lst"), "test": os.path.join(task_data_path, "tests.jsonl"), }, "name": task_name, }, path=task_config_path, )
def download_mrqa_natural_questions_data_and_write_config( task_name: str, task_data_path: str, task_config_path: str): os.makedirs(task_data_path, exist_ok=True) download_utils.download_file( "https://s3.us-east-2.amazonaws.com/mrqa/release/v2/train/NaturalQuestionsShort.jsonl.gz", os.path.join(task_data_path, "train.jsonl.gz"), ) download_utils.download_file( "https://s3.us-east-2.amazonaws.com/mrqa/release/v2/dev/NaturalQuestionsShort.jsonl.gz", os.path.join(task_data_path, "val.jsonl.gz"), ) py_io.write_json( data={ "task": task_name, "paths": { "train": os.path.join(task_data_path, "train.jsonl.gz"), "val": os.path.join(task_data_path, "val.jsonl.gz"), }, "name": task_name, }, path=task_config_path, )
def download_newsqa_data_and_write_config( task_name: str, task_data_path: str, task_config_path: str ): def get_consensus_answer(row_): answer_char_start, answer_char_end = None, None if row_.validated_answers: validated_answers_ = json.loads(row.validated_answers) answer_, max_count = max(validated_answers_.items(), key=itemgetter(1)) total_count = sum(validated_answers_.values()) if max_count >= total_count / 2.0: if answer_ != "none" and answer_ != "bad_question": answer_char_start, answer_char_end = map(int, answer_.split(":")) else: # No valid answer. pass else: # Check row_.answer_char_ranges for most common answer. # No validation was done so there must be an answer with consensus. answers = Counter() for user_answer in row_.answer_char_ranges.split("|"): for ans in user_answer.split(","): answers[ans] += 1 top_answer = answers.most_common(1) if top_answer: top_answer, _ = top_answer[0] if ":" in top_answer: answer_char_start, answer_char_end = map(int, top_answer.split(":")) return answer_char_start, answer_char_end def load_combined(path): result = pd.read_csv( path, encoding="utf-8", dtype=dict(is_answer_absent=float), na_values=dict(question=[], story_text=[], validated_answers=[]), keep_default_na=False, ) if "story_text" in result.keys(): for row_ in display.tqdm( result.itertuples(), total=len(result), desc="Adjusting story texts" ): story_text_ = row_.story_text.replace("\r\n", "\n") result.at[row_.Index, "story_text"] = story_text_ return result def _map_answers(answers): result = [] for a in answers.split("|"): user_answers = [] result.append(dict(sourcerAnswers=user_answers)) for r in a.split(","): if r == "None": user_answers.append(dict(noAnswer=True)) else: start_, end_ = map(int, r.split(":")) user_answers.append(dict(s=start_, e=end_)) return result def strip_empty_strings(strings): while strings and strings[-1] == "": del strings[-1] return strings # Require: cnn_stories.tgz cnn_stories_path = os.path.join(task_data_path, "cnn_stories.tgz") assert os.path.exists(cnn_stories_path), ( "Download CNN Stories from https://cs.nyu.edu/~kcho/DMQA/ and save to " + cnn_stories_path ) # Require: newsqa-data-v1/newsqa-data-v1.csv dataset_path = os.path.join(task_data_path, "newsqa-data-v1", "newsqa-data-v1.csv") if os.path.exists(dataset_path): pass elif os.path.exists(os.path.join(task_data_path, "newsqa-data-v1.zip")): download_utils.unzip_file( zip_path=os.path.join(task_data_path, "newsqa-data-v1.zip"), extract_location=task_data_path, delete=False, ) else: raise AssertionError( "Download https://www.microsoft.com/en-us/research/project/newsqa-dataset/#!download" " and save to " + os.path.join(task_data_path, "newsqa-data-v1.zip") ) # Download auxiliary data os.makedirs(task_data_path, exist_ok=True) file_name_list = [ "train_story_ids.csv", "dev_story_ids.csv", "test_story_ids.csv", "stories_requiring_extra_newline.csv", "stories_requiring_two_extra_newlines.csv", "stories_to_decode_specially.csv", ] for file_name in file_name_list: download_utils.download_file( f"https://raw.githubusercontent.com/Maluuba/newsqa/master/maluuba/newsqa/{file_name}", os.path.join(task_data_path, file_name), ) dataset = load_combined(dataset_path) remaining_story_ids = set(dataset["story_id"]) with open( os.path.join(task_data_path, "stories_requiring_extra_newline.csv"), "r", encoding="utf-8" ) as f: stories_requiring_extra_newline = set(f.read().split("\n")) with open( os.path.join(task_data_path, "stories_requiring_two_extra_newlines.csv"), "r", encoding="utf-8", ) as f: stories_requiring_two_extra_newlines = set(f.read().split("\n")) with open( os.path.join(task_data_path, "stories_to_decode_specially.csv"), "r", encoding="utf-8" ) as f: stories_to_decode_specially = set(f.read().split("\n")) # Start combining data files story_id_to_text = {} with tarfile.open(cnn_stories_path, mode="r:gz", encoding="utf-8") as t: highlight_indicator = "@highlight" copyright_line_pattern = re.compile( "^(Copyright|Entire contents of this article copyright, )" ) with display.tqdm(total=len(remaining_story_ids), desc="Getting story texts") as pbar: for member in t.getmembers(): story_id = member.name if story_id in remaining_story_ids: remaining_story_ids.remove(story_id) story_file = t.extractfile(member) # Correct discrepancies in stories. # Problems are caused by using several programming languages and libraries. # When ingesting the stories, we started with Python 2. # After dealing with unicode issues, we tried switching to Python 3. # That caused inconsistency problems so we switched back to Python 2. # Furthermore, when crowdsourcing, JavaScript and HTML templating perturbed # the stories. # So here we map the text to be compatible with the indices. lines = map(lambda s_: s_.strip().decode("utf-8"), story_file.readlines()) story_file.close() lines = list(lines) highlights_start = lines.index(highlight_indicator) story_lines = lines[:highlights_start] story_lines = strip_empty_strings(story_lines) while len(story_lines) > 1 and copyright_line_pattern.search(story_lines[-1]): story_lines = strip_empty_strings(story_lines[:-2]) if story_id in stories_requiring_two_extra_newlines: story_text = "\n\n\n".join(story_lines) elif story_id in stories_requiring_extra_newline: story_text = "\n\n".join(story_lines) else: story_text = "\n".join(story_lines) story_text = story_text.replace("\xe2\x80\xa2", "\xe2\u20ac\xa2") story_text = story_text.replace("\xe2\x82\xac", "\xe2\u201a\xac") story_text = story_text.replace("\r", "\n") if story_id in stories_to_decode_specially: story_text = story_text.replace("\xe9", "\xc3\xa9") story_id_to_text[story_id] = story_text pbar.update() if len(remaining_story_ids) == 0: break for row in display.tqdm(dataset.itertuples(), total=len(dataset), desc="Setting story texts"): # Set story_text since we cannot include it in the dataset. story_text = story_id_to_text[row.story_id] dataset.at[row.Index, "story_text"] = story_text # Handle endings that are too large. answer_char_ranges = row.answer_char_ranges.split("|") updated_answer_char_ranges = [] ranges_updated = False for user_answer_char_ranges in answer_char_ranges: updated_user_answer_char_ranges = [] for char_range in user_answer_char_ranges.split(","): if char_range != "None": start, end = map(int, char_range.split(":")) if end > len(story_text): ranges_updated = True end = len(story_text) if start < end: updated_user_answer_char_ranges.append("%d:%d" % (start, end)) else: # It's unclear why but sometimes the end is after the start. # We'll filter these out. ranges_updated = True else: updated_user_answer_char_ranges.append(char_range) if updated_user_answer_char_ranges: updated_user_answer_char_ranges = ",".join(updated_user_answer_char_ranges) updated_answer_char_ranges.append(updated_user_answer_char_ranges) if ranges_updated: updated_answer_char_ranges = "|".join(updated_answer_char_ranges) dataset.at[row.Index, "answer_char_ranges"] = updated_answer_char_ranges if row.validated_answers and not pd.isnull(row.validated_answers): updated_validated_answers = {} validated_answers = json.loads(row.validated_answers) for char_range, count in validated_answers.items(): if ":" in char_range: start, end = map(int, char_range.split(":")) if end > len(story_text): ranges_updated = True end = len(story_text) if start < end: char_range = "{}:{}".format(start, end) updated_validated_answers[char_range] = count else: # It's unclear why but sometimes the end is after the start. # We'll filter these out. ranges_updated = True else: updated_validated_answers[char_range] = count if ranges_updated: updated_validated_answers = json.dumps( updated_validated_answers, ensure_ascii=False, separators=(",", ":") ) dataset.at[row.Index, "validated_answers"] = updated_validated_answers # Process Splits data = [] cache = dict() train_story_ids = set( pd.read_csv(os.path.join(task_data_path, "train_story_ids.csv"))["story_id"].values ) dev_story_ids = set( pd.read_csv(os.path.join(task_data_path, "dev_story_ids.csv"))["story_id"].values ) test_story_ids = set( pd.read_csv(os.path.join(task_data_path, "test_story_ids.csv"))["story_id"].values ) def _get_data_type(story_id_): if story_id_ in train_story_ids: return "train" elif story_id_ in dev_story_ids: return "dev" elif story_id_ in test_story_ids: return "test" else: return ValueError("{} not found in any story ID set.".format(story_id)) for row in display.tqdm(dataset.itertuples(), total=len(dataset), desc="Building json"): questions = cache.get(row.story_id) if questions is None: questions = [] datum = dict( storyId=row.story_id, type=_get_data_type(row.story_id), text=row.story_text, questions=questions, ) cache[row.story_id] = questions data.append(datum) q = dict( q=row.question, answers=_map_answers(row.answer_char_ranges), isAnswerAbsent=row.is_answer_absent, ) if row.is_question_bad != "?": q["isQuestionBad"] = float(row.is_question_bad) if row.validated_answers and not pd.isnull(row.validated_answers): validated_answers = json.loads(row.validated_answers) q["validatedAnswers"] = [] for answer, count in validated_answers.items(): answer_item = dict(count=count) if answer == "none": answer_item["noAnswer"] = True elif answer == "bad_question": answer_item["badQuestion"] = True else: s, e = map(int, answer.split(":")) answer_item["s"] = s answer_item["e"] = e q["validatedAnswers"].append(answer_item) consensus_start, consensus_end = get_consensus_answer(row) if consensus_start is None and consensus_end is None: if q.get("isQuestionBad", 0) >= 0.5: q["consensus"] = dict(badQuestion=True) else: q["consensus"] = dict(noAnswer=True) else: q["consensus"] = dict(s=consensus_start, e=consensus_end) questions.append(q) phase_dict = { "train": [], "val": [], "test": [], } phase_map = {"train": "train", "dev": "val", "test": "test"} for entry in data: phase = phase_map[entry["type"]] output_entry = {"text": entry["text"], "storyId": entry["storyId"], "qas": []} for qn in entry["questions"]: if "badQuestion" in qn["consensus"] or "noAnswer" in qn["consensus"]: continue output_entry["qas"].append({"question": qn["q"], "answer": qn["consensus"]}) phase_dict[phase].append(output_entry) for phase, phase_data in phase_dict.items(): py_io.write_jsonl(phase_data, os.path.join(task_data_path, f"{phase}.jsonl")) py_io.write_json( data={ "task": task_name, "paths": { "train": os.path.join(task_data_path, "train.jsonl"), "val": os.path.join(task_data_path, "val.jsonl"), "test": os.path.join(task_data_path, "val.jsonl"), }, "name": task_name, }, path=task_config_path, ) for file_name in file_name_list: os.remove(os.path.join(task_data_path, file_name))
def download_tydiqa_data_and_write_config(task_data_base_path: str, task_config_base_path: str): tydiqa_temp_path = py_io.create_dir(task_data_base_path, "tydiqa_temp") full_train_path = os.path.join(tydiqa_temp_path, "tydiqa-goldp-v1.1-train.json") download_utils.download_file( "https://storage.googleapis.com/tydiqa/v1.1/tydiqa-goldp-v1.1-train.json", full_train_path, ) download_utils.download_and_untar( "https://storage.googleapis.com/tydiqa/v1.1/tydiqa-goldp-v1.1-dev.tgz", tydiqa_temp_path, ) languages_dict = { "arabic": "ar", "bengali": "bn", "english": "en", "finnish": "fi", "indonesian": "id", "korean": "ko", "russian": "ru", "swahili": "sw", "telugu": "te", } # Split train data data = py_io.read_json(full_train_path) lang2data = {lang: [] for lang in languages_dict.values()} for doc in data["data"]: for par in doc["paragraphs"]: context = par["context"] for qa in par["qas"]: question = qa["question"] question_id = qa["id"] example_lang = languages_dict[question_id.split("-")[0]] q_id = question_id.split("-")[-1] for answer in qa["answers"]: a_start, a_text = answer["answer_start"], answer["text"] a_end = a_start + len(a_text) assert context[a_start:a_end] == a_text lang2data[example_lang].append({ "paragraphs": [{ "context": context, "qas": [{ "answers": qa["answers"], "question": question, "id": q_id }], }] }) for full_lang, lang in languages_dict.items(): task_name = f"tydiqa_{lang}" task_data_path = py_io.create_dir(task_data_base_path, task_name) train_path = os.path.join(task_data_path, f"tydiqa.{lang}.train.json") py_io.write_json( data=data, path=train_path, skip_if_exists=True, ) val_path = os.path.join(task_data_path, f"tydiqa.{lang}.dev.json") os.rename( src=os.path.join(tydiqa_temp_path, "tydiqa-goldp-v1.1-dev", f"tydiqa-goldp-dev-{full_lang}.json"), dst=val_path, ) py_io.write_json( data={ "task": "tydiqa", "paths": { "train": train_path, "val": val_path }, "kwargs": { "language": lang }, "name": task_name, }, path=os.path.join(task_config_base_path, f"{task_name}_config.json"), skip_if_exists=True, ) shutil.rmtree(tydiqa_temp_path)
def download_udpos_data_and_write_config(task_data_base_path: str, task_config_base_path: str): # UDPOS requires networkx==1.11 def _read_one_file(file): # Adapted from https://github.com/JunjieHu/xtreme/blob/ # 9fe0b142d0ee3eb7dd047ab86f12a76702e79bb4/utils_preprocess.py data = [] sent, tag, lines = [], [], [] for line in open(file, "r"): items = line.strip().split("\t") if len(items) != 10: num_empty = sum([int(w == "_") for w in sent]) if num_empty == 0 or num_empty < len(sent) - 1: data.append((sent, tag, lines)) sent, tag, lines = [], [], [] else: sent.append(items[1].strip()) tag.append(items[3].strip()) lines.append(line.strip()) assert len(sent) == int( items[0]), "line={}, sent={}, tag={}".format( line, sent, tag) return data def _remove_empty_space(data): # Adapted from https://github.com/google-research/xtreme/blob/ # 522434d1aece34131d997a97ce7e9242a51a688a/utils_preprocess.py#L212 new_data = {} for split in data: new_data[split] = [] for sent, tag, lines in data[split]: new_sent = [ "".join(w.replace("\u200c", "").split(" ")) for w in sent ] lines = [line.replace("\u200c", "") for line in lines] assert len(" ".join(new_sent).split(" ")) == len(tag) new_data[split].append((new_sent, tag, lines)) return new_data def check_file(file): # Adapted from https://github.com/google-research/xtreme/blob/ # 522434d1aece34131d997a97ce7e9242a51a688a/utils_preprocess.py#L223 for i, l in enumerate(open(file)): items = l.strip().split("\t") assert len(items[0].split(" ")) == len( items[1].split(" ")), "idx={}, line={}".format(i, l) def _write_files(data, output_dir, lang_, suffix): # Adapted from https://github.com/google-research/xtreme/blob/ # 522434d1aece34131d997a97ce7e9242a51a688a/utils_preprocess.py#L228 for split in data: if len(data[split]) > 0: prefix = os.path.join(output_dir, f"{split}-{lang_}") if suffix == "mt": path = prefix + ".mt.tsv" if os.path.exists(path): logger.info( 'Skip writing to %s since it already exists.', path) else: with py_io.get_lock(path): with open(path, "w") as fout: for idx, (sent, tag, _) in enumerate(data[split]): newline = "\n" if idx != len( data[split]) - 1 else "" fout.write("{}\t{}{}".format( " ".join(sent), " ".join(tag), newline)) check_file(prefix + ".mt.tsv") logger.info(" - finish checking " + prefix + ".mt.tsv") elif suffix == "tsv": path = prefix + ".tsv" if os.path.exists(path): logger.info( 'Skip writing to %s since it already exists.', path) else: with py_io.get_lock(path): with open(path, "w") as fout: for sidx, (sent, tag, _) in enumerate(data[split]): for widx, (w, t) in enumerate(zip(sent, tag)): newline = ( "" if (sidx == len(data[split]) - 1) and (widx == len(sent) - 1) else "\n") fout.write("{}\t{}{}".format( w, t, newline)) fout.write("\n") elif suffix == "conll": path = prefix + ".conll" if os.path.exists(path): logger.info( 'Skip writing to %s since it already exists.', path) else: with open(path, "w") as fout: for _, _, lines in data[split]: for line in lines: fout.write(line.strip() + "\n") fout.write("\n") logger.info(f"finish writing file to {prefix}.{suffix}") languages = ("af ar bg de el en es et eu fa fi fr he hi hu id it ja " "kk ko mr nl pt ru ta te th tl tr ur vi yo zh").split() udpos_temp_path = py_io.create_dir(task_data_base_path, "udpos_temp") download_utils.download_and_untar( "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3105/" "ud-treebanks-v2.5.tgz", udpos_temp_path, ) download_utils.download_file( "https://raw.githubusercontent.com/google-research/xtreme/master/third_party/" "ud-conversion-tools/lib/conll.py", os.path.join(udpos_temp_path, "conll.py"), ) conll = filesystem.import_from_path( os.path.join(udpos_temp_path, "conll.py")) conllu_path_ls = sorted( glob.glob(os.path.join(udpos_temp_path, "*", "*", "*.conllu"))) conll_path = os.path.join(udpos_temp_path, "conll") # === Convert conllu files to conll === # for input_path in display.tqdm( conllu_path_ls, desc="Convert conllu files to conll format"): input_path_fol, input_path_file = os.path.split(input_path) lang = input_path_file.split("_")[0] os.makedirs(os.path.join(conll_path, lang), exist_ok=True) output_path = os.path.join( conll_path, lang, strings.replace_suffix(input_path_file, "conllu", "conll")) pos_rank_precedence_dict = { "default": ("VERB NOUN PROPN PRON ADJ NUM ADV INTJ AUX ADP DET PART CCONJ SCONJ X PUNCT" ).split(" "), "es": "VERB AUX PRON ADP DET".split(" "), "fr": "VERB AUX PRON NOUN ADJ ADV ADP DET PART SCONJ CONJ".split(" "), "it": "VERB AUX ADV PRON ADP DET INTJ".split(" "), } if lang in pos_rank_precedence_dict: current_pos_precedence_list = pos_rank_precedence_dict[lang] else: current_pos_precedence_list = pos_rank_precedence_dict["default"] cio = conll.CoNLLReader() orig_treebank = cio.read_conll_u(input_path) modif_treebank = copy.copy(orig_treebank) for s in modif_treebank: s.filter_sentence_content( replace_subtokens_with_fused_forms=True, posPreferenceDict=current_pos_precedence_list, node_properties_to_remove=False, remove_deprel_suffixes=False, remove_arabic_diacritics=False, ) cio.write_conll( list_of_graphs=modif_treebank, conll_path=Path(output_path), conllformat="conll2006", print_fused_forms=True, print_comments=False, ) # === Convert conll to final format === # for lang in display.tqdm(languages, desc="Convert conll to final format"): task_name = f"udpos_{lang}" task_data_path = os.path.join(task_data_base_path, task_name) os.makedirs(task_data_path, exist_ok=True) all_examples = {k: [] for k in ["train", "val", "test"]} for path in sorted(glob.glob(os.path.join(conll_path, lang, "*.conll"))): examples = _read_one_file(path) if "train" in path: all_examples["train"] += examples elif "dev" in path: all_examples["val"] += examples elif "test" in path: all_examples["test"] += examples else: raise KeyError() all_examples = _remove_empty_space(all_examples) _write_files( data=all_examples, output_dir=task_data_path, lang_=lang, suffix="tsv", ) paths_dict = { phase: os.path.join(task_data_path, f"{phase}-{lang}.tsv") for phase, phase_data in all_examples.items() if len(phase_data) > 0 } py_io.write_json( data={ "task": "udpos", "paths": paths_dict, "name": task_name, "kwargs": { "language": lang }, }, path=os.path.join(task_config_base_path, f"{task_name}_config.json"), skip_if_exists=True, ) shutil.rmtree(udpos_temp_path)