Exemplo n.º 1
0
def create_wikipedia_redirect_pickle(redirect_csv, output_pickle):
    countries = {}
    with open(COUNTRY_LIST_PATH) as f:
        for line in f:
            k, v = line.split('\t')
            countries[k] = v.strip()

    db = QuestionDatabase()
    pages = set(db.all_answers().values())

    with open(redirect_csv) as redirect_f:
        redirects = {}
        n_total = 0
        n_selected = 0
        for row in csv.reader(redirect_f, quotechar='"', escapechar='\\'):
            n_total += 1
            source = row[0]
            target = row[1]
            if (target not in pages or source in countries
                    or target.startswith('WikiProject')
                    or target.endswith("_topics")
                    or target.endswith("_(overview)")):
                continue
            else:
                redirects[source] = target
                n_selected += 1

        log.info(
            'Filtered {} raw wikipedia redirects to {} matching redirects'.
            format(n_total, n_selected))

    with open(output_pickle, 'wb') as output_f:
        pickle.dump(redirects, output_f)
Exemplo n.º 2
0
def process_file(filename):
    with open(filename, 'r') as f:
        questions = defaultdict(set)
        for line in f:
            tokens = line.split()
            offset = 1 if int(tokens[0]) == -1 else 0
            ident = tokens[1 + offset].replace("'", "").split('_')
            q = int(ident[0])
            s = int(ident[1])
            t = int(ident[2])
            guess = tokens[3 + offset]
            questions[(q, s, t)].add(guess)
        qdb = QuestionDatabase('data/questions.db')
        answers = qdb.all_answers()
        recall = 0
        warn = 0
        for ident, guesses in questions.items():
            if len(guesses) < conf['n_guesses']:
                log.info("WARNING LOW GUESSES")
                log.info(
                    'Question {0} is missing guesses, only has {1}'.format(
                        ident, len(guesses)))
                warn += 1
            correct = answers[ident[0]].replace(' ', '_') in guesses
            recall += correct
        log.info('Recall: {0} Total: {1}'.format(recall / len(questions),
                                                 len(questions)))
        log.info('Warned lines: {0}'.format(warn))
Exemplo n.º 3
0
def load_data(pred_file: str, meta_file: str, q_db: QuestionDatabase) -> Sequence:
    preds = load_predictions(pred_file)
    metas = load_meta(meta_file)
    answers = q_db.all_answers()

    def create_line(group):
        question = group[0]
        elements = group[1]
        st_groups = (
            seq(elements).group_by(lambda x: (x[0].sentence, x[0].token)).sorted()
        )
        st_lines = []
        for st, v in st_groups:
            scored_guesses = (
                seq(v)
                .map(lambda x: ScoredGuess(x[0].score, x[1].guess))
                .sorted(reverse=True)
                .list()
            )
            st_lines.append(
                Line(
                    question,
                    st[0],
                    st[1],
                    scored_guesses[0].score > 0,
                    scored_guesses[0].guess,
                    answers[question],
                    scored_guesses,
                )
            )
        return question, st_lines

    def fix_missing_label(pm):
        prediction = pm[0]
        meta = pm[1]
        if (
            prediction.question is None
            or prediction.token is None
            or prediction.sentence is None
        ):
            log.info(
                "WARNING: Prediction malformed, fixing with meta line: {0}".format(
                    prediction
                )
            )
            prediction = Prediction(
                prediction.score, meta.question, meta.sentence, meta.token
            )
        assert meta.question == prediction.question
        assert meta.sentence == prediction.sentence
        assert meta.token == prediction.token
        return prediction, meta

    return (
        preds.zip(metas)
        .map(fix_missing_label)
        .group_by(lambda x: x[0].question)
        .map(create_line)
    )
