Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(description='Train a model on the Hotpot QA dataset')
    parser.add_argument("name", help="Where to store the model")
    parser.add_argument("--elmo", action='store_true', help="Whether to use elmo or not")
    parser.add_argument("-c", "--continue_model", action='store_true', help="Whether to start a new run or "
                                                                            "continue an existing one")
    args = parser.parse_args()

    with open(__file__, "r") as f:
        notes = f.read()

    continue_existing_run = args.continue_model
    # save_preprocessed = args.save
    if continue_existing_run:
        print("We will continue an existing run!")
    else:
        print("We will start a new run!")

    out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S")
    if continue_existing_run:
        out = args.name

    # model = get_model_with_yes_no(rnn_dim=150, use_elmo=args.elmo, keep_rate=0.7)
    model = get_full_hotpot_model(rnn_dim=150, use_elmo=args.elmo, keep_rate=0.8)

    corpus = HotpotQuestions()
    train_batcher = ClusteredBatcher(25, multiple_contexts_len, truncate_batches=True)
    dev_batcher = ClusteredBatcher(90, multiple_contexts_len, truncate_batches=True)
    data = HotpotQATrainingData(corpus=corpus, train_batcher=train_batcher, dev_batcher=dev_batcher,
                                sample_filter=HotpotQuestionFilterWithSpans(1, keep_yes_no=True),
                                preprocessor=HotpotTextLengthPreprocessorWithSpans(600),
                                sample_train=None, sample_dev=None, sample_seed=18,
                                group_pairs_in_batches=True, distractor_pairs=2)

    eval = [LossEvaluator(), MultiParagraphSpanEvaluator(8, "hotpot", yes_no_option=True, supporting_facts_option=True)]

    n_epochs = 80

    eval_samples = dict(dev=None, train=3000)

    params = TrainParams(
        SerializableOptimizer("Adadelta", dict(learning_rate=1.0)),
        num_epochs=n_epochs, ema=0.999, max_checkpoints_to_keep=2,
        async_encoding=8, log_period=30, eval_period=1200, save_period=1200,
        eval_samples=eval_samples, best_weights=('dev', 'b8/question-text-f1'),
        monitor_gradients=True, clip_norm=None
    )

    if not continue_existing_run:
        trainer.start_training(data, model, params, eval, model_dir.ModelDir(out), notes, save_graph=False)
    else:
        resume_training_with(data=data, out=model_dir.ModelDir(out),
                             train_params=params, evaluators=eval, notes=notes, start_eval=True)
Ejemplo n.º 2
0
def build_squad_elmo(vocab_file, embd_file):
    corpus = SquadRelevanceCorpus()
    some_batcher = ClusteredBatcher(64, multiple_contexts_len, truncate_batches=True)
    data = SquadBinaryRelevanceTrainingData(corpus=corpus, train_batcher=some_batcher, dev_batcher=some_batcher,
                                            sample_filter=None, preprocessor=None,
                                            sample_train=None, sample_dev=None, sample_seed=18)
    build_vocab_from_preprocessed(data, vocab_file, embd_file)
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser(description='Train a model on the Squad relevance dataset')
    parser.add_argument("name", help="Where to store the model")
    parser.add_argument("--elmo", action='store_true', help="Whether to use elmo or not")
    parser.add_argument("-c", "--continue_model", action='store_true', help="Whether to start a new run or "
                                                                            "continue an existing one")
    args = parser.parse_args()

    with open(__file__, "r") as f:
        notes = f.read()

    continue_existing_run = args.continue_model
    # save_preprocessed = args.save
    if continue_existing_run:
        print("We will continue an existing run!")
    else:
        print("We will start a new run!")

    out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S")
    if continue_existing_run:
        out = args.name

    # model = get_basic_model(500, post_merger_params=None, use_elmo=args.elmo, keep_rate=0.8)
    # model = get_context_to_question_model(rnn_dim=150, q2c=False, res_rnn=True, res_self_att=False)
    # model = get_context_with_bottleneck_to_question_model(rnn_dim=500, q2c=False, res_rnn=True, res_self_att=False)
    # model = get_ablate_model()
    # model = get_fixed_context_to_question(150)
    # model = get_bottleneck_to_seq_model(500, q2c=False, res_rnn=True, res_self_att=False, seq_len=50)
    # model = get_multi_encode_model(0, 200, num_encodings=5, map_embed=False)
    # model = get_multi_encode_softmax_weighting_model(0, 400, num_encodings=5, map_embed=False)
    model = get_sentences_model(512, use_elmo=args.elmo, keep_rate=0.8)

    corpus = SquadRelevanceCorpus()
    train_batcher = ClusteredBatcher(45, multiple_contexts_len, truncate_batches=True)
    dev_batcher = ClusteredBatcher(128, multiple_contexts_len, truncate_batches=True)
    data = SquadBinaryRelevanceTrainingData(corpus=corpus, train_batcher=train_batcher, dev_batcher=dev_batcher,
                                            sample_filter=None, preprocessor=SquadTextLengthPreprocessor(600),
                                            sample_train=None, sample_dev=None, sample_seed=18)

    eval = [LossEvaluator(), BinaryClassificationEvaluator()]

    n_epochs = 80

    adadelta = SerializableOptimizer("Adadelta", dict(learning_rate=1.0))
    momentum = SerializableOptimizer("Momentum", dict(learning_rate=0.01, momentum=0.9, use_nesterov=True))
    adam = SerializableOptimizer("Adam", dict(learning_rate=1e-4))

    reduce_lr_on_plateau = ReduceLROnPlateau(dev_name='dev', scalar_name='loss', factor=0.2,
                                             patience=8, verbose=1, mode='min', terminate_th=1e-5)

    params = TrainParams(
        adadelta,
        num_epochs=n_epochs, ema=0.999, max_checkpoints_to_keep=2,
        async_encoding=8, log_period=30, eval_period=1800, save_period=1800,
        eval_samples=dict(dev=None, train=3000), best_weights=('dev', 'binary-relevance/average_precision'),
        monitor_gradients=True, clip_norm=None, regularization_lambda=None, reduce_lr_on_plateau=None
    )

    if not continue_existing_run:
        trainer.start_training(data, model, params, eval, model_dir.ModelDir(out), notes, save_graph=False)
    else:
        resume_training_with(data=data, out=model_dir.ModelDir(out),
                             train_params=params, evaluators=eval, notes=notes, start_eval=True)
Ejemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser(
        description='Train a model on the Hotpot pairwise relevance dataset')
    parser.add_argument("name", help="Where to store the model")
    parser.add_argument("-c",
                        "--continue_model",
                        action='store_true',
                        help="Whether to start a new run or "
                        "continue an existing one")
    args = parser.parse_args()

    with open(__file__, "r") as f:
        notes = f.read()

    continue_existing_run = args.continue_model
    # save_preprocessed = args.save
    if continue_existing_run:
        print("We will continue an existing run!")
    else:
        print("We will start a new run!")

    out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S")
    if continue_existing_run:
        out = args.name

    # model = get_model(rnn_dim=150)
    # model = get_contexts_to_question_model(rnn_dim=150, post_merge='res_self_att')
    model = get_multi_hop_model(rnn_dim=150,
                                c2c=True,
                                q2c=False,
                                res_rnn=True,
                                res_self_att=False,
                                post_merge=True,
                                encoder='max',
                                merge_type='max',
                                num_c2c_hops=1)
    # model = get_context_only_model(rnn_dim=150, res_rnn=True, res_self_att=False,
    #                                encoder='max', num_c2c_hops=1)

    corpus = HotpotQuestions()
    train_batcher = ClusteredBatcher(45,
                                     multiple_contexts_len,
                                     truncate_batches=True)
    dev_batcher = ClusteredBatcher(45,
                                   multiple_contexts_len,
                                   truncate_batches=True)
    data = HotpotBinaryRelevanceTrainingData(
        corpus=corpus,
        train_batcher=train_batcher,
        dev_batcher=dev_batcher,
        sample_filter=HotpotQuestionFilter(2),
        preprocessor=HotpotTextLengthPreprocessor(600),
        sample_train=None,
        sample_dev=None,
        sample_seed=18,
        add_gold_distractors=True)

    eval = [LossEvaluator(), BinaryClassificationEvaluator()]

    n_epochs = 80

    params = TrainParams(SerializableOptimizer("Adadelta",
                                               dict(learning_rate=1.0)),
                         num_epochs=n_epochs,
                         ema=0.999,
                         max_checkpoints_to_keep=2,
                         async_encoding=8,
                         log_period=30,
                         eval_period=1800,
                         save_period=1800,
                         eval_samples=dict(dev=None, train=3000),
                         best_weights=('dev', 'binary-relevance/f1_score'),
                         monitor_gradients=True)

    if not continue_existing_run:
        trainer.start_training(data,
                               model,
                               params,
                               eval,
                               model_dir.ModelDir(out),
                               notes,
                               save_graph=False)
    else:
        resume_training_with(data=data,
                             out=model_dir.ModelDir(out),
                             train_params=params,
                             evaluators=eval,
                             notes=notes,
                             start_eval=True)
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser(description='Evaluate a model on SQuAD')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument('-n', '--sample_questions', type=int, default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument('-b', '--batch_size', type=int, default=64,
                        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument('-s', '--step', default=None,
                        help="Weights to load, can be a checkpoint step or 'latest'")
    # parser.add_argument('-c', '--corpus', choices=["dev", "train"], default="dev")
    parser.add_argument('--num_runs', type=int, default=1,
                        help="Number of different seeds to test on, for more accurate results")
    parser.add_argument('--no_ema', action="store_true", help="Don't use EMA weights even if they exist")
    args = parser.parse_args()

    model_dir = ModelDir(args.model)

    corpus = HotpotQuestions()
    # if args.corpus == "dev":
    #     questions = corpus.get_dev()
    # else:
    #     questions = corpus.get_train()
    questions = corpus.get_dev()

    question_filter = HotpotQuestionFilter(2)  # TODO add option to cancel this, and more fine-grained analysis
    question_preprocessor = HotpotTextLengthPreprocessor(600)
    questions = [question_preprocessor.preprocess(x) for x in questions
                 if (question_filter.keep(x) and question_preprocessor.preprocess(x) is not None)]

    if args.sample_questions:
        np.random.RandomState(0).shuffle(sorted(questions, key=lambda x: x.question_id))
        questions = questions[:args.sample_questions]

    batcher = ClusteredBatcher(args.batch_size, multiple_contexts_len, truncate_batches=True)
    datasets = [HotpotStratifiedBinaryQuestionParagraphPairsDataset(questions, batcher, fixed_dataset=True, sample_seed=i)
                for i in range(args.num_runs)]

    evaluators = [BinaryClassificationEvaluator(), RecordFineGrainedBinaryPrediction()]

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {f'dev_seed_{i}': dataset for i, dataset in enumerate(datasets)},
                              corpus.get_resource_loader(), checkpoint, not args.no_ema, 10)

    scalars = {key: np.mean([eval_dict.scalars[key] for eval_dict in evaluation.values()])
               for key in evaluation['dev_seed_0'].scalars.keys()}

    # Print the scalar results in a two column table
    # scalars = evaluation.scalars
    cols = list(sorted(scalars.keys()))
    table = [cols]
    header = ["Metric", ""]
    table.append([("%s" % scalars[x] if x in scalars else "-") for x in cols])
    print_table([header] + transpose_lists(table))
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(description='Train a model on the Hotpot iterative relevance dataset')
    parser.add_argument("name", help="Where to store the model")
    parser.add_argument("--elmo", action='store_true', help="Whether to use elmo or not")
    parser.add_argument("--label-method", choices=["br-as-cp", "span", 'tfidf'], default="tfidf")
    parser.add_argument("--rank", action='store_true', help="Whether to use ranking loss or not")
    parser.add_argument("-c", "--continue_model", action='store_true', help="Whether to start a new run or "
                                                                            "continue an existing one")
    args = parser.parse_args()

    with open(__file__, "r") as f:
        notes = f.read()

    continue_existing_run = args.continue_model
    # save_preprocessed = args.save
    if continue_existing_run:
        print("We will continue an existing run!")
    else:
        print("We will start a new run!")

    out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S")
    if continue_existing_run:
        out = args.name

    print(f"Labeling method: {args.label_method}")

    # model = get_model(rnn_dim=500, use_elmo=args.elmo, keep_rate=0.8)
    model = get_reread_model(rnn_dim=512, use_elmo=args.elmo, encoder_keep_rate=0.8, reread_keep_rate=0.8,
                             two_phase_att=False, res_rnn=True, res_self_att=False,
                             multiply_iteration_probs=False, reformulate_by_context=False,
                             rank_first=True, first_rank_lambda=1.0,
                             rank_second=True, second_rank_lambda=1.0,
                             reread_rnn_dim=None, ranking_gamma=1.0)
    # model = get_reread_simple_score(rnn_dim=512, use_elmo=args.elmo, keep_rate=0.8,
    #                                 two_phase_att=False, res_rnn=True, res_self_att=False, reformulate_by_context=False,
    #                                 rank_first=True,
    #                                 rank_second=True,
    #                                 reread_rnn_dim=None)
    # model = get_reread_merge_model(rnn_dim=512, use_elmo=args.elmo, keep_rate=0.8,
    #                                res_rnn=True, res_self_att=False,
    #                                multiply_iteration_probs=False)

    corpus = HotpotQuestions()
    if not args.rank:
        train_batcher = ClusteredBatcher(45, multiple_contexts_len, truncate_batches=True)
        dev_batcher = ClusteredBatcher(90, multiple_contexts_len, truncate_batches=True)
        data = HotpotIterativeRelevanceTrainingData(corpus=corpus, train_batcher=train_batcher, dev_batcher=dev_batcher,
                                                    sample_filter=HotpotQuestionFilter(2),
                                                    preprocessor=HotpotTextLengthPreprocessor(600),
                                                    sample_train=None, sample_dev=None, sample_seed=18,
                                                    bridge_as_comparison=args.label_method == 'br-as-cp',
                                                    label_by_span=args.label_method == 'span')
    else:
        train_batcher = ClusteredBatcher(25, multiple_contexts_len, truncate_batches=True)
        dev_batcher = ClusteredBatcher(75, multiple_contexts_len, truncate_batches=True)
        data = HotpotIterativeRelevanceTrainingData(corpus=corpus, train_batcher=train_batcher, dev_batcher=dev_batcher,
                                                    sample_filter=HotpotQuestionFilter(2),
                                                    preprocessor=HotpotTextLengthPreprocessor(600),
                                                    sample_train=None, sample_dev=None, sample_seed=18,
                                                    bridge_as_comparison=args.label_method == 'br-as-cp',
                                                    group_pairs_in_batches=True,
                                                    label_by_span=args.label_method == 'span',
                                                    num_distractors_in_batch=2,
                                                    max_batch_size=model.max_batch_size)

    eval = [LossEvaluator(), IterativeRelevanceEvaluator()]

    n_epochs = 80

    eval_samples = dict(dev=None, train=1500)
    if args.rank:
        eval_samples.update(dict(dev_grouped=1500, train_grouped=1500))

    params = TrainParams(
        SerializableOptimizer("Adadelta", dict(learning_rate=1.0)),
        num_epochs=n_epochs, ema=0.999, max_checkpoints_to_keep=2,
        async_encoding=8, log_period=30, eval_period=1800, save_period=1800,
        eval_samples=eval_samples, best_weights=('dev', 'iterative-relevance/second/average_precision'),
        monitor_gradients=True, clip_norm=None
    )

    if not continue_existing_run:
        trainer.start_training(data, model, params, eval, model_dir.ModelDir(out), notes, save_graph=False)
    else:
        resume_training_with(data=data, out=model_dir.ModelDir(out),
                             train_params=params, evaluators=eval, notes=notes, start_eval=True)
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(description='Evaluate a model on SQuAD')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument('-n', '--sample_questions', type=int, default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument('-b', '--batch_size', type=int, default=64,
                        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument('-s', '--step', default=None,
                        help="Weights to load, can be a checkpoint step or 'latest'")
    # parser.add_argument('-c', '--corpus', choices=["dev", "train"], default="dev")
    parser.add_argument('--no-ema', action="store_true", help="Don't use EMA weights even if they exist")
    parser.add_argument('--per-doc', action='store_true', help="Whether to test only against full doc, or against "
                                                               "distractors.")
    parser.add_argument('--save-errors', default=None, type=str)
    args = parser.parse_args()

    model_dir = ModelDir(args.model)

    corpus = SquadRelevanceCorpus()
    # if args.corpus == "dev":
    #     questions = corpus.get_dev()
    # else:
    #     questions = corpus.get_train()
    questions = corpus.get_dev()

    question_preprocessor = SquadTextLengthPreprocessor(600)
    questions = [question_preprocessor.preprocess(x) for x in questions
                 if (question_preprocessor.preprocess(x) is not None)]

    if args.sample_questions:
        questions = sorted(questions, key=lambda x: x.question_id)
        np.random.RandomState(0).shuffle(questions)
        questions = questions[:args.sample_questions]

    batcher = ClusteredBatcher(args.batch_size, multiple_contexts_len, truncate_batches=True)
    if args.per_doc:
        data = SquadFullDocumentDataset(questions, batcher, corpus.dev_title_to_document)
    else:
        data = SquadFullQuestionParagraphPairsDataset(questions, batcher)

    evaluators = [BinaryClassificationEvaluator(), RecordFineGrainedBinaryPrediction(), RecordFullRankings()]

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {'dev_full': data},
                              corpus.get_resource_loader(), checkpoint, not args.no_ema, 10)['dev_full']

    if args.save_errors is not None:
        errors_dict = evaluation.per_sample['per_question_errors']

        def format_error(wrong: Tuple[BinaryQuestionAndParagraphs, float],
                         correct: Tuple[BinaryQuestionAndParagraphs, float]):
            question = ' '.join(wrong[0].question)
            qid = wrong[0].question_id
            wrong_text = ' '.join(wrong[0].paragraphs[0])
            wrong_score = wrong[1]
            correct_text = ' '.join(correct[0].paragraphs[0])
            correct_score = correct[1]
            return f"Question: {question}, ID: {qid}\n" \
                   f"Incorrect First Place: (score: {wrong_score})\n{wrong_text}\n" \
                   f"Correct Passage: (score: {correct_score})\n{correct_text}\n"

        with open(args.save_errors, 'wt') as f:
            for false_par, true_par in errors_dict.values():
                f.write(format_error(false_par, true_par))

    # Print the scalar results in a two column table
    scalars = evaluation.scalars
    cols = list(sorted(scalars.keys()))
    table = [cols]
    header = ["Metric", ""]
    table.append([("%s" % scalars[x] if x in scalars else "-") for x in cols])
    print_table([header] + transpose_lists(table))
def main():
    parser = argparse.ArgumentParser(
        description='Full ranking evaluation on Hotpot')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument('-n',
                        '--sample_questions',
                        type=int,
                        default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument(
        '-b',
        '--batch_size',
        type=int,
        default=64,
        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument(
        '-s',
        '--step',
        default=None,
        help="Weights to load, can be a checkpoint step or 'latest'")
    # parser.add_argument('-c', '--corpus', choices=["dev", "train"], default="dev")
    parser.add_argument('--no_ema',
                        action="store_true",
                        help="Don't use EMA weights even if they exist")
    args = parser.parse_args()

    model_dir = ModelDir(args.model)

    corpus = HotpotQuestions()
    # if args.corpus == "dev":
    #     questions = corpus.get_dev()
    # else:
    #     questions = corpus.get_train()
    questions = corpus.get_dev()

    question_preprocessor = HotpotTextLengthPreprocessor(600)
    questions = [
        question_preprocessor.preprocess(x) for x in questions
        if (question_preprocessor.preprocess(x) is not None)
    ]

    if args.sample_questions:
        np.random.RandomState(0).shuffle(
            sorted(questions, key=lambda x: x.question_id))
        questions = questions[:args.sample_questions]

    batcher = ClusteredBatcher(args.batch_size,
                               multiple_contexts_len,
                               truncate_batches=True)
    data = HotpotFullQuestionParagraphPairsDataset(questions, batcher)

    evaluators = [
        BinaryClassificationEvaluator(),
        RecordFineGrainedBinaryPrediction(),
        RecordFullRankings()
    ]

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {'dev_full': data},
                              corpus.get_resource_loader(), checkpoint,
                              not args.no_ema, 10)['dev_full']

    # Print the scalar results in a two column table
    scalars = evaluation.scalars
    cols = list(sorted(scalars.keys()))
    table = [cols]
    header = ["Metric", ""]
    table.append([("%s" % scalars[x] if x in scalars else "-") for x in cols])
    print_table([header] + transpose_lists(table))
Ejemplo n.º 9
0
title_counts = load_counts(join(LOCAL_DATA_DIR, 'hotpot', 'wiki_title_word_counts.txt'))
word_counts.update(title_counts)
voc = set(word_counts.keys())

print("Loading encoder...")

spec = QuestionAndParagraphsSpec(batch_size=None, max_num_contexts=2,
                                 max_num_question_words=None, max_num_context_words=None)
encoder = SentenceEncoderIterativeModel(model_dir_path=args.encoder_model, vocabulary=voc,
                                        spec=spec, loader=loader, use_char_inputs=False,
                                        use_ema=not args.no_ema,
                                        checkpoint=args.checkpoint)

print("Loading QA model...")
evaluators = [RecordHotpotQAPrediction(15, True, sp_prediction=True, disable_tqdm=True)]
batcher = ClusteredBatcher(64, multiple_contexts_len, truncate_batches=True)
qa_model_dir = ModelDir(args.qa_model)
checkpoint = None
if checkpoint == 'best':
    checkpoint = qa_model_dir.get_best_weights()
if checkpoint is not None:
    print("Using best weights")
else:
    print("Using latest checkpoint")
    checkpoint = qa_model_dir.get_latest_checkpoint()
qa_model = qa_model_dir.get_model()
assert isinstance(qa_model, AttentionQAFullHotpot)
qa_sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True), graph=tf.Graph())
qa_spec = QuestionAndParagraphsSpec(batch_size=None, max_num_contexts=1,
                                    max_num_question_words=None, max_num_context_words=None)
with qa_sess.graph.as_default():
Ejemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser(
        description='Full ranking evaluation on Hotpot')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument(
        'output',
        type=str,
        help="Store the per-paragraph results in csv format in this file, "
        "or the json prediction if in test mode")
    parser.add_argument('-n',
                        '--sample_questions',
                        type=int,
                        default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument(
        '-b',
        '--batch_size',
        type=int,
        default=64,
        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument(
        '-s',
        '--step',
        default=None,
        help="Weights to load, can be a checkpoint step or 'latest'")
    parser.add_argument('-a',
                        '--answer_bound',
                        type=int,
                        default=8,
                        help="Max answer span length")
    parser.add_argument('-c',
                        '--corpus',
                        choices=[
                            "distractors", "gold", "hotpot_file",
                            "retrieval_file", "top_titles"
                        ],
                        default="distractors")
    parser.add_argument('-t',
                        '--tokens',
                        type=int,
                        default=None,
                        help="Max tokens per a paragraph")
    parser.add_argument('--input_file', type=str, default=None)
    parser.add_argument('--docs_file', type=str, default=None)
    parser.add_argument('--num_workers',
                        type=int,
                        default=16,
                        help='Number of workers for tokenizing')
    parser.add_argument('--no_ema',
                        action="store_true",
                        help="Don't use EMA weights even if they exist")
    parser.add_argument('--no_sp',
                        action="store_true",
                        help="Don't predict supporting facts")
    parser.add_argument('--test_mode',
                        action='store_true',
                        help="produce a prediction file, no answers given")
    args = parser.parse_args()

    model_dir = ModelDir(args.model)
    batcher = ClusteredBatcher(args.batch_size,
                               multiple_contexts_len,
                               truncate_batches=True)
    loader = ResourceLoader()

    if args.corpus not in {"distractors", "gold"} and args.input_file is None:
        raise ValueError(
            "Must pass an input file if not using precomputed dataset")

    if args.corpus in {"distractors", "gold"} and args.test_mode:
        raise ValueError(
            "Test mode not available in 'distractors' or 'gold' mode")

    if args.corpus in {"distractors", "gold"}:

        corpus = HotpotQuestions()
        loader = corpus.get_resource_loader()
        questions = corpus.get_dev()

        question_preprocessor = HotpotTextLengthPreprocessorWithSpans(
            args.tokens)
        questions = [
            question_preprocessor.preprocess(x) for x in questions
            if (question_preprocessor.preprocess(x) is not None)
        ]

        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(questions, key=lambda x: x.question_id))
            questions = questions[:args.sample_questions]

        data = HotpotFullQADistractorsDataset(questions, batcher)
        gold_idxs = set(data.gold_idxs)
        if args.corpus == 'gold':
            data.samples = [data.samples[i] for i in data.gold_idxs]
        qid2samples = {}
        qid2idx = {}
        for i, sample in enumerate(data.samples):
            key = sample.question_id
            if key in qid2samples:
                qid2samples[key].append(sample)
                qid2idx[key].append(i)
            else:
                qid2samples[key] = [sample]
                qid2idx[key] = [i]
        questions = []
        print("Ranking pairs...")
        gold_ranks = []
        for qid, samples in tqdm(qid2samples.items()):
            question = " ".join(samples[0].question)
            pars = [" ".join(x.paragraphs[0]) for x in samples]
            ranks = get_paragraph_ranks(question, pars)
            for sample, rank, idx in zip(samples, ranks, qid2idx[qid]):
                questions.append(
                    RankedQAPair(question=sample.question,
                                 paragraphs=sample.paragraphs,
                                 spans=np.zeros((0, 2), dtype=np.int32),
                                 question_id=sample.question_id,
                                 answer=sample.answer,
                                 rank=rank,
                                 q_type=sample.q_type,
                                 sentence_segments=sample.sentence_segments))
                if idx in gold_idxs:
                    gold_ranks.append(rank + 1)
        print(f"Mean rank: {np.mean(gold_ranks)}")
        ranks_counter = Counter(gold_ranks)
        for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
            print(f"Hits at {i}: {ranks_counter[i]}")

    elif args.corpus == 'hotpot_file':  # a hotpot json format input file. We rank the pairs with tf-idf
        with open(args.input_file, 'r') as f:
            hotpot_data = json.load(f)
        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(hotpot_data, key=lambda x: x['_id']))
            hotpot_data = hotpot_data[:args.sample_questions]
        title2sentences = {
            context[0]: context[1]
            for q in hotpot_data for context in q['context']
        }
        question_tok_texts = tokenize_texts(
            [q['question'] for q in hotpot_data], num_workers=args.num_workers)
        sentences_tok = tokenize_texts(list(title2sentences.values()),
                                       num_workers=args.num_workers,
                                       sentences=True)
        if args.tokens is not None:
            sentences_tok = [
                truncate_paragraph(p, args.tokens) for p in sentences_tok
            ]
        title2tok_sents = {
            title: sentences
            for title, sentences in zip(title2sentences.keys(), sentences_tok)
        }
        questions = []
        for idx, question in enumerate(tqdm(hotpot_data,
                                            desc='tf-idf ranking')):
            q_titles = [title for title, _ in question['context']]
            par_pairs = [(title1, title2) for i, title1 in enumerate(q_titles)
                         for title2 in q_titles[i + 1:]]
            if len(par_pairs) == 0:
                continue
            ranks = get_paragraph_ranks(question['question'], [
                ' '.join(title2sentences[t1] + title2sentences[t2])
                for t1, t2 in par_pairs
            ])
            for rank, par_pair in zip(ranks, par_pairs):
                sent_tok_pair = title2tok_sents[par_pair[0]] + title2tok_sents[
                    par_pair[1]]
                sentence_segments, _ = get_segments_from_sentences_fix_sup(
                    sent_tok_pair, np.zeros(0))
                missing_sent_idx = [[
                    i for i, sent in enumerate(title2tok_sents[title])
                    if len(sent) == 0
                ] for title in par_pair]
                questions.append(
                    RankedQAPair(
                        question=question_tok_texts[idx],
                        paragraphs=[flatten_iterable(sent_tok_pair)],
                        spans=np.zeros((0, 2), dtype=np.int32),
                        question_id=question['_id'],
                        answer='noanswer'
                        if args.test_mode else question['answer'],
                        rank=rank,
                        q_type='null' if args.test_mode else question['type'],
                        sentence_segments=[sentence_segments],
                        par_titles_num_sents=[
                            (title,
                             sum(1 for sent in title2tok_sents[title]
                                 if len(sent) > 0)) for title in par_pair
                        ],
                        missing_sent_idxs=missing_sent_idx,
                        true_sp=[]
                        if args.test_mode else question['supporting_facts']))
    elif args.corpus == 'retrieval_file' or args.corpus == 'top_titles':
        if args.docs_file is None:
            print("Using DB documents")
            doc_db = DocDB(config.DOC_DB, full_docs=False)
        else:
            with open(args.docs_file, 'r') as f:
                docs = json.load(f)
        with open(args.input_file, 'r') as f:
            retrieval_data = json.load(f)
        if args.sample_questions:
            np.random.RandomState(0).shuffle(
                sorted(retrieval_data, key=lambda x: x['qid']))
            retrieval_data = retrieval_data[:args.sample_questions]

        def parname_to_text(par_name):
            par_title = par_name_to_title(par_name)
            par_num = int(par_name.split('_')[-1])
            if args.docs_file is None:
                return doc_db.get_doc_sentences(par_title)
            return docs[par_title][par_num]

        if args.corpus == 'top_titles':
            print("Top TF-IDF!")
            for q in retrieval_data:
                top_titles = q['top_titles'][:10]
                q['paragraph_pairs'] = [(title1 + '_0', title2 + '_0')
                                        for i, title1 in enumerate(top_titles)
                                        for title2 in top_titles[i + 1:]]

        question_tok_texts = tokenize_texts(
            [q['question'] for q in retrieval_data],
            num_workers=args.num_workers)
        all_parnames = list(
            set([
                parname for q in retrieval_data
                for pair in q['paragraph_pairs'] for parname in pair
            ]))
        texts_tok = tokenize_texts([parname_to_text(x) for x in all_parnames],
                                   num_workers=args.num_workers,
                                   sentences=True)
        if args.tokens is not None:
            texts_tok = [truncate_paragraph(p, args.tokens) for p in texts_tok]
        parname2tok_text = {
            parname: text
            for parname, text in zip(all_parnames, texts_tok)
        }
        questions = []
        for idx, question in enumerate(retrieval_data):
            for rank, par_pair in enumerate(question['paragraph_pairs']):
                tok_pair = parname2tok_text[par_pair[0]] + parname2tok_text[
                    par_pair[1]]
                sentence_segments, _ = get_segments_from_sentences_fix_sup(
                    tok_pair, np.zeros(0))
                missing_sent_idx = [[
                    i for i, sent in enumerate(parname2tok_text[parname])
                    if len(sent) == 0
                ] for parname in par_pair]
                questions.append(
                    RankedQAPair(
                        question=question_tok_texts[idx],
                        paragraphs=[flatten_iterable(tok_pair)],
                        spans=np.zeros((0, 2), dtype=np.int32),
                        question_id=question['qid'],
                        answer='noanswer'
                        if args.test_mode else question['answers'][0],
                        rank=rank,
                        q_type='null' if args.test_mode else question['type'],
                        sentence_segments=[sentence_segments],
                        par_titles_num_sents=[
                            (par_name_to_title(parname),
                             sum(1 for sent in parname2tok_text[parname]
                                 if len(sent) > 0)) for parname in par_pair
                        ],
                        missing_sent_idxs=missing_sent_idx,
                        true_sp=[]
                        if args.test_mode else question['supporting_facts']))
    else:
        raise NotImplementedError()

    data = DummyDataset(questions, batcher)
    evaluators = [
        RecordHotpotQAPrediction(args.answer_bound,
                                 True,
                                 sp_prediction=not args.no_sp)
    ]

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {args.corpus: data}, loader,
                              checkpoint, not args.no_ema, 10)[args.corpus]

    print("Saving result")
    output_file = args.output

    df = pd.DataFrame(evaluation.per_sample)

    df.sort_values(["question_id", "rank"], inplace=True, ascending=True)
    group_by = ["question_id"]

    def get_ranked_scores(score_name):
        filtered_df = df[df.type == 'comparison'] if "Cp" in score_name else \
            df[df.type == 'bridge'] if "Br" in score_name else df
        target_prefix = 'joint' if 'joint' in score_name else 'sp' if 'sp' in score_name else 'text'
        target_score = f"{target_prefix}_{'em' if 'EM' in score_name else 'f1'}"
        return compute_ranked_scores_with_yes_no(
            filtered_df,
            span_q_col="span_question_scores",
            yes_no_q_col="yes_no_question_scores",
            yes_no_scores_col="yes_no_confidence_scores",
            span_scores_col="predicted_score",
            span_target_score=target_score,
            group_cols=group_by)

    if not args.test_mode:
        score_names = ["EM", "F1", "Br EM", "Br F1", "Cp EM", "Cp F1"]
        if not args.no_sp:
            score_names.extend([
                f"{prefix} {name}" for prefix in ['sp', 'joint']
                for name in score_names
            ])

        table = [["N Paragraphs"] + score_names]
        scores = [get_ranked_scores(score_name) for score_name in score_names]
        table += list([str(i + 1), *["%.4f" % x for x in score_vals]]
                      for i, score_vals in enumerate(zip(*scores)))
        print_table(table)

        df.to_csv(output_file, index=False)

    else:
        df_to_pred(df, output_file)
Ejemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser(description='Full ranking evaluation on Hotpot')
    parser.add_argument('model', help='model directory to evaluate')
    parser.add_argument('-n', '--sample_questions', type=int, default=None,
                        help="(for testing) run on a subset of questions")
    parser.add_argument('-b', '--batch_size', type=int, default=64,
                        help="Batch size, larger sizes can be faster but uses more memory")
    parser.add_argument('-s', '--step', default=None,
                        help="Weights to load, can be a checkpoint step or 'latest'")
    # parser.add_argument('-c', '--corpus', choices=["dev", "train"], default="dev")
    parser.add_argument('--no-ema', action="store_true", help="Don't use EMA weights even if they exist")
    parser.add_argument('--save-errors', default=None, type=str)
    parser.add_argument('--br-as-cp', action='store_true')
    parser.add_argument('--mult-probs', action='store_true')
    args = parser.parse_args()

    model_dir = ModelDir(args.model)

    corpus = HotpotQuestions()
    # if args.corpus == "dev":
    #     questions = corpus.get_dev()
    # else:
    #     questions = corpus.get_train()
    questions = corpus.get_dev()

    question_preprocessor = HotpotTextLengthPreprocessor(600)
    questions = [question_preprocessor.preprocess(x) for x in questions
                 if (question_preprocessor.preprocess(x) is not None)]

    if args.sample_questions:
        np.random.RandomState(0).shuffle(sorted(questions, key=lambda x: x.question_id))
        questions = questions[:args.sample_questions]

    batcher = ClusteredBatcher(args.batch_size, multiple_contexts_len, truncate_batches=True)
    data = HotpotFullIterativeDataset(questions, batcher, bridge_as_comparison=args.br_as_cp)

    evaluators = [IterativeRelevanceEvaluator(), RecordFullIterativeRankings(multiply_iteration_probs=args.mult_probs)]

    if args.step is not None:
        if args.step == "latest":
            checkpoint = model_dir.get_latest_checkpoint()
        else:
            checkpoint = model_dir.get_checkpoint(int(args.step))
    else:
        checkpoint = model_dir.get_best_weights()
        if checkpoint is not None:
            print("Using best weights")
        else:
            print("Using latest checkpoint")
            checkpoint = model_dir.get_latest_checkpoint()

    model = model_dir.get_model()

    evaluation = trainer.test(model, evaluators, {'dev_full': data},
                              corpus.get_resource_loader(), checkpoint, not args.no_ema, 10)['dev_full']

    if args.save_errors is not None:
        first_errors_dict = evaluation.per_sample['per_question_first_errors']
        second_errors_dict = evaluation.per_sample['per_question_second_errors']

        def format_first_error(wrong: Tuple[IterativeQuestionAndParagraphs, float, float],
                               correct: Tuple[IterativeQuestionAndParagraphs, float, float]):
            question = ' '.join(wrong[0].question)
            qid = wrong[0].question_id
            q_type = wrong[0].q_type
            wrong_text = ' '.join(wrong[0].paragraphs[0])
            wrong_score = wrong[1]
            correct_text = ' '.join(correct[0].paragraphs[0])
            correct_score = correct[1]
            return f"Question: {question}, ID: {qid}, type: {q_type}\n" \
                   f"Incorrect First Place: (score: {wrong_score})\n{wrong_text}\n-\n" \
                   f"Correct Passage: (score: {correct_score})\n{correct_text}\n***\n"

        def format_second_error(wrong: Tuple[IterativeQuestionAndParagraphs, float, float],
                                correct: Tuple[IterativeQuestionAndParagraphs, float, float]):
            question = ' '.join(wrong[0].question)
            qid = wrong[0].question_id
            q_type = wrong[0].q_type
            wrong_texts = [' '.join(par) for par in wrong[0].paragraphs]
            wrong_first_score = wrong[1]
            wrong_final_score = wrong[2]
            correct_texts = [' '.join(par) for par in correct[0].paragraphs]
            correct_first_score = correct[1]
            correct_final_score = correct[2]
            return f"Question: {question}, ID: {qid}, type: {q_type}\n" \
                   f"Incorrect First Place Pair: (score: {wrong_final_score})\n" \
                   f"Paragraph 1 (score: {wrong_first_score})\n" \
                   f"{wrong_texts[0]}\n" \
                   f"Paragraph 2:\n" \
                   f"{wrong_texts[1]}\n-\n" \
                   f"Correct Pair: (score: {correct_final_score})\n" \
                   f"Paragraph 1 (score: {correct_first_score})\n" \
                   f"{correct_texts[0]}\n" \
                   f"Paragraph 2:\n" \
                   f"{correct_texts[1]}\n***\n"

        with open(args.save_errors, 'wt') as f:
            f.write("First paragraph errors:\n*****************************\n")
            for false_par, true_par in first_errors_dict.values():
                f.write(format_first_error(false_par, true_par))
            f.write("Second paragraph errors:\n*****************************\n")
            for false_par, true_par in second_errors_dict.values():
                f.write(format_second_error(false_par, true_par))

    # Print the scalar results in a two column table
    scalars = evaluation.scalars
    cols = list(sorted(scalars.keys()))
    table = [cols]
    header = ["Metric", ""]
    table.append([("%s" % scalars[x] if x in scalars else "-") for x in cols])
    print_table([header] + transpose_lists(table))