def test_sum_forwards_equals_backwards(self, its = 200): for it in range(its): def deltaTime(label): return 0 net = gen_simple_ConcreteNet(defaultGenLabel, deltaTime, sortable = True, pathMustExist = True) ring = semiring.LogRealsField() labelToWeight = memoize(lambda label: ring.one if label is None else randn()) def getAgenda(forwards): return wnet.SimpleSumAgenda(ring) totalWeightForwards = wnet.sum(net, labelToWeight, ring, getAgenda = getAgenda, forwards = True) totalWeightBackwards = wnet.sum(net, labelToWeight, ring, getAgenda = getAgenda, forwards = True) assert ring.isClose(totalWeightForwards, totalWeightBackwards)
def test_memoize(self): def f(x): return x * x fe = FnEval(f) fm = memoize(fe) assert fe.evalCount == 0 x1 = 0.1 x2 = 0.2 x3 = 0.3 assert fm(x1) == f(x1) assert fe.evalCount == 1 assert fm(x2) == f(x2) assert fe.evalCount == 2 assert fm(x1) == f(x1) assert fe.evalCount == 2 assert fm(x2) == f(x2) assert fe.evalCount == 2 assert fm(x2) == f(x2) assert fe.evalCount == 2 assert fm(x3) == f(x3) assert fe.evalCount == 3