def test_default_for_key_args(self): mock = Mock() class MyMetric(metrics.Metric): def __init__(self, *args, **kwargs): super().__init__('test') mock(*args, **kwargs) default_for_key('test', 10, some_arg='a test')(MyMetric) metrics.get_default('test') mock.assert_called_once_with(10, some_arg='a test')
def test_default_for_key_class(self): metric = metrics.Loss metric = default_for_key('test')(metric) self.assertTrue(metrics.get_default('test').name == 'loss') self.assertTrue(metric == metrics.Loss)
def test_default_for_key_class(self): metric = metrics.Loss metric = default_for_key('test')(metric) self.assertTrue('test' in metrics.DEFAULT_METRICS) self.assertTrue(metrics.DEFAULT_METRICS['test'].name == 'loss') self.assertTrue(metric == metrics.Loss)