Ejemplo n.º 1
0
def test_log_grad():
    a = ad.variable(2)
    c = ad.log(a)

    dc_da = c.grad(a)

    with ad.Session() as sess:
        dc_da_out = sess.run(dc_da)

    assert dc_da_out.value == 1 / 2
Ejemplo n.º 2
0
def test_reciprocal_grad():
    a = ad.variable(2)
    c = ad.reciprocal(a)

    dc_da = c.grad(a)

    with ad.Session() as sess:
        dc_da_out = sess.run(dc_da)

    assert dc_da_out.value == -1 / 4
Ejemplo n.º 3
0
def test_session_neg():
    a = ad.constant(1, 'a')

    assert a.value == 1
    c_true = ad.constant(-1, 'c')
    c = ad.neg(a)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
Ejemplo n.º 4
0
def test_session_reciprocal():
    a = ad.constant(3, 'a')
    assert a.value == 3

    c_true = ad.constant(1 / 3, 'c')
    c = ad.reciprocal(a)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
Ejemplo n.º 5
0
def test_session_log():
    a = ad.constant(3, 'a')

    assert a.value == 3

    c_true = ad.constant(math.log(3), 'c')
    c = ad.log(a)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
Ejemplo n.º 6
0
def test_pow_grad():
    a = ad.variable(2)
    b = ad.variable(3)
    c = ad.pow(a, b)

    dc_da = c.grad(a)
    dc_db = c.grad(b)

    with ad.Session() as sess:
        dc_da_out = sess.run(dc_da)
        dc_db_out = sess.run(dc_db)

    assert dc_da_out.value == 3 * 2**2
    assert dc_db_out.value == 2**3 * math.log(2)
Ejemplo n.º 7
0
def test_add_grad():
    a = ad.variable(1)
    b = ad.variable(2)
    c = ad.add(a, b)

    dc_da = c.grad(a)
    dc_db = c.grad(b)

    with ad.Session() as sess:
        dc_da_out = sess.run(dc_da)
        dc_db_out = sess.run(dc_db)

    assert dc_da_out.value == 1
    assert dc_db_out.value == 1
Ejemplo n.º 8
0
def test_session_pow():
    a = ad.constant(2, 'a')
    b = ad.constant(3, 'b')

    assert a.value == 2
    assert b.value == 3

    c_true = ad.constant(2**3, 'c')
    c = ad.pow(a, b)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
Ejemplo n.º 9
0
def test_session_div():
    a = ad.constant(1, 'a')
    b = ad.constant(2, 'b')

    assert a.value == 1
    assert b.value == 2

    c_true = ad.constant(1 / 2, 'c')
    c = ad.div(a, b)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
Ejemplo n.º 10
0
def test_session_graph():
    graph = ad.Graph()
    sess = ad.Session(graph)
    assert sess.graph is not None
Ejemplo n.º 11
0
def test_session_default_graph():
    sess = ad.Session()
    assert sess._graph_set is not None