Exemplo n.º 1
0
class AggregatedTokenMetrics(Metrics):
    def __init__(self):
        super().__init__()
        self.common = BaseAccuracyMetrics()
        self.target_non_unk = IndexedAccuracyMetrics('Target not unk')
        self.prediction_non_unk = IndexedAccuracyMetrics('Prediction not unk')

    def drop_state(self):
        self.common.drop_state()
        self.target_non_unk.drop_state()
        self.prediction_non_unk.drop_state()

    def report(self, prediction_target):
        prediction, target = prediction_target
        prediction = prediction.view(-1)
        target = target.view(-1)

        self.common.report((prediction, target))

        pred_non_unk_indices = (prediction != 0).nonzero().squeeze()
        target_non_unk_indices = (target != 0).nonzero().squeeze()

        self.prediction_non_unk.report(prediction, target,
                                       pred_non_unk_indices)
        self.target_non_unk.report(prediction, target, target_non_unk_indices)

    def get_current_value(self, should_print=False):
        print('P1 = {}'.format(self.common.get_current_value(False)))
        print('P2 = {}'.format(self.prediction_non_unk.metrics.hits /
                               (self.common.hits + self.common.misses)))
        print('P3 = {}'.format(self.target_non_unk.get_current_value(False)))
        print('P4 = {}'.format(
            self.prediction_non_unk.get_current_value(False)))
Exemplo n.º 2
0
class TerminalAccuracyMetrics(Metrics):
    def __init__(self, dim=2):
        super().__init__()
        self.dim = dim
        self.general_accuracy = BaseAccuracyMetrics()
        self.empty_accuracy = IndexedAccuracyMetrics(
            label='Accuracy on terminals that ground truth is <empty>'
        )
        self.non_empty_accuracy = IndexedAccuracyMetrics(
            label='Accuracy on terminals that ground truth is not <empty>'
        )
        self.ground_not_unk_accuracy = IndexedAccuracyMetrics(
            label='Accuracy on terminals that ground truth is not <unk> (and ground truth is not <empty>)'
        )
        self.model_not_unk_accuracy = IndexedAccuracyMetrics(
            label='Accuracy on terminals that model predicted to non <unk> (and ground truth is not <empty>)'
        )

    def drop_state(self):
        self.general_accuracy.drop_state()
        self.empty_accuracy.drop_state()
        self.non_empty_accuracy.drop_state()
        self.ground_not_unk_accuracy.drop_state()
        self.model_not_unk_accuracy.drop_state()

    def report(self, prediction_target):
        prediction, target = prediction_target
        _, predicted = torch.max(prediction, dim=self.dim)
        predicted = predicted.view(-1)
        target = target.view(-1)

        self.general_accuracy.report((predicted, target))

        if not self.is_train:
            empty_indexes = torch.nonzero(target == 0).squeeze()
            self.empty_accuracy.report(predicted, target, empty_indexes)

            non_empty_indexes = torch.nonzero(target - EMPTY_TOKEN_ID).squeeze()
            self.non_empty_accuracy.report(predicted, target, non_empty_indexes)

            predicted = torch.index_select(predicted, 0, non_empty_indexes)
            target = torch.index_select(target, 0, non_empty_indexes)

            ground_not_unk_indexes = torch.nonzero(target - UNKNOWN_TOKEN_ID).squeeze()
            self.ground_not_unk_accuracy.report(predicted, target, ground_not_unk_indexes)

            model_not_unk_indexes = torch.nonzero(predicted - UNKNOWN_TOKEN_ID).squeeze()
            self.model_not_unk_accuracy.report(predicted, target, model_not_unk_indexes)

    def get_current_value(self, should_print=False):
        general_accuracy = self.general_accuracy.get_current_value(should_print=should_print)
        if (not self.is_train) and should_print:
            self.empty_accuracy.get_current_value(should_print=True)
            self.non_empty_accuracy.get_current_value(should_print=True)
            self.ground_not_unk_accuracy.get_current_value(should_print=True)
            self.model_not_unk_accuracy.get_current_value(should_print=True)
        return general_accuracy