コード例 #1
0
ファイル: loader.py プロジェクト: SebiSebi/DataMine
def ARCDataset(arc_type):
    """
    Loads an ARC dataset given the partition (see the ARCType enum).
    Any error during reading will generate an exception.

    Returns a Pandas DataFrame with 4 columns:
    * 'id': string
    * 'question': string
    * 'answers': list[string], 3 <= length <= 5
    * 'correct': oneof('A', 'B', 'C', D', 'E', '1', '2', '3', '4')
    """
    assert (isinstance(arc_type, ARCType))
    download_dataset(Collection.ALLEN_AI_ARC, check_shallow_integrity)
    all_data = []
    all_ids = set()
    with open(type_to_data_file(arc_type), "rt") as f:
        for line in f:
            entry = json.loads(line)
            assert (isinstance(entry, dict))
            assert (len(entry) == 3)

            # Extract fields.
            question_id = entry["id"]
            correct_answer = entry["answerKey"]
            entry = entry["question"]
            assert (len(entry) == 2)
            question = entry["stem"]
            for choice in entry["choices"]:
                assert (isinstance(choice["label"], string_types))
                assert (len(choice["label"]) == 1)
            answers = [
                choice["text"] for choice in sorted(entry["choices"],
                                                    key=lambda x: x["label"])
            ]

            # Validate fields.
            assert (isinstance(question_id, string_types))
            assert (isinstance(correct_answer, string_types))
            assert (correct_answer
                    in ["A", "B", "C", "D", "E", "1", "2", "3",
                        "4"])  # noqa: E501
            assert (correct_answer in [x["label"] for x in entry["choices"]])
            assert (valid_choices([x["label"] for x in entry["choices"]]))
            assert (isinstance(question, string_types))
            assert (len(answers) in [3, 4, 5])
            for answer in answers:
                assert (isinstance(answer, string_types))

            assert (question_id not in all_ids)
            all_ids.add(question_id)
            all_data.append({
                "id": question_id,
                "question": question,
                "answers": answers,
                "correct": correct_answer
            })
    assert (len(all_data) == len(all_ids))
    df = pd.DataFrame(all_data)
    return df
コード例 #2
0
def TriviaQADataset(trivia_qa_type):
    """
    Loads a TriviaQA dataset given the split (see the TriviaQAType enum).
    Any error during reading will generate an exception.
    """
    assert (isinstance(trivia_qa_type, TriviaQAType))
    download_dataset(Collection.TRIVIA_QA, check_shallow_integrity)
    raise NotImplementedError("Implement TriviaQA")
コード例 #3
0
 def test_exception_raised_if_url_not_reachable(self, mock_config):
     # We also check that the dataset directory is not created if existing.
     os.makedirs(
             os.path.join(datamine_cache_dir(), self.FAKE_DATASET.name),
             mode=0o755
     )
     mock_config.return_value = self.FAKE_CONFIG
     with self.assertRaises(Exception):
         download_dataset(Collection.RACE, lambda _: False)
コード例 #4
0
    def test_dataset_is_downloaded_if_missing(self, mock_config):
        mock_config.return_value = self.FAKE_CONFIG
        responses.add(responses.GET, "http://fake-website.com/my/files.zip",
                      body=self.FAKE_URL_DATA1, status=200,
                      headers={'content-length': str(len(self.FAKE_URL_DATA1))},  # noqa: E501
                      stream=True)
        responses.add(responses.GET, "http://fake-website.com/my2/file.json",
                      body=self.FAKE_URL_DATA2, status=200,
                      headers={'content-length': str(len(self.FAKE_URL_DATA2))},  # noqa: E501
                      stream=True)
        return_code = download_dataset(Collection.RACE, lambda _: False)
        self.assertEqual(return_code, 2)

        data_dir = os.path.join(datamine_cache_dir(), self.FAKE_DATASET.name)
        self.assertEqual(
                open(os.path.join(data_dir, "1.txt"), "rt").read(),
                "First question"
        )
        self.assertEqual(
                open(os.path.join(data_dir, "2.txt"), "rt").read(),
                "Second question"
        )
        self.assertEqual(
                open(os.path.join(data_dir, "dir/3.txt"), "rt").read(),
                "Third question"
        )
        self.assertEqual(
                open(os.path.join(data_dir, "file.json"), "rt").read(),
                "This is a JSON file."
        )
