예제 #1
0
파일: trainer.py 프로젝트: sjliu0920/MUPPET
def resume_training_with(data: TrainingData,
                         out: ModelDir,
                         train_params: TrainParams,
                         evaluators: List[Evaluator],
                         notes: str = None,
                         dry_run: bool = False,
                         start_eval=False):
    """ Resume training an existing model with the specified parameters """
    with open(join(out.dir, "model.pkl"), "rb") as f:
        model = pickle.load(f)
    latest = out.get_latest_checkpoint()
    if latest is None:
        raise ValueError("No checkpoint to resume from found in " +
                         out.save_dir)
    print(f"Loaded checkpoint from {out.save_dir}")

    _train(model,
           data,
           latest,
           None,
           False,
           train_params,
           evaluators,
           out,
           notes,
           dry_run,
           start_eval=start_eval)
예제 #2
0
파일: trainer.py 프로젝트: sjliu0920/MUPPET
def start_training_with_params(out: ModelDir,
                               notes: str = None,
                               dry_run=False,
                               start_eval=False):
    """ Train a model with existing parameters etc """

    train_params = out.get_last_train_params()
    model = out.get_model()

    train_data = train_params["data"]

    evaluators = train_params["evaluators"]
    params = train_params["train_params"]
    params.num_epochs = 24 * 3

    _train(model, train_data, None, None, False, params, evaluators, out,
           notes, dry_run, start_eval)
예제 #3
0
def main():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('name', help='name of output to examine')
    parser.add_argument('--eval', "-e", action="store_true")
    parser.add_argument('--n-async', "-n", type=int, default=8)
    parser.add_argument('--dev-b', type=int, default=None)
    args = parser.parse_args()

    resume_training(ModelDir(args.name), start_eval=args.eval, async_encoding=args.n_async, dev_batch_size=args.dev_b)
