def main(argv):
    del argv  # Unused.

    if FLAGS.debug:
        random.seed(0)

    reformulator_instance = reformulator.Reformulator(
        hparams_path=FLAGS.hparams_path,
        source_prefix='<en> <2en> ',  ## FLAGS.source_prefix,
        out_dir=FLAGS.out_dir,
        environment_server_address=FLAGS.environment_server_address)
    environment_fn = environment_client.make_environment_reward_fn(
        FLAGS.environment_server_address,
        mode=FLAGS.mode,
        env_call_parallelism=FLAGS.env_sample_parallelism)

    eval_environment_fn = environment_client.make_environment_reward_fn(
        FLAGS.environment_server_address,
        mode='searchqa',
        env_call_parallelism=FLAGS.env_eval_parallelism)

    # Read data.
    questions, annotations, docid_2_answer = read_data(
        questions_file=FLAGS.train_questions,
        annotations_file=FLAGS.train_annotations,
        answers_file=FLAGS.train_data,
        preprocessing_mode=FLAGS.mode)
    dev_questions, dev_annotations, dev_docid_2_answer = read_data(
        questions_file=FLAGS.dev_questions,
        annotations_file=FLAGS.dev_annotations,
        answers_file=FLAGS.dev_data,
        preprocessing_mode=FLAGS.mode,
        max_lines=FLAGS.max_dev_examples)

    ### Inference:
    all_reformulated_question = []
    custom_questions = [
        'What is the reimbursement policies and how to claim it?'
    ]  ## what are the ways to save tax?,How many casual leaves one is entitled in a year?, why katappa killed bahubali?
    all_reformulated_question.append(custom_questions)
    # print('custom_questions:', type(custom_questions))
    responses = reformulator_instance.reformulate(
        questions=custom_questions,
        inference_mode=reformulator_pb2.ReformulatorRequest.BEAM_SEARCH)

    ### GREEDY , SAMPLING , BEAM_SEARCH , TRIE_GREEDY , TRIE_SAMPLE , TRIE_BEAM_SEARCH

    # print('responses:', responses)
    # Discard answers.
    custom_reformulations = [[rf.reformulation for rf in rsp]
                             for rsp in responses]
    all_reformulated_question.append(custom_reformulations)
    # for i in range(len(custom_reformulations)):
    #   print('The reformulations of "', custom_questions[i], '" are:', custom_reformulations[i])

    # ----------------------------------------------------------------------------------------------
    print(
        '----------------------------reformulations of reformulation--------------------------------'
    )
    # print('custom_reformulations:', type(custom_reformulations), len(custom_reformulations))
    for j in range(len(custom_reformulations)):
        # print('-----------------------reformulation of ',j,' reformulations---------------------------')
        responses_of1st_infer = reformulator_instance.reformulate(
            questions=custom_reformulations[j],
            inference_mode=reformulator_pb2.ReformulatorRequest.BEAM_SEARCH)

        custom_reformulations_of1st_infer = [[rf.reformulation for rf in rsp]
                                             for rsp in responses_of1st_infer]
        # for k in range(len(custom_reformulations_of1st_infer)):
        # print('----------------------------------------------------')
        # print('The reformulations of "', custom_reformulations[j][k], '" are:', custom_reformulations_of1st_infer[k])
        all_reformulated_question.append(custom_reformulations_of1st_infer)

    all_reformulated_question = flatten(all_reformulated_question)
    print('all_reformulated_question:', len(all_reformulated_question),
          len(set(all_reformulated_question)), set(all_reformulated_question))
    all_reformulated_question = set(all_reformulated_question)

    outF = open("all_reformulated_question.txt", "w")
    for q in all_reformulated_question:
        # write line to output file
        outF.write(q)
        outF.write("\n")
    outF.close()
