def test_floordivision(): # Two nodes via function call assert floordiv(Node(8, 8), Node(3, 3)) == Node(2, 0) # Two nodes via overload assert Node(8, 8) // Node(3, 3) == Node(2, 0) # Two constants via function call assert floordiv(11, 3) == Node(3, 0) # Node and constant assert floordiv(Node(7.5, 3.2), 2) == Node(7.5, 3.2) // 2 == Node(3, 0)
def test_multiplication(): # Two nodes via function call assert multiplication(Node(1, 2), Node(3, 4)) == Node(3, 10) # Two nodes via overload assert Node(1, 2) * Node(3, 4) == Node(3, 10) # Two constants via function call assert multiplication(1, 2) == Node(2, 0) # Node and constant assert multiplication(Node(2, 2), 2) == Node(4, 4)
def test_addition(): # Two nodes via function call assert addition(Node(1, 2), Node(3, 4)) == Node(4, 6) # Two nodes via overload assert Node(1, 2) + Node(3, 4) == Node(4, 6) # Two constants via function call assert addition(1, 2) == Node(3, 0) # Node and constant assert addition(Node(1, 2), 2) == Node(3, 2)
def test_subtraction(): # Two nodes via function call assert addition(Node(3, 4), -Node(1, 2)) == Node(2, 2) # Two nodes via overload assert Node(3, 4) - Node(1, 2) == Node(2, 2) # Two constants via function call assert addition(2, -1) == Node(1, 0) # Node and constant assert addition(Node(3, 2), -2) == Node(1, 2)
def test_floor(): # Nodes assert floor(Node(3.3, 3)) == Node(3, 0) assert floor(Node(3.9, 3)) == Node(3, 0) assert floor(Node(-2.5, -2)) == Node(-3, 0) # Constants assert floor(3.3) == Node(3, 0) assert floor(-3.2) == Node(-4, 0)
def test_exp(): # Nodes assert exp(Node(2, 0)) == Node(np.exp(2), 0) assert exp(Node(2, 2)) == Node(np.exp(2), 2*np.exp(2)) assert exp(Node(-2, -2)) == Node(np.exp(-2), -2*np.exp(-2)) # Constants assert exp(1) == Node(np.exp(1), 0) assert exp(-1) == Node(np.exp(-1), 0)
def test_trunc(): # Nodes assert trunc(Node(2.1, 0)) == Node(2, 0) assert trunc(Node(2.9, 0)) == Node(2, 0) assert trunc(Node(-1.5, 0)) == Node(-1, 0) # Constants assert trunc(1.5) == Node(1, 0) assert trunc(-1.5) == Node(-1, 0)
def test_round(): # Not specifying num digits assert _round(Node(2.56, 5.55)) == Node(3, 0) assert _round(1.5) == Node(2, 0) # Specifying num digits assert _round(Node(2.56, 5.55), 1) == Node(2.6, 0) assert _round(1.53, 1) == Node(1.5, 0) assert _round(1.539, 2) == Node(1.54, 0)
def test_ceil(): # Nodes assert ceil(Node(4.1, 3)) == Node(5, 0) assert ceil(Node(-2.5, -2)) == Node(-2, 0) # Constants assert ceil(5.1) == Node(6, 0) assert ceil(-3.5) == Node(-3, 0)
def test_str(): n = Node(2, 1) assert str( n) == 'Node object with value 2 and derivative 1 and back-gradient 0'
def test_le(): assert Node(2, 0) < Node(3, -1)
def test_lt(): assert Node(2, 0) <= Node(3, 0)
def test_complex(): with pytest.raises(NotImplementedError): complex(Node(1, 3))
def test_trunc(): assert trunc(Node(3.33, 3)) == Node(3, 0)
def test_floor(): assert floor(Node(3.3, 3)) == Node(3, 0)
def test_invert(): with pytest.raises(ValueError): ~(Node(2, 0))
def test_round(): assert round(Node(2.2, 3.2), 0) == Node(2, 0) assert round(Node(2.2, 3.2), Node(0, 0)) == Node(2, 0)
def test_rpow(): with pytest.raises(NotImplementedError): 2**Node(2, 3)
def test_ceil(): assert ceil(Node(3.3, 3)) == Node(4, 0)
def test_sub(): assert Node(1, 2) - 2 == Node(-1, 2)
def test_complex_input(): with pytest.raises(TypeError): Node(complex(3), 2)
def test_rsub(): assert 2 - Node(1, 2) == Node(1, -2)
def test_ne(): assert Node(2, 3) != Node(3, 2)
def test_pos(): print(+Node(-1, -2)) assert +(Node(-1, -2)) == Node(-1, -2)
def test_gt(): assert Node(3, 0) >= Node(1, 0)
def test_neg(): assert -(Node(1, 2)) == Node(-1, -2)
def test_ge(): assert Node(2, 0) > Node(1, 4)
def test_rfloordiv(): assert 8 // Node(3, 3) == Node(2, 0)
def test_repr(): n = Node(2, -1) assert repr(n) == 'Node(2, -1)'
def test_abs(): assert abs(Node(-1, 1)) == Node(1, -1)