Exemplo n.º 4
0
def create_output(path: str):
    df = read_dfs(path).cache()
    question_db = QuestionDatabase()
    answers = question_db.all_answers()
    for qnum in answers:
        answers[qnum] = format_guess(answers[qnum])

    sc = SparkContext.getOrCreate()  # type: SparkContext
    b_answers = sc.broadcast(answers)

    def generate_string(group):
        rows = group[1]
        result = ""
        feature_values = []
        meta = None
        qnum = None
        sentence = None
        token = None
        guess = None
        for name in FEATURE_NAMES:
            named_feature_list = list(
                filter(lambda r: r.feature_name == name, rows))
            if len(named_feature_list) != 1:
                raise ValueError(
                    'Encountered more than one row when there should be exactly one row'
                )
            named_feature = named_feature_list[0]
            if meta is None:
                qnum = named_feature.qnum
                sentence = named_feature.sentence
                token = named_feature.token
                guess = named_feature.guess
                meta = '{} {} {} {}'.format(qnum, named_feature.sentence,
                                            named_feature.token, guess)
            feature_values.append(named_feature.feature_value)
        assert '@' not in result, \
            '@ is a special character that is split on and not allowed in the feature line'

        vw_features = ' '.join(feature_values)
        if guess == b_answers.value[qnum]:
            vw_label = "1 '{}_{}_{} ".format(qnum, sentence, token)
        else:
            vw_label = "-1 '{}_{}_{} ".format(qnum, sentence, token)

        return vw_label + vw_features + '@' + meta

    for fold in VW_FOLDS:
        group_features(df.filter(df.fold == fold))\
            .map(generate_string)\
            .saveAsTextFile('output/vw_input/{0}.vw'.format(fold))
    sc.stop()
Exemplo n.º 5
0
def generate_questions():
    with open('data/100_possible_questions.pickle', 'rb') as f:
        qs = pickle.load(f)

    with open('data/qb_questions.txt', 'w') as f:
        for q in qs:
            f.write(q.flatten_text())
            f.write('\n')

    db = QuestionDatabase()
    answers = db.all_answers().values()
    with open('data/answers.txt', 'w') as f:
        for a in answers:
            f.write(a.lower().replace(' ', '_'))
            f.write('\n')
Exemplo n.º 6
0
def create_wikipedia_cache(dump_path):
    from qanta.spark import create_spark_session

    spark = create_spark_session()
    db = QuestionDatabase()
    answers = set(db.all_answers().values())
    b_answers = spark.sparkContext.broadcast(answers)
    # Paths used in spark need to be absolute and it needs to exist
    page_path = os.path.abspath(safe_path(WIKI_PAGE_PATH))

    def create_page(row):
        title = normalize_wikipedia_title(row.title)
        filter_answers = b_answers.value
        if title in filter_answers:
            page = WikipediaPage(title, row.text, None, None, row.id, row.url)
            write_page(page, page_path=page_path)

    spark.read.json(dump_path).rdd.foreach(create_page)
