def download_pawsx_data_and_write_config(task_data_base_path: str, task_config_base_path: str): pawsx_temp_path = py_io.create_dir(task_data_base_path, "pawsx_temp") download_utils.download_and_untar( "https://storage.googleapis.com/paws/pawsx/x-final.tar.gz", pawsx_temp_path, ) languages = sorted(os.listdir(os.path.join(pawsx_temp_path, "x-final"))) for lang in languages: task_name = f"pawsx_{lang}" os.rename( src=os.path.join(pawsx_temp_path, "x-final", lang), dst=os.path.join(task_data_base_path, task_name), ) paths_dict = { "val": os.path.join(task_data_base_path, task_name, "dev_2k.tsv"), "test": os.path.join(task_data_base_path, task_name, "test_2k.tsv"), } if lang == "en": paths_dict["train"] = os.path.join(task_data_base_path, task_name, "train.tsv") datastructures.set_dict_keys(paths_dict, ["train", "val", "test"]) py_io.write_json( data={ "task": "pawsx", "paths": paths_dict, "name": task_name, "kwargs": {"language": lang}, }, path=os.path.join(task_config_base_path, f"{task_name}_config.json"), ) shutil.rmtree(pawsx_temp_path)
def test_set_dict_keys(): d = {"a": 1, "b": 2, "c": 3} assert list(py_datastructures.set_dict_keys( d, ["a", "b", "c"])) == ["a", "b", "c"] assert list(py_datastructures.set_dict_keys( d, ["a", "c", "b"])) == ["a", "c", "b"] with pytest.raises(AssertionError): py_datastructures.set_dict_keys(d, ["a", "b"])