Esempio n. 1
0
def main(args):
    """
    Compute predictions and scores for inputs using specified BERT model and LaserTagger model. 
    
    Read input sentences from input_file_path, convert the sentences to predicted summaries using pretrained
    models whose names are specified in the list_of_models, and compute exact score and SARI score if whether_score is
    true. The predictions are stored in an output file pred.tsv. If scores are computed, the scores are stored in
    an output file score.tsv.

    Args:
        args: command line arguments.
    """

    whether_score = args.score
    input_file_path = args.path_to_input_file
    list_of_models = args.models
    whether_grammar = args.grammar

    __download_models(list_of_models)
    __validate_scripts(args)

    __clean_up()
    subprocess.call(['mkdir', TEMP_FOLDER_NAME], cwd=os.path.expanduser('~'))
    spaced_sentences, spaced_summaries = __preprocess_input(
        input_file_path, whether_score)

    with open(os.path.expanduser(TEMP_FOLDER_PATH + "/cleaned_data.tsv"),
              'wt') as out_file:
        tsv_writer = csv.writer(out_file, delimiter='\t')
        for i, sentence in enumerate(spaced_sentences):
            tsv_writer.writerow([sentence, spaced_summaries[i]])
    print("-------Number of input is", len(spaced_sentences), "-------")

    # calculate and print predictions to output file
    for model in list_of_models:
        print("------Running on model", model, "-------")
        prediction_command = [
            'python',
            os.path.expanduser(args.abs_path_to_lasertagger) +
            '/predict_main.py', "--input_format=wikisplit",
            "--label_map_file=./" + model + "/label_map.txt",
            "--input_file=" + "./" + TEMP_FOLDER_NAME + "/cleaned_data.tsv",
            "--saved_model=./" + model + "/export_model", "--vocab_file=" +
            os.path.expanduser(args.abs_path_to_bert) + "/vocab.txt",
            "--output_file=" + "./" + TEMP_FOLDER_NAME + "/output_" + model +
            ".tsv", "--embedding_type=" + args.embedding_type,
            "--batch_size=" + str(args.batch_size)
        ]
        if args.masking:
            prediction_command.append("--enable_masking=true")
        subprocess.call(prediction_command, cwd=os.path.expanduser("~"))
        print("------Completed running on model", model, "-------")

    output_row_list = []

    model = list_of_models[0]
    output_row = ["original"]
    tsv_file = open(
        os.path.expanduser(TEMP_FOLDER_PATH + "/output_" + model + ".tsv"))
    read_tsv = csv.reader(tsv_file, delimiter="\t")
    for row in read_tsv:
        output_row.append(row[0])
    output_row_list.append(output_row)

    for model in list_of_models:
        output_row = [model]
        tsv_file = open(
            os.path.expanduser(TEMP_FOLDER_PATH + "/output_" + model + ".tsv"))
        read_tsv = csv.reader(tsv_file, delimiter="\t")
        for row in read_tsv:
            output_row.append(post_processing(row[1]))
        output_row_list.append(output_row)

    if whether_grammar:
        tool = language_tool_python.LanguageTool('en-US')
        for model in list_of_models:
            output_row = [model + "_corrected"]
            tsv_file = open(
                os.path.expanduser(TEMP_FOLDER_PATH + "/output_" + model +
                                   ".tsv"))
            read_tsv = csv.reader(tsv_file, delimiter="\t")
            for row in read_tsv:
                output_row.append(tool.correct(post_processing(row[1])))
            output_row_list.append(output_row)

    if whether_score:
        model = list_of_models[0]
        output_row = ["target"]
        tsv_file = open(
            os.path.expanduser(TEMP_FOLDER_PATH + "/output_" + model + ".tsv"))
        read_tsv = csv.reader(tsv_file, delimiter="\t")
        for row in read_tsv:
            output_row.append(row[2])
        output_row_list.append(output_row)

    with open(os.path.expanduser("~/pred.tsv"), 'wt') as out_file:
        tsv_writer = csv.writer(out_file, delimiter='\t')
        for i in range(len(output_row)):
            this_row = []
            for row in output_row_list:
                this_row.append(row[i])
            tsv_writer.writerow(this_row)
    print("------Predictions written out to pred.tsv------")

    # calculate and print scores to output file if whether_score is True
    if whether_score:
        for model in list_of_models:
            print("------Calculating score for model", model, "-------")
            f = open(
                os.path.expanduser(TEMP_FOLDER_PATH + "/score_" + model +
                                   ".txt"), "w")
            subprocess.call([
                'python',
                os.path.expanduser(args.abs_path_to_lasertagger) +
                '/score_main.py', "--prediction_file=" + "./" +
                TEMP_FOLDER_NAME + "/output_" + model + ".tsv"
            ],
                            cwd=os.path.expanduser('~'),
                            stdout=f)

        output_row_list = []
        output_row = [
            "score", "Exact score", "SARI score", "KEEP score",
            "ADDITION score", "DELETION score"
        ]
        output_row_list.append(output_row)

        for model in list_of_models:
            output_row = [model]
            f = open(
                os.path.expanduser(TEMP_FOLDER_PATH + "/score_" + model +
                                   ".txt"))
            lines = f.readlines()
            for line in lines:
                output_row.append(line.split()[2])
            output_row_list.append(output_row)

        with open(os.path.expanduser("~/score.tsv"), 'wt') as out_file:
            tsv_writer = csv.writer(out_file, delimiter='\t')
            for i in range(len(output_row)):
                this_row = []
                for row in output_row_list:
                    this_row.append(row[i])
                tsv_writer.writerow(this_row)
            print("------Scores written out to score.tsv------")

    __clean_up()
Esempio n. 2
0
 def test_with_redundant_marks(self):
     input_text = "Test . ? Test"
     output = post_processing(input_text)
     self.assertEqual(output, 'Test . Test')
Esempio n. 3
0
 def test_with_leading_paired_marks_and_paired_marks_in_middle(self):
     input_text = "[ ] . . Test [ ] Test"
     output = post_processing(input_text)
     self.assertEqual(output, 'Test [] Test')
Esempio n. 4
0
 def test_with_paired_marks(self):
     input_text = "[ Test ]"
     output = post_processing(input_text)
     self.assertEqual(output, '[ Test ]')
Esempio n. 5
0
 def test_with_leading_unpaired_punctuation_marks_and_other_marks(self):
     input_text = "' \" . . Test"
     output = post_processing(input_text)
     self.assertEqual(output, 'Test')