Exemplo n.º 7
0
    def create_report(self, directory: str):
        with open(os.path.join(directory, 'guesser_params.pickle'), 'rb') as f:
            params = pickle.load(f)
        dev_guesses = AbstractGuesser.load_guesses(directory,
                                                   folds=[c.GUESSER_DEV_FOLD])

        qdb = QuestionDatabase()
        questions = qdb.all_questions()

        # Compute recall and accuracy
        dev_recall = compute_fold_recall(dev_guesses, questions)
        dev_questions = {
            qnum: q
            for qnum, q in questions.items() if q.fold == c.GUESSER_DEV_FOLD
        }
        dev_recall_stats = compute_recall_at_positions(dev_recall)
        dev_summary_accuracy = compute_summary_accuracy(
            dev_questions, dev_recall_stats)
        dev_summary_recall = compute_summary_recall(dev_questions,
                                                    dev_recall_stats)

        accuracy_plot('/tmp/dev_accuracy.png', dev_summary_accuracy,
                      'Guesser Dev')
        recall_plot('/tmp/dev_recall.png', dev_questions, dev_summary_recall,
                    'Guesser Dev')

        # Obtain metrics on number of answerable questions based on the dataset requested
        all_answers = {g for g in qdb.all_answers().values()}
        all_questions = list(qdb.all_questions().values())
        answer_lookup = {
            qnum: guess
            for qnum, guess in qdb.all_answers().items()
        }
        dataset = self.qb_dataset()
        training_data = dataset.training_data()

        min_n_answers = {g for g in training_data[1]}

        train_questions = [
            q for q in all_questions if q.fold == c.GUESSER_TRAIN_FOLD
        ]
        train_answers = {q.page for q in train_questions}

        dev_questions = [
            q for q in all_questions if q.fold == c.GUESSER_DEV_FOLD
        ]
        dev_answers = {q.page for q in dev_questions}

        min_n_train_questions = [
            q for q in train_questions if q.page in min_n_answers
        ]

        all_common_train_dev = train_answers.intersection(dev_answers)
        min_common_train_dev = min_n_answers.intersection(dev_answers)

        all_train_answerable_questions = [
            q for q in train_questions if q.page in train_answers
        ]
        all_dev_answerable_questions = [
            q for q in dev_questions if q.page in train_answers
        ]

        min_train_answerable_questions = [
            q for q in train_questions if q.page in min_n_answers
        ]
        min_dev_answerable_questions = [
            q for q in dev_questions if q.page in min_n_answers
        ]

        # The next section of code generates the percent of questions correct by the number
        # of training examples.
        Row = namedtuple('Row', [
            'fold', 'guess', 'guesser', 'qnum', 'score', 'sentence', 'token',
            'correct', 'answerable_1', 'answerable_2', 'n_examples'
        ])

        train_example_count_lookup = seq(train_questions) \
            .group_by(lambda q: q.page) \
            .smap(lambda page, group: (page, len(group))) \
            .dict()

        def guess_to_row(*args):
            guess = args[1]
            qnum = args[3]
            answer = answer_lookup[qnum]

            return Row(
                *args, answer == guess, answer in train_answers, answer
                in min_n_answers, train_example_count_lookup[answer]
                if answer in train_example_count_lookup else 0)

        dev_data = seq(dev_guesses) \
            .smap(guess_to_row) \
            .group_by(lambda r: (r.qnum, r.sentence)) \
            .smap(lambda key, group: seq(group).max_by(lambda q: q.sentence)) \
            .to_pandas(columns=Row._fields)
        dev_data['correct_int'] = dev_data['correct'].astype(int)
        dev_data['ones'] = 1
        dev_counts = dev_data\
            .groupby('n_examples')\
            .agg({'correct_int': np.mean, 'ones': np.sum})\
            .reset_index()
        correct_by_n_count_plot('/tmp/dev_correct_by_count.png', dev_counts,
                                'Guesser Dev')
        n_train_vs_fold_plot('/tmp/n_train_vs_dev.png', dev_counts,
                             'Guesser Dev')

        with open(os.path.join(directory, 'guesser_report.pickle'), 'wb') as f:
            pickle.dump(
                {
                    'dev_accuracy': dev_summary_accuracy,
                    'guesser_name': self.display_name(),
                    'guesser_params': params
                }, f)

        output = safe_path(os.path.join(directory, 'guesser_report.pdf'))
        report = ReportGenerator('guesser.md')
        report.create(
            {
                'dev_recall_plot':
                '/tmp/dev_recall.png',
                'dev_accuracy_plot':
                '/tmp/dev_accuracy.png',
                'dev_accuracy':
                dev_summary_accuracy,
                'guesser_name':
                self.display_name(),
                'guesser_params':
                params,
                'n_answers_all_folds':
                len(all_answers),
                'n_total_train_questions':
                len(train_questions),
                'n_train_questions':
                len(min_n_train_questions),
                'n_dev_questions':
                len(dev_questions),
                'n_total_train_answers':
                len(train_answers),
                'n_train_answers':
                len(min_n_answers),
                'n_dev_answers':
                len(dev_answers),
                'all_n_common_train_dev':
                len(all_common_train_dev),
                'all_p_common_train_dev':
                len(all_common_train_dev) / max(1, len(dev_answers)),
                'min_n_common_train_dev':
                len(min_common_train_dev),
                'min_p_common_train_dev':
                len(min_common_train_dev) / max(1, len(dev_answers)),
                'all_n_answerable_train':
                len(all_train_answerable_questions),
                'all_p_answerable_train':
                len(all_train_answerable_questions) / len(train_questions),
                'all_n_answerable_dev':
                len(all_dev_answerable_questions),
                'all_p_answerable_dev':
                len(all_dev_answerable_questions) / len(dev_questions),
                'min_n_answerable_train':
                len(min_train_answerable_questions),
                'min_p_answerable_train':
                len(min_train_answerable_questions) / len(train_questions),
                'min_n_answerable_dev':
                len(min_dev_answerable_questions),
                'min_p_answerable_dev':
                len(min_dev_answerable_questions) / len(dev_questions),
                'dev_correct_by_count_plot':
                '/tmp/dev_correct_by_count.png',
                'n_train_vs_dev_plot':
                '/tmp/n_train_vs_dev.png',
            }, output)