Exemple #1
0
class MixupAcc(AdvancedMetric):
    def __init__(self):
        super(MixupAcc, self).__init__('mixup_acc')
        self.cat_acc = CategoricalAccuracy().root

    def process_train(self, *args):
        super(MixupAcc, self).process_train(*args)
        state = args[0]

        target1, target2 = state[torchbearer.Y_TRUE]
        _state = args[0].copy()
        _state[torchbearer.Y_TRUE] = target1
        acc1 = self.cat_acc.process(_state)

        _state = args[0].copy()
        _state[torchbearer.Y_TRUE] = target2
        acc2 = self.cat_acc.process(_state)

        return acc1 * state[torchbearer.MIXUP_LAMBDA] + acc2 * (
            1 - state[torchbearer.MIXUP_LAMBDA])

    def process_validate(self, *args):
        super(MixupAcc, self).process_validate(*args)

        return self.cat_acc.process(*args)

    def reset(self, state):
        self.cat_acc.reset(state)
Exemple #2
0
    def __init__(self):
        super(DefaultAccuracy,
              self).__init__('placeholder')  # Don't set yet, wait for reset

        self.metric = CategoricalAccuracy()  # Default to CategoricalAccuracy
        self.name = self.metric.name
        self._loaded = False
    def test_ignore_index(self):
        metric = CategoricalAccuracy(ignore_index=1).root  # Get root node of Tree for testing
        targets = [1, 1, 0]

        metric.train()
        result = metric.process(self._state)
        for i in range(0, len(targets)):
            self.assertEqual(result[i], targets[i],
                             msg='returned: ' + str(result[i]) + ' expected: ' + str(targets[i])
                                 + ' in: ' + str(result))
 def setUp(self):
     self._state = {
         torchbearer.Y_TRUE: Variable(torch.LongTensor([0, 1, 2, 2, 1])),
         torchbearer.Y_PRED: Variable(torch.FloatTensor([
             [0.9, 0.1, 0.1], # Correct
             [0.1, 0.9, 0.1], # Correct
             [0.1, 0.1, 0.9], # Correct
             [0.9, 0.1, 0.1], # Incorrect
             [0.9, 0.1, 0.1], # Incorrect
         ]))
     }
     self._targets = [1, 1, 1, 0, 0]
     self._metric = CategoricalAccuracy().root  # Get root node of Tree for testing
Exemple #5
0
class DefaultAccuracy(Metric):
    """The default accuracy metric loads in a different accuracy metric depending on the loss function or criterion in
    use at the start of training. Default for keys: `acc`, `accuracy`. The following bindings are in place for both nn
    and functional variants:

    - cross entropy loss -> :class:`.CategoricalAccuracy` [DEFAULT]
    - nll loss -> :class:`.CategoricalAccuracy`
    - mse loss -> :class:`.MeanSquaredError`
    - bce loss -> :class:`.BinaryAccuracy`
    - bce loss with logits -> :class:`.BinaryAccuracy`
    """
    def __init__(self):
        super(DefaultAccuracy,
              self).__init__('placeholder')  # Don't set yet, wait for reset

        self.metric = CategoricalAccuracy()  # Default to CategoricalAccuracy
        self.name = self.metric.name
        self._loaded = False
        self._train = True

    def train(self):
        self._train = True
        return self.metric.train()

    def eval(self, data_key=None):
        self._train = False
        return self.metric.eval(data_key=data_key)

    def process(self, *args):
        return self.metric.process(*args)

    def process_final(self, *args):
        return self.metric.process_final(*args)

    def reset(self, state):
        if not self._loaded:
            criterion = state[torchbearer.CRITERION]

            name = None

            if hasattr(criterion, '__name__'):
                name = criterion.__name__
            elif hasattr(criterion, '__class__'):
                name = criterion.__class__.__name__

            if name is not None and name in __loss_map__:
                self.metric = __loss_map__[name]()
                self.name = self.metric.name
                if self._train:
                    self.metric.train()
                else:
                    self.metric.eval(data_key=state[torchbearer.DATA])

            self._loaded = True

        return self.metric.reset(state)
