Beispiel #1
0
def split_data(dataset_path, output_dir_path='split_data'):
    original = NewsQaDataset.load_combined(dataset_path)

    logging.info("Loading story ID's split.")
    train_story_ids = set(
        pd.read_csv(os.path.join(_dir_name,
                                 'train_story_ids.csv'))['story_id'].values)
    dev_story_ids = set(
        pd.read_csv(os.path.join(_dir_name,
                                 'dev_story_ids.csv'))['story_id'].values)
    test_story_ids = set(
        pd.read_csv(os.path.join(_dir_name,
                                 'test_story_ids.csv'))['story_id'].values)

    train_data = []
    dev_data = []
    test_data = []

    for row in tqdm(original.itertuples(),
                    total=len(original),
                    mininterval=2,
                    unit_scale=True,
                    unit=" questions",
                    desc="Splitting data"):
        story_id = row.story_id

        # Filter out when no answer was picked because these weren't used in the original paper.
        # FIXME Soon, if data was tokenized first, then it won't have answer_char_ranges, so we should check something else.
        # See the FIXME in the tokenizer for what field to check.
        answer_char_ranges = row.answer_char_ranges.split('|')
        none_count = answer_char_ranges.count('None')
        if none_count == len(answer_char_ranges):
            continue
        if story_id in train_story_ids:
            train_data.append(row)
        elif story_id in dev_story_ids:
            dev_data.append(row)
        elif story_id in test_story_ids:
            test_data.append(row)
        else:
            logging.warning("%s is not in train, dev, nor test", story_id)

    if not os.path.exists(output_dir_path):
        os.makedirs(output_dir_path)

    def _write_to_csv(data, path):
        logging.info("Writing %d rows to %s", len(data), path)
        pd.DataFrame(data=data).to_csv(path,
                                       columns=original.columns.values,
                                       index=False,
                                       encoding='utf-8')

    assert len(train_data) == 92549, "Incorrect amount of training data."
    assert len(dev_data) == 5166, "Incorrect amount of validation data."
    assert len(test_data) == 5126, "Incorrect amount of test data."

    logging.info("Writing split data to %s", output_dir_path)
    _write_to_csv(train_data, os.path.join(output_dir_path, 'train.csv'))
    _write_to_csv(dev_data, os.path.join(output_dir_path, 'dev.csv'))
    _write_to_csv(test_data, os.path.join(output_dir_path, 'test.csv'))
Beispiel #2
0
    def setUpClass(cls):
        dir_name = os.path.dirname(os.path.abspath(__file__))
        tokenized_data_path = os.path.join(dir_name, '..',
                                           'newsqa-data-tokenized-v1.csv')
        if not os.path.exists(tokenized_data_path):
            combined_data_path = os.path.join(
                dir_name, '../../../combined-newsqa-data-v1.csv')
            combined_data_path = os.path.abspath(combined_data_path)

            if not os.path.exists(combined_data_path):
                NewsQaDataset().dump(path=combined_data_path)

            tokenize(combined_data_path=combined_data_path,
                     output_path=tokenized_data_path)

        cls.dataset = NewsQaDataset.load_combined(tokenized_data_path)
