Exemple #1
0
class AggregatedTokenMetrics(Metrics):
    def __init__(self):
        super().__init__()
        self.common = BaseAccuracyMetrics()
        self.target_non_unk = IndexedAccuracyMetrics('Target not unk')
        self.prediction_non_unk = IndexedAccuracyMetrics('Prediction not unk')

    def drop_state(self):
        self.common.drop_state()
        self.target_non_unk.drop_state()
        self.prediction_non_unk.drop_state()

    def report(self, prediction_target):
        prediction, target = prediction_target
        prediction = prediction.view(-1)
        target = target.view(-1)

        self.common.report((prediction, target))

        pred_non_unk_indices = (prediction != 0).nonzero().squeeze()
        target_non_unk_indices = (target != 0).nonzero().squeeze()

        self.prediction_non_unk.report(prediction, target,
                                       pred_non_unk_indices)
        self.target_non_unk.report(prediction, target, target_non_unk_indices)

    def get_current_value(self, should_print=False):
        print('P1 = {}'.format(self.common.get_current_value(False)))
        print('P2 = {}'.format(self.prediction_non_unk.metrics.hits /
                               (self.common.hits + self.common.misses)))
        print('P3 = {}'.format(self.target_non_unk.get_current_value(False)))
        print('P4 = {}'.format(
            self.prediction_non_unk.get_current_value(False)))
Exemple #2
0
class TerminalAccuracyMetrics(Metrics):
    def __init__(self, dim=2):
        super().__init__()
        self.dim = dim
        self.general_accuracy = BaseAccuracyMetrics()
        self.empty_accuracy = IndexedAccuracyMetrics(
            label='Accuracy on terminals that ground truth is <empty>'
        )
        self.non_empty_accuracy = IndexedAccuracyMetrics(
            label='Accuracy on terminals that ground truth is not <empty>'
        )
        self.ground_not_unk_accuracy = IndexedAccuracyMetrics(
            label='Accuracy on terminals that ground truth is not <unk> (and ground truth is not <empty>)'
        )
        self.model_not_unk_accuracy = IndexedAccuracyMetrics(
            label='Accuracy on terminals that model predicted to non <unk> (and ground truth is not <empty>)'
        )

    def drop_state(self):
        self.general_accuracy.drop_state()
        self.empty_accuracy.drop_state()
        self.non_empty_accuracy.drop_state()
        self.ground_not_unk_accuracy.drop_state()
        self.model_not_unk_accuracy.drop_state()

    def report(self, prediction_target):
        prediction, target = prediction_target
        _, predicted = torch.max(prediction, dim=self.dim)
        predicted = predicted.view(-1)
        target = target.view(-1)

        self.general_accuracy.report((predicted, target))

        if not self.is_train:
            empty_indexes = torch.nonzero(target == 0).squeeze()
            self.empty_accuracy.report(predicted, target, empty_indexes)

            non_empty_indexes = torch.nonzero(target - EMPTY_TOKEN_ID).squeeze()
            self.non_empty_accuracy.report(predicted, target, non_empty_indexes)

            predicted = torch.index_select(predicted, 0, non_empty_indexes)
            target = torch.index_select(target, 0, non_empty_indexes)

            ground_not_unk_indexes = torch.nonzero(target - UNKNOWN_TOKEN_ID).squeeze()
            self.ground_not_unk_accuracy.report(predicted, target, ground_not_unk_indexes)

            model_not_unk_indexes = torch.nonzero(predicted - UNKNOWN_TOKEN_ID).squeeze()
            self.model_not_unk_accuracy.report(predicted, target, model_not_unk_indexes)

    def get_current_value(self, should_print=False):
        general_accuracy = self.general_accuracy.get_current_value(should_print=should_print)
        if (not self.is_train) and should_print:
            self.empty_accuracy.get_current_value(should_print=True)
            self.non_empty_accuracy.get_current_value(should_print=True)
            self.ground_not_unk_accuracy.get_current_value(should_print=True)
            self.model_not_unk_accuracy.get_current_value(should_print=True)
        return general_accuracy
Exemple #3
0
 def __init__(self, dim=2):
     super().__init__()
     self.dim = dim
     self.general_accuracy = BaseAccuracyMetrics()
     self.empty_accuracy = IndexedAccuracyMetrics(
         label='Accuracy on terminals that ground truth is <empty>'
     )
     self.non_empty_accuracy = IndexedAccuracyMetrics(
         label='Accuracy on terminals that ground truth is not <empty>'
     )
     self.ground_not_unk_accuracy = IndexedAccuracyMetrics(
         label='Accuracy on terminals that ground truth is not <unk> (and ground truth is not <empty>)'
     )
     self.model_not_unk_accuracy = IndexedAccuracyMetrics(
         label='Accuracy on terminals that model predicted to non <unk> (and ground truth is not <empty>)'
     )
def eval_nt(results_dir, save_dir, group=False):
    reader = ResultsReader(results_dir=results_dir)

    metrics = SequentialMetrics([
        NonTerminalsMetricsWrapper(BaseAccuracyMetrics()),
        SingleNonTerminalAccuracyMetrics(
            non_terminals_file='data/ast/non_terminals.json',
            results_dir=save_dir,
            group=group,
            dim=None)
    ])

    # run_nt_metrics(reader, metrics)

    metrics.drop_state()
    metrics.report((torch.from_numpy(reader.predicted[:, :, 0]),
                    ASTTarget(torch.from_numpy(reader.target), None)))
    metrics.get_current_value(should_print=True)
Exemple #5
0
 def __init__(self):
     super().__init__()
     self.common = BaseAccuracyMetrics()
     self.target_non_unk = IndexedAccuracyMetrics('Target not unk')
     self.prediction_non_unk = IndexedAccuracyMetrics('Prediction not unk')
def get_accuracy_result(results_dir):
    reader = ResultsReader(results_dir=results_dir)
    metrics = NonTerminalsMetricsWrapper(BaseAccuracyMetrics())
    run_nt_metrics(reader, metrics)