예제 #4
0
def main():
    parser = argparse.ArgumentParser(description='Evaluate a model on SQuAD')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument('-n', '--sample_questions', type=int, default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument('-b', '--batch_size', type=int, default=64,
                        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument('-s', '--step', default=None,
                        help="Weights to load, can be a checkpoint step or 'latest'")
    # parser.add_argument('-c', '--corpus', choices=["dev", "train"], default="dev")
    parser.add_argument('--num_runs', type=int, default=1,
                        help="Number of different seeds to test on, for more accurate results")
    parser.add_argument('--no_ema', action="store_true", help="Don't use EMA weights even if they exist")
    args = parser.parse_args()

    model_dir = ModelDir(args.model)

    corpus = HotpotQuestions()
    # if args.corpus == "dev":
    #     questions = corpus.get_dev()
    # else:
    #     questions = corpus.get_train()
    questions = corpus.get_dev()

    question_filter = HotpotQuestionFilter(2)  # TODO add option to cancel this, and more fine-grained analysis
    question_preprocessor = HotpotTextLengthPreprocessor(600)
    questions = [question_preprocessor.preprocess(x) for x in questions
                 if (question_filter.keep(x) and question_preprocessor.preprocess(x) is not None)]

    if args.sample_questions:
        np.random.RandomState(0).shuffle(sorted(questions, key=lambda x: x.question_id))
        questions = questions[:args.sample_questions]

    batcher = ClusteredBatcher(args.batch_size, multiple_contexts_len, truncate_batches=True)
    datasets = [HotpotStratifiedBinaryQuestionParagraphPairsDataset(questions, batcher, fixed_dataset=True, sample_seed=i)
                for i in range(args.num_runs)]

    evaluators = [BinaryClassificationEvaluator(), RecordFineGrainedBinaryPrediction()]

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {f'dev_seed_{i}': dataset for i, dataset in enumerate(datasets)},
                              corpus.get_resource_loader(), checkpoint, not args.no_ema, 10)

    scalars = {key: np.mean([eval_dict.scalars[key] for eval_dict in evaluation.values()])
               for key in evaluation['dev_seed_0'].scalars.keys()}

    # Print the scalar results in a two column table
    # scalars = evaluation.scalars
    cols = list(sorted(scalars.keys()))
    table = [cols]
    header = ["Metric", ""]
    table.append([("%s" % scalars[x] if x in scalars else "-") for x in cols])
    print_table([header] + transpose_lists(table))
예제 #5
0
def main():
    parser = argparse.ArgumentParser(description='Evaluate a model on SQuAD')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument('-n', '--sample_questions', type=int, default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument('-b', '--batch_size', type=int, default=64,
                        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument('-s', '--step', default=None,
                        help="Weights to load, can be a checkpoint step or 'latest'")
    # parser.add_argument('-c', '--corpus', choices=["dev", "train"], default="dev")
    parser.add_argument('--no-ema', action="store_true", help="Don't use EMA weights even if they exist")
    parser.add_argument('--per-doc', action='store_true', help="Whether to test only against full doc, or against "
                                                               "distractors.")
    parser.add_argument('--save-errors', default=None, type=str)
    args = parser.parse_args()

    model_dir = ModelDir(args.model)

    corpus = SquadRelevanceCorpus()
    # if args.corpus == "dev":
    #     questions = corpus.get_dev()
    # else:
    #     questions = corpus.get_train()
    questions = corpus.get_dev()

    question_preprocessor = SquadTextLengthPreprocessor(600)
    questions = [question_preprocessor.preprocess(x) for x in questions
                 if (question_preprocessor.preprocess(x) is not None)]

    if args.sample_questions:
        questions = sorted(questions, key=lambda x: x.question_id)
        np.random.RandomState(0).shuffle(questions)
        questions = questions[:args.sample_questions]

    batcher = ClusteredBatcher(args.batch_size, multiple_contexts_len, truncate_batches=True)
    if args.per_doc:
        data = SquadFullDocumentDataset(questions, batcher, corpus.dev_title_to_document)
    else:
        data = SquadFullQuestionParagraphPairsDataset(questions, batcher)

    evaluators = [BinaryClassificationEvaluator(), RecordFineGrainedBinaryPrediction(), RecordFullRankings()]

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {'dev_full': data},
                              corpus.get_resource_loader(), checkpoint, not args.no_ema, 10)['dev_full']

    if args.save_errors is not None:
        errors_dict = evaluation.per_sample['per_question_errors']

        def format_error(wrong: Tuple[BinaryQuestionAndParagraphs, float],
                         correct: Tuple[BinaryQuestionAndParagraphs, float]):
            question = ' '.join(wrong[0].question)
            qid = wrong[0].question_id
            wrong_text = ' '.join(wrong[0].paragraphs[0])
            wrong_score = wrong[1]
            correct_text = ' '.join(correct[0].paragraphs[0])
            correct_score = correct[1]
            return f"Question: {question}, ID: {qid}\n" \
                   f"Incorrect First Place: (score: {wrong_score})\n{wrong_text}\n" \
                   f"Correct Passage: (score: {correct_score})\n{correct_text}\n"

        with open(args.save_errors, 'wt') as f:
            for false_par, true_par in errors_dict.values():
                f.write(format_error(false_par, true_par))

    # Print the scalar results in a two column table
    scalars = evaluation.scalars
    cols = list(sorted(scalars.keys()))
    table = [cols]
    header = ["Metric", ""]
    table.append([("%s" % scalars[x] if x in scalars else "-") for x in cols])
    print_table([header] + transpose_lists(table))
def main():
    parser = argparse.ArgumentParser(
        description='Full ranking evaluation on Hotpot')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument('-n',
                        '--sample_questions',
                        type=int,
                        default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument(
        '-b',
        '--batch_size',
        type=int,
        default=64,
        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument(
        '-s',
        '--step',
        default=None,
        help="Weights to load, can be a checkpoint step or 'latest'")
    # parser.add_argument('-c', '--corpus', choices=["dev", "train"], default="dev")
    parser.add_argument('--no_ema',
                        action="store_true",
                        help="Don't use EMA weights even if they exist")
    args = parser.parse_args()

    model_dir = ModelDir(args.model)

    corpus = HotpotQuestions()
    # if args.corpus == "dev":
    #     questions = corpus.get_dev()
    # else:
    #     questions = corpus.get_train()
    questions = corpus.get_dev()

    question_preprocessor = HotpotTextLengthPreprocessor(600)
    questions = [
        question_preprocessor.preprocess(x) for x in questions
        if (question_preprocessor.preprocess(x) is not None)
    ]

    if args.sample_questions:
        np.random.RandomState(0).shuffle(
            sorted(questions, key=lambda x: x.question_id))
        questions = questions[:args.sample_questions]

    batcher = ClusteredBatcher(args.batch_size,
                               multiple_contexts_len,
                               truncate_batches=True)
    data = HotpotFullQuestionParagraphPairsDataset(questions, batcher)

    evaluators = [
        BinaryClassificationEvaluator(),
        RecordFineGrainedBinaryPrediction(),
        RecordFullRankings()
    ]

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {'dev_full': data},
                              corpus.get_resource_loader(), checkpoint,
                              not args.no_ema, 10)['dev_full']

    # Print the scalar results in a two column table
    scalars = evaluation.scalars
    cols = list(sorted(scalars.keys()))
    table = [cols]
    header = ["Metric", ""]
    table.append([("%s" % scalars[x] if x in scalars else "-") for x in cols])
    print_table([header] + transpose_lists(table))
예제 #7
0
word_counts.update(title_counts)
voc = set(word_counts.keys())

print("Loading encoder...")

spec = QuestionAndParagraphsSpec(batch_size=None, max_num_contexts=2,
                                 max_num_question_words=None, max_num_context_words=None)
encoder = SentenceEncoderIterativeModel(model_dir_path=args.encoder_model, vocabulary=voc,
                                        spec=spec, loader=loader, use_char_inputs=False,
                                        use_ema=not args.no_ema,
                                        checkpoint=args.checkpoint)

print("Loading QA model...")
evaluators = [RecordHotpotQAPrediction(15, True, sp_prediction=True, disable_tqdm=True)]
batcher = ClusteredBatcher(64, multiple_contexts_len, truncate_batches=True)
qa_model_dir = ModelDir(args.qa_model)
checkpoint = None
if checkpoint == 'best':
    checkpoint = qa_model_dir.get_best_weights()
if checkpoint is not None:
    print("Using best weights")
else:
    print("Using latest checkpoint")
    checkpoint = qa_model_dir.get_latest_checkpoint()
qa_model = qa_model_dir.get_model()
assert isinstance(qa_model, AttentionQAFullHotpot)
qa_sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True), graph=tf.Graph())
qa_spec = QuestionAndParagraphsSpec(batch_size=None, max_num_contexts=1,
                                    max_num_question_words=None, max_num_context_words=None)
with qa_sess.graph.as_default():
    qa_model.set_inputs(None, loader, voc=voc, input_spec=qa_spec)
예제 #8
0
def main():
    parser = argparse.ArgumentParser(
        description='Full ranking evaluation on Hotpot')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument(
        'output',
        type=str,
        help="Store the per-paragraph results in csv format in this file, "
        "or the json prediction if in test mode")
    parser.add_argument('-n',
                        '--sample_questions',
                        type=int,
                        default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument(
        '-b',
        '--batch_size',
        type=int,
        default=64,
        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument(
        '-s',
        '--step',
        default=None,
        help="Weights to load, can be a checkpoint step or 'latest'")
    parser.add_argument('-a',
                        '--answer_bound',
                        type=int,
                        default=8,
                        help="Max answer span length")
    parser.add_argument('-c',
                        '--corpus',
                        choices=[
                            "distractors", "gold", "hotpot_file",
                            "retrieval_file", "top_titles"
                        ],
                        default="distractors")
    parser.add_argument('-t',
                        '--tokens',
                        type=int,
                        default=None,
                        help="Max tokens per a paragraph")
    parser.add_argument('--input_file', type=str, default=None)
    parser.add_argument('--docs_file', type=str, default=None)
    parser.add_argument('--num_workers',
                        type=int,
                        default=16,
                        help='Number of workers for tokenizing')
    parser.add_argument('--no_ema',
                        action="store_true",
                        help="Don't use EMA weights even if they exist")
    parser.add_argument('--no_sp',
                        action="store_true",
                        help="Don't predict supporting facts")
    parser.add_argument('--test_mode',
                        action='store_true',
                        help="produce a prediction file, no answers given")
    args = parser.parse_args()

    model_dir = ModelDir(args.model)
    batcher = ClusteredBatcher(args.batch_size,
                               multiple_contexts_len,
                               truncate_batches=True)
    loader = ResourceLoader()

    if args.corpus not in {"distractors", "gold"} and args.input_file is None:
        raise ValueError(
            "Must pass an input file if not using precomputed dataset")

    if args.corpus in {"distractors", "gold"} and args.test_mode:
        raise ValueError(
            "Test mode not available in 'distractors' or 'gold' mode")

    if args.corpus in {"distractors", "gold"}:

        corpus = HotpotQuestions()
        loader = corpus.get_resource_loader()
        questions = corpus.get_dev()

        question_preprocessor = HotpotTextLengthPreprocessorWithSpans(
            args.tokens)
        questions = [
            question_preprocessor.preprocess(x) for x in questions
            if (question_preprocessor.preprocess(x) is not None)
        ]

        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(questions, key=lambda x: x.question_id))
            questions = questions[:args.sample_questions]

        data = HotpotFullQADistractorsDataset(questions, batcher)
        gold_idxs = set(data.gold_idxs)
        if args.corpus == 'gold':
            data.samples = [data.samples[i] for i in data.gold_idxs]
        qid2samples = {}
        qid2idx = {}
        for i, sample in enumerate(data.samples):
            key = sample.question_id
            if key in qid2samples:
                qid2samples[key].append(sample)
                qid2idx[key].append(i)
            else:
                qid2samples[key] = [sample]
                qid2idx[key] = [i]
        questions = []
        print("Ranking pairs...")
        gold_ranks = []
        for qid, samples in tqdm(qid2samples.items()):
            question = " ".join(samples[0].question)
            pars = [" ".join(x.paragraphs[0]) for x in samples]
            ranks = get_paragraph_ranks(question, pars)
            for sample, rank, idx in zip(samples, ranks, qid2idx[qid]):
                questions.append(
                    RankedQAPair(question=sample.question,
                                 paragraphs=sample.paragraphs,
                                 spans=np.zeros((0, 2), dtype=np.int32),
                                 question_id=sample.question_id,
                                 answer=sample.answer,
                                 rank=rank,
                                 q_type=sample.q_type,
                                 sentence_segments=sample.sentence_segments))
                if idx in gold_idxs:
                    gold_ranks.append(rank + 1)
        print(f"Mean rank: {np.mean(gold_ranks)}")
        ranks_counter = Counter(gold_ranks)
        for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
            print(f"Hits at {i}: {ranks_counter[i]}")

    elif args.corpus == 'hotpot_file':  # a hotpot json format input file. We rank the pairs with tf-idf
        with open(args.input_file, 'r') as f:
            hotpot_data = json.load(f)
        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(hotpot_data, key=lambda x: x['_id']))
            hotpot_data = hotpot_data[:args.sample_questions]
        title2sentences = {
            context[0]: context[1]
            for q in hotpot_data for context in q['context']
        }
        question_tok_texts = tokenize_texts(
            [q['question'] for q in hotpot_data], num_workers=args.num_workers)
        sentences_tok = tokenize_texts(list(title2sentences.values()),
                                       num_workers=args.num_workers,
                                       sentences=True)
        if args.tokens is not None:
            sentences_tok = [
                truncate_paragraph(p, args.tokens) for p in sentences_tok
            ]
        title2tok_sents = {
            title: sentences
            for title, sentences in zip(title2sentences.keys(), sentences_tok)
        }
        questions = []
        for idx, question in enumerate(tqdm(hotpot_data,
                                            desc='tf-idf ranking')):
            q_titles = [title for title, _ in question['context']]
            par_pairs = [(title1, title2) for i, title1 in enumerate(q_titles)
                         for title2 in q_titles[i + 1:]]
            if len(par_pairs) == 0:
                continue
            ranks = get_paragraph_ranks(question['question'], [
                ' '.join(title2sentences[t1] + title2sentences[t2])
                for t1, t2 in par_pairs
            ])
            for rank, par_pair in zip(ranks, par_pairs):
                sent_tok_pair = title2tok_sents[par_pair[0]] + title2tok_sents[
                    par_pair[1]]
                sentence_segments, _ = get_segments_from_sentences_fix_sup(
                    sent_tok_pair, np.zeros(0))
                missing_sent_idx = [[
                    i for i, sent in enumerate(title2tok_sents[title])
                    if len(sent) == 0
                ] for title in par_pair]
                questions.append(
                    RankedQAPair(
                        question=question_tok_texts[idx],
                        paragraphs=[flatten_iterable(sent_tok_pair)],
                        spans=np.zeros((0, 2), dtype=np.int32),
                        question_id=question['_id'],
                        answer='noanswer'
                        if args.test_mode else question['answer'],
                        rank=rank,
                        q_type='null' if args.test_mode else question['type'],
                        sentence_segments=[sentence_segments],
                        par_titles_num_sents=[
                            (title,
                             sum(1 for sent in title2tok_sents[title]
                                 if len(sent) > 0)) for title in par_pair
                        ],
                        missing_sent_idxs=missing_sent_idx,
                        true_sp=[]
                        if args.test_mode else question['supporting_facts']))
    elif args.corpus == 'retrieval_file' or args.corpus == 'top_titles':
        if args.docs_file is None:
            print("Using DB documents")
            doc_db = DocDB(config.DOC_DB, full_docs=False)
        else:
            with open(args.docs_file, 'r') as f:
                docs = json.load(f)
        with open(args.input_file, 'r') as f:
            retrieval_data = json.load(f)
        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(retrieval_data, key=lambda x: x['qid']))
            retrieval_data = retrieval_data[:args.sample_questions]

        def parname_to_text(par_name):
            par_title = par_name_to_title(par_name)
            par_num = int(par_name.split('_')[-1])
            if args.docs_file is None:
                return doc_db.get_doc_sentences(par_title)
            return docs[par_title][par_num]

        if args.corpus == 'top_titles':
            print("Top TF-IDF!")
            for q in retrieval_data:
                top_titles = q['top_titles'][:10]
                q['paragraph_pairs'] = [(title1 + '_0', title2 + '_0')
                                        for i, title1 in enumerate(top_titles)
                                        for title2 in top_titles[i + 1:]]

        question_tok_texts = tokenize_texts(
            [q['question'] for q in retrieval_data],
            num_workers=args.num_workers)
        all_parnames = list(
            set([
                parname for q in retrieval_data
                for pair in q['paragraph_pairs'] for parname in pair
            ]))
        texts_tok = tokenize_texts([parname_to_text(x) for x in all_parnames],
                                   num_workers=args.num_workers,
                                   sentences=True)
        if args.tokens is not None:
            texts_tok = [truncate_paragraph(p, args.tokens) for p in texts_tok]
        parname2tok_text = {
            parname: text
            for parname, text in zip(all_parnames, texts_tok)
        }
        questions = []
        for idx, question in enumerate(retrieval_data):
            for rank, par_pair in enumerate(question['paragraph_pairs']):
                tok_pair = parname2tok_text[par_pair[0]] + parname2tok_text[
                    par_pair[1]]
                sentence_segments, _ = get_segments_from_sentences_fix_sup(
                    tok_pair, np.zeros(0))
                missing_sent_idx = [[
                    i for i, sent in enumerate(parname2tok_text[parname])
                    if len(sent) == 0
                ] for parname in par_pair]
                questions.append(
                    RankedQAPair(
                        question=question_tok_texts[idx],
                        paragraphs=[flatten_iterable(tok_pair)],
                        spans=np.zeros((0, 2), dtype=np.int32),
                        question_id=question['qid'],
                        answer='noanswer'
                        if args.test_mode else question['answers'][0],
                        rank=rank,
                        q_type='null' if args.test_mode else question['type'],
                        sentence_segments=[sentence_segments],
                        par_titles_num_sents=[
                            (par_name_to_title(parname),
                             sum(1 for sent in parname2tok_text[parname]
                                 if len(sent) > 0)) for parname in par_pair
                        ],
                        missing_sent_idxs=missing_sent_idx,
                        true_sp=[]
                        if args.test_mode else question['supporting_facts']))
    else:
        raise NotImplementedError()

    data = DummyDataset(questions, batcher)
    evaluators = [
        RecordHotpotQAPrediction(args.answer_bound,
                                 True,
                                 sp_prediction=not args.no_sp)
    ]

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {args.corpus: data}, loader,
                              checkpoint, not args.no_ema, 10)[args.corpus]

    print("Saving result")
    output_file = args.output

    df = pd.DataFrame(evaluation.per_sample)

    df.sort_values(["question_id", "rank"], inplace=True, ascending=True)
    group_by = ["question_id"]

    def get_ranked_scores(score_name):
        filtered_df = df[df.type == 'comparison'] if "Cp" in score_name else \
            df[df.type == 'bridge'] if "Br" in score_name else df
        target_prefix = 'joint' if 'joint' in score_name else 'sp' if 'sp' in score_name else 'text'
        target_score = f"{target_prefix}_{'em' if 'EM' in score_name else 'f1'}"
        return compute_ranked_scores_with_yes_no(
            filtered_df,
            span_q_col="span_question_scores",
            yes_no_q_col="yes_no_question_scores",
            yes_no_scores_col="yes_no_confidence_scores",
            span_scores_col="predicted_score",
            span_target_score=target_score,
            group_cols=group_by)

    if not args.test_mode:
        score_names = ["EM", "F1", "Br EM", "Br F1", "Cp EM", "Cp F1"]
        if not args.no_sp:
            score_names.extend([
                f"{prefix} {name}" for prefix in ['sp', 'joint']
                for name in score_names
            ])

        table = [["N Paragraphs"] + score_names]
        scores = [get_ranked_scores(score_name) for score_name in score_names]
        table += list([str(i + 1), *["%.4f" % x for x in score_vals]]
                      for i, score_vals in enumerate(zip(*scores)))
        print_table(table)

        df.to_csv(output_file, index=False)

    else:
        df_to_pred(df, output_file)
