def test_update_losses(self): l1 = RALoss('l1', 5) l2 = RALoss('l2', 5) losses = [(l1, 10), (l2, 20)] utils.update_losses(losses) utils.update_losses(losses) self.assertEqual(l1.get_history(), [10, 10]) self.assertEqual(l2.get_history(), [20, 20])
def test_get_value(self): l = RALoss('name', 3) dat = [0, 1, 2, 3, 4, 5, 6, 7] for x in dat: l.update(x) self.assertEqual(dat, l.get_history()) self.assertEqual(6.0, l.get_value()) self.assertEqual(3.0, l.get_value(i=4))