示例#1
0
    def f_eval(task: Task) -> List[Row]:
        score_map = b_guesser_score_map.value.scores()
        df = task.guess_df
        result = []
        if len(df) > 0:
            # Refer to code in evaluate_feature_question for explanation why this is safe
            first_row = df.iloc[0]
            qnum = int(first_row.qnum)
            sentence = int(first_row.sentence)
            token = int(first_row.token)
            fold = first_row.fold
            for guess in df.guess:
                vw_features = []
                key = (qnum, sentence, token, guess)
                vw_features.append(format_guess(guess))
                for guesser in score_map:
                    if key in score_map[guesser]:
                        score = score_map[guesser][key]
                        feature = '{guesser}_score:{score} {guesser}_found:1'.format(
                            guesser=guesser, score=score)
                        vw_features.append(feature)
                    else:
                        vw_features.append('{}_found:-1'.format(guesser))
                f_value = '|guessers ' + ' '.join(vw_features)
                row = Row(fold, qnum, sentence, token, guess, 'guessers',
                          f_value)
                result.append(row)

        return result
示例#2
0
def preprocess_titles():
    # stop_words = set(stopwords.words('english'))
    titles_file = open('data/titles-sorted.txt')
    db = QuestionDatabase()
    pages = {format_guess(page) for page in db.questions_with_pages().keys()}
    with open('data/processed-titles-sorted.txt', 'w') as f:
        for line in titles_file:
            page = format_guess(line.strip().lower())
            # if len(page) > 2 and re.match(r"^[a-zA-Z0-9_()']+$", page)\
            #         and page not in stop_words and page[0].isalnum():
            if page in pages:
                f.write(line.strip().lower())
            else:
                f.write('@')
            f.write('\n')
    titles_file.close()
示例#3
0
def question_recall(guesses, qst, question_lookup):
    qnum, sentence, token = qst
    answer = format_guess(question_lookup[qnum].page)
    sorted_guesses = sorted(guesses, reverse=True, key=lambda g: g.score)
    for i, guess_row in enumerate(sorted_guesses, 1):
        if answer == guess_row.guess:
            return qnum, sentence, token, i
    return qnum, sentence, token, None
示例#4
0
    def __getitem__(self, key: str):
        key = format_guess(key)
        if key in self.cache:
            return self.cache[key]

        if "/" in key:
            filename = "%s/%s" % (self.path, key.replace("/", "---"))
        else:
            filename = "%s/%s" % (self.path, key)
        page = None
        if os.path.exists(filename):
            try:
                page = pickle.load(open(filename, 'rb'))
            except pickle.UnpicklingError:
                page = None
            except AttributeError:
                log.info("Error loading %s" % key)
                page = None
            except ImportError:
                log.info("Error importing %s" % key)
                page = None

        if page is None:
            if key in self.countries:
                raw = [
                    self.load_page("%s%s" % (x, self.countries[key]))
                    for x in COUNTRY_SUB
                ]
                raw.append(self.load_page(key))
                log.info("%s is a country!" % key)
            else:
                raw = [self.load_page(key)]

            raw = [x for x in raw if x is not None]
            if raw:
                if len(raw) > 1:
                    log.info("%i pages for %s" % (len(raw), key))
                page = WikipediaPage(
                    "\n".join(x.content for x in raw),
                    seq(raw).map(lambda x: x.links).flatten().list(),
                    seq(raw).map(lambda x: x.categories).flatten().list())

                log.info("Writing file to %s" % filename)
                pickle.dump(page,
                            open(filename, 'wb'),
                            protocol=pickle.HIGHEST_PROTOCOL)
            else:
                log.info("Dummy page for %s" % key)
                page = WikipediaPage()
                if self.write_dummy:
                    pickle.dump(page,
                                open(filename, 'wb'),
                                protocol=pickle.HIGHEST_PROTOCOL)

        self.cache[key] = page
        return page
示例#5
0
 def train(self, training_data):
     documents = {}
     for sentences, ans in zip(training_data[0], training_data[1]):
         page = format_guess(ans)
         paragraph = ' '.join(sentences)
         if page in documents:
             documents[page] += ' ' + paragraph
         else:
             documents[page] = paragraph
     ElasticSearchIndex.build(documents)
