Exemplo n.º 1
0
class TestRunningMean(unittest.TestCase):
    def setUp(self):
        self._metric = Metric('test')
        self._mean = RunningMean('test')
        self._cache = [
            torch.Tensor([1.0]),
            torch.Tensor([1.5]),
            torch.Tensor([2.0])
        ]
        self._target = 1.5

    def test_train(self):
        result = self._mean._process_train(torch.FloatTensor([1.0, 1.5, 2.0]))
        self.assertAlmostEqual(self._target, result, 3, 0.002)

    def test_step(self):
        result = self._mean._step(self._cache)
        self.assertEqual(self._target, result)

    def test_dims(self):
        mean = RunningMean('test', dim=0)
        cache = [
            mean._process_train(torch.Tensor([[1., 2.], [3., 4.]])),
            mean._process_train(torch.Tensor([[4., 3.], [2., 1.]])),
            mean._process_train(torch.Tensor([[1., 1.], [1., 1.]]))
        ]

        res = mean._step(cache)
        self.assertTrue(len(res) == 2)
        for m in res:
            self.assertTrue(abs(m - 2.0) < 0.0001)
Exemplo n.º 2
0
 def setUp(self):
     self._metric = Metric('test')
     self._mean = RunningMean('test')
     self._cache = [
         torch.Tensor([1.0]),
         torch.Tensor([1.5]),
         torch.Tensor([2.0])
     ]
     self._target = 1.5
Exemplo n.º 3
0
    def test_dims(self):
        mean = RunningMean('test', dim=0)
        cache = [
            mean._process_train(torch.Tensor([[1., 2.], [3., 4.]])),
            mean._process_train(torch.Tensor([[4., 3.], [2., 1.]])),
            mean._process_train(torch.Tensor([[1., 1.], [1., 1.]]))
        ]

        res = mean._step(cache)
        self.assertTrue(len(res) == 2)
        for m in res:
            self.assertTrue(abs(m - 2.0) < 0.0001)
Exemplo n.º 4
0
class TestRunningMean(unittest.TestCase):
    def setUp(self):
        self._metric = Metric('test')
        self._mean = RunningMean('test')
        self._cache = [1.0, 1.5, 2.0]
        self._target = 1.5

    def test_train(self):
        result = self._mean._process_train(torch.FloatTensor([1.0, 1.5, 2.0]))
        self.assertAlmostEqual(self._target, result, 3, 0.002)

    def test_step(self):
        result = self._mean._step(self._cache)
        self.assertEqual(self._target, result)
Exemplo n.º 5
0
        def build(self):
            if isinstance(self.inner, MetricFactory):
                inner = self.inner.build()
            else:
                inner = self.inner

            if not isinstance(inner, MetricTree):
                inner = MetricTree(inner)
            inner.add_child(ToDict(RunningMean('running_' + inner.name, batch_size=batch_size, step_size=step_size)))
            return inner
Exemplo n.º 6
0
def running_mean(clazz=None, batch_size=50, step_size=10, dim=None):
    """The :func:`running_mean` decorator is used to add a :class:`.RunningMean` to the :class:`.MetricTree`. If the
    inner class is not a :class:`.MetricTree` then one will be created. The :class:`.RunningMean` will be wrapped in a
    :class:`.ToDict` (with 'running\_' prepended to the name) for simplicity.

    .. note::
        The decorator function does not need to be called if not desired, both: `@running_mean` and `@running_mean()`
        are acceptable.

    Example: ::

        >>> import torch
        >>> from torchbearer import metrics

        >>> @metrics.running_mean(step_size=2) # Update every 2 steps
        ... @metrics.lambda_metric('my_metric')
        ... def metric(y_pred, y_true):
        ...     return y_pred + y_true
        ...
        >>> metric.reset({})
        >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4
        {'running_my_metric': 4.0}
        >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6
        {'running_my_metric': 4.0}
        >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8, triggers update
        {'running_my_metric': 6.0}

    Args:
        clazz: The class to *decorate*
        batch_size (int): See :class:`.RunningMean`
        step_size (int): See :class:`.RunningMean`
        dim (int, tuple): See :class:`.RunningMean`

    Returns:
        decorator or :class:`.MetricTree` instance or wrapper
    """
    if clazz is None:

        def decorator(clazz):
            return running_mean(clazz,
                                batch_size=batch_size,
                                step_size=step_size,
                                dim=dim)

        return decorator

    return _wrap_and_add_to_tree(
        clazz, lambda metric: ToDict(
            RunningMean('running_' + metric.name,
                        batch_size=batch_size,
                        step_size=step_size,
                        dim=dim)))
Exemplo n.º 7
0
 def setUp(self):
     self._metric = Metric('test')
     self._mean = RunningMean('test')
     self._cache = [1.0, 1.5, 2.0]
     self._target = 1.5