def run_train(self, train_examples, valid_examples):
        best_loss = float('inf')
        best_f1 = 0.0
        patience_tally = 0
        valid_batches = self.get_batches(valid_examples)

        for epoch in range(MAX_EPOCHS):
            if patience_tally > PATIENCE:
                print('Terminating')
                break

            self.train()
            train_batches = self.get_batches(train_examples, shuffle=True)

            train_loss = 0
            for batch_data in train_batches:
                train_loss += self.run_gradient_step(batch_data)

            self.eval()
            validation_loss = 0
            validation_predicted_labels = []
            validation_gold_labels = []
            with torch.no_grad():
                for batch_data in valid_batches:
                    b_loss, b_logprobs = self.forward(batch_data)
                    validation_loss += float(b_loss.cpu())
                    validation_predicted_labels.extend(
                        b_logprobs.argmax(-1).tolist())
                    validation_gold_labels.extend(batch_data.labels.tolist())

            validation_loss = validation_loss / len(valid_batches)
            validation_precision, validation_recall, validation_f1 = compute_score(
                validation_predicted_labels,
                validation_gold_labels,
                verbose=False)

            if validation_f1 >= best_f1:
                best_f1 = validation_f1
                torch.save(self, self.model_path)
                saved = True
                patience_tally = 0
            else:
                saved = False
                patience_tally += 1

            print('Epoch: {}'.format(epoch))
            print('Training loss: {:.3f}'.format(train_loss /
                                                 len(train_batches)))
            print('Validation loss: {:.3f}'.format(validation_loss))
            print('Validation precision: {:.3f}'.format(validation_precision))
            print('Validation recall: {:.3f}'.format(validation_recall))
            print('Validation f1: {:.3f}'.format(validation_f1))
            if saved:
                print('Saved')
            print('-----------------------------------')
            sys.stdout.flush()
    def run_train(self, train_examples, valid_examples):
        """Runs training over the entire training set across several epochs. Following each epoch,
           F1 on the validation data is computed. If the validation F1 has improved, save the model.
           Early-stopping is employed to stop training if validation hasn't improved for a certain number
           of epochs."""
        valid_batches = self.manager.get_batches(valid_examples, self.get_device())
        best_loss = float('inf')
        best_f1 = 0.0
        patience_tally = 0

        for epoch in range(MAX_EPOCHS):
            if patience_tally > PATIENCE:
                print('Terminating: {}'.format(epoch))
                break
            
            self.train()
            train_batches = self.manager.get_batches(train_examples, self.get_device(), shuffle=True)
            
            train_loss = 0
            for batch_data in train_batches:
                train_loss += self.run_gradient_step(batch_data)
        
            self.eval()
            validation_loss = 0
            validation_predicted_labels = []
            validation_gold_labels = []
            with torch.no_grad():
                for batch_data in valid_batches:
                    b_loss, b_logprobs = self.forward(batch_data)
                    validation_loss += float(b_loss.cpu())
                    validation_predicted_labels.extend(b_logprobs.argmax(-1).tolist())
                    validation_gold_labels.extend(batch_data.labels.tolist())

            validation_loss = validation_loss/len(valid_batches)
            validation_precision, validation_recall, validation_f1 = compute_score(
                validation_predicted_labels, validation_gold_labels, verbose=False)
            
            if validation_f1 >= best_f1:
                best_f1 = validation_f1
                torch.save(self, self.model_path)
                saved = True
                patience_tally = 0
            else:
                saved = False
                patience_tally += 1
            
            print('Epoch: {}'.format(epoch))
            print('Training loss: {:.3f}'.format(train_loss/len(train_batches)))
            print('Validation loss: {:.3f}'.format(validation_loss))
            print('Validation precision: {:.3f}'.format(validation_precision))
            print('Validation recall: {:.3f}'.format(validation_recall))
            print('Validation f1: {:.3f}'.format(validation_f1))
            if saved:
                print('Saved')
            print('-----------------------------------')
            sys.stdout.flush()
    def compute_metrics(self, predicted_labels, test_examples, write_file):
        gold_labels = []
        correct = 0

        print('Writing to: {}'.format(write_file))
        with open(write_file, 'w+') as f:
            for e, ex in enumerate(test_examples):
                f.write('{} {}\n'.format(ex.id, predicted_labels[e]))
                gold_label = ex.label
                if gold_label == predicted_labels[e]:
                    correct += 1
                gold_labels.append(gold_label)

        accuracy = float(correct) / len(test_examples)
        precision, recall, f1 = compute_score(predicted_labels, gold_labels,
                                              False)

        print('Precision: {}'.format(precision))
        print('Recall: {}'.format(recall))
        print('F1: {}'.format(f1))
        print('Accuracy: {}'.format(accuracy))
    def compute_metrics(self, predicted_labels, test_examples, model_name):
        """Computes evaluation metrics."""
        gold_labels = []
        correct = 0
        for e, ex in enumerate(test_examples):
            if ex.label == predicted_labels[e]:
                correct += 1
            gold_labels.append(ex.label)
        
        accuracy = float(correct)/len(test_examples)
        precision, recall, f1 = compute_score(predicted_labels, gold_labels)
        
        print('Precision: {}'.format(precision))
        print('Recall: {}'.format(recall))
        print('F1: {}'.format(f1))
        print('Accuracy: {}\n'.format(accuracy))

        write_file = os.path.join(DETECTION_DIR, '{}_detection.txt'.format(model_name))
        with open(write_file, 'w+') as f:
            for e, ex in enumerate(test_examples):
                f.write('{} {}\n'.format(ex.id, predicted_labels[e]))
