示例#1
0
def test_eq():
    x1 = ad.Var(3)
    x2 = ad.Var(3)
    y1 = 3.
    y2 = 3.

    assert x1 == x2 == y1 == y2
示例#2
0
def test_cmp():
    x1, x2 = ad.Var(1), ad.Var(2)
    y = 3

    assert x1 < x2
    assert x2 > x1
    assert x1 < y
    assert y > x1
    assert x2 < y
    assert y > x2
示例#3
0
def test_topo1():
    x1, x2, x3 = ad.Var(1), ad.Var(2), ad.Var(3)
    y1 = x1 * x2
    y2 = x2 + x3
    y3 = y1 / y2
    y3.backward()

    for var in [x1, x2, x3, y1, y2, y3]:
        assert not var._engine.is_working()

    assert math.isclose(x1.grad, 0.4)
    assert math.isclose(x2.grad, 0.12)
    assert math.isclose(x3.grad, -0.08)
示例#4
0
def test_neg():
    x = ad.Var(2)
    y = -x
    y.backward()

    assert x.grad == -1

    # advanced
    x1, x2 = ad.Var(1), ad.Var(2)
    x3 = x1 + (-x2)
    x3.backward()

    assert x1.grad == 1
    assert x2.grad == -1
示例#5
0
def test_add():
    # add two vars
    x1, x2 = ad.Var(1), ad.Var(2)
    x3 = x1 + x2
    x3.backward()

    assert x1.grad == 1
    assert x2.grad == 1

    # add var and pyint
    x1, x2 = ad.Var(1), 2
    x3 = x1 + x2
    x3.backward()

    assert x1.grad == 1

    x1, x2 = 1, ad.Var(2)
    x3 = x1 + x2
    x3.backward()

    assert x2.grad == 1
示例#6
0
def test_truediv():
    # div two vars
    x1, x2 = ad.Var(1), ad.Var(2)
    x3 = x1 / x2
    x3.backward()

    assert x1.grad == 0.5
    assert x2.grad == -0.25

    # add var and variables
    x1, x2 = ad.Var(1), 2
    x3 = x1 / x2
    x3.backward()

    assert x1.grad == 0.5

    x1, x2 = 1, ad.Var(2)
    x3 = x1 / x2
    x3.backward()

    assert x2.grad == -0.25
示例#7
0
def test_mul():
    # mul two vars
    x1, x2 = ad.Var(2), ad.Var(3)
    x3 = x1 * x2
    x3.backward()

    assert x1.grad == 3
    assert x2.grad == 2

    # add var and variables
    x1, x2 = ad.Var(2), 3
    x3 = x1 * x2
    x3.backward()

    assert x1.grad == 3

    x1, x2 = 2, ad.Var(3)
    x3 = x1 * x2
    x3.backward()

    assert x2.grad == 2
示例#8
0
def test_sub():
    # sub two vars
    x1, x2 = ad.Var(1), ad.Var(2)
    x3 = x1 - x2
    x3.backward()

    assert x1.grad == 1
    assert x2.grad == -1

    # sub var and pyint
    x1, x2 = ad.Var(1), 2
    x3 = x1 - x2
    x3.backward()

    assert x1.grad == 1

    x1, x2 = 1, ad.Var(2)
    x3 = x1 - x2
    x3.backward()

    assert x2.grad == -1
示例#9
0
def test_reciprocal():
    x = ad.Var(2)
    y = op.reciprocal(x)
    y.backward()

    assert x.grad == -0.25
示例#10
0
def test_op():
    x1, x2 = ad.Var(1.), ad.Var(2)
    y = 3.

    # add, radd
    x3 = x1 + x2
    assert isinstance(x3, ad.Var)
    assert x3 == ad.Var(3)

    x3 = x1 + y
    assert isinstance(x3, ad.Var)
    assert x3 == ad.Var(4)

    x3 = y + x1
    assert isinstance(x3, ad.Var)
    assert x3 == ad.Var(4)

    # sub, rsub
    x3 = x1 - x2
    assert isinstance(x3, ad.Var)
    assert x3 == ad.Var(-1)

    x3 = x1 - y
    assert isinstance(x3, ad.Var)
    assert x3 == ad.Var(-2)

    x3 = y - x1
    assert isinstance(x3, ad.Var)
    assert x3 == ad.Var(2)

    # mul, rmul
    x3 = x1 * x2
    assert isinstance(x3, ad.Var)
    assert x3 == ad.Var(2)

    x3 = x1 * y
    assert isinstance(x3, ad.Var)
    assert x3 == ad.Var(3)

    x3 = y * x1
    assert isinstance(x3, ad.Var)
    assert x3 == ad.Var(3)

    # truediv, rtruediv
    x3 = x1 / x2
    assert isinstance(x3, ad.Var)
    assert x3 == ad.Var(0.5)

    x3 = x1 / y
    assert isinstance(x3, ad.Var)
    assert x3 == ad.Var(1 / 3)

    x3 = y / x1
    assert isinstance(x3, ad.Var)
    assert x3 == ad.Var(3)

    with pytest.raises(ZeroDivisionError):
        x3 = x1 / ad.Var(0)

    # neg
    x3 = -x1
    assert isinstance(x3, ad.Var)
    assert x3 == ad.Var(-1)