예제 #9
0
    def __init__(self,
                 model_dir_path: str,
                 vocabulary: Set[str],
                 spec: QuestionAndParagraphsSpec,
                 loader,
                 use_char_inputs: bool,
                 use_ema: bool,
                 checkpoint: str = 'best'):
        if checkpoint not in {'best', 'latest'}:
            raise ValueError(
                "checkpoint value must be either 'best' or 'latest'.")
        self.model_dir = ModelDir(model_dir_path)
        self.checkpoint = None
        if checkpoint == 'best':
            self.checkpoint = self.model_dir.get_best_weights()
        if self.checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            self.checkpoint = self.model_dir.get_latest_checkpoint()
        print(f"Restoring checkpoint: {self.checkpoint}")
        self.model = self.model_dir.get_model()
        assert isinstance(self.model, MultipleContextModel)
        if self.model.use_elmo and use_char_inputs:
            self.model.lm_model.embed_weights_file = None
        self.sess = tf.Session(
            config=tf.ConfigProto(allow_soft_placement=True), graph=tf.Graph())
        with self.sess.graph.as_default():
            self.model.set_inputs(None,
                                  loader,
                                  voc=vocabulary,
                                  input_spec=spec)
            inputs = self.model.get_placeholders()
            input_dict = {
                p: x
                for p, x in zip(self.model.get_placeholders(), inputs)
            }
            with self.sess.as_default():
                _ = self.model.get_predictions_for(
                    input_dict)  # for building the model
            if not self.model.use_elmo or not use_char_inputs:
                saver = tf.train.Saver()
                saver.restore(self.sess, self.checkpoint)
            else:
                self.sess.run(tf.global_variables_initializer())
                optimistic_restore(self.sess, self.checkpoint)

            if use_ema:
                ema = tf.train.ExponentialMovingAverage(0)
                reader = tf.train.NewCheckpointReader(self.checkpoint)
                expected_ema_names = {
                    ema.average_name(x): x
                    for x in tf.trainable_variables()
                    if reader.has_tensor(ema.average_name(x))
                }
                if len(expected_ema_names) > 0:
                    print("Restoring EMA variables")
                    saver = tf.train.Saver(expected_ema_names)
                    saver.restore(self.sess, self.checkpoint)
        if self.model.use_elmo and not use_char_inputs:
            # TODO: Might be redundant! the weights are already saved in the checkpoint
            elmo_token_embed_placeholder, elmo_token_embed_init = self.model.get_elmo_token_embed_ph_and_op(
            )
            print("Loading ELMo weights...")
            elmo_token_embed_weights = load_elmo_pretrained_token_embeddings(
                self.model.lm_model.embed_weights_file)
            self.sess.run(elmo_token_embed_init,
                          feed_dict={
                              elmo_token_embed_placeholder:
                              elmo_token_embed_weights
                          })

        self.sess.graph.finalize()
