def build(self): if isinstance(self.inner, MetricFactory): inner = self.inner.build() else: inner = self.inner if not isinstance(inner, MetricTree): inner = MetricTree(inner) inner.add_child(ToDict(Std(inner.name + '_std'))) return inner
def build(self): if isinstance(self.inner, MetricFactory): inner = self.inner.build() else: inner = self.inner if not isinstance(inner, MetricTree): inner = MetricTree(inner) inner.add_child(ToDict(RunningMean('running_' + inner.name, batch_size=batch_size, step_size=step_size))) return inner
def test_process_final(self): root = Metric('test') root.process_final = Mock(return_value='test') leaf1 = Metric('test') leaf1.process_final = Mock(return_value={'test': 10}) leaf2 = Metric('test') leaf2.process_final = Mock(return_value=None) tree = MetricTree(root) tree.add_child(leaf1) tree.add_child(leaf2) self.assertTrue(tree.process_final('args') == {'test': 10}) root.process_final.assert_called_once_with('args') leaf1.process_final.assert_called_once_with('test') leaf2.process_final.assert_called_once_with('test')
def test_dict_return(self): root = Metric('test') root.process = Mock(return_value={0: 'test', 1: 'something else'}) leaf1 = Metric('test') leaf1.process = Mock(return_value={'test': 10}) leaf2 = Metric('test') leaf2.process = Mock(return_value=None) tree = MetricTree(root) tree.add_child(leaf1) tree.add_child(leaf2) self.assertTrue(tree.process('args') == {'test': 10}) root.process.assert_called_once_with('args') leaf1.process.assert_called_once_with('test') leaf2.process.assert_called_once_with('test')
def _wrap_and_add_to_tree(clazz, child_func): if inspect.isclass(clazz): class Wrapper(MetricTree): def __init__(self, *args, **kwargs): inner = clazz(*args, **kwargs) if isinstance(inner, MetricTree): super(Wrapper, self).__init__(inner.root) self.children = inner.children else: super(Wrapper, self).__init__(inner) self.add_child(child_func(self.root)) return Wrapper else: inner = clazz if not isinstance(inner, MetricTree): inner = MetricTree(inner) inner.add_child(child_func(inner)) return inner
def test_reset(self): root = Metric('test') root.reset = Mock() leaf = Metric('test') leaf.reset = Mock() tree = MetricTree(root) tree.add_child(leaf) tree.reset({}) root.reset.assert_called_once_with({}) leaf.reset.assert_called_once_with({})
def test_eval(self): root = Metric('test') root.eval = Mock() leaf = Metric('test') leaf.eval = Mock() tree = MetricTree(root) tree.add_child(leaf) tree.eval() root.eval.assert_called_once() leaf.eval.assert_called_once()
def test_train(self): root = Metric('test') root.train = Mock() leaf = Metric('test') leaf.train = Mock() tree = MetricTree(root) tree.add_child(leaf) tree.train() root.train.assert_called_once() leaf.train.assert_called_once()
def test_eval(self): root = Metric('test') root.eval = Mock() leaf = Metric('test') leaf.eval = Mock() tree = MetricTree(root) tree.add_child(leaf) tree.eval() self.assertEqual(root.eval.call_count, 1) self.assertEqual(leaf.eval.call_count, 1)
def test_train(self): root = Metric('test') root.train = Mock() leaf = Metric('test') leaf.train = Mock() tree = MetricTree(root) tree.add_child(leaf) tree.train() self.assertEqual(root.train.call_count, 1) self.assertEqual(leaf.train.call_count, 1)
def test_string(self): root = Metric('test') tree = MetricTree(root) self.assertEqual(str(root), str(tree))