示例#6
0
 def train(self, training_data: TrainingData) -> None:
     documents = {}
     for sentence, ans in zip(training_data[0], training_data[1]):
         page = format_guess(ans)
         paragraph = ' '.join(sentence)
         if page in documents:
             documents[page] += ' ' + paragraph
         else:
             documents[page] = paragraph
     WhooshWikiIndex.build(documents, index_path=WHOOSH_WIKI_INDEX_PATH)
示例#7
0
    def initialize_cache(path):
        """
        This function iterates over all pages and accessing them in the cache. This forces a
        prefetch of all wiki pages
        """
        db = QuestionDatabase(QB_QUESTION_DB)
        pages = db.questions_with_pages()
        cw = CachedWikipedia(path)
        pool = Pool()

        input_data = [(format_guess(title), cw) for title in pages.keys()]
        pool.starmap(access_page, input_data)
示例#8
0
 def run(self):
     db = QuestionDatabase(QB_QUESTION_DB)
     questions = db.all_questions()
     with open(safe_path(EXPO_QUESTIONS), 'w', newline='') as f:
         f.write('id,answer,sent,text\n')
         writer = csv.writer(f, delimiter=',')
         for q in questions.values():
             if q.fold != 'test':
                 continue
             max_sent = max(q.text.keys())
             for i in range(max_sent + 1):
                 writer.writerow(
                     [q.qnum, format_guess(q.page), i, q.text[i]])
示例#9
0
def create_output(path: str):
    df = read_dfs(path).cache()
    question_db = QuestionDatabase()
    answers = question_db.all_answers()
    for qnum in answers:
        answers[qnum] = format_guess(answers[qnum])

    sc = SparkContext.getOrCreate()  # type: SparkContext
    b_answers = sc.broadcast(answers)

    def generate_string(group):
        rows = group[1]
        result = ""
        feature_values = []
        meta = None
        qnum = None
        sentence = None
        token = None
        guess = None
        for name in FEATURE_NAMES:
            named_feature_list = list(
                filter(lambda r: r.feature_name == name, rows))
            if len(named_feature_list) != 1:
                raise ValueError(
                    'Encountered more than one row when there should be exactly one row'
                )
            named_feature = named_feature_list[0]
            if meta is None:
                qnum = named_feature.qnum
                sentence = named_feature.sentence
                token = named_feature.token
                guess = named_feature.guess
                meta = '{} {} {} {}'.format(qnum, named_feature.sentence,
                                            named_feature.token, guess)
            feature_values.append(named_feature.feature_value)
        assert '@' not in result, \
            '@ is a special character that is split on and not allowed in the feature line'

        vw_features = ' '.join(feature_values)
        if guess == b_answers.value[qnum]:
            vw_label = "1 '{}_{}_{} ".format(qnum, sentence, token)
        else:
            vw_label = "-1 '{}_{}_{} ".format(qnum, sentence, token)

        return vw_label + vw_features + '@' + meta

    for fold in VW_FOLDS:
        group_features(df.filter(df.fold == fold))\
            .map(generate_string)\
            .saveAsTextFile('output/vw_input/{0}.vw'.format(fold))
    sc.stop()
示例#10
0
 def create_line(group):
     question = group[0]
     elements = group[1]
     st_groups = seq(elements).group_by(
         lambda x: (x[0].sentence, x[0].token)).sorted()
     st_lines = []
     for st, v in st_groups:
         scored_guesses = seq(v)\
             .map(lambda x: ScoredGuess(x[0].score, x[1].guess)).sorted(reverse=True).list()
         st_lines.append(
             Line(question, st[0], st[1],
                  scored_guesses[0].score > 0, scored_guesses[0].guess,
                  format_guess(answers[question]), scored_guesses))
     return question, st_lines
示例#11
0
文件: lm_wrapper.py 项目: xxlatgh/qb
 def normalize_title(corpus, title):
     norm_title = corpus + format_guess(title)
     return norm_title