예제 #10
0
def main():
    parser = argparse.ArgumentParser(description='Full ranking evaluation on Hotpot')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument('-n', '--sample_questions', type=int, default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument('-b', '--batch_size', type=int, default=64,
                        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument('-s', '--step', default=None,
                        help="Weights to load, can be a checkpoint step or 'latest'")
    # parser.add_argument('-c', '--corpus', choices=["dev", "train"], default="dev")
    parser.add_argument('--no-ema', action="store_true", help="Don't use EMA weights even if they exist")
    parser.add_argument('--save-errors', default=None, type=str)
    parser.add_argument('--br-as-cp', action='store_true')
    parser.add_argument('--mult-probs', action='store_true')
    args = parser.parse_args()

    model_dir = ModelDir(args.model)

    corpus = HotpotQuestions()
    # if args.corpus == "dev":
    #     questions = corpus.get_dev()
    # else:
    #     questions = corpus.get_train()
    questions = corpus.get_dev()

    question_preprocessor = HotpotTextLengthPreprocessor(600)
    questions = [question_preprocessor.preprocess(x) for x in questions
                 if (question_preprocessor.preprocess(x) is not None)]

    if args.sample_questions:
        np.random.RandomState(0).shuffle(sorted(questions, key=lambda x: x.question_id))
        questions = questions[:args.sample_questions]

    batcher = ClusteredBatcher(args.batch_size, multiple_contexts_len, truncate_batches=True)
    data = HotpotFullIterativeDataset(questions, batcher, bridge_as_comparison=args.br_as_cp)

    evaluators = [IterativeRelevanceEvaluator(), RecordFullIterativeRankings(multiply_iteration_probs=args.mult_probs)]

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {'dev_full': data},
                              corpus.get_resource_loader(), checkpoint, not args.no_ema, 10)['dev_full']

    if args.save_errors is not None:
        first_errors_dict = evaluation.per_sample['per_question_first_errors']
        second_errors_dict = evaluation.per_sample['per_question_second_errors']

        def format_first_error(wrong: Tuple[IterativeQuestionAndParagraphs, float, float],
                               correct: Tuple[IterativeQuestionAndParagraphs, float, float]):
            question = ' '.join(wrong[0].question)
            qid = wrong[0].question_id
            q_type = wrong[0].q_type
            wrong_text = ' '.join(wrong[0].paragraphs[0])
            wrong_score = wrong[1]
            correct_text = ' '.join(correct[0].paragraphs[0])
            correct_score = correct[1]
            return f"Question: {question}, ID: {qid}, type: {q_type}\n" \
                   f"Incorrect First Place: (score: {wrong_score})\n{wrong_text}\n-\n" \
                   f"Correct Passage: (score: {correct_score})\n{correct_text}\n***\n"

        def format_second_error(wrong: Tuple[IterativeQuestionAndParagraphs, float, float],
                                correct: Tuple[IterativeQuestionAndParagraphs, float, float]):
            question = ' '.join(wrong[0].question)
            qid = wrong[0].question_id
            q_type = wrong[0].q_type
            wrong_texts = [' '.join(par) for par in wrong[0].paragraphs]
            wrong_first_score = wrong[1]
            wrong_final_score = wrong[2]
            correct_texts = [' '.join(par) for par in correct[0].paragraphs]
            correct_first_score = correct[1]
            correct_final_score = correct[2]
            return f"Question: {question}, ID: {qid}, type: {q_type}\n" \
                   f"Incorrect First Place Pair: (score: {wrong_final_score})\n" \
                   f"Paragraph 1 (score: {wrong_first_score})\n" \
                   f"{wrong_texts[0]}\n" \
                   f"Paragraph 2:\n" \
                   f"{wrong_texts[1]}\n-\n" \
                   f"Correct Pair: (score: {correct_final_score})\n" \
                   f"Paragraph 1 (score: {correct_first_score})\n" \
                   f"{correct_texts[0]}\n" \
                   f"Paragraph 2:\n" \
                   f"{correct_texts[1]}\n***\n"

        with open(args.save_errors, 'wt') as f:
            f.write("First paragraph errors:\n*****************************\n")
            for false_par, true_par in first_errors_dict.values():
                f.write(format_first_error(false_par, true_par))
            f.write("Second paragraph errors:\n*****************************\n")
            for false_par, true_par in second_errors_dict.values():
                f.write(format_second_error(false_par, true_par))

    # Print the scalar results in a two column table
    scalars = evaluation.scalars
    cols = list(sorted(scalars.keys()))
    table = [cols]
    header = ["Metric", ""]
    table.append([("%s" % scalars[x] if x in scalars else "-") for x in cols])
    print_table([header] + transpose_lists(table))