Beispiel #3
0
def tokenize(cnn_stories='cnn_stories.tgz',
             csv_dataset='newsqa-data-v1.csv',
             combined_data_path='combined-newsqa-data-v1.csv',
             output_path='newsqa-data-tokenized-v1.csv'):
    newsqa_data = NewsQaDataset(cnn_stories,
                                csv_dataset,
                                combined_data_path=combined_data_path)
    dataset = newsqa_data.dataset

    dir_name = os.path.dirname(os.path.abspath(__file__))
    requirements = (dir_name, os.path.join(dir_name, 'stanford-postagger.jar'),
                    os.path.join(dir_name, 'slf4j-api.jar'))
    for req in requirements:
        if not os.path.exists(req):
            raise Exception(
                "Missing `%s`\n"
                "Please refer to the README in the root of the project regarding the JAR's required."
                % req)

    packed_filename = os.path.join(dir_name, csv_dataset + '.pck')
    unpacked_filename = os.path.join(dir_name, csv_dataset + '.tpck')

    logging.info("(1/3) - Packing data to `%s`.", packed_filename)
    with io.open(packed_filename, mode='w', encoding='utf-8') as writer:
        pack(dataset, writer)
    logging.info("(2/3) - Tokenizing packed file to `%s`.", unpacked_filename)
    classpath = os.pathsep.join(requirements)

    cmd = 'javac -classpath %s %s' % (
        classpath, os.path.join(dir_name, 'TokenizerSplitter.java'))
    logging.info("Running `%s`", cmd)
    exit_status = os.system(cmd)
    if exit_status:
        sys.exit(exit_status)

    cmd = 'java -classpath %s TokenizerSplitter %s > %s' % (
        classpath, packed_filename, unpacked_filename)
    logging.info("Running `%s`\nMaluuba: The warnings below are normal.", cmd)
    exit_status = os.system(cmd)
    if exit_status:
        sys.exit(exit_status)

    os.remove(packed_filename)

    logging.info("(3/3) - Unpacking tokenized file to `%s`", output_path)
    with io.open(unpacked_filename, mode='r', encoding='utf-8') as packed:
        unpack(dataset, packed, output_path)

    os.remove(unpacked_filename)
Beispiel #4
0
    def test_load_combined(self):
        dir_name = os.path.dirname(os.path.abspath(__file__))
        combined_data_path = os.path.join(dir_name, '..',
                                          'combined-newsqa-data-v1.csv')

        dataset = NewsQaDataset.load_combined(combined_data_path)

        row = dataset.iloc[0]
        self.assertEqual(
            './cnn/stories/42d01e187213e86f5fe617fe32e716ff7fa3afc4.story',
            row.story_id)
        self.assertEqual("What was the amount of children murdered?",
                         row.question)
        self.assertEqual('294:297|None|None', row['answer_char_ranges'])
        self.assertEqual(0.0, row['is_answer_absent'])
        self.assertEqual('0.0', row['is_question_bad'])
        self.assertEqual('{"none": 1, "294:297": 2}', row['validated_answers'])
        self.assertEqual("NEW DELHI, India (CNN) -- A high court in nort",
                         row.story_text[:46])
        self.assertEqual({"19 "}, _get_answers(row))
Beispiel #5
0
    def test_load_combined(self):
        dir_name = os.path.dirname(os.path.abspath(__file__))
        combined_data_path = os.path.join(
            dir_name, '../../../combined-newsqa-data-v1.csv')
        combined_data_path = os.path.abspath(combined_data_path)

        if not os.path.exists(combined_data_path):
            self.newsqa_dataset.dump(path=combined_data_path)

        dataset = NewsQaDataset.load_combined(combined_data_path)

        for original_row in tqdm(self.newsqa_dataset.dataset.itertuples(),
                                 desc="Comparing stories",
                                 total=len(self.newsqa_dataset.dataset),
                                 unit_scale=True,
                                 mininterval=2,
                                 unit=" rows"):
            expected = original_row.story_text
            actual = dataset.iloc[original_row.Index].story_text
            self.assertEqual(
                expected,
                actual,
                msg="Story texts at position %d are not equal."
                "\nExpected:\"%s\""
                "\n     Got:\"%s\"" %
                (original_row.Index, repr(expected), repr(actual)))

        row = dataset.iloc[0]
        self.assertEqual(
            './cnn/stories/42d01e187213e86f5fe617fe32e716ff7fa3afc4.story',
            row.story_id)
        self.assertEqual("What was the amount of children murdered?",
                         row.question)
        self.assertEqual('294:297|None|None', row['answer_char_ranges'])
        self.assertEqual(0.0, row['is_answer_absent'])
        self.assertEqual('0.0', row['is_question_bad'])
        self.assertEqual('{"none": 1, "294:297": 2}', row['validated_answers'])
        self.assertEqual("NEW DELHI, India (CNN) -- A high court in nort",
                         row.story_text[:46])
        self.assertEqual({"19 "}, _get_answers(row))
                                    'combined-newsqa-data-v1.csv')