コード例 #5
0
ファイル: loader.py プロジェクト: SebiSebi/DataMine
def CSQADataset(csqa_type):
    """
    TODO(sebisebi): add description
    """
    assert(isinstance(csqa_type, CSQAType))
    download_dataset(Collection.CSQA, check_shallow_integrity)
    all_ids = set()
    all_data = []
    with open(type_to_data_file(csqa_type), "rt") as f:
        for line in f:
            entry = json.loads(line)
            assert(len(entry) == 2 if csqa_type == CSQAType.TEST else 3)
            question_id = entry["id"]
            correct_answer = entry.get("answerKey", None)
            entry = entry["question"]
            assert(isinstance(question_id, string_types))
            if csqa_type != CSQAType.TEST:
                assert(correct_answer in ["A", "B", "C", "D", "E"])
            else:
                assert(correct_answer is None)
            assert(len(entry) == 3)

            question = entry["stem"]
            question_concept = entry["question_concept"]
            answers = [
                choice["text"]
                for choice in sorted(
                    entry["choices"],
                    key=lambda x: x["label"])
            ]
            assert(isinstance(question, string_types))
            assert(isinstance(question_concept, string_types))
            assert(len(answers) == 5)
            for answer in answers:
                assert(isinstance(answer, string_types))
            assert(question_id not in all_ids)
            all_ids.add(question_id)
            all_data.append({
                "id": question_id,
                "question": question,
                "answers": answers,
                "correct": correct_answer,
                "question_concept": question_concept
            })
    assert(len(all_ids) == len(all_data))
    df = pd.DataFrame(all_data)
    return df
コード例 #6
0
    def test_dataset_not_downloaded_if_locally_available(self, mock_fn):
        def fake_integrity_check(dataset_id):
            self.assertEqual(dataset_id, Collection.RACE)
            return True

        return_code = download_dataset(Collection.RACE, fake_integrity_check)
        mock_fn.assert_not_called()
        self.assertEqual(return_code, 1)
コード例 #7
0
def OBQAFacts():
    """
    Yields the 1326 core science facts from the OpenBook QA dataset.

    Examples:
        * wind causes erosion
        * wind is a renewable resource

    Returns a generator of facts (strings).
    """
    download_dataset(Collection.ALLEN_AI_OBQA, check_shallow_integrity)
    facts_file = os.path.join(OBQA_CACHE_DIR, "OpenBookQA-V1-Sep2018", "Data",
                              "Main", "openbook.txt")
    with open(facts_file, "rt") as f:
        for line in f:
            fact = line.strip(string.whitespace + "\"")
            if len(fact) > 0:
                yield fact
コード例 #8
0
def download():
    assert (len(sys.argv) >= 1)
    assert (sys.argv[0] == "data_mine download")

    if len(sys.argv) != 2:
        msg.error("Usage: python -m data_mine download <dataset_name>",
                  exits=1)  # noqa: E501
    dataset_name = sys.argv[1]
    if dataset_name not in set([x.name for x in Collection]):
        msg.error("Invalid dataset: {}".format(dataset_name))
        msg.info("Available datasets:")
        msg.info("\n".join(sorted([x.name for x in Collection])), exits=1)

    dataset_id = Collection.from_str(dataset_name)
    msg.info("Checking if {} is already downloaded ...".format(dataset_name))
    return_code = download_dataset(dataset_id, check_deep_integrity)
    if return_code == 1:
        msg.info("{} already available at: {}".format(
            dataset_name, os.path.join(datamine_cache_dir(), dataset_name)))
