示例#1
0
    def setUp(self):
        self._metric = Metric('test')
        self._metric.train = Mock()
        self._metric.eval = Mock()
        self._metric.reset = Mock()
        self._metric.process = Mock(return_value='process')
        self._metric.process_final = Mock(return_value='process_final')

        self._to_dict = ToDict(self._metric)
示例#2
0
        def build(self):
            if isinstance(self.inner, MetricFactory):
                inner = self.inner.build()
            else:
                inner = self.inner

            return ToDict(inner)
示例#3
0
def mean(clazz):
    """The :func:`mean` decorator is used to add a :class:`.Mean` to the :class:`.MetricTree` which will will output a
    mean value at the end of each epoch. At build time, if the inner class is not a :class:`.MetricTree`, one will be
    created. The :class:`.Mean` will also be wrapped in a :class:`.ToDict` for simplicity.

    Example: ::

        >>> import torch
        >>> from torchbearer import metrics

        >>> @metrics.mean
        ... @metrics.lambda_metric('my_metric')
        ... def metric(y_pred, y_true):
        ...     return y_pred + y_true
        ...
        >>> metric.reset({})
        >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4
        {}
        >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6
        {}
        >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8
        {}
        >>> metric.process_final()
        {'my_metric': 6.0}

    :param clazz: The class to *decorate*
    :return: A :class:`.MetricTree` with a :class:`.Mean` appended or a wrapper class that extends :class:`.MetricTree`
    """
    return _wrap_and_add_to_tree(clazz,
                                 lambda metric: ToDict(Mean(metric.name)))
示例#4
0
def std(clazz):
    """The :func:`std` decorator is used to add a :class:`.Std` to the :class:`.MetricTree` which will will output a
    population standard deviation value at the end of each epoch. At build time, if the inner class is not a
    :class:`.MetricTree`, one will be created. The :class:`.Std` will also be wrapped in a :class:`.ToDict` (with '_std'
    appended) for simplicity.

    Example: ::

        >>> import torch
        >>> from torchbearer import metrics

        >>> @metrics.std
        ... @metrics.lambda_metric('my_metric')
        ... def metric(y_pred, y_true):
        ...     return y_pred + y_true
        ...
        >>> metric.reset({})
        >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4
        {}
        >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6
        {}
        >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8
        {}
        >>> '%.4f' % metric.process_final()['my_metric_std']
        '1.6330'

    :param clazz: The class to *decorate*
    :return: A :class:`.MetricTree` with a :class:`.Std` appended or a wrapper class that extends :class:`.MetricTree`
    """
    return _wrap_and_add_to_tree(
        clazz, lambda metric: ToDict(Std(metric.name + '_std')))
def to_dict(clazz):
    """The :func:`to_dict` decorator is used to wrap either a :class:`.Metric` class or a :class:`.Metric` instance with
    a :class:`.ToDict` instance. The result is that future output will be wrapped in a `dict[name, value]`.

    Example: ::

        >>> from torchbearer import metrics

        >>> @metrics.lambda_metric('my_metric')
        ... def my_metric(y_pred, y_true):
        ...     return y_pred + y_true
        ...
        >>> my_metric.process({'y_pred':4, 'y_true':5})
        9

        >>> @metrics.to_dict
        ... @metrics.lambda_metric('my_metric')
        ... def my_metric(y_pred, y_true):
        ...     return y_pred + y_true
        ...
        >>> my_metric.process({'y_pred':4, 'y_true':5})
        {'my_metric': 9}

    :param clazz: The class to *decorate*
    :return: A :class:`.ToDict` instance or a :class:`.ToDict` wrapper of the given class
    """
    if inspect.isclass(clazz):
        class Wrapper(ToDict):
            def __init__(self, *args, **kwargs):
                super(Wrapper, self).__init__(clazz(*args, **kwargs))
        return Wrapper
    else:
        return ToDict(clazz)
示例#6
0
    def setUp(self):
        self._metric = Metric('test')
        self._metric.train = Mock()
        self._metric.eval = Mock()
        self._metric.reset = Mock()
        self._metric.process = Mock(return_value='process')
        self._metric.process_final = Mock(return_value='process_final')

        self._to_dict = ToDict(self._metric)
示例#7
0
        def build(self):
            if isinstance(self.inner, MetricFactory):
                inner = self.inner.build()
            else:
                inner = self.inner

            if not isinstance(inner, MetricTree):
                inner = MetricTree(inner)
            inner.add_child(ToDict(RunningMean('running_' + inner.name, batch_size=batch_size, step_size=step_size)))
            return inner
示例#8
0
        def build(self):
            if isinstance(self.inner, MetricFactory):
                inner = self.inner.build()
            else:
                inner = self.inner

            if not isinstance(inner, MetricTree):
                inner = MetricTree(inner)
            inner.add_child(ToDict(Std(inner.name + '_std')))
            return inner
