class AggregatedTokenMetrics(Metrics): def __init__(self): super().__init__() self.common = BaseAccuracyMetrics() self.target_non_unk = IndexedAccuracyMetrics('Target not unk') self.prediction_non_unk = IndexedAccuracyMetrics('Prediction not unk') def drop_state(self): self.common.drop_state() self.target_non_unk.drop_state() self.prediction_non_unk.drop_state() def report(self, prediction_target): prediction, target = prediction_target prediction = prediction.view(-1) target = target.view(-1) self.common.report((prediction, target)) pred_non_unk_indices = (prediction != 0).nonzero().squeeze() target_non_unk_indices = (target != 0).nonzero().squeeze() self.prediction_non_unk.report(prediction, target, pred_non_unk_indices) self.target_non_unk.report(prediction, target, target_non_unk_indices) def get_current_value(self, should_print=False): print('P1 = {}'.format(self.common.get_current_value(False))) print('P2 = {}'.format(self.prediction_non_unk.metrics.hits / (self.common.hits + self.common.misses))) print('P3 = {}'.format(self.target_non_unk.get_current_value(False))) print('P4 = {}'.format( self.prediction_non_unk.get_current_value(False)))
def __init__(self, non_terminals_file, results_dir=None, add_unk=False, dim=2): """ :param non_terminals_file: file with json of non-terminals :param results_dir: where to save json with accuracies per non-terminal :param dim: dimension to run max function on for predicted values """ super().__init__() print('Python SingleNonTerminalAccuracyMetrics created!') self.non_terminals = read_non_terminals(non_terminals_file) if add_unk: self.non_terminals.append('<unk>') self.non_terminals_number = len(self.non_terminals) self.results_dir = results_dir self.dim = dim self.accuracies = [ IndexedAccuracyMetrics(label='ERROR') for _ in self.non_terminals ]
def __init__(self, dim=2): super().__init__() self.dim = dim self.general_accuracy = BaseAccuracyMetrics() self.empty_accuracy = IndexedAccuracyMetrics( label='Accuracy on terminals that ground truth is <empty>' ) self.non_empty_accuracy = IndexedAccuracyMetrics( label='Accuracy on terminals that ground truth is not <empty>' ) self.ground_not_unk_accuracy = IndexedAccuracyMetrics( label='Accuracy on terminals that ground truth is not <unk> (and ground truth is not <empty>)' ) self.model_not_unk_accuracy = IndexedAccuracyMetrics( label='Accuracy on terminals that model predicted to non <unk> (and ground truth is not <empty>)' )
def __init__(self): super().__init__() self.common = BaseAccuracyMetrics() self.target_non_unk = IndexedAccuracyMetrics('Target not unk') self.prediction_non_unk = IndexedAccuracyMetrics('Prediction not unk')
class TerminalAccuracyMetrics(Metrics): def __init__(self, dim=2): super().__init__() self.dim = dim self.general_accuracy = BaseAccuracyMetrics() self.empty_accuracy = IndexedAccuracyMetrics( label='Accuracy on terminals that ground truth is <empty>' ) self.non_empty_accuracy = IndexedAccuracyMetrics( label='Accuracy on terminals that ground truth is not <empty>' ) self.ground_not_unk_accuracy = IndexedAccuracyMetrics( label='Accuracy on terminals that ground truth is not <unk> (and ground truth is not <empty>)' ) self.model_not_unk_accuracy = IndexedAccuracyMetrics( label='Accuracy on terminals that model predicted to non <unk> (and ground truth is not <empty>)' ) def drop_state(self): self.general_accuracy.drop_state() self.empty_accuracy.drop_state() self.non_empty_accuracy.drop_state() self.ground_not_unk_accuracy.drop_state() self.model_not_unk_accuracy.drop_state() def report(self, prediction_target): prediction, target = prediction_target _, predicted = torch.max(prediction, dim=self.dim) predicted = predicted.view(-1) target = target.view(-1) self.general_accuracy.report((predicted, target)) if not self.is_train: empty_indexes = torch.nonzero(target == 0).squeeze() self.empty_accuracy.report(predicted, target, empty_indexes) non_empty_indexes = torch.nonzero(target - EMPTY_TOKEN_ID).squeeze() self.non_empty_accuracy.report(predicted, target, non_empty_indexes) predicted = torch.index_select(predicted, 0, non_empty_indexes) target = torch.index_select(target, 0, non_empty_indexes) ground_not_unk_indexes = torch.nonzero(target - UNKNOWN_TOKEN_ID).squeeze() self.ground_not_unk_accuracy.report(predicted, target, ground_not_unk_indexes) model_not_unk_indexes = torch.nonzero(predicted - UNKNOWN_TOKEN_ID).squeeze() self.model_not_unk_accuracy.report(predicted, target, model_not_unk_indexes) def get_current_value(self, should_print=False): general_accuracy = self.general_accuracy.get_current_value(should_print=should_print) if (not self.is_train) and should_print: self.empty_accuracy.get_current_value(should_print=True) self.non_empty_accuracy.get_current_value(should_print=True) self.ground_not_unk_accuracy.get_current_value(should_print=True) self.model_not_unk_accuracy.get_current_value(should_print=True) return general_accuracy