default_output_dir = os.path.join(dir_name, 'split-data-nil')

parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path',
                    default=default_dataset_path,
                    help="The path to the dataset to split. Default: %s" %
                    default_dataset_path)
parser.add_argument(
    '--output_dir',
    default=default_output_dir,
    help="The path folder to put the split up data. Default: %s" %
    default_output_dir)
args = parser.parse_args()

original = NewsQaDataset.load_combined(args.dataset_path)

logging.info("Loading story ID's split.")
train_story_ids = set(
    pd.read_csv(os.path.join(dir_name,
                             'train_story_ids.csv'))['story_id'].values)
dev_story_ids = set(
    pd.read_csv(os.path.join(dir_name,
                             'dev_story_ids.csv'))['story_id'].values)
test_story_ids = set(
    pd.read_csv(os.path.join(dir_name,
                             'test_story_ids.csv'))['story_id'].values)

train_data = []
dev_data = []
test_data = []
Beispiel #7
0
 def setUpClass(cls):
     cls.newsqa_dataset = NewsQaDataset()
Beispiel #8
0
    # In case you're running this file from this folder.
    from data_processing import NewsQaDataset

if __name__ == "__main__":
    dir_name = os.path.dirname(os.path.abspath(__file__))

    parser = argparse.ArgumentParser()
    parser.add_argument('--cnn_stories_path',
                        default=os.path.join(dir_name, 'cnn_stories.tgz'),
                        help="The path to the CNN stories (cnn_stories.tgz).")
    parser.add_argument(
        '--dataset_path',
        default=os.path.join(dir_name, 'newsqa-data-v1.csv'),
        help="The path to the dataset with questions and answers.")
    args = parser.parse_args()

    newsqa_data = NewsQaDataset(args.cnn_stories_path, args.dataset_path)

    logger = logging.getLogger('newsqa')
    logger.setLevel(logging.INFO)

    # Dump the dataset to common formats.
    newsqa_data.dump(path='combined-newsqa-data-v1.json')
    newsqa_data.dump(path='combined-newsqa-data-v1.csv')

    tokenized_data_path = os.path.join(dir_name,
                                       'newsqa-data-tokenized-v1.csv')
    tokenize(output_path=tokenized_data_path)
    split_data(dataset_path=tokenized_data_path)
    simplify(output_dir_path='split_data')
Beispiel #9
0
import argparse
import itertools
import os

try:
    # Prefer a more specific path for when you run from the root of this repo
    # or if the root of the repo is in your path.
    from maluuba.newsqa.data_processing import NewsQaDataset
except:
    # In case you're running this file from this folder.
    from data_processing import NewsQaDataset

if __name__ == "__main__":
    dir_name = os.path.dirname(os.path.abspath(__file__))

    parser = argparse.ArgumentParser()
    parser.add_argument('--cnn_stories_path', default=os.path.join(dir_name, 'cnn_stories.tgz'),
                        help="The path to the CNN stories (cnn_stories.tgz).")
    parser.add_argument('--dataset_path', default=os.path.join(dir_name, 'newsqa-data-v1.csv'),
                        help="The path to the dataset with questions and answers.")
    args = parser.parse_args()

    newsqa_data = NewsQaDataset(args.cnn_stories_path, args.dataset_path)

    # Dump the dataset to one file.
    newsqa_data.dump(path='combined-newsqa-data-v1.csv')

    print("Some answers:")
    for _, row in itertools.islice(newsqa_data.get_questions_and_answers().iterrows(), 10):
        print(row)