def compute_sari(test_data, predictions):
    source_sentences = [get_processed_comment_str(ex.old_comment_subtokens) for ex in test_data]
    target_sentences = [[get_processed_comment_str(ex.new_comment_subtokens)] for ex in test_data]
    predicted_sentences = [' '.join(p) for p in predictions]

    inp = zip(source_sentences, target_sentences, predicted_sentences)
    scores = []

    for source, target, predicted in inp:
        scores.append(SARIsent(source, predicted, target))
    
    return 100*sum(scores)/float(len(scores))
def compute_unchanged(test_data, predictions):
    source_sentences = [get_processed_comment_str(ex.old_comment_subtokens) for ex in test_data]
    predicted_sentences = [' '.join(p) for p in predictions]
    unchanged = 0

    for source, predicted in zip(source_sentences, predicted_sentences):
        if source == predicted:
            unchanged += 1
    
    return 100*(unchanged)/len(test_data)
Beispiel #3
0
    def run_evaluation(self, test_data, rerank, model_name, method_details=None, tokenization_features=None):
        """Predicts updated comments for all comments in the test set and computes evaluation metrics."""
        self.eval()

        test_batches = self.manager.get_batches(test_data, self.get_device(), method_details=method_details,
                                                tokenization_features=tokenization_features)

        test_predictions = []
        generation_predictions = []

        gold_strs = []
        pred_strs = []
        src_strs = []

        references = []
        pred_instances = []
        inconsistency_labels = []

        with torch.no_grad():
            for b_idx, batch_data in enumerate(test_batches):
                print('Evaluating {}/{}: {}'.format(b_idx, len(test_batches), datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
                sys.stdout.flush()
                pred, labels = self.beam_decode(batch_data)
                test_predictions.extend(pred)
                inconsistency_labels.extend(labels)

        print('Beam terminating: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))

        if not rerank:
            test_predictions = [pred[0][0] for pred in test_predictions]
        else:
            print('Rerank starting: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
            comment_generation_model = self.get_generation_model()
            reranked_predictions = []
            formatted_beam_predictions = []
            ordered_ids = []
            test_example_cache = dict()
            old_comment_subtokens = []
            model_scores = np.zeros(len(test_predictions)*len(test_predictions[0]), dtype=np.float)

            for i in range(len(test_predictions)):
                for b, (b_pred, b_score) in enumerate(test_predictions[i]):
                    try:
                        b_pred_str = diff_utils.format_minimal_diff_spans(old_comment_subtokens, b_pred)
                    except:
                        b_pred_str = ''
                    
                    formatted_beam_predictions.append(b_pred_str.split(' '))
                    ordered_ids.append(test_data[i].id)
                    test_example_cache[test_data[i].id] = test_data[i]
                    old_comment_subtokens.append([get_processed_comment_sequence(test_data[i].old_comment_subtokens)])
                    model_scores[b] = b_score
            
            print('Rerank computing likelihood: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
            likelihood_scores = self.get_likelihood_scores(comment_generation_model,
                formatted_beam_predictions, ordered_ids, test_example_cache)
            print('Rerank computing METEOR: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
            old_meteor_scores = np.asarray(compute_sentence_meteor(old_comment_subtokens, formatted_beam_predictions))
            print('Rerank aggregating socres: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
            rerank_scores = MODEL_LAMBDA*model_scores + LIKELIHOOD_LAMBDA*likelihood_scores + OLD_METEOR_LAMBDA*old_meteor_scores
            rerank_scores = rerank_scores.reshape((len(test_predictions), len(test_predictions[0])))
            selected_indices = np.argsort(-rerank_scores, axis=-1)[:,0]
            print('Rerank computing final scores: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
            for i in range(len(test_predictions)):
                reranked_predictions.append(test_predictions[i][selected_indices[i]][0])
            test_predictions = reranked_predictions

        print('Final evaluation step starting: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))

        predicted_labels = []
        gold_labels = []
        pseudo_predicted_labels = []
        correct = 0
        pseudo_correct = 0

        for i in range(len(test_predictions)):
            if inconsistency_labels[i] == 0:
                pred_str = get_processed_comment_str(test_data[i].old_comment_subtokens)
            else:
                pred_str = diff_utils.format_minimal_diff_spans(
                    get_processed_comment_sequence(test_data[i].old_comment_subtokens), test_predictions[i])

            gold_str = get_processed_comment_str(test_data[i].new_comment_subtokens)
            src_str = get_processed_comment_str(test_data[i].old_comment_subtokens)
            prediction = pred_str.split()
        
            gold_strs.append(gold_str)
            pred_strs.append(pred_str)
            src_strs.append(src_str)

            predicted_label = inconsistency_labels[i]
            pseudo_predicted_label = int(pred_str != src_str)
            gold_label = test_data[i].label

            if predicted_label == gold_label:
                correct += 1
            if pseudo_predicted_label == gold_label:
                pseudo_correct += 1
            
            predicted_labels.append(predicted_label)
            pseudo_predicted_labels.append(pseudo_predicted_label)
            gold_labels.append(gold_label)
            
            references.append([get_processed_comment_sequence(test_data[i].new_comment_subtokens)])
            pred_instances.append(prediction)

            print('Old comment: {}'.format(src_str))
            print('Gold comment: {}'.format(gold_str))
            print('Predicted comment: {}'.format(pred_str))
            print('Gold Comment changes: {}'.format(test_data[i].span_minimal_diff_comment_subtokens))
            print('Raw prediction: {}'.format(' '.join(test_predictions[i])))
            print('Inconsistency label: {}'.format(inconsistency_labels[i]))
            print('Pseudo inconsistency label: {}\n'.format(pseudo_predicted_label))
            try:
                print('Old code:\n{}\n'.format(get_old_code(test_data[i])))
            except:
                print('Failed to print old code\n')
            try:
                print('New code:\n{}\n'.format(get_new_code(test_data[i])))
            except:
                print('Failed to print new code\n')
            print('----------------------------')

        if rerank:
            prediction_file = '{}_beam_rerank.txt'.format(model_name)
            pseudo_detection_file = '{}_beam_rerank_pseudo_detection.txt'.format(model_name)
        else:
            prediction_file = '{}_beam.txt'.format(model_name)
            pseudo_detection_file = '{}_beam_pseudo_detection.txt'.format(model_name)
        
        detection_file = os.path.join(PREDICTION_DIR, '{}_detection.txt'.format(model_name))
        pseudo_detection_file = os.path.join(PREDICTION_DIR, pseudo_detection_file)

        prediction_file = os.path.join(PREDICTION_DIR, prediction_file)
        src_file = os.path.join(PREDICTION_DIR, '{}_src.txt'.format(model_name))
        ref_file = os.path.join(PREDICTION_DIR, '{}_ref.txt'.format(model_name))
        
        write_predictions(pred_strs, prediction_file)
        write_predictions(src_strs, src_file)
        write_predictions(gold_strs, ref_file)

        predicted_accuracy = compute_accuracy(gold_strs, pred_strs)
        predicted_bleu = compute_bleu(references, pred_instances)
        predicted_meteor = compute_meteor(references, pred_instances)
        predicted_sari = compute_sari(test_data, pred_instances)
        predicted_gleu = compute_gleu(test_data, src_file, ref_file, prediction_file)

        print('Update Accuracy: {}'.format(predicted_accuracy))
        print('Update BLEU: {}'.format(predicted_bleu))
        print('Update Meteor: {}'.format(predicted_meteor))
        print('Update SARI: {}'.format(predicted_sari))
        print('Update GLEU: {}\n'.format(predicted_gleu))

        if self.manager.task == 'dual':
            with open(detection_file, 'w+') as f:
                for d in range(len(predicted_labels)):
                    f.write('{} {}\n'.format(test_data[d].id, predicted_labels[d]))

            detection_precision, detection_recall, detection_f1 = compute_score(
                predicted_labels, gold_labels, False)
            print('Detection Precision: {}'.format(detection_precision))
            print('Detection Recall: {}'.format(detection_recall))
            print('Detection F1: {}'.format(detection_f1))
            print('Detection Accuracy: {}\n'.format(float(correct)/len(test_data)))

        if self.manager.task == 'update':
            # Evaluating implicit detection.
            with open(pseudo_detection_file, 'w+') as f:
                for d in range(len(pseudo_predicted_labels)):
                    f.write('{} {}\n'.format(test_data[d].id, pseudo_predicted_labels[d]))
            
            pseudo_detection_precision, pseudo_detection_recall, pseudo_detection_f1 = compute_score(
                pseudo_predicted_labels, gold_labels, False)
            print('Pseudo Detection Precision: {}'.format(pseudo_detection_precision))
            print('Pseudo Detection Recall: {}'.format(pseudo_detection_recall))
            print('Pseudo Detection F1: {}'.format(pseudo_detection_f1))
            print('Pseudo Detection Accuracy: {}\n'.format(float(pseudo_correct)/len(test_data)))
            print('Detection F1: {}'.format(f1))
            print('Detection Accuracy: {}\n'.format(
                float(num_correct) / len(predicted_labels)))

        if args.update_output_file:
            update_strs = load_predicted_generation_sequences(
                args.update_output_file, indices)

            references = []
            pred_instances = []
            src_strs = []
            gold_strs = []
            pred_strs = []

            for i in range(len(examples)):
                src_str = get_processed_comment_str(
                    examples[i].old_comment_subtokens)
                src_strs.append(src_str)

                gold_str = get_processed_comment_str(
                    examples[i].new_comment_subtokens)
                gold_strs.append(gold_str)
                references.append([gold_str.split()])

                if args.detection_output_file and predicted_labels[i] == 0:
                    pred_instances.append(src_str.split())
                    pred_strs.append(src_str)
                else:
                    pred_instances.append(update_strs[i].split())
                    pred_strs.append(update_strs[i])

            prediction_file = os.path.join(os.getcwd(), 'pred.txt')