示例#9
0
def running_mean(clazz=None, batch_size=50, step_size=10, dim=None):
    """The :func:`running_mean` decorator is used to add a :class:`.RunningMean` to the :class:`.MetricTree`. If the
    inner class is not a :class:`.MetricTree` then one will be created. The :class:`.RunningMean` will be wrapped in a
    :class:`.ToDict` (with 'running\_' prepended to the name) for simplicity.

    .. note::
        The decorator function does not need to be called if not desired, both: `@running_mean` and `@running_mean()`
        are acceptable.

    Example: ::

        >>> import torch
        >>> from torchbearer import metrics

        >>> @metrics.running_mean(step_size=2) # Update every 2 steps
        ... @metrics.lambda_metric('my_metric')
        ... def metric(y_pred, y_true):
        ...     return y_pred + y_true
        ...
        >>> metric.reset({})
        >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4
        {'running_my_metric': 4.0}
        >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6
        {'running_my_metric': 4.0}
        >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8, triggers update
        {'running_my_metric': 6.0}

    Args:
        clazz: The class to *decorate*
        batch_size (int): See :class:`.RunningMean`
        step_size (int): See :class:`.RunningMean`
        dim (int, tuple): See :class:`.RunningMean`

    Returns:
        decorator or :class:`.MetricTree` instance or wrapper
    """
    if clazz is None:

        def decorator(clazz):
            return running_mean(clazz,
                                batch_size=batch_size,
                                step_size=step_size,
                                dim=dim)

        return decorator

    return _wrap_and_add_to_tree(
        clazz, lambda metric: ToDict(
            RunningMean('running_' + metric.name,
                        batch_size=batch_size,
                        step_size=step_size,
                        dim=dim)))
示例#10
0
def std(clazz=None, unbiased=True, dim=None):
    """The :func:`std` decorator is used to add a :class:`.Std` to the :class:`.MetricTree` which will will output a
    sample standard deviation value at the end of each epoch. At build time, if the inner class is not a
    :class:`.MetricTree`, one will be created. The :class:`.Std` will also be wrapped in a :class:`.ToDict` (with '_std'
    appended) for simplicity.

    Example: ::

        >>> import torch
        >>> from torchbearer import metrics

        >>> @metrics.std
        ... @metrics.lambda_metric('my_metric')
        ... def metric(y_pred, y_true):
        ...     return y_pred + y_true
        ...
        >>> metric.reset({})
        >>> metric.process({'y_pred':torch.Tensor([2]), 'y_true':torch.Tensor([2])}) # 4
        {}
        >>> metric.process({'y_pred':torch.Tensor([3]), 'y_true':torch.Tensor([3])}) # 6
        {}
        >>> metric.process({'y_pred':torch.Tensor([4]), 'y_true':torch.Tensor([4])}) # 8
        {}
        >>> '%.4f' % metric.process_final()['my_metric_std']
        '2.0000'

    Args:
        clazz: The class to *decorate*
        unbiased (bool): See :class:`.Std`
        dim (int, tuple): See :class:`.Std`

    Returns:
        A :class:`.MetricTree` with a :class:`.Std` appended or a wrapper class that extends :class:`.MetricTree`
    """
    if clazz is None:

        def decorator(clazz):
            return std(clazz, unbiased=unbiased, dim=dim)

        return decorator

    return _wrap_and_add_to_tree(
        clazz, lambda metric: ToDict(
            Std(metric.name + '_std', unbiased=unbiased, dim=dim)))
示例#11
0
class TestToDict(unittest.TestCase):
    def setUp(self):
        self._metric = Metric('test')
        self._metric.train = Mock()
        self._metric.eval = Mock()
        self._metric.reset = Mock()
        self._metric.process = Mock(return_value='process')
        self._metric.process_final = Mock(return_value='process_final')

        self._to_dict = ToDict(self._metric)

    def test_train_process(self):
        self._to_dict.train()
        self.assertEqual(self._metric.train.call_count, 1)

        self.assertTrue(self._to_dict.process('input') == {'test': 'process'})
        self._metric.process.assert_called_once_with('input')

    def test_train_process_final(self):
        self._to_dict.train()
        self.assertEqual(self._metric.train.call_count, 1)

        self.assertTrue(
            self._to_dict.process_final('input') == {'test': 'process_final'})
        self._metric.process_final.assert_called_once_with('input')

    def test_eval_process(self):
        self._to_dict.eval()
        self.assertEqual(self._metric.eval.call_count, 1)

        self.assertTrue(
            self._to_dict.process('input') == {'val_test': 'process'})
        self._metric.process.assert_called_once_with('input')

    def test_eval_process_final(self):
        self._to_dict.eval()
        self.assertEqual(self._metric.eval.call_count, 1)

        self.assertTrue(
            self._to_dict.process_final('input') ==
            {'val_test': 'process_final'})
        self._metric.process_final.assert_called_once_with('input')

    def test_reset(self):
        self._to_dict.reset('test')
        self._metric.reset.assert_called_once_with('test')
示例#12
0
class TestToDict(unittest.TestCase):
    def setUp(self):
        self._metric = Metric('test')
        self._metric.train = Mock()
        self._metric.eval = Mock()
        self._metric.reset = Mock()
        self._metric.process = Mock(return_value='process')
        self._metric.process_final = Mock(return_value='process_final')

        self._to_dict = ToDict(self._metric)

    def test_train_process(self):
        self._to_dict.train()
        self._metric.train.assert_called_once()

        self.assertTrue(self._to_dict.process('input') == {'test': 'process'})
        self._metric.process.assert_called_once_with('input')

    def test_train_process_final(self):
        self._to_dict.train()
        self._metric.train.assert_called_once()

        self.assertTrue(self._to_dict.process_final('input') == {'test': 'process_final'})
        self._metric.process_final.assert_called_once_with('input')

    def test_eval_process(self):
        self._to_dict.eval()
        self._metric.eval.assert_called_once()

        self.assertTrue(self._to_dict.process('input') == {'val_test': 'process'})
        self._metric.process.assert_called_once_with('input')

    def test_eval_process_final(self):
        self._to_dict.eval()
        self._metric.eval.assert_called_once()

        self.assertTrue(self._to_dict.process_final('input') == {'val_test': 'process_final'})
        self._metric.process_final.assert_called_once_with('input')

    def test_reset(self):
        self._to_dict.reset('test')
        self._metric.reset.assert_called_once_with('test')