class TestCategoricalAccuracy(unittest.TestCase):
    def setUp(self):
        self._state = {
            torchbearer.Y_TRUE: Variable(torch.LongTensor([0, 1, 2, 2, 1])),
            torchbearer.Y_PRED: Variable(torch.FloatTensor([
                [0.9, 0.1, 0.1], # Correct
                [0.1, 0.9, 0.1], # Correct
                [0.1, 0.1, 0.9], # Correct
                [0.9, 0.1, 0.1], # Incorrect
                [0.9, 0.1, 0.1], # Incorrect
            ]))
        }
        self._targets = [1, 1, 1, 0, 0]
        self._metric = CategoricalAccuracy()

    def test_train_process(self):
        self._metric.train()
        result = self._metric.process(self._state)
        for i in range(0, len(self._targets)):
            self.assertEqual(result[i], self._targets[i],
                             msg='returned: ' + str(result[i]) + ' expected: ' + str(self._targets[i])
                                 + ' in: ' + str(result))

    def test_validate_process(self):
        self._metric.eval()
        result = self._metric.process(self._state)
        for i in range(0, len(self._targets)):
            self.assertEqual(result[i], self._targets[i],
                             msg='returned: ' + str(result[i]) + ' expected: ' + str(self._targets[i])
                                 + ' in: ' + str(result))
 def setUp(self):
     self._state = {
         torchbearer.Y_TRUE: Variable(torch.LongTensor([0, 1, 2, 2, 1])),
         torchbearer.Y_PRED: Variable(torch.FloatTensor([
             [0.9, 0.1, 0.1], # Correct
             [0.1, 0.9, 0.1], # Correct
             [0.1, 0.1, 0.9], # Correct
             [0.9, 0.1, 0.1], # Incorrect
             [0.9, 0.1, 0.1], # Incorrect
         ]))
     }
     self._targets = [1, 1, 1, 0, 0]
     self._metric = CategoricalAccuracy()
class TestCategoricalAccuracy(unittest.TestCase):
    def setUp(self):
        self._state = {
            torchbearer.Y_TRUE:
            Variable(torch.LongTensor([0, 1, 2, 2, 1])),
            torchbearer.Y_PRED:
            Variable(
                torch.FloatTensor([
                    [0.9, 0.1, 0.1],  # Correct
                    [0.1, 0.9, 0.1],  # Correct
                    [0.1, 0.1, 0.9],  # Correct
                    [0.9, 0.1, 0.1],  # Incorrect
                    [0.9, 0.1, 0.1],  # Incorrect
                ]))
        }
        self._targets = [1, 1, 1, 0, 0]
        self._metric = CategoricalAccuracy()

    @patch('torchbearer.metrics.primitives.CategoricalAccuracy')
    def test_ignore_index_args_passed(self, mock):
        CategoricalAccuracyFactory(ignore_index=1).build()
        mock.assert_called_once_with(ignore_index=1)

    def test_ignore_index(self):
        metric = CategoricalAccuracy(ignore_index=1)
        targets = [1, 1, 0]

        metric.train()
        result = metric.process(self._state)
        for i in range(0, len(targets)):
            self.assertEqual(result[i],
                             targets[i],
                             msg='returned: ' + str(result[i]) +
                             ' expected: ' + str(targets[i]) + ' in: ' +
                             str(result))

    def test_train_process(self):
        self._metric.train()
        result = self._metric.process(self._state)
        for i in range(0, len(self._targets)):
            self.assertEqual(result[i],
                             self._targets[i],
                             msg='returned: ' + str(result[i]) +
                             ' expected: ' + str(self._targets[i]) + ' in: ' +
                             str(result))

    def test_validate_process(self):
        self._metric.eval()
        result = self._metric.process(self._state)
        for i in range(0, len(self._targets)):
            self.assertEqual(result[i],
                             self._targets[i],
                             msg='returned: ' + str(result[i]) +
                             ' expected: ' + str(self._targets[i]) + ' in: ' +
                             str(result))
class TestCategoricalAccuracy(unittest.TestCase):
    def setUp(self):
        self._state = {
            torchbearer.Y_TRUE:
            Variable(torch.LongTensor([0, 1, 2, 2, 1])),
            torchbearer.Y_PRED:
            Variable(
                torch.FloatTensor([
                    [0.9, 0.1, 0.1],  # Correct
                    [0.1, 0.9, 0.1],  # Correct
                    [0.1, 0.1, 0.9],  # Correct
                    [0.9, 0.1, 0.1],  # Incorrect
                    [0.9, 0.1, 0.1],  # Incorrect
                ]))
        }
        self._targets = [1, 1, 1, 0, 0]
        self._metric = CategoricalAccuracy(
        ).root  # Get root node of Tree for testing

    def test_ignore_index(self):
        metric = CategoricalAccuracy(
            ignore_index=1).root  # Get root node of Tree for testing
        targets = [1, 1, 0]

        metric.train()
        result = metric.process(self._state)
        for i in range(0, len(targets)):
            self.assertEqual(result[i],
                             targets[i],
                             msg='returned: ' + str(result[i]) +
                             ' expected: ' + str(targets[i]) + ' in: ' +
                             str(result))

    def test_train_process(self):
        self._metric.train()
        result = self._metric.process(self._state)
        for i in range(0, len(self._targets)):
            self.assertEqual(result[i],
                             self._targets[i],
                             msg='returned: ' + str(result[i]) +
                             ' expected: ' + str(self._targets[i]) + ' in: ' +
                             str(result))

    def test_validate_process(self):
        self._metric.eval()
        result = self._metric.process(self._state)
        for i in range(0, len(self._targets)):
            self.assertEqual(result[i],
                             self._targets[i],
                             msg='returned: ' + str(result[i]) +
                             ' expected: ' + str(self._targets[i]) + ' in: ' +
                             str(result))
Exemple #10
0
 def __init__(self):
     super(MixupAcc, self).__init__('mixup_acc')
     self.cat_acc = CategoricalAccuracy().root