コード例 #1
0
ファイル: test_session.py プロジェクト: tor4z/AD_OO
def test_session_neg():
    a = ad.constant(1, 'a')

    assert a.value == 1
    c_true = ad.constant(-1, 'c')
    c = ad.neg(a)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
コード例 #2
0
ファイル: test_session.py プロジェクト: tor4z/AD_OO
def test_session_reciprocal():
    a = ad.constant(3, 'a')
    assert a.value == 3

    c_true = ad.constant(1 / 3, 'c')
    c = ad.reciprocal(a)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
コード例 #3
0
ファイル: test_session.py プロジェクト: tor4z/AD_OO
def test_session_log():
    a = ad.constant(3, 'a')

    assert a.value == 3

    c_true = ad.constant(math.log(3), 'c')
    c = ad.log(a)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
コード例 #4
0
ファイル: test_session.py プロジェクト: tor4z/AD_OO
def test_session_pow():
    a = ad.constant(2, 'a')
    b = ad.constant(3, 'b')

    assert a.value == 2
    assert b.value == 3

    c_true = ad.constant(2**3, 'c')
    c = ad.pow(a, b)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
コード例 #5
0
ファイル: test_session.py プロジェクト: tor4z/AD_OO
def test_session_div():
    a = ad.constant(1, 'a')
    b = ad.constant(2, 'b')

    assert a.value == 1
    assert b.value == 2

    c_true = ad.constant(1 / 2, 'c')
    c = ad.div(a, b)  # build graph

    with ad.Session() as sess:
        # eval graph
        c_out = sess.run(c)

    assert c_out.value == c_true.value
コード例 #6
0
ファイル: test_ops.py プロジェクト: tor4z/AD_OO
def test_constant():
    # int value
    a = ad.constant(1, 'a')
    assert a.value == 1
    assert a.name == 'a'

    # float value
    a = ad.constant(0.5, 'a')
    assert a.value == 0.5
    assert a.name == 'a'

    # numpy array
    np_a = np.random.rand(5)
    a = ad.constant(np_a, 'a')
    assert np.array_equal(a.value, np_a)
    assert a.name == 'a'
コード例 #7
0
ファイル: test_ops.py プロジェクト: tor4z/AD_OO
def test_trainable_parameters():
    node1 = ad.constant(1)
    node2 = ad.variable(2)
    node3 = ad.variable(3)
    node4 = ad.variable(4)
    node5 = ad.constant(5)

    node1_node2 = ad.mul(node1, node2)
    node1_node3 = ad.mul(node1, node3)
    node1_node4 = ad.mul(node1, node4)

    node_sum = ad.add(node1_node2,
                      node1_node3,
                      node1_node4)
    final_node = ad.mul(node_sum, node5)
    trainable_nodes = final_node.trainable_parameters()

    assert len(trainable_nodes) == 3

    for node in trainable_nodes:
        assert node.value in [2, 3, 4]
コード例 #8
0
def test_trainable_parameters_grads():
    node1 = ad.constant(1)
    node2 = ad.variable(2)
    node3 = ad.variable(3)
    node4 = ad.variable(4)
    node5 = ad.constant(5)

    node1_node2 = ad.mul(node1, node2)
    node1_node3 = ad.mul(node1, node3)
    node1_node4 = ad.mul(node1, node4)

    node_sum = ad.add(node1_node2, node1_node3, node1_node4)
    final_node = ad.mul(node_sum, node5)
    trainable_nodes = final_node.trainable_parameters()

    with Session() as sess:
        out = sess.run(final_node)
    assert out.value == 45

    # test grad
    trainable_node_grads = final_node.grad(trainable_nodes)
    assert len(trainable_node_grads) == len(trainable_nodes)

    with Session() as sess:
        trainable_node_grads = sess.run(trainable_node_grads)

    for node, grad in zip(trainable_nodes, trainable_node_grads):
        print(grad.value, node.value)
        continue
        if node.name == '2':
            assert grad.value == 5
        elif node.name == '3':
            assert grad.value == 5
        elif node.name == '4':
            assert grad.value == 5
        else:
            raise ValueError('Not a trainable node')

    assert False