Exemple #1
0
def split_data(dataset_path, output_dir_path='split_data'):
    original = NewsQaDataset.load_combined(dataset_path)

    logging.info("Loading story ID's split.")
    train_story_ids = set(
        pd.read_csv(os.path.join(_dir_name,
                                 'train_story_ids.csv'))['story_id'].values)
    dev_story_ids = set(
        pd.read_csv(os.path.join(_dir_name,
                                 'dev_story_ids.csv'))['story_id'].values)
    test_story_ids = set(
        pd.read_csv(os.path.join(_dir_name,
                                 'test_story_ids.csv'))['story_id'].values)

    train_data = []
    dev_data = []
    test_data = []

    for row in tqdm(original.itertuples(),
                    total=len(original),
                    mininterval=2,
                    unit_scale=True,
                    unit=" questions",
                    desc="Splitting data"):
        story_id = row.story_id

        # Filter out when no answer was picked because these weren't used in the original paper.
        # FIXME Soon, if data was tokenized first, then it won't have answer_char_ranges, so we should check something else.
        # See the FIXME in the tokenizer for what field to check.
        answer_char_ranges = row.answer_char_ranges.split('|')
        none_count = answer_char_ranges.count('None')
        if none_count == len(answer_char_ranges):
            continue
        if story_id in train_story_ids:
            train_data.append(row)
        elif story_id in dev_story_ids:
            dev_data.append(row)
        elif story_id in test_story_ids:
            test_data.append(row)
        else:
            logging.warning("%s is not in train, dev, nor test", story_id)

    if not os.path.exists(output_dir_path):
        os.makedirs(output_dir_path)

    def _write_to_csv(data, path):
        logging.info("Writing %d rows to %s", len(data), path)
        pd.DataFrame(data=data).to_csv(path,
                                       columns=original.columns.values,
                                       index=False,
                                       encoding='utf-8')

    assert len(train_data) == 92549, "Incorrect amount of training data."
    assert len(dev_data) == 5166, "Incorrect amount of validation data."
    assert len(test_data) == 5126, "Incorrect amount of test data."

    logging.info("Writing split data to %s", output_dir_path)
    _write_to_csv(train_data, os.path.join(output_dir_path, 'train.csv'))
    _write_to_csv(dev_data, os.path.join(output_dir_path, 'dev.csv'))
    _write_to_csv(test_data, os.path.join(output_dir_path, 'test.csv'))
Exemple #2
0
    def setUpClass(cls):
        dir_name = os.path.dirname(os.path.abspath(__file__))
        tokenized_data_path = os.path.join(dir_name, '..',
                                           'newsqa-data-tokenized-v1.csv')
        if not os.path.exists(tokenized_data_path):
            combined_data_path = os.path.join(dir_name, '..',
                                              'combined-newsqa-data-v1.csv')
            tokenize(combined_data_path=combined_data_path,
                     output_path=tokenized_data_path)

        cls.dataset = NewsQaDataset.load_combined(tokenized_data_path)
Exemple #3
0
    def test_load_combined(self):
        dir_name = os.path.dirname(os.path.abspath(__file__))
        combined_data_path = os.path.join(dir_name, '..',
                                          'combined-newsqa-data-v1.csv')

        dataset = NewsQaDataset.load_combined(combined_data_path)

        row = dataset.iloc[0]
        self.assertEqual(
            './cnn/stories/42d01e187213e86f5fe617fe32e716ff7fa3afc4.story',
            row.story_id)
        self.assertEqual("What was the amount of children murdered?",
                         row.question)
        self.assertEqual('294:297|None|None', row['answer_char_ranges'])
        self.assertEqual(0.0, row['is_answer_absent'])
        self.assertEqual('0.0', row['is_question_bad'])
        self.assertEqual('{"none": 1, "294:297": 2}', row['validated_answers'])
        self.assertEqual("NEW DELHI, India (CNN) -- A high court in nort",
                         row.story_text[:46])
        self.assertEqual({"19 "}, _get_answers(row))
Exemple #4
0
    def test_load_combined(self):
        dir_name = os.path.dirname(os.path.abspath(__file__))
        combined_data_path = os.path.join(
            dir_name, '../../../combined-newsqa-data-v1.csv')
        combined_data_path = os.path.abspath(combined_data_path)

        if not os.path.exists(combined_data_path):
            self.newsqa_dataset.dump(path=combined_data_path)

        dataset = NewsQaDataset.load_combined(combined_data_path)

        for original_row in tqdm(self.newsqa_dataset.dataset.itertuples(),
                                 desc="Comparing stories",
                                 total=len(self.newsqa_dataset.dataset),
                                 unit_scale=True,
                                 mininterval=2,
                                 unit=" rows"):
            expected = original_row.story_text
            actual = dataset.iloc[original_row.Index].story_text
            self.assertEqual(
                expected,
                actual,
                msg="Story texts at position %d are not equal."
                "\nExpected:\"%s\""
                "\n     Got:\"%s\"" %
                (original_row.Index, repr(expected), repr(actual)))

        row = dataset.iloc[0]
        self.assertEqual(
            './cnn/stories/42d01e187213e86f5fe617fe32e716ff7fa3afc4.story',
            row.story_id)
        self.assertEqual("What was the amount of children murdered?",
                         row.question)
        self.assertEqual('294:297|None|None', row['answer_char_ranges'])
        self.assertEqual(0.0, row['is_answer_absent'])
        self.assertEqual('0.0', row['is_question_bad'])
        self.assertEqual('{"none": 1, "294:297": 2}', row['validated_answers'])
        self.assertEqual("NEW DELHI, India (CNN) -- A high court in nort",
                         row.story_text[:46])
        self.assertEqual({"19 "}, _get_answers(row))
                                    'combined-newsqa-data-v1.csv')
default_output_dir = os.path.join(dir_name, 'split-data-nil')

parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path',
                    default=default_dataset_path,
                    help="The path to the dataset to split. Default: %s" %
                    default_dataset_path)
parser.add_argument(
    '--output_dir',
    default=default_output_dir,
    help="The path folder to put the split up data. Default: %s" %
    default_output_dir)
args = parser.parse_args()

original = NewsQaDataset.load_combined(args.dataset_path)

logging.info("Loading story ID's split.")
train_story_ids = set(
    pd.read_csv(os.path.join(dir_name,
                             'train_story_ids.csv'))['story_id'].values)
dev_story_ids = set(
    pd.read_csv(os.path.join(dir_name,
                             'dev_story_ids.csv'))['story_id'].values)
test_story_ids = set(
    pd.read_csv(os.path.join(dir_name,
                             'test_story_ids.csv'))['story_id'].values)

train_data = []
dev_data = []
test_data = []