Exemplo n.º 1
0
def _run_eval_with_selector(questions, annotations, docid_2_answer,
                            reformulator_instance, selector_model, batch_size,
                            environment_fn):
    """Runs a joined eval with the reformulator and selector model."""
    f1s = []
    for batch_id, (questions_batch, annotations_batch) in enumerate(
            batch(questions, annotations, batch_size)):
        responses = reformulator_instance.reformulate(
            questions=questions_batch,
            inference_mode=reformulator_pb2.ReformulatorRequest.BEAM_SEARCH)

        # Discard answers.
        reformulations = [[rf.reformulation for rf in rsp]
                          for rsp in responses]

        question_and_rewrites, answers, scores = query_environment(
            original_questions=questions_batch,
            rewrites=reformulations,
            annotations=annotations_batch,
            environment_fn=environment_fn,
            docid_2_answer=docid_2_answer,
            token_level_f1_scores=True)
        f1s.append(selector_model.eval(question_and_rewrites, answers, scores))
        if FLAGS.debug and batch_id == 0:
            print('Running Eval...')
            print('Questions: {}, Annotation: {}'.format(
                safe_string(questions_batch[0]),
                safe_string(annotations_batch[0])))
            print('Rewrites: {}'.format(safe_string(reformulations[0])))
            print('Answers and Scores: {}'.format(
                zip(safe_string(answers[0]), scores[0])))

    return np.mean(f1s)
Exemplo n.º 2
0
def main(argv):
    del argv  # Unused.

    if FLAGS.debug:
        random.seed(0)

    reformulator_instance = reformulator.Reformulator(
        hparams_path=FLAGS.hparams_path,
        source_prefix=FLAGS.source_prefix,
        out_dir=FLAGS.out_dir,
        environment_server_address=FLAGS.environment_server_address)
    environment_fn = environment_client.make_environment_reward_fn(
        FLAGS.environment_server_address,
        mode=FLAGS.mode,
        env_call_parallelism=FLAGS.env_sample_parallelism)

    eval_environment_fn = environment_client.make_environment_reward_fn(
        FLAGS.environment_server_address,
        mode='searchqa',
        env_call_parallelism=FLAGS.env_eval_parallelism)

    # Read data.
    questions, annotations, docid_2_answer = read_data(
        questions_file=FLAGS.train_questions,
        annotations_file=FLAGS.train_annotations,
        answers_file=FLAGS.train_data,
        preprocessing_mode=FLAGS.mode)
    dev_questions, dev_annotations, dev_docid_2_answer = read_data(
        questions_file=FLAGS.dev_questions,
        annotations_file=FLAGS.dev_annotations,
        answers_file=FLAGS.dev_data,
        preprocessing_mode=FLAGS.mode,
        max_lines=FLAGS.max_dev_examples)

    # Summary writer that writes events to a folder. TensorBoard will later read
    # from it.
    summary_writer = tf.summary.FileWriter(
        os.path.join(
            FLAGS.tensorboard_dir,
            'reformulator_and_selector_training_log_' + str(time.time())))

    if FLAGS.enable_selector_training:
        selector_model = selector.Selector()
        last_save_step = 0

    global_step = 0
    for epoch in range(FLAGS.epochs):
        for batch_id, (questions_batch, annotations_batch) in enumerate(
                batch(questions, annotations, FLAGS.batch_size_train)):
            # Run eval every num_steps_per_eval batches.
            if global_step % FLAGS.num_steps_per_eval is 0:
                if FLAGS.debug:
                    print('Running eval...')
                eval_start_time = time.time()

                if not FLAGS.enable_selector_training:
                    eval_f1_avg = _run_reformulator_eval(
                        dev_questions, dev_annotations, reformulator_instance,
                        environment_fn, FLAGS.batch_size_eval)
                else:
                    eval_f1_avg = _run_eval_with_selector(
                        questions=dev_questions,
                        annotations=dev_annotations,
                        docid_2_answer=dev_docid_2_answer,
                        reformulator_instance=reformulator_instance,
                        selector_model=selector_model,
                        batch_size=FLAGS.batch_size_eval,
                        environment_fn=eval_environment_fn)

                # Correct the average F1 score for deleted datapoints in the SearchQA
                # dataset.
                if FLAGS.mode == 'searchqa':
                    eval_f1_avg = _correct_searchqa_score(eval_f1_avg,
                                                          dataset='dev')

                eval_time = time.time() - eval_start_time

                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='eval_f1_avg',
                                       value=eval_f1_avg)
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='eval_time',
                                       value=eval_time)

                if FLAGS.debug:
                    print('Avg F1 on dev: {}.'.format(eval_f1_avg))
                    print('Time to finish eval: {}'.format(eval_time))

            start_time = time.time()
            if FLAGS.debug:
                print('Epoch {}, Batch {}.'.format(epoch, batch_id))
                print('Question: [{}]; Id: {}'.format(questions_batch[0],
                                                      annotations_batch[0]))

            # Retrieve rewrites for selector training using beam search.
            if FLAGS.enable_selector_training:
                responses_beam = reformulator_instance.reformulate(
                    questions=questions_batch,
                    inference_mode=reformulator_pb2.ReformulatorRequest.
                    BEAM_SEARCH)

                # Discard answers.
                reformulations_beam = [[rf.reformulation for rf in rsp]
                                       for rsp in responses_beam]

            if FLAGS.enable_reformulator_training:
                # Train reformulator model.
                if FLAGS.debug:
                    print('Training reformulator...')

                reformulator_loss, f1s, reformulations = reformulator_instance.train(
                    sources=questions_batch, annotations=annotations_batch)

                f1_avg = f1s.mean()

                if [] in reformulations:
                    if FLAGS.debug:
                        print('Found empty rewrites! Skipping this batch.')
                    continue

                if FLAGS.debug:
                    print('Rewrite: {}'.format(safe_string(reformulations[0])))
                    print('Avg F1: {}'.format(f1_avg))
                    print('Loss  : {}'.format(reformulator_loss))

                # Write the f1_avg and loss to Tensorboard.
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='f1_avg',
                                       value=f1_avg)
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='reformulator_loss',
                                       value=reformulator_loss)

            # Train selector model.
            if FLAGS.enable_selector_training:
                (selector_questions, selector_answers,
                 selector_scores) = query_environment(
                     original_questions=questions_batch,
                     rewrites=reformulations_beam,
                     annotations=annotations_batch,
                     environment_fn=eval_environment_fn,
                     docid_2_answer=docid_2_answer,
                     token_level_f1_scores=False)

                if FLAGS.debug:
                    print('Training selector...')

                train_selector_loss, train_selector_accuracy = selector_model.train(
                    selector_questions, selector_answers, selector_scores)

                # Regularly save a checkpoint.
                if global_step - last_save_step >= FLAGS.steps_per_save_selector:
                    selector_model.save(str(global_step))
                    last_save_step = global_step
                    print('Selector saved at step: {}'.format(global_step))

                if FLAGS.debug:
                    print('Train Accuracy: {}'.format(train_selector_accuracy))
                    print('Train Loss    : {}'.format(train_selector_loss))

                # Write the accuracy and loss to Tensorboard.
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='train_selector_accuracy',
                                       value=train_selector_accuracy)
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='train_selector_loss',
                                       value=train_selector_loss)

            iteration_time = time.time() - start_time
            if FLAGS.debug:
                print('Iteration time: {}'.format(iteration_time))

            misc_utils.add_summary(summary_writer,
                                   global_step,
                                   tag='iteration_time',
                                   value=iteration_time)

            # Increment the global counter
            global_step += 1