コード例 #9
0
def OBQADataset(obqa_type, with_retrieved_facts=False):
    """
    Loads an OpenBookQA dataset given the type (see the OBQAType enum).
    Any error during reading will generate an exception.

    Returns a Pandas DataFrame with 4 columns:
    * 'id': string
    * 'question': string
    * 'answers': list[string], length = 4
    * 'correct': oneof('A', 'B', 'C', D')

    If `with_retrieved_facts` is True then a new column is added with
    the name `retrieved_facts`. This column is an array with 4 dictionaries
    each describing the retrieved facts supporting the corresponding candidate
    answer. The retrieved facts dict for a given answer has the following keys:
    * 'context': string - a string with multiple facts from the OBQA "book".
    * 'token_based': list[string]: a list of facts extracted with Lucene.
    * 'vector_based': list[string]: a list of facts extracted with embeddings.
    The list of facts (both token and vector based) are sorted in decreasing
    order of the relevance score (the first fact in the list is the most
    relevant). We recommend the "context" field to be used. It is a carefully
    constructed string using the top 5 token facts, top 5 vector facts and
    then interleaving the remaining facts (until about 500 tokens made up
    context). Facts are concatenated using " . " as a separator.
    """
    assert (isinstance(obqa_type, OBQAType))
    download_dataset(Collection.ALLEN_AI_OBQA, check_shallow_integrity)
    all_data = []
    all_ids = set()
    retrieved_facts = None
    if with_retrieved_facts:
        retrieved_facts = json.load(
            open(os.path.join(OBQA_CACHE_DIR,
                              "extracted_facts.json")))  # noqa: E501
    with open(type_to_data_file(obqa_type), "rt") as f:
        for line in f:
            entry = json.loads(line)
            assert (len(entry) == 3)

            question_id = entry["id"]
            correct_answer = entry["answerKey"]
            assert (isinstance(question_id, string_types))
            assert (correct_answer in ["A", "B", "C", "D"])

            entry = entry["question"]
            assert (len(entry) == 2)
            question = entry["stem"]
            answers = [
                choice["text"] for choice in sorted(entry["choices"],
                                                    key=lambda x: x["label"])
            ]
            assert (isinstance(question, string_types))
            assert (len(answers) == 4)
            for answer in answers:
                assert (isinstance(answer, string_types))
            assert (question_id not in all_ids)
            all_ids.add(question_id)
            new_row = {
                "id": question_id,
                "question": question,
                "answers": answers,
                "correct": correct_answer
            }
            if with_retrieved_facts:
                queries = [question + " " + answer for answer in answers]
                facts = [retrieved_facts[query] for query in queries]
                assert (len(facts) == 4)
                for fact in facts:
                    assert (len(fact) == 3)
                    assert (isinstance(fact, dict))
                    assert ("context" in fact)
                    assert ("token_based" in fact)
                    assert ("vector_based" in fact)
                new_row["retrieved_facts"] = facts
            all_data.append(new_row)
    assert (len(all_data) == len(all_ids))
    df = pd.DataFrame(all_data)
    return df
