def test_allExpressions(self): """Compare value of all expressions of depth 1 and 2 with math library. Also check for correct reference counting""" def generate_all_combinations(indep): """Generate all possible combinations of expressions, depth 2 (currently approx. 50000)""" us = [u % i for u in UNARIES for i in indep] uu = [u % i for u in UNARIES for i in us] bss = [b % (i, j) for b in BINARIES for i in indep for j in indep] ub = [u % i for u in UNARIES for i in bss] bsu = [b % (i, j) for b in BINARIES for i in indep for j in us] bus = [b % (j, i) for b in BINARIES for i in indep for j in us] buu = [b % (i, j) for b in BINARIES for i in us for j in us] bbs = [b % (i, j) for b in BINARIES for i in bss for j in indep] bsb = [b % (j, i) for b in BINARIES for i in bss for j in indep] bbb = [b % (i, j) for b in BINARIES for i in bss for j in bss] bbu = [b % (i, j) for b in BINARIES for i in bss for j in us] bub = [b % (j, i) for b in BINARIES for i in bss for j in us] return us + uu + bss + ub + bsu + bus + buu + bbs + bsb + bbb + bbu + bub indep = ["z", "o", "x", "y"] allExpressions = generate_all_combinations(indep) z, o = Salt(SYM_ZERO.dup()), Salt(SYM_ONE.dup()) x, y = Leaf(0.4), Leaf(1.5) scopeFloat = {"z": z.value, "o": o.value, "x": x.value, "y": y.value, "inv": lambda x: 1.0/x, "squ": lambda x: x * x} math = __import__('math') m = dict((n, getattr(math, n)) for n in dir(math) if not n.startswith("_")) scopeFloat.update(m) scopeSalt = {"z": z, "o": o, "x": x, "y": y} salt = __import__('salty') m = dict((n, getattr(salt, n)) for n in dir(salt) if not n.startswith("_")) scopeSalt.update(m) for ex in allExpressions: try: ref = eval(ex, scopeFloat) print(ref, ex, scopeFloat["x"]) if not isinstance(ref, float): raise ValueError() if abs(ref) > 1e10: raise ValueError() except ValueError: pass except ZeroDivisionError: pass else: res = eval(ex, scopeSalt).value try: err = "Expression '%s' went south: %e != %e" % (ex, ref, res) self.assertAlmostEqual(ref, res, 5, err) except Exception: err = "Expression '%s' went south: %s != %s" % \ (ex, repr(ref), repr(res)) self.assertTrue(False, err) # and now, reference count check on independent nodes zn, on, xn = z.node, o.node, x.node del scopeSalt, z, o, x self.assertEqual(zn.ref_count, 1) # global object SYM_ZERO remaining self.assertEqual(on.ref_count, 1) # global object SYM_ONE remaining self.assertEqual(xn.ref_count, 0)
def test_add(self): a = Leaf(3.0) b = Leaf(4.0) c = a + b self.assertEqual(c.value, 7.0) a.value = 5.0 c.invalidate() self.assertEqual(c.value, 9.0)
def test_div(self): a, b = Leaf(6.0), Leaf(4.0) c = a / b res = Derivative([c], [a, b])[0].value self.assertListEqual(res, [0.25, -0.375]) del c self.assertEqual(a.node.ref_count, 1) self.assertEqual(b.node.ref_count, 1)
def test_mul(self): a, b = Leaf(11.0), Leaf(7.0) c = a * b res = Derivative([c], [a, b])[0].value self.assertListEqual(res, [7.0, 11.0]) del c self.assertEqual(a.node.ref_count, 1) self.assertEqual(b.node.ref_count, 1)
def test_pow(self): a, b = Leaf(3.0), Leaf(4.0) c = a**b res = Derivative([c], [a, b])[0].value self.assertListEqual(res, [108.0, 81 * math.log(3.0)]) del c self.assertEqual(a.node.ref_count, 1) self.assertEqual(b.node.ref_count, 1)
def test_div(self): a, b = Leaf(6.0), Leaf(4.0) c = a / b res = sparse_derivative([c], [a, b]) res = [res[0][0].value, res[0][1].value] self.assertListEqual(res, [0.25, -0.375]) del c self.assertEqual(a.node.ref_count, 1) self.assertEqual(b.node.ref_count, 1)
def test_sub(self): a, b = Leaf(11.0), Leaf(7.0) c = a - b res = sparse_derivative([c], [a, b]) res = [res[0][0].value, res[0][1].value] self.assertListEqual(res, [1.0, -1.0]) del c self.assertEqual(a.node.ref_count, 1) self.assertEqual(b.node.ref_count, 1)
def test_uneven(self): a, b = Leaf(3.0), Leaf(2.0) c = [ a * b, (a * b, a * b), {a * b, a * b, a * b}, ([a * b, a * b], (a * b, a * b)) ] dups = simplify(c) self.assertEqual(dups, 9) self.assertEqual(a.node.ref_count, 2)
def test_sel(self): a, b = Leaf(13.0), Leaf(4.0) c = a.select(b) d = Derivative([c], [a, b]) self.assertListEqual(d.value, [[1.0, 0.0]]) b.value = -1.0 d[0][0].invalidate() self.assertListEqual(d.value, [[0.0, 0.0]]) del c, d self.assertEqual(a.node.ref_count, 1) self.assertEqual(b.node.ref_count, 1)
def test_sel(self): a, b = Leaf(13.0), Leaf(4.0) c = a.select(b) d = sparse_derivative([c], [a, b]) dv = [d[0][0].value, d[0][1].value] self.assertListEqual(dv, [1.0, 0.0]) b.value = -1.0 d[0][0].invalidate() self.assertEqual(d[0][0].value, 0.0) del c, d self.assertEqual(a.node.ref_count, 1) self.assertEqual(b.node.ref_count, 1)
def test_mul(self): a, b = Leaf(11.0), Leaf(7.0) print("Def a:", a.node.ref_count) c = a * b print("Def c:", a.node.ref_count) res = sparse_derivative([c], [a, b]) print("Der c:", a.node.ref_count) res = [res[0][0].value, res[0][1].value] self.assertListEqual(res, [7.0, 11.0]) print("Del res:", a.node.ref_count) del c print("Del c:", a.node.ref_count) self.assertEqual(a.node.ref_count, 1) self.assertEqual(b.node.ref_count, 1)
def test_cosh(self): a = Leaf(2.0) c = cosh(a) res = Derivative([c], [a])[0, 0].value self.assertEqual(res, math.sinh(2.0)) del c self.assertEqual(a.node.ref_count, 1)
def test_cos(self): a = Leaf(2.0) c = cos(a) res = sparse_derivative([c], [a])[0][0].value self.assertEqual(res, -math.sin(2.0)) del c self.assertEqual(a.node.ref_count, 1)
def test_hook2(self): a = Leaf(1.0) b = exp(a) b_plain = b.plain() c = Derivative([b, b_plain], [a]).value self.assertEqual(c[0][0], b.value) self.assertEqual(c[1][0], 0.0)
def test_inv(self): a = Leaf(2.0) c = inv(a) res = sparse_derivative([c], [a])[0][0].value self.assertEqual(res, -0.25) del c self.assertEqual(a.node.ref_count, 1)
def test_tanh(self): a = Leaf(2.0) c = tanh(a) res = sparse_derivative([c], [a])[0][0].value self.assertEqual(res, math.cosh(2.0)**-2) del c self.assertEqual(a.node.ref_count, 1)
def test_sqrt(self): a = Leaf(9.0) c = sqrt(a) res = sparse_derivative([c], [a])[0][0].value self.assertEqual(res, 1.0 / 6.0) del c self.assertEqual(a.node.ref_count, 1)
def test_acos(self): a = Leaf(0.6) c = acos(a) res = sparse_derivative([c], [a])[0][0].value self.assertEqual(res, -(1 - 0.36)**-0.5) del c self.assertEqual(a.node.ref_count, 1)
def test_atan(self): a = Leaf(0.6) c = atan(a) res = sparse_derivative([c], [a])[0][0].value self.assertEqual(res, 1 / (1 + 0.6 * 0.6)) del c self.assertEqual(a.node.ref_count, 1)
def test_asin(self): a = Leaf(0.6) c = asin(a) res = Derivative([c], [a])[0, 0].value self.assertEqual(res, (1 - 0.36)**-0.5) del c self.assertEqual(a.node.ref_count, 1)
def test_log(self): a = Leaf(4.0) c = log(a) res = Derivative([c], [a])[0, 0].value self.assertEqual(res, 0.25) del c self.assertEqual(a.node.ref_count, 1)
def test_tan(self): a = Leaf(2.0) c = tan(a) res = Derivative([c], [a])[0, 0].value self.assertEqual(res, math.cos(2.0)**-2) del c self.assertEqual(a.node.ref_count, 1)
def test_neg(self): a = Leaf(13.0) c = -a res = sparse_derivative([c], [a])[0][0].value self.assertEqual(res, -1.0) del c self.assertEqual(a.node.ref_count, 1)
def test_nested(self): N = 300 # not much more, stack would be exhausted c = a = Leaf(2.0) for _ in range(N): c = c + sin(a) dups = simplify(c) self.assertEqual(dups, N - 1) self.assertEqual(a.node.ref_count, 3)
def test_cache(self): a = Leaf(3.0) b = 1 / a self.assertEqual(b.node.tid, ID_INV) c = 12.0 / sin(a) d = 12.0 - sin(a) self.assertEqual(id(c.node.childs[0].tid), id(d.node.childs[0].tid))
def test_squ(self): a = Leaf(13.0) c = a * a d = squ(a) res = Derivative([c, d], [a]).value self.assertListEqual(res, [[26.0], [26.0]]) del c, d self.assertEqual(a.node.ref_count, 1)
def test_squ(self): a = Leaf(13.0) c = a * a d = squ(a) res = sparse_derivative([c, d], [a]) res = [res[0][0].value, res[1][0].value] self.assertListEqual(res, [[26.0], [26.0]]) del c, d self.assertEqual(a.node.ref_count, 1)
def test_empanadina(self): def func(x): y = x**6 J = 6 * x**5 return y, J a = Leaf(2.0) b = sqrt(a) y = empanadina(func, b) self.assertAlmostEqual(y.value, 8.0) self.assertAlmostEqual(Derivative([y], [a])[0, 0].value, 12.0)
def test_nested(self): a, b = Leaf(3.0), Leaf(2.0) c = [ a * b, (a * b, a * b), {a * b, a * b, a * b}, ([a * b, a * b], (a * b, a * b)) ] c = SaltArray(c) vc = c.value ref = [6.0, [6.0, 6.0], [6.0, 6.0, 6.0], [[6.0, 6.0], [6.0, 6.0]]] self.assertListEqual(ref, vc) c.invalidate() a.value = 4.0 ref = [8.0, [8.0, 8.0], [8.0, 8.0, 8.0], [[8.0, 8.0], [8.0, 8.0]]] vc = c.value self.assertListEqual(ref, vc) a.value = 5.0 vc = c.recalc() ref = [ 10.0, [10.0, 10.0], [10.0, 10.0, 10.0], [[10.0, 10.0], [10.0, 10.0]] ] self.assertListEqual(ref, vc)
def test_float_mix_pow(self): Salt.FLOAT_CACHE_MAX = 0 # no caching here a = Leaf(3.0) b = a ** 4 self.assertEqual(b.value, 81.0) b = 4.0 ** a self.assertEqual(b.value, 64.0) with self.assertRaises(ValueError): b = a ** "Hello World!" Salt.ALLOW_MIX_FLOAT = False with self.assertRaises(AttributeError): b = a ** 4.0 with self.assertRaises(AttributeError): b = 4 ** a