示例#12
0
    def create_report(self, directory: str):
        with open(os.path.join(directory, 'guesser_params.pickle'), 'rb') as f:
            params = pickle.load(f)
        all_guesses = AbstractGuesser.load_guesses(directory)
        dev_guesses = all_guesses[all_guesses.fold == 'dev']
        test_guesses = all_guesses[all_guesses.fold == 'test']

        qdb = QuestionDatabase()
        questions = qdb.all_questions()

        # Compute recall and accuracy
        dev_recall = compute_fold_recall(dev_guesses, questions)
        test_recall = compute_fold_recall(test_guesses, questions)

        dev_questions = {qnum: q for qnum, q in questions.items() if q.fold == 'dev'}
        test_questions = {qnum: q for qnum, q in questions.items() if q.fold == 'test'}

        dev_recall_stats = compute_recall_at_positions(dev_recall)
        test_recall_stats = compute_recall_at_positions(test_recall)

        dev_summary_accuracy = compute_summary_accuracy(dev_questions, dev_recall_stats)
        test_summary_accuracy = compute_summary_accuracy(test_questions, test_recall_stats)

        dev_summary_recall = compute_summary_recall(dev_questions, dev_recall_stats)
        test_summary_recall = compute_summary_recall(test_questions, test_recall_stats)

        accuracy_plot('/tmp/dev_accuracy.png', dev_summary_accuracy, 'Dev')
        accuracy_plot('/tmp/test_accuracy.png', test_summary_accuracy, 'Test')
        recall_plot('/tmp/dev_recall.png', dev_questions, dev_summary_recall, 'Dev')
        recall_plot('/tmp/test_recall.png', test_questions, test_summary_recall, 'Test')

        # Obtain metrics on number of answerable questions based on the dataset requested
        all_answers = {format_guess(g) for g in qdb.all_answers().values()}
        all_questions = list(qdb.all_questions().values())
        answer_lookup = {qnum: format_guess(guess) for qnum, guess in qdb.all_answers().items()}
        dataset = self.qb_dataset()
        training_data = dataset.training_data()

        min_n_answers = {format_guess(g) for g in training_data[1]}

        train_questions = [q for q in all_questions if q.fold == 'train']
        train_answers = {format_guess(q.page) for q in train_questions}

        dev_questions = [q for q in all_questions if q.fold == 'dev']
        dev_answers = {format_guess(q.page) for q in dev_questions}

        test_questions = [q for q in all_questions if q.fold == 'test']
        test_answers = [format_guess(q.page) for q in test_questions]

        min_n_train_questions = [q for q in train_questions if
                                 format_guess(q.page) in min_n_answers]

        all_common_train_dev = train_answers.intersection(dev_answers)
        all_common_train_test = train_answers.intersection(test_answers)

        min_common_train_dev = min_n_answers.intersection(dev_answers)
        min_common_train_test = min_n_answers.intersection(test_answers)

        all_train_answerable_questions = [q for q in train_questions
                                          if format_guess(q.page) in train_answers]
        all_dev_answerable_questions = [q for q in dev_questions
                                        if format_guess(q.page) in train_answers]
        all_test_answerable_questions = [q for q in test_questions
                                         if format_guess(q.page) in train_answers]

        min_train_answerable_questions = [q for q in train_questions
                                          if format_guess(q.page) in min_n_answers]
        min_dev_answerable_questions = [q for q in dev_questions
                                        if format_guess(q.page) in min_n_answers]
        min_test_answerable_questions = [q for q in test_questions
                                         if format_guess(q.page) in min_n_answers]

        # The next section of code generates the percent of questions correct by the number
        # of training examples.
        Row = namedtuple('Row', [
            'fold', 'guess', 'guesser',
            'qnum', 'score', 'sentence', 'token',
            'correct', 'answerable_1', 'answerable_2',
            'n_examples'
        ])

        train_example_count_lookup = seq(train_questions) \
            .group_by(lambda q: format_guess(q.page)) \
            .smap(lambda page, group: (page, len(group))) \
            .dict()

        def guess_to_row(*args):
            guess = args[1]
            qnum = args[3]
            answer = answer_lookup[qnum]

            return Row(
                *args,
                answer == guess,
                answer in train_answers,
                answer in min_n_answers,
                train_example_count_lookup[answer] if answer in train_example_count_lookup else 0
            )

        dev_data = seq(dev_guesses) \
            .smap(guess_to_row) \
            .group_by(lambda r: (r.qnum, r.sentence)) \
            .smap(lambda key, group: seq(group).max_by(lambda q: q.sentence)) \
            .to_pandas(columns=Row._fields)
        dev_data['correct_int'] = dev_data['correct'].astype(int)
        dev_data['ones'] = 1
        dev_counts = dev_data\
            .groupby('n_examples')\
            .agg({'correct_int': np.mean, 'ones': np.sum})\
            .reset_index()
        correct_by_n_count_plot('/tmp/dev_correct_by_count.png', dev_counts, 'dev')
        n_train_vs_fold_plot('/tmp/n_train_vs_dev.png', dev_counts, 'dev')

        test_data = seq(test_guesses) \
            .smap(guess_to_row) \
            .group_by(lambda r: (r.qnum, r.sentence)) \
            .smap(lambda key, group: seq(group).max_by(lambda q: q.sentence)) \
            .to_pandas(columns=Row._fields)
        test_data['correct_int'] = test_data['correct'].astype(int)
        test_data['ones'] = 1
        test_counts = dev_data \
            .groupby('n_examples') \
            .agg({'correct_int': np.mean, 'ones': np.sum}) \
            .reset_index()
        correct_by_n_count_plot('/tmp/test_correct_by_count.png', test_counts, 'test')
        n_train_vs_fold_plot('/tmp/n_train_vs_test.png', test_counts, 'test')

        report = ReportGenerator({
            'dev_recall_plot': '/tmp/dev_recall.png',
            'test_recall_plot': '/tmp/test_recall.png',
            'dev_accuracy_plot': '/tmp/dev_accuracy.png',
            'test_accuracy_plot': '/tmp/test_accuracy.png',
            'dev_accuracy': dev_summary_accuracy,
            'test_accuracy': test_summary_accuracy,
            'guesser_name': self.display_name(),
            'guesser_params': pformat(params),
            'n_answers_all_folds': len(all_answers),
            'n_total_train_questions': len(train_questions),
            'min_class_examples': dataset.min_class_examples,
            'n_train_questions': len(min_n_train_questions),
            'n_dev_questions': len(dev_questions),
            'n_test_questions': len(test_questions),
            'n_total_train_answers': len(train_answers),
            'n_train_answers': len(min_n_answers),
            'n_dev_answers': len(dev_answers),
            'n_test_answers': len(test_answers),
            'all_n_common_train_dev': len(all_common_train_dev),
            'all_n_common_train_test': len(all_common_train_test),
            'all_p_common_train_dev': len(all_common_train_dev) / max(1, len(dev_answers)),
            'all_p_common_train_test': len(all_common_train_test) / max(1, len(test_answers)),
            'min_n_common_train_dev': len(min_common_train_dev),
            'min_n_common_train_test': len(min_common_train_test),
            'min_p_common_train_dev': len(min_common_train_dev) / max(1, len(dev_answers)),
            'min_p_common_train_test': len(min_common_train_test) / max(1, len(test_answers)),
            'all_n_answerable_train': len(all_train_answerable_questions),
            'all_p_answerable_train': len(all_train_answerable_questions) / len(train_questions),
            'all_n_answerable_dev': len(all_dev_answerable_questions),
            'all_p_answerable_dev': len(all_dev_answerable_questions) / len(dev_questions),
            'all_n_answerable_test': len(all_test_answerable_questions),
            'all_p_answerable_test': len(all_test_answerable_questions) / len(test_questions),
            'min_n_answerable_train': len(min_train_answerable_questions),
            'min_p_answerable_train': len(min_train_answerable_questions) / len(train_questions),
            'min_n_answerable_dev': len(min_dev_answerable_questions),
            'min_p_answerable_dev': len(min_dev_answerable_questions) / len(dev_questions),
            'min_n_answerable_test': len(min_test_answerable_questions),
            'min_p_answerable_test': len(min_test_answerable_questions) / len(test_questions),
            'dev_correct_by_count_plot': '/tmp/dev_correct_by_count.png',
            'test_correct_by_count_plot': '/tmp/test_correct_by_count.png',
            'n_train_vs_dev_plot': '/tmp/n_train_vs_dev.png',
            'n_train_vs_test_plot': '/tmp/n_train_vs_test.png'
        }, 'guesser.md')
        output = safe_path(os.path.join(directory, 'guesser_report.pdf'))
        report.create(output)
        with open(os.path.join(directory, 'guesser_report.pickle'), 'wb') as f:
            pickle.dump({
                'dev_accuracy': dev_summary_accuracy,
                'test_accuracy': test_summary_accuracy,
                'guesser_name': self.display_name(),
                'guesser_params': params
            }, f)