コード例 #10
0
ファイル: loader.py プロジェクト: SebiSebi/DataMine
def HotpotQADataset(hotpot_qa_type):
    """
    Loads a HotpotQA dataset given the split (see the HotpotQAType esplit.
    Any error during reading will generate an exception.

    Returns a Pandas DataFrame with 8 columns:
    * 'id': string;
    * 'question': string;
    * 'answer': string (this is `None` for the test split);
    * 'gold_paragraphs': list[string];
    * 'supporting_facts': list;
    * 'context': list;
    * 'question_type': string, oneof('comparison', 'bridge', 'None');
    * 'question_level': string, oneof('easy', 'medium', 'hard', 'None').

    Please read the dataset's additional information page for a detailed
    explanation on the semantics of the fields above.
    """
    assert(isinstance(hotpot_qa_type, HotpotQAType))
    download_dataset(Collection.HOTPOT_QA, check_shallow_integrity)
    data = json.load(open(type_to_data_file(hotpot_qa_type), "rt"))
    assert(isinstance(data, list))
    processed_data = []
    all_ids = set()
    for entry in data:
        assert(isinstance(entry, dict))
        if hotpot_qa_type != HotpotQAType.TEST_FULLWIKI:
            assert(len(entry) == 7)
        else:
            assert(len(entry) == 3)  # _id, question, context

        # Extract fields.
        question_id = entry["_id"]
        question = entry["question"]
        answer = entry.get("answer", None)
        supporting_facts = entry.get("supporting_facts", [])
        context = entry["context"]
        question_type = entry.get("type", None)
        question_level = entry.get("level", None)

        # Validate fields.
        assert(isinstance(question_id, string_types))
        assert(isinstance(question, string_types))
        assert(isinstance(supporting_facts, list))
        assert(isinstance(context, list))
        if hotpot_qa_type != HotpotQAType.TEST_FULLWIKI:
            assert(isinstance(answer, string_types))
            assert(len(supporting_facts) > 0)
            assert(question_type in ["comparison", "bridge"])
            assert(question_level in ["easy", "medium", "hard"])
        else:
            assert(answer is None)
            assert(len(supporting_facts) == 0)
            assert(question_type is None)
            assert(question_level is None)
        # Note: most the the questions have 10 contexts.
        assert(len(context) <= 10)

        # Get the list of supporting sentences by joining the supporting
        # facts with the context by title. There can be duplicate titles
        # in the supporting facts.
        titles = [title for title, _ in supporting_facts]
        titles = list(more_itertools.unique_everseen(titles))
        title2contents = {title: sentences for title, sentences in context}
        assert(len(title2contents) == len(context))
        if hotpot_qa_type == HotpotQAType.DEV_FULLWIKI:
            titles = filter(lambda title: title in title2contents, titles)
        gold_paragraphs = [' '.join(title2contents[title]) for title in titles]
        for paragraph in gold_paragraphs:
            assert(isinstance(paragraph, string_types))
        if hotpot_qa_type == HotpotQAType.TRAIN:
            assert(len(gold_paragraphs) == 2)
        elif hotpot_qa_type == HotpotQAType.DEV_DISTRACTOR:
            assert(len(gold_paragraphs) == 2)
        elif hotpot_qa_type == HotpotQAType.DEV_FULLWIKI:
            assert(len(gold_paragraphs) <= 2)
        else:
            assert(hotpot_qa_type == HotpotQAType.TEST_FULLWIKI)
            assert(len(gold_paragraphs) == 0)

        assert(question_id not in all_ids)
        all_ids.add(question_id)
        processed_data.append({
            "id": question_id,
            "question": question,
            "answer": answer,
            "gold_paragraphs": gold_paragraphs,
            "supporting_facts": supporting_facts,
            "context": context,
            "question_type": question_type,
            "question_level": question_level
        })
    assert(len(processed_data) == len(data))
    assert(len(data) == len(all_ids))
    df = pd.DataFrame(processed_data)
    return df
コード例 #11
0
def CosmosQADataset(cosmos_qa_type):
    """
    Loads a Cosmos QA dataset given the split (see the CosmosQAType enum).
    Any error during reading will generate an exception.

    Returns a Pandas DataFrame with 5 columns:
    * 'id': string
    * 'question': string
    * 'context': string
    * 'answers': list[string], length = 4
    * 'correct': oneof('A', 'B', 'C', D') or None for the test split
    """
    assert (isinstance(cosmos_qa_type, CosmosQAType))
    download_dataset(Collection.COSMOS_QA, check_shallow_integrity)

    def extract_answers(entry):
        for i in range(0, 4):
            key = "answer{}".format(i)
            answer = entry[key]
            assert (isinstance(answer, string_types))
            yield answer
            del entry[key]

    all_ids = set()
    all_data = []
    with open(type_to_data_file(cosmos_qa_type), "rt") as f:
        for line in f:
            entry = json.loads(line)
            assert (isinstance(entry, dict))
            if cosmos_qa_type != CosmosQAType.TEST:
                assert (len(entry) == 8)
            else:
                assert (len(entry) == 7)

            # Extract data.
            question_id = entry["id"]
            question = entry["question"]
            context = entry["context"]
            answers = list(extract_answers(entry))
            label = entry.get("label", None)
            if label is not None:
                label = chr(ord('A') + int(label))

            # Validate data.
            assert (isinstance(question_id, string_types))
            assert (isinstance(question, string_types))
            assert (isinstance(context, string_types))
            assert (isinstance(answers, list) and len(answers) == 4)
            if cosmos_qa_type == CosmosQAType.TEST:
                assert (label is None)
            else:
                assert (label in ["A", "B", "C", "D"])

            assert (question_id not in all_ids)
            all_ids.add(question_id)
            all_data.append({
                "id": question_id,
                "question": question,
                "context": context,
                "answers": answers,
                "correct": label,
            })
    assert (len(all_data) == len(all_ids))
    df = pd.DataFrame(all_data)
    return df
