class AggregatedTokenMetrics(Metrics): def __init__(self): super().__init__() self.common = BaseAccuracyMetrics() self.target_non_unk = IndexedAccuracyMetrics('Target not unk') self.prediction_non_unk = IndexedAccuracyMetrics('Prediction not unk') def drop_state(self): self.common.drop_state() self.target_non_unk.drop_state() self.prediction_non_unk.drop_state() def report(self, prediction_target): prediction, target = prediction_target prediction = prediction.view(-1) target = target.view(-1) self.common.report((prediction, target)) pred_non_unk_indices = (prediction != 0).nonzero().squeeze() target_non_unk_indices = (target != 0).nonzero().squeeze() self.prediction_non_unk.report(prediction, target, pred_non_unk_indices) self.target_non_unk.report(prediction, target, target_non_unk_indices) def get_current_value(self, should_print=False): print('P1 = {}'.format(self.common.get_current_value(False))) print('P2 = {}'.format(self.prediction_non_unk.metrics.hits / (self.common.hits + self.common.misses))) print('P3 = {}'.format(self.target_non_unk.get_current_value(False))) print('P4 = {}'.format( self.prediction_non_unk.get_current_value(False)))
class TerminalAccuracyMetrics(Metrics): def __init__(self, dim=2): super().__init__() self.dim = dim self.general_accuracy = BaseAccuracyMetrics() self.empty_accuracy = IndexedAccuracyMetrics( label='Accuracy on terminals that ground truth is <empty>' ) self.non_empty_accuracy = IndexedAccuracyMetrics( label='Accuracy on terminals that ground truth is not <empty>' ) self.ground_not_unk_accuracy = IndexedAccuracyMetrics( label='Accuracy on terminals that ground truth is not <unk> (and ground truth is not <empty>)' ) self.model_not_unk_accuracy = IndexedAccuracyMetrics( label='Accuracy on terminals that model predicted to non <unk> (and ground truth is not <empty>)' ) def drop_state(self): self.general_accuracy.drop_state() self.empty_accuracy.drop_state() self.non_empty_accuracy.drop_state() self.ground_not_unk_accuracy.drop_state() self.model_not_unk_accuracy.drop_state() def report(self, prediction_target): prediction, target = prediction_target _, predicted = torch.max(prediction, dim=self.dim) predicted = predicted.view(-1) target = target.view(-1) self.general_accuracy.report((predicted, target)) if not self.is_train: empty_indexes = torch.nonzero(target == 0).squeeze() self.empty_accuracy.report(predicted, target, empty_indexes) non_empty_indexes = torch.nonzero(target - EMPTY_TOKEN_ID).squeeze() self.non_empty_accuracy.report(predicted, target, non_empty_indexes) predicted = torch.index_select(predicted, 0, non_empty_indexes) target = torch.index_select(target, 0, non_empty_indexes) ground_not_unk_indexes = torch.nonzero(target - UNKNOWN_TOKEN_ID).squeeze() self.ground_not_unk_accuracy.report(predicted, target, ground_not_unk_indexes) model_not_unk_indexes = torch.nonzero(predicted - UNKNOWN_TOKEN_ID).squeeze() self.model_not_unk_accuracy.report(predicted, target, model_not_unk_indexes) def get_current_value(self, should_print=False): general_accuracy = self.general_accuracy.get_current_value(should_print=should_print) if (not self.is_train) and should_print: self.empty_accuracy.get_current_value(should_print=True) self.non_empty_accuracy.get_current_value(should_print=True) self.ground_not_unk_accuracy.get_current_value(should_print=True) self.model_not_unk_accuracy.get_current_value(should_print=True) return general_accuracy