Ejemplo n.º 1
0
 def setUp(self):
     self._metric_function = Mock(return_value='test')
     self._metric = EpochLambda('test', self._metric_function, step_size=3)
     self._metric.reset({
         torchbearer.DEVICE: 'cpu',
         torchbearer.DATA_TYPE: torch.float32
     })
     self._states = [{
         torchbearer.BATCH: 0,
         torchbearer.Y_TRUE: torch.LongTensor([0]),
         torchbearer.Y_PRED: torch.FloatTensor([0.0]),
         torchbearer.DEVICE: 'cpu'
     }, {
         torchbearer.BATCH: 1,
         torchbearer.Y_TRUE: torch.LongTensor([1]),
         torchbearer.Y_PRED: torch.FloatTensor([0.1]),
         torchbearer.DEVICE: 'cpu'
     }, {
         torchbearer.BATCH: 2,
         torchbearer.Y_TRUE: torch.LongTensor([2]),
         torchbearer.Y_PRED: torch.FloatTensor([0.2]),
         torchbearer.DEVICE: 'cpu'
     }, {
         torchbearer.BATCH: 3,
         torchbearer.Y_TRUE: torch.LongTensor([3]),
         torchbearer.Y_PRED: torch.FloatTensor([0.3]),
         torchbearer.DEVICE: 'cpu'
     }, {
         torchbearer.BATCH: 4,
         torchbearer.Y_TRUE: torch.LongTensor([4]),
         torchbearer.Y_PRED: torch.FloatTensor([0.4]),
         torchbearer.DEVICE: 'cpu'
     }]
Ejemplo n.º 2
0
    def _add_metric(self, state):
        from pycm import ConfusionMatrix

        def make_cm(y_pred, y_true):
            _, y_pred = torch.max(y_pred, 1)
            cm = ConfusionMatrix(y_true.cpu().numpy(),
                                 y_pred.cpu().numpy(), **self.kwargs)
            for handler in self._handlers:
                handler(cm, state)

        my_metric = EpochLambda('pycm', make_cm, False)
        my_metric.reset(state)
        state[torchbearer.METRIC_LIST] = MetricList(
            [state[torchbearer.METRIC_LIST], my_metric])
Ejemplo n.º 3
0
    def test_not_running(self):
        metric = EpochLambda('test', self._metric_function, running=False, step_size=6)
        metric.reset({torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32})
        metric.train()

        for i in range(12):
            metric.process(self._states[0])

        self._metric_function.assert_not_called()
Ejemplo n.º 4
0
 def setUp(self):
     self._metric_function = Mock(return_value='test')
     self._metric = EpochLambda('test', self._metric_function, step_size=3)
     self._metric.reset({torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32})
     self._states = [{torchbearer.BATCH: 0, torchbearer.Y_TRUE: torch.LongTensor([0]), torchbearer.Y_PRED: torch.FloatTensor([0.0]), torchbearer.DEVICE: 'cpu'},
                     {torchbearer.BATCH: 1, torchbearer.Y_TRUE: torch.LongTensor([1]), torchbearer.Y_PRED: torch.FloatTensor([0.1]), torchbearer.DEVICE: 'cpu'},
                     {torchbearer.BATCH: 2, torchbearer.Y_TRUE: torch.LongTensor([2]), torchbearer.Y_PRED: torch.FloatTensor([0.2]), torchbearer.DEVICE: 'cpu'},
                     {torchbearer.BATCH: 3, torchbearer.Y_TRUE: torch.LongTensor([3]), torchbearer.Y_PRED: torch.FloatTensor([0.3]), torchbearer.DEVICE: 'cpu'},
                     {torchbearer.BATCH: 4, torchbearer.Y_TRUE: torch.LongTensor([4]), torchbearer.Y_PRED: torch.FloatTensor([0.4]), torchbearer.DEVICE: 'cpu'}]
