示例#1
0
def test_log_grad():
    a = ad.variable(2)
    c = ad.log(a)

    dc_da = c.grad(a)

    with ad.Session() as sess:
        dc_da_out = sess.run(dc_da)

    assert dc_da_out.value == 1 / 2
示例#2
0
def test_reciprocal_grad():
    a = ad.variable(2)
    c = ad.reciprocal(a)

    dc_da = c.grad(a)

    with ad.Session() as sess:
        dc_da_out = sess.run(dc_da)

    assert dc_da_out.value == -1 / 4
示例#3
0
def test_session_neg():
    a = ad.constant(1, 'a')

    assert a.value == 1
    c_true = ad.constant(-1, 'c')
    c = ad.neg(a)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
示例#4
0
def test_session_reciprocal():
    a = ad.constant(3, 'a')
    assert a.value == 3

    c_true = ad.constant(1 / 3, 'c')
    c = ad.reciprocal(a)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
示例#5
0
def test_session_log():
    a = ad.constant(3, 'a')

    assert a.value == 3

    c_true = ad.constant(math.log(3), 'c')
    c = ad.log(a)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
示例#6
0
def test_pow_grad():
    a = ad.variable(2)
    b = ad.variable(3)
    c = ad.pow(a, b)

    dc_da = c.grad(a)
    dc_db = c.grad(b)

    with ad.Session() as sess:
        dc_da_out = sess.run(dc_da)
        dc_db_out = sess.run(dc_db)

    assert dc_da_out.value == 3 * 2**2
    assert dc_db_out.value == 2**3 * math.log(2)
示例#7
0
def test_add_grad():
    a = ad.variable(1)
    b = ad.variable(2)
    c = ad.add(a, b)

    dc_da = c.grad(a)
    dc_db = c.grad(b)

    with ad.Session() as sess:
        dc_da_out = sess.run(dc_da)
        dc_db_out = sess.run(dc_db)

    assert dc_da_out.value == 1
    assert dc_db_out.value == 1
示例#8
0
def test_session_pow():
    a = ad.constant(2, 'a')
    b = ad.constant(3, 'b')

    assert a.value == 2
    assert b.value == 3

    c_true = ad.constant(2**3, 'c')
    c = ad.pow(a, b)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
示例#9
0
def test_session_div():
    a = ad.constant(1, 'a')
    b = ad.constant(2, 'b')

    assert a.value == 1
    assert b.value == 2

    c_true = ad.constant(1 / 2, 'c')
    c = ad.div(a, b)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
示例#10
0
def test_session_graph():
    graph = ad.Graph()
    sess = ad.Session(graph)
    assert sess.graph is not None
示例#11
0
def test_session_default_graph():
    sess = ad.Session()
    assert sess._graph_set is not None