Ejemplo n.º 2
0
def main(argv):
    del argv  # Unused.

    if FLAGS.debug:
        random.seed(0)

    reformulator_instance = reformulator.Reformulator(
        hparams_path=FLAGS.hparams_path,
        source_prefix=FLAGS.source_prefix,
        out_dir=FLAGS.out_dir,
        environment_server_address=FLAGS.environment_server_address)
    environment_fn = environment_client.make_environment_reward_fn(
        FLAGS.environment_server_address,
        mode=FLAGS.mode,
        env_call_parallelism=FLAGS.env_sample_parallelism)

    eval_environment_fn = environment_client.make_environment_reward_fn(
        FLAGS.environment_server_address,
        mode='searchqa',
        env_call_parallelism=FLAGS.env_eval_parallelism)

    # Read data.
    questions, annotations, docid_2_answer = read_data(
        questions_file=FLAGS.train_questions,
        annotations_file=FLAGS.train_annotations,
        answers_file=FLAGS.train_data,
        preprocessing_mode=FLAGS.mode)
    dev_questions, dev_annotations, dev_docid_2_answer = read_data(
        questions_file=FLAGS.dev_questions,
        annotations_file=FLAGS.dev_annotations,
        answers_file=FLAGS.dev_data,
        preprocessing_mode=FLAGS.mode,
        max_lines=FLAGS.max_dev_examples)

    # Summary writer that writes events to a folder. TensorBoard will later read
    # from it.
    summary_writer = tf.summary.FileWriter(
        os.path.join(
            FLAGS.tensorboard_dir,
            'reformulator_and_selector_training_log_' + str(time.time())))

    if FLAGS.enable_selector_training:
        selector_model = selector.Selector()
        last_save_step = 0

    global_step = 0
    for epoch in range(FLAGS.epochs):
        for batch_id, (questions_batch, annotations_batch) in enumerate(
                batch(questions, annotations, FLAGS.batch_size_train)):
            # Run eval every num_steps_per_eval batches.
            if global_step % FLAGS.num_steps_per_eval is 0:
                if FLAGS.debug:
                    print('Running eval...')
                eval_start_time = time.time()

                if not FLAGS.enable_selector_training:
                    eval_f1_avg = _run_reformulator_eval(
                        dev_questions, dev_annotations, reformulator_instance,
                        environment_fn, FLAGS.batch_size_eval)
                else:
                    eval_f1_avg = _run_eval_with_selector(
                        questions=dev_questions,
                        annotations=dev_annotations,
                        docid_2_answer=dev_docid_2_answer,
                        reformulator_instance=reformulator_instance,
                        selector_model=selector_model,
                        batch_size=FLAGS.batch_size_eval,
                        environment_fn=eval_environment_fn)

                # Correct the average F1 score for deleted datapoints in the SearchQA
                # dataset.
                if FLAGS.mode == 'searchqa':
                    eval_f1_avg = _correct_searchqa_score(eval_f1_avg,
                                                          dataset='dev')

                eval_time = time.time() - eval_start_time

                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='eval_f1_avg',
                                       value=eval_f1_avg)
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='eval_time',
                                       value=eval_time)

                if FLAGS.debug:
                    print('Avg F1 on dev: {}.'.format(eval_f1_avg))
                    print('Time to finish eval: {}'.format(eval_time))

            start_time = time.time()
            if FLAGS.debug:
                print('Epoch {}, Batch {}.'.format(epoch, batch_id))
                print('Question: [{}]; Id: {}'.format(questions_batch[0],
                                                      annotations_batch[0]))

            # Retrieve rewrites for selector training using beam search.
            if FLAGS.enable_selector_training:
                responses_beam = reformulator_instance.reformulate(
                    questions=questions_batch,
                    inference_mode=reformulator_pb2.ReformulatorRequest.
                    BEAM_SEARCH)

                # Discard answers.
                reformulations_beam = [[rf.reformulation for rf in rsp]
                                       for rsp in responses_beam]

            if FLAGS.enable_reformulator_training:
                # Train reformulator model.
                if FLAGS.debug:
                    print('Training reformulator...')

                reformulator_loss, f1s, reformulations = reformulator_instance.train(
                    sources=questions_batch, annotations=annotations_batch)

                f1_avg = f1s.mean()

                if [] in reformulations:
                    if FLAGS.debug:
                        print('Found empty rewrites! Skipping this batch.')
                    continue

                if FLAGS.debug:
                    print('Rewrite: {}'.format(safe_string(reformulations[0])))
                    print('Avg F1: {}'.format(f1_avg))
                    print('Loss  : {}'.format(reformulator_loss))

                # Write the f1_avg and loss to Tensorboard.
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='f1_avg',
                                       value=f1_avg)
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='reformulator_loss',
                                       value=reformulator_loss)

            # Train selector model.
            if FLAGS.enable_selector_training:
                (selector_questions, selector_answers,
                 selector_scores) = query_environment(
                     original_questions=questions_batch,
                     rewrites=reformulations_beam,
                     annotations=annotations_batch,
                     environment_fn=eval_environment_fn,
                     docid_2_answer=docid_2_answer,
                     token_level_f1_scores=False)

                if FLAGS.debug:
                    print('Training selector...')

                train_selector_loss, train_selector_accuracy = selector_model.train(
                    selector_questions, selector_answers, selector_scores)

                # Regularly save a checkpoint.
                if global_step - last_save_step >= FLAGS.steps_per_save_selector:
                    selector_model.save(str(global_step))
                    last_save_step = global_step
                    print('Selector saved at step: {}'.format(global_step))

                if FLAGS.debug:
                    print('Train Accuracy: {}'.format(train_selector_accuracy))
                    print('Train Loss    : {}'.format(train_selector_loss))

                # Write the accuracy and loss to Tensorboard.
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='train_selector_accuracy',
                                       value=train_selector_accuracy)
                misc_utils.add_summary(summary_writer,
                                       global_step,
                                       tag='train_selector_loss',
                                       value=train_selector_loss)

            iteration_time = time.time() - start_time
            if FLAGS.debug:
                print('Iteration time: {}'.format(iteration_time))

            misc_utils.add_summary(summary_writer,
                                   global_step,
                                   tag='iteration_time',
                                   value=iteration_time)

            # Increment the global counter
            global_step += 1
                            domain_data['de'][
                                phrase] = transformedRow[:
                                                         index] + transformedRow[
                                                             index + 1:]
                    else:
                        domain_data['en'][
                            phrase] = transformedRow[:index] + transformedRow[
                                index + 1:]
        return domain_data


## ----------------------------------------------------------------------------------------------------------

reformulator_instance = reformulator.Reformulator(
    hparams_path='px/nmt/example_configs/reformulator.json',
    source_prefix='<en> <2en> ',
    out_dir='./tmp/active-qa/reformulator/',
    environment_server_address='localhost:10000')


### To convert list of list of list... to single list.
def flatten(lst):
    return sum(([x] if not isinstance(x, list) else flatten(x) for x in lst),
               [])


def question_paraphrase(query):
    # print("query0:", type(query), query)
    questions = query
    all_reformulated_question = []
    responses = reformulator_instance.reformulate(