Ejemplo n.º 5
0
class TestEpochLambda(unittest.TestCase):
    def setUp(self):
        self._metric_function = Mock(return_value='test')
        self._metric = EpochLambda('test', self._metric_function, step_size=3)
        self._metric.reset({
            torchbearer.DEVICE: 'cpu',
            torchbearer.DATA_TYPE: torch.float32
        })
        self._states = [{
            torchbearer.BATCH: 0,
            torchbearer.Y_TRUE: torch.LongTensor([0]),
            torchbearer.Y_PRED: torch.FloatTensor([0.0]),
            torchbearer.DEVICE: 'cpu'
        }, {
            torchbearer.BATCH: 1,
            torchbearer.Y_TRUE: torch.LongTensor([1]),
            torchbearer.Y_PRED: torch.FloatTensor([0.1]),
            torchbearer.DEVICE: 'cpu'
        }, {
            torchbearer.BATCH: 2,
            torchbearer.Y_TRUE: torch.LongTensor([2]),
            torchbearer.Y_PRED: torch.FloatTensor([0.2]),
            torchbearer.DEVICE: 'cpu'
        }, {
            torchbearer.BATCH: 3,
            torchbearer.Y_TRUE: torch.LongTensor([3]),
            torchbearer.Y_PRED: torch.FloatTensor([0.3]),
            torchbearer.DEVICE: 'cpu'
        }, {
            torchbearer.BATCH: 4,
            torchbearer.Y_TRUE: torch.LongTensor([4]),
            torchbearer.Y_PRED: torch.FloatTensor([0.4]),
            torchbearer.DEVICE: 'cpu'
        }]

    def test_train(self):
        self._metric.train()
        calls = [[torch.FloatTensor([0.0]),
                  torch.LongTensor([0])],
                 [
                     torch.FloatTensor([0.0, 0.1, 0.2, 0.3]),
                     torch.LongTensor([0, 1, 2, 3])
                 ]]
        for i in range(len(self._states)):
            self._metric.process(self._states[i])
        self.assertEqual(2, len(self._metric_function.call_args_list))
        for i in range(len(self._metric_function.call_args_list)):
            self.assertTrue(
                torch.eq(self._metric_function.call_args_list[i][0][0],
                         calls[i][0]).all)
            self.assertTrue(
                torch.lt(
                    torch.abs(
                        torch.add(
                            self._metric_function.call_args_list[i][0][1],
                            -calls[i][1])), 1e-12).all)
        self._metric_function.reset_mock()
        self._metric.process_final({})

        self.assertEqual(self._metric_function.call_count, 1)
        self.assertTrue(
            torch.eq(self._metric_function.call_args_list[0][0][1],
                     torch.LongTensor([0, 1, 2, 3, 4])).all)
        self.assertTrue(
            torch.lt(
                torch.abs(
                    torch.add(self._metric_function.call_args_list[0][0][0],
                              -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))),
                1e-12).all)

    def test_validate(self):
        self._metric.eval()
        for i in range(len(self._states)):
            self._metric.process(self._states[i])
        self._metric_function.assert_not_called()
        self._metric.process_final_validate({})

        self.assertEqual(self._metric_function.call_count, 1)
        self.assertTrue(
            torch.eq(self._metric_function.call_args_list[0][0][1],
                     torch.LongTensor([0, 1, 2, 3, 4])).all)
        self.assertTrue(
            torch.lt(
                torch.abs(
                    torch.add(self._metric_function.call_args_list[0][0][0],
                              -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))),
                1e-12).all)

    def test_not_running(self):
        metric = EpochLambda('test',
                             self._metric_function,
                             running=False,
                             step_size=6)
        metric.reset({
            torchbearer.DEVICE: 'cpu',
            torchbearer.DATA_TYPE: torch.float32
        })
        metric.train()

        for i in range(12):
            metric.process(self._states[0])

        self._metric_function.assert_not_called()
Ejemplo n.º 6
0
 def build(self):
     if on_epoch:
         return EpochLambda(name, metric_function)
     else:
         return BatchLambda(name, metric_function)
Ejemplo n.º 7
0
 def decorator(metric_function):
     if on_epoch:
         return EpochLambda(name, metric_function)
     else:
         return BatchLambda(name, metric_function)
Ejemplo n.º 8
0
class TestEpochLambda(unittest.TestCase):
    def setUp(self):
        self._metric_function = Mock(return_value='test')
        self._metric = EpochLambda('test', self._metric_function, step_size=3)
        self._metric.reset({torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32})
        self._states = [{torchbearer.BATCH: 0, torchbearer.Y_TRUE: torch.LongTensor([0]), torchbearer.Y_PRED: torch.FloatTensor([0.0]), torchbearer.DEVICE: 'cpu'},
                        {torchbearer.BATCH: 1, torchbearer.Y_TRUE: torch.LongTensor([1]), torchbearer.Y_PRED: torch.FloatTensor([0.1]), torchbearer.DEVICE: 'cpu'},
                        {torchbearer.BATCH: 2, torchbearer.Y_TRUE: torch.LongTensor([2]), torchbearer.Y_PRED: torch.FloatTensor([0.2]), torchbearer.DEVICE: 'cpu'},
                        {torchbearer.BATCH: 3, torchbearer.Y_TRUE: torch.LongTensor([3]), torchbearer.Y_PRED: torch.FloatTensor([0.3]), torchbearer.DEVICE: 'cpu'},
                        {torchbearer.BATCH: 4, torchbearer.Y_TRUE: torch.LongTensor([4]), torchbearer.Y_PRED: torch.FloatTensor([0.4]), torchbearer.DEVICE: 'cpu'}]

    def test_train(self):
        self._metric.train()
        calls = [[torch.FloatTensor([0.0]), torch.LongTensor([0])],
                 [torch.FloatTensor([0.0, 0.1, 0.2, 0.3]), torch.LongTensor([0, 1, 2, 3])]]
        for i in range(len(self._states)):
            self._metric.process(self._states[i])
        self.assertEqual(2, len(self._metric_function.call_args_list))
        for i in range(len(self._metric_function.call_args_list)):
            self.assertTrue(torch.eq(self._metric_function.call_args_list[i][0][0], calls[i][0]).all)
            self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[i][0][1], -calls[i][1])), 1e-12).all)
        self._metric_function.reset_mock()
        self._metric.process_final({})

        self._metric_function.assert_called_once()
        self.assertTrue(torch.eq(self._metric_function.call_args_list[0][0][1], torch.LongTensor([0, 1, 2, 3, 4])).all)
        self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[0][0][0], -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))), 1e-12).all)

    def test_validate(self):
        self._metric.eval()
        for i in range(len(self._states)):
            self._metric.process(self._states[i])
        self._metric_function.assert_not_called()
        self._metric.process_final_validate({})

        self._metric_function.assert_called_once()
        self.assertTrue(torch.eq(self._metric_function.call_args_list[0][0][1], torch.LongTensor([0, 1, 2, 3, 4])).all)
        self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[0][0][0], -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))), 1e-12).all)

    def test_not_running(self):
        metric = EpochLambda('test', self._metric_function, running=False, step_size=6)
        metric.reset({torchbearer.DEVICE: 'cpu', torchbearer.DATA_TYPE: torch.float32})
        metric.train()

        for i in range(12):
            metric.process(self._states[0])

        self._metric_function.assert_not_called()