def test_1d_simples(): """Test trivial gradient updates.""" LOGGER.info("Testing Addition.") a = tj.var(np.ones(3) * 5) b = tj.var(np.ones(3) * 10) c = tj.add(a, b) for i in range(5): g = tj.gradients(c, [a]) LOGGER.info("%s *** a: %s -- b: %s -- g: %s -- dt: 0.1" % (i, a, b, g)) a.update(a.v + g[0] * 1e-1) assert _true(a.v > 5.0), "A should be larger than 5 but is %s" % a LOGGER.info("Testing Subtraction.") a = tj.var(np.ones(3) * 5) b = tj.var(np.ones(3) * 10) c = tj.sub(a, b) for i in range(5): g = tj.gradients(c, [b]) LOGGER.info("%s *** a: %s -- b: %s -- g: %s -- dt: 0.1" % (i, a, b, g)) b.update(b.v + g[0] * 1e-1) assert _true(b.v < 10.0), "A should be smaller than 10 but is %s" % a LOGGER.info("Testing Multiplication.") a = tj.var(np.ones(3) * 5) b = tj.var(np.ones(3) * 10) c = tj.mul(a, b) for i in range(5): g = tj.gradients(c, [a]) LOGGER.info("%s *** a: %s -- b: %s -- g: %s -- dt: 0.1" % (i, a, b, g)) a.update(a.v + g[0] * 1e-1) assert _true(a.v > 5.0), "A should be larger than 5 but is %s" % a LOGGER.info("Testing Division.") a = tj.var(np.ones(3) * 5) b = tj.var(np.ones(3) * 10) c = tj.div(a, b) for i in range(5): g = tj.gradients(c, [b]) LOGGER.info("%s *** a: %s -- b: %s -- g: %s -- dt: 0.1" % (i, a, b, g)) b.update(b.v + g[0] * 1e-1) assert _true(b.v < 10.0), "B should be smaller than 10 but is %s" % a
def test_convex(): """Test optimising a simple convex function.""" LOGGER.info("Testing simple convex: x^2") for s in [1, 3]: a = tj.var(np.ones(s)) b = tj.mul(a, a) LOGGER.info("Initial a: %s" % a) for i in range(100): g = tj.gradients(b, [a]) a.update(a.v - g[0] * 1e-1) LOGGER.info("Final a: %s" % a) assert _true( abs(a.v) < 1.0), "A should be smaller than 1 but is %s" % a LOGGER.info("Testing more complex convex: (x * 5 + 3 - x)^2") for s in [1, 3]: a = tj.var(np.ones(s)) b = tj.var(5) c = tj.mul(a, b) c = tj.add(c, 3) c = tj.sub(c, a) c = tj.mul(c, c) LOGGER.info("Initial a: %s" % a) for i in range(100): g = tj.gradients(c, [a]) a.update(a.v - g[0] * 1e-2) LOGGER.info("Final a: %s" % a.v) assert _true( abs(a.v) < 1.0), "A should be smaller than 1 but is %s" % a
def test_adding_monoids(): """See if monoids is added to graph.""" tj.tjgraph.clear() LOGGER.info("Checking right number of nodes.") a = tj.var(np.random.rand(1, 5)) b = tj.var(np.random.rand(1, 5)) c = tj.add(a, b, name="first-addition") nodes = tj.tjgraph.get_nodes() vars = tj.tjgraph.get_variables() assert len(nodes) == 3, "Nodes should be 3 is %s" % len(nodes) assert len(vars) == 2, "Vars should be 2 is %s" % len(vars) LOGGER.info("Checking names.") var_names = [v.name for v in vars] node_names = [n.name for n in nodes] assert "first-addition" in node_names,\ "first-addition should be in %s" % node_names assert "first-addition" not in var_names,\ "first-addition should not be in %s" % var_names d = tj.mul(c, b, name="first-multiplication") nodes = tj.tjgraph.get_nodes() node_names = [n.name for n in nodes] LOGGER.info("Checking more right number of nodes.") assert len(nodes) == 4, "Nodes should be 4 is %s" % len(nodes) assert "first-addition" in node_names,\ "first-addition should be in %s" % node_names tj.tjgraph.clear() nodes = tj.tjgraph.get_nodes() vars = tj.tjgraph.get_variables() LOGGER.info("Testing clearing") assert len(nodes) == 0, "Nodes should be 0 is %s" % len(nodes) assert len(vars) == 0, "Vars should be 0 is %s" % len(vars) a = tj.var(np.random.rand(1, 5)) b = tj.var(np.random.rand(1, 5)) c = tj.add(a, b) c = tj.add(c, a) c = tj.add(c, b) d = tj.mul(c, b, name="cookie") e = tj.sub(c, d) f = tj.div(d, e, name="milk") f = tj.add(f, f) nodes = tj.tjgraph.get_nodes() vars = tj.tjgraph.get_variables() var_names = [v.name for v in vars] node_names = [n.name for n in nodes] LOGGER.info("Test adding a lot of monoids") assert len(nodes) == 9, "Nodes should be 9 is %s" % len(nodes) assert len(vars) == 2, "Vars should be 2 is %s" % len(vars) LOGGER.info("Checking names") assert "cookie" in node_names,\ "cookie should be in %s" % node_names assert "cookie" not in var_names,\ "cookie should not be in %s" % var_names assert "milk" in node_names,\ "milk should be in %s" % node_names assert "milk" not in var_names,\ "milk should not be in %s" % var_names tj.tjgraph.clear() nodes = tj.tjgraph.get_nodes() vars = tj.tjgraph.get_variables() LOGGER.info("Testing clearing") assert len(nodes) == 0, "Nodes should be 0 is %s" % len(nodes) assert len(vars) == 0, "Vars should be 0 is %s" % len(vars)