コード例 #12
0
def RACEDataset(race_type):
    """
    Loads a RACE dataset given the type (see the RACEType enum).
    Any error during reading will generate an exception.

    Returns a Pandas DataFrame with 5 columns:
    * 'article': string
    * 'question': string
    * 'answers': list[string], length = 4
    * 'correct': oneof('A', 'B', 'C', D')
    * 'id': string

    The returned IDs are unique and have this format: `index`-`passage_id`.
    Examples: 1-middle1548.txt, 2-middle1548.txt, etc. The `passage_id` is
    frequently the name of the file. All the questions related to the same
    passage are grouped in the same file in the RACE dataset (convention).
    Because in each RACE file  there are multiple questions, the counter is
    necessary in order to guarantee that IDs are unique (the file name is
    not sufficient). We translate the `passage_id` into the `question_id`
    using the per-passage-question counter.
    """
    assert (isinstance(race_type, RACEType))
    download_dataset(Collection.RACE, check_shallow_integrity)
    dirpath = type_to_data_directory(race_type)
    all_data = []
    q_ids = {}
    for path in os.listdir(dirpath):
        assert (os.path.isfile(os.path.join(dirpath, path)))
        with open(os.path.join(dirpath, path), 'rt') as f:
            entry = json.load(f)
            """
            Each passage is a JSON file. The JSON file contains these fields:

            1. article: A string, which is the passage.
            2. questions: A string list. Each string is a query. We have two
                          types of questions. First one is an interrogative
                          sentence. Another one has a placeholder, which is
                          represented by _.
            3. options: A list of the options list. Each options list contains
                        4 strings, which are the candidate option.
            4. answers: A list contains the golden label of each query.
            5. id: Each passage has an id in this dataset. Note: the ids are
                   not unique in the question set! Questions in the same file
                   have the same id (the name of the file). This id is more of
                   a passage id than a question id.
            """
            assert (len(entry) == 5)
            assert (set(entry.keys()) == {
                "article", "questions", "options", "answers", "id"
            })
            article = entry["article"]
            questions = entry["questions"]
            options = entry["options"]
            answers = entry["answers"]
            q_id = entry["id"]
            assert (isinstance(article, string_types))
            assert (isinstance(questions, list))
            assert (isinstance(options, list))
            assert (isinstance(answers, list))
            assert (isinstance(q_id, string_types))
            assert (len(questions) == len(options))
            assert (len(questions) == len(answers))
            for question, option, answer in zip(questions, options, answers):
                assert (isinstance(question, string_types))
                assert (isinstance(option, list) and len(option) == 4)
                assert (isinstance(answer, string_types))
                assert (answer in ["A", "B", "C", "D"])
                all_data.append({
                    'article': article,
                    'question': question,
                    'answers': option,
                    'correct': answer,
                    'id': next_question_id(q_ids, q_id)
                })
    df = pd.DataFrame(all_data)
    return df