Esempio n. 5
0
    def run_evaluation(self, test_data, rerank, model_name, method_details=None, tokenization_features=None):
        """Predicts updated comments for all comments in the test set and computes evaluation metrics."""
        self.eval()

        test_batches = self.manager.get_batches(test_data, self.get_device(), method_details=method_details,
                                                tokenization_features=tokenization_features)

        test_predictions = []
        generation_predictions = []

        gold_strs = []
        pred_strs = []
        src_strs = []

        references = []
        pred_instances = []
        inconsistency_labels = []

        with torch.no_grad():
            for b_idx, batch_data in enumerate(test_batches):
                print('Evaluating {}/{}: {}'.format(b_idx, len(test_batches), datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
                sys.stdout.flush()
                pred, labels = self.beam_decode(batch_data)
                test_predictions.extend(pred)
                inconsistency_labels.extend(labels)

        print('Beam terminating: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))

        if not rerank:
            test_predictions = [pred[0][0] for pred in test_predictions]
        else:
            print('Rerank starting: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
            comment_generation_model = self.get_generation_model()
            reranked_predictions = []
            formatted_beam_predictions = []
            ordered_ids = []
            test_example_cache = dict()
            old_comment_subtokens = []
            model_scores = np.zeros(len(test_predictions)*len(test_predictions[0]), dtype=np.float)

            for i in range(len(test_predictions)):
                for b, (b_pred, b_score) in enumerate(test_predictions[i]):
                    try:
                        b_pred_str = diff_utils.format_minimal_diff_spans(old_comment_subtokens, b_pred)
                    except:
                        b_pred_str = ''
                    
                    formatted_beam_predictions.append(b_pred_str.split(' '))
                    ordered_ids.append(test_data[i].id)
                    test_example_cache[test_data[i].id] = test_data[i]
                    old_comment_subtokens.append([get_processed_comment_sequence(test_data[i].old_comment_subtokens)])
                    model_scores[b] = b_score
            
            print('Rerank computing likelihood: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
            likelihood_scores = self.get_likelihood_scores(comment_generation_model,
                formatted_beam_predictions, ordered_ids, test_example_cache)
            print('Rerank computing METEOR: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
            old_meteor_scores = np.asarray(compute_sentence_meteor(old_comment_subtokens, formatted_beam_predictions))
            print('Rerank aggregating socres: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
            rerank_scores = MODEL_LAMBDA*model_scores + LIKELIHOOD_LAMBDA*likelihood_scores + OLD_METEOR_LAMBDA*old_meteor_scores
            rerank_scores = rerank_scores.reshape((len(test_predictions), len(test_predictions[0])))
            selected_indices = np.argsort(-rerank_scores, axis=-1)[:,0]
            print('Rerank computing final scores: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))
            for i in range(len(test_predictions)):
                reranked_predictions.append(test_predictions[i][selected_indices[i]][0])
            test_predictions = reranked_predictions

        print('Final evaluation step starting: {}'.format(datetime.now().strftime("%m/%d/%Y %H:%M:%S")))

        predicted_labels = []
        gold_labels = []
        pseudo_predicted_labels = []
        correct = 0
        pseudo_correct = 0

        for i in range(len(test_predictions)):
            if inconsistency_labels[i] == 0:
                pred_str = get_processed_comment_str(test_data[i].old_comment_subtokens)
            else:
                pred_str = diff_utils.format_minimal_diff_spans(
                    get_processed_comment_sequence(test_data[i].old_comment_subtokens), test_predictions[i])

            gold_str = get_processed_comment_str(test_data[i].new_comment_subtokens)
            src_str = get_processed_comment_str(test_data[i].old_comment_subtokens)
            prediction = pred_str.split()
        
            gold_strs.append(gold_str)
            pred_strs.append(pred_str)
            src_strs.append(src_str)

            predicted_label = inconsistency_labels[i]
            pseudo_predicted_label = int(pred_str != src_str)
            gold_label = test_data[i].label

            if predicted_label == gold_label:
                correct += 1
            if pseudo_predicted_label == gold_label:
                pseudo_correct += 1
            
            predicted_labels.append(predicted_label)
            pseudo_predicted_labels.append(pseudo_predicted_label)
            gold_labels.append(gold_label)
            
            references.append([get_processed_comment_sequence(test_data[i].new_comment_subtokens)])
            pred_instances.append(prediction)

            print('Old comment: {}'.format(src_str))
            print('Gold comment: {}'.format(gold_str))
            print('Predicted comment: {}'.format(pred_str))
            print('Gold Comment changes: {}'.format(test_data[i].span_minimal_diff_comment_subtokens))
            print('Raw prediction: {}'.format(' '.join(test_predictions[i])))
            print('Inconsistency label: {}'.format(inconsistency_labels[i]))
            print('Pseudo inconsistency label: {}\n'.format(pseudo_predicted_label))
            try:
                print('Old code:\n{}\n'.format(get_old_code(test_data[i])))
            except:
                print('Failed to print old code\n')
            try:
                print('New code:\n{}\n'.format(get_new_code(test_data[i])))
            except:
                print('Failed to print new code\n')
            print('----------------------------')

        if rerank:
            prediction_file = '{}_beam_rerank.txt'.format(model_name)
            pseudo_detection_file = '{}_beam_rerank_pseudo_detection.txt'.format(model_name)
        else:
            prediction_file = '{}_beam.txt'.format(model_name)
            pseudo_detection_file = '{}_beam_pseudo_detection.txt'.format(model_name)
        
        detection_file = os.path.join(PREDICTION_DIR, '{}_detection.txt'.format(model_name))
        pseudo_detection_file = os.path.join(PREDICTION_DIR, pseudo_detection_file)

        prediction_file = os.path.join(PREDICTION_DIR, prediction_file)
        src_file = os.path.join(PREDICTION_DIR, '{}_src.txt'.format(model_name))
        ref_file = os.path.join(PREDICTION_DIR, '{}_ref.txt'.format(model_name))
        
        write_predictions(pred_strs, prediction_file)
        write_predictions(src_strs, src_file)
        write_predictions(gold_strs, ref_file)

        predicted_accuracy = compute_accuracy(gold_strs, pred_strs)
        predicted_bleu = compute_bleu(references, pred_instances)
        predicted_meteor = compute_meteor(references, pred_instances)
        predicted_sari = compute_sari(test_data, pred_instances)
        predicted_gleu = compute_gleu(test_data, src_file, ref_file, prediction_file)

        print('Update Accuracy: {}'.format(predicted_accuracy))
        print('Update BLEU: {}'.format(predicted_bleu))
        print('Update Meteor: {}'.format(predicted_meteor))
        print('Update SARI: {}'.format(predicted_sari))
        print('Update GLEU: {}\n'.format(predicted_gleu))

        if self.manager.task == 'dual':
            with open(detection_file, 'w+') as f:
                for d in range(len(predicted_labels)):
                    f.write('{} {}\n'.format(test_data[d].id, predicted_labels[d]))

            detection_precision, detection_recall, detection_f1 = compute_score(
                predicted_labels, gold_labels, False)
            print('Detection Precision: {}'.format(detection_precision))
            print('Detection Recall: {}'.format(detection_recall))
            print('Detection F1: {}'.format(detection_f1))
            print('Detection Accuracy: {}\n'.format(float(correct)/len(test_data)))

        if self.manager.task == 'update':
            # Evaluating implicit detection.
            with open(pseudo_detection_file, 'w+') as f:
                for d in range(len(pseudo_predicted_labels)):
                    f.write('{} {}\n'.format(test_data[d].id, pseudo_predicted_labels[d]))
            
            pseudo_detection_precision, pseudo_detection_recall, pseudo_detection_f1 = compute_score(
                pseudo_predicted_labels, gold_labels, False)
            print('Pseudo Detection Precision: {}'.format(pseudo_detection_precision))
            print('Pseudo Detection Recall: {}'.format(pseudo_detection_recall))
            print('Pseudo Detection F1: {}'.format(pseudo_detection_f1))
            print('Pseudo Detection Accuracy: {}\n'.format(float(pseudo_correct)/len(test_data)))
    for e, example in enumerate(test_examples):
        if example.id in clean_ids:
            clean_positions.append(e)
    clean_test_examples = [test_examples[pos] for pos in clean_positions]

    eval_tuples = [(test_examples, positions, 'full'),
                   (clean_test_examples, clean_positions, 'clean')]

    for (examples, indices, test_type) in eval_tuples:
        if args.detection_output_file:
            predicted_labels = load_predicted_detection_labels(
                args.detection_output_file, indices)
            gold_labels = [ex.label for ex in examples]

            precision, recall, f1 = compute_score(predicted_labels,
                                                  gold_labels,
                                                  verbose=False)

            num_correct = 0
            for p, p_label in enumerate(predicted_labels):
                if p_label == gold_labels[p]:
                    num_correct += 1

            print('Detection Precision: {}'.format(precision))
            print('Detection Recall: {}'.format(recall))
            print('Detection F1: {}'.format(f1))
            print('Detection Accuracy: {}\n'.format(
                float(num_correct) / len(predicted_labels)))

        if args.update_output_file:
            update_strs = load_predicted_generation_sequences(