예제 #1
0
def test_1d_simples():
    """Test trivial gradient updates."""
    LOGGER.info("Testing Addition.")
    a = tj.var(np.ones(3) * 5)
    b = tj.var(np.ones(3) * 10)

    c = tj.add(a, b)

    for i in range(5):
        g = tj.gradients(c, [a])
        LOGGER.info("%s *** a: %s -- b: %s -- g: %s -- dt: 0.1" % (i, a, b, g))

        a.update(a.v + g[0] * 1e-1)

    assert _true(a.v > 5.0), "A should be larger than 5 but is %s" % a

    LOGGER.info("Testing Subtraction.")
    a = tj.var(np.ones(3) * 5)
    b = tj.var(np.ones(3) * 10)

    c = tj.sub(a, b)

    for i in range(5):
        g = tj.gradients(c, [b])
        LOGGER.info("%s *** a: %s -- b: %s -- g: %s -- dt: 0.1" % (i, a, b, g))

        b.update(b.v + g[0] * 1e-1)

    assert _true(b.v < 10.0), "A should be smaller than 10 but is %s" % a

    LOGGER.info("Testing Multiplication.")
    a = tj.var(np.ones(3) * 5)
    b = tj.var(np.ones(3) * 10)

    c = tj.mul(a, b)

    for i in range(5):
        g = tj.gradients(c, [a])
        LOGGER.info("%s *** a: %s -- b: %s -- g: %s -- dt: 0.1" % (i, a, b, g))

        a.update(a.v + g[0] * 1e-1)

    assert _true(a.v > 5.0), "A should be larger than 5 but is %s" % a

    LOGGER.info("Testing Division.")
    a = tj.var(np.ones(3) * 5)
    b = tj.var(np.ones(3) * 10)

    c = tj.div(a, b)

    for i in range(5):
        g = tj.gradients(c, [b])
        LOGGER.info("%s *** a: %s -- b: %s -- g: %s -- dt: 0.1" % (i, a, b, g))

        b.update(b.v + g[0] * 1e-1)

    assert _true(b.v < 10.0), "B should be smaller than 10 but is %s" % a
예제 #2
0
def test_convex():
    """Test optimising a simple convex function."""
    LOGGER.info("Testing simple convex: x^2")
    for s in [1, 3]:
        a = tj.var(np.ones(s))
        b = tj.mul(a, a)

        LOGGER.info("Initial a: %s" % a)

        for i in range(100):
            g = tj.gradients(b, [a])

            a.update(a.v - g[0] * 1e-1)

        LOGGER.info("Final a: %s" % a)

        assert _true(
            abs(a.v) < 1.0), "A should be smaller than 1 but is %s" % a

    LOGGER.info("Testing more complex convex: (x * 5 + 3 - x)^2")

    for s in [1, 3]:
        a = tj.var(np.ones(s))
        b = tj.var(5)

        c = tj.mul(a, b)
        c = tj.add(c, 3)
        c = tj.sub(c, a)
        c = tj.mul(c, c)

        LOGGER.info("Initial a: %s" % a)

        for i in range(100):
            g = tj.gradients(c, [a])

            a.update(a.v - g[0] * 1e-2)

        LOGGER.info("Final a: %s" % a.v)

        assert _true(
            abs(a.v) < 1.0), "A should be smaller than 1 but is %s" % a
예제 #3
0
def test_adding_monoids():
    """See if monoids is added to graph."""
    tj.tjgraph.clear()

    LOGGER.info("Checking right number of nodes.")
    a = tj.var(np.random.rand(1, 5))
    b = tj.var(np.random.rand(1, 5))

    c = tj.add(a, b, name="first-addition")

    nodes = tj.tjgraph.get_nodes()
    vars = tj.tjgraph.get_variables()

    assert len(nodes) == 3, "Nodes should be 3 is %s" % len(nodes)
    assert len(vars) == 2, "Vars should be 2 is %s" % len(vars)

    LOGGER.info("Checking names.")
    var_names = [v.name for v in vars]
    node_names = [n.name for n in nodes]
    assert "first-addition" in node_names,\
        "first-addition should be in %s" % node_names

    assert "first-addition" not in var_names,\
        "first-addition should not be in %s" % var_names

    d = tj.mul(c, b, name="first-multiplication")

    nodes = tj.tjgraph.get_nodes()
    node_names = [n.name for n in nodes]

    LOGGER.info("Checking more right number of nodes.")
    assert len(nodes) == 4, "Nodes should be 4 is %s" % len(nodes)
    assert "first-addition" in node_names,\
        "first-addition should be in %s" % node_names

    tj.tjgraph.clear()

    nodes = tj.tjgraph.get_nodes()
    vars = tj.tjgraph.get_variables()

    LOGGER.info("Testing clearing")
    assert len(nodes) == 0, "Nodes should be 0 is %s" % len(nodes)
    assert len(vars) == 0, "Vars should be 0 is %s" % len(vars)

    a = tj.var(np.random.rand(1, 5))
    b = tj.var(np.random.rand(1, 5))

    c = tj.add(a, b)
    c = tj.add(c, a)
    c = tj.add(c, b)
    d = tj.mul(c, b, name="cookie")
    e = tj.sub(c, d)
    f = tj.div(d, e, name="milk")
    f = tj.add(f, f)

    nodes = tj.tjgraph.get_nodes()
    vars = tj.tjgraph.get_variables()

    var_names = [v.name for v in vars]
    node_names = [n.name for n in nodes]

    LOGGER.info("Test adding a lot of monoids")
    assert len(nodes) == 9, "Nodes should be 9 is %s" % len(nodes)
    assert len(vars) == 2, "Vars should be 2 is %s" % len(vars)

    LOGGER.info("Checking names")
    assert "cookie" in node_names,\
        "cookie should be in %s" % node_names

    assert "cookie" not in var_names,\
        "cookie should not be in %s" % var_names

    assert "milk" in node_names,\
        "milk should be in %s" % node_names

    assert "milk" not in var_names,\
        "milk should not be in %s" % var_names

    tj.tjgraph.clear()

    nodes = tj.tjgraph.get_nodes()
    vars = tj.tjgraph.get_variables()

    LOGGER.info("Testing clearing")
    assert len(nodes) == 0, "Nodes should be 0 is %s" % len(nodes)
    assert len(vars) == 0, "Vars should be 0 is %s" % len(vars)