コード例 #13
0
ファイル: loader.py プロジェクト: SebiSebi/DataMine
def DROPDataset(drop_type):
    """
    TODO(sebisebi): add description
    """
    assert (isinstance(drop_type, DROPType))
    download_dataset(Collection.ALLEN_AI_DROP, check_shallow_integrity)

    def parse_answer(answer):
        """
        Answer format (sanitized of other unwanted fields):

        "answer": {
            "number": "3",
            "date": {
                "day": "",
                "month": "",
                "year": ""
            },
            "spans": [],
        }

        Returns the type and the answer as a string.

        The type can be:
            a) "number" (be aware that this can be integer or real).
            b) "date"
            c) "spans"
            d) None, some answer are completely empty.
        """
        assert (len(answer) == 3)
        assert (set(answer.keys()) == set(["number", "date", "spans"]))

        def is_number():
            return len(answer["number"]) > 0

        def is_date():
            date = answer["date"]
            if len(date) == 0:
                return False
            assert (set(date.keys()) == set(["day", "month", "year"]))
            return len("" + date["day"] + date["month"] + date["year"]) > 0

        def is_span():
            return len(answer["spans"]) > 0

        if is_number():
            assert (not is_date())
            assert (not is_span())
            float(answer["number"])
            return "number", str(answer["number"])

        if is_date():
            assert (not is_number())
            assert (not is_span())
            return "date", serialize_date(answer["date"])

        if is_span():
            assert (not is_number())
            assert (not is_date())
            return "spans", ", ".join(answer["spans"])

        return None, None

    # The Subject ID represents the context category. Examples include
    # history_4122, nfl_3073 or history_3259. It seems that all questions
    # target NFL or history subjects.
    all_query_ids = set()
    all_questions = []
    data = json.load(open(type_to_data_file(drop_type), "rt"))
    for subject_id in sorted(data.keys()):
        entry = data[subject_id]
        assert (len(entry) == 3)  # passage, qa_pairs and wiki_url
        passage = entry["passage"]
        assert (isinstance(passage, string_types))
        """
        {
            "question": "How many points were scored first?",
            "answer": {
                "number": "3",
                "date": {
                    "day": "",
                    "month": "",
                    "year": ""
                },
                "spans": [],

                "hit_id": "",  # Not useful, always empty when present.
                "worker_id": ""  # Not useful, always empty when present.
            },
            "query_id": "33f9f7bd-518b-45ae-86d5-c1475167d54f",
            "highlights": [],  # Always empty or missing.
            "question_type": [],  # Always empty or missing.
            "validated_answers": [],  # Always empty or missing.
            "expert_answers": []  # Always empty or missing.
        }
        """
        for qa_pair in entry["qa_pairs"]:
            unwanted_fields = [
                "highlights", "question_type", "validated_answers",
                "expert_answers", "workerid", "workerscore",
                "incorrect_options", "ai_answer"
            ]
            for unwanted_field in unwanted_fields:
                if unwanted_field in qa_pair:
                    del qa_pair[unwanted_field]
            assert (len(qa_pair) == 3)

            question = qa_pair["question"]
            answer = qa_pair["answer"]
            query_id = qa_pair["query_id"]
            assert (isinstance(question, string_types))
            assert (isinstance(answer, dict))
            assert (isinstance(query_id, string_types))

            # This answer has 2 correct answers. Manually remove the false one.
            if query_id == "daf712ed-3849-48a1-b9b5-f7d21b0c0ab7":
                answer["number"] = ""

            # This is a duplicate query.
            if query_id == "28553293-d719-441b-8f00-ce3dc6df5398":
                if query_id in all_query_ids:
                    continue

            # Sanitize the answer object.
            unwanted_fields = ["worker_id", "hit_id"]
            for unwanted_field in unwanted_fields:
                if unwanted_field in answer:
                    assert (len(answer[unwanted_field]) == 0)
                    del answer[unwanted_field]
            assert (len(answer) == 3)
            answer_type, parsed_answer = parse_answer(answer)
            if answer_type is None:
                continue
            assert (answer_type in ["number", "date", "spans"])
            assert (len(parsed_answer) >= 1)
            all_query_ids.add(str(query_id))
            all_questions.append({
                "query_id": query_id,
                "question": question,
                "passage": passage,
                "answer_type": answer_type,
                "parsed_answer": parsed_answer,
                "original_answer": answer
            })
    assert (len(all_questions) == len(all_query_ids))

    df = pd.DataFrame(all_questions)
    return df