Exemple #1
0
 def setUp(self):
     self._metric_function = Mock(return_value='test')
     self._metric = BatchLambda('test', self._metric_function)
     self._states = [{
         torchbearer.Y_TRUE: Variable(torch.FloatTensor([1])),
         torchbearer.Y_PRED: Variable(torch.FloatTensor([2]))
     }, {
         torchbearer.Y_TRUE: Variable(torch.FloatTensor([3])),
         torchbearer.Y_PRED: Variable(torch.FloatTensor([4]))
     }, {
         torchbearer.Y_TRUE: Variable(torch.FloatTensor([5])),
         torchbearer.Y_PRED: Variable(torch.FloatTensor([6]))
     }]
Exemple #2
0
class TestBatchLambda(unittest.TestCase):
    def setUp(self):
        self._metric_function = Mock(return_value='test')
        self._metric = BatchLambda('test', self._metric_function)
        self._states = [{
            torchbearer.Y_TRUE: Variable(torch.FloatTensor([1])),
            torchbearer.Y_PRED: Variable(torch.FloatTensor([2]))
        }, {
            torchbearer.Y_TRUE: Variable(torch.FloatTensor([3])),
            torchbearer.Y_PRED: Variable(torch.FloatTensor([4]))
        }, {
            torchbearer.Y_TRUE: Variable(torch.FloatTensor([5])),
            torchbearer.Y_PRED: Variable(torch.FloatTensor([6]))
        }]

    def test_train(self):
        self._metric.train()
        calls = []
        for i in range(len(self._states)):
            self._metric.process(self._states[i])
            calls.append(
                call(self._states[i][torchbearer.Y_PRED].data,
                     self._states[i][torchbearer.Y_TRUE].data))
        self._metric_function.assert_has_calls(calls)

    def test_validate(self):
        self._metric.eval()
        calls = []
        for i in range(len(self._states)):
            self._metric.process(self._states[i])
            calls.append(
                call(self._states[i][torchbearer.Y_PRED].data,
                     self._states[i][torchbearer.Y_TRUE].data))
        self._metric_function.assert_has_calls(calls)
class TestBatchLambda(unittest.TestCase):
    def setUp(self):
        self._metric_function = Mock(return_value='test')
        self._metric = BatchLambda('test', self._metric_function)
        self._states = [{torchbearer.Y_TRUE: Variable(torch.FloatTensor([1])), torchbearer.Y_PRED: Variable(torch.FloatTensor([2]))},
                        {torchbearer.Y_TRUE: Variable(torch.FloatTensor([3])), torchbearer.Y_PRED: Variable(torch.FloatTensor([4]))},
                        {torchbearer.Y_TRUE: Variable(torch.FloatTensor([5])), torchbearer.Y_PRED: Variable(torch.FloatTensor([6]))}]

    def test_train(self):
        self._metric.train()
        calls = []
        for i in range(len(self._states)):
            self._metric.process(self._states[i])
            calls.append(call(self._states[i][torchbearer.Y_PRED].data, self._states[i][torchbearer.Y_TRUE].data))
        self._metric_function.assert_has_calls(calls)

    def test_validate(self):
        self._metric.eval()
        calls = []
        for i in range(len(self._states)):
            self._metric.process(self._states[i])
            calls.append(call(self._states[i][torchbearer.Y_PRED].data, self._states[i][torchbearer.Y_TRUE].data))
        self._metric_function.assert_has_calls(calls)
 def build(self):
     if on_epoch:
         return EpochLambda(name, metric_function)
     else:
         return BatchLambda(name, metric_function)
Exemple #5
0
 def decorator(metric_function):
     if on_epoch:
         return EpochLambda(name, metric_function)
     else:
         return BatchLambda(name, metric_function)
 def setUp(self):
     self._metric_function = Mock(return_value='test')
     self._metric = BatchLambda('test', self._metric_function)
     self._states = [{torchbearer.Y_TRUE: Variable(torch.FloatTensor([1])), torchbearer.Y_PRED: Variable(torch.FloatTensor([2]))},
                     {torchbearer.Y_TRUE: Variable(torch.FloatTensor([3])), torchbearer.Y_PRED: Variable(torch.FloatTensor([4]))},
                     {torchbearer.Y_TRUE: Variable(torch.FloatTensor([5])), torchbearer.Y_PRED: Variable(torch.FloatTensor([6]))}]