def test_ipow_banned(): x = Var([1, 2, 3]) with pytest.raises(TypeError): x **= 3
def test_rdiv_vars(): x = Var([3, 3, 3]) y = 6 / x y.backward() assert np.all(y == Var([2, 2, 2])) assert np.all(y.grad(x) == np.array([-2 / 3, -2 / 3, -2 / 3]))
def test_rdiv_var_fails_with_divide_by_zero(): with pytest.raises((ZeroDivisionError, FloatingPointError)): 1 / Var([0, 2, 3])
def test_mul_var_number(): x = Var([3, 2, 1]) y = x * 5 y.backward() assert np.all(y == Var([15, 10, 5])) assert np.all(y.grad(x) == np.array([5, 5, 5]))
def test_div_var_number(): x = Var([10, 10]) y = x / 5 y.backward() assert np.all(y == Var([2, 2])) assert np.all(y.grad(x) == np.array([0.2, .2]))
def test_cosh(): var1 = Var(0.) var2 = ops.cosh(var1) var2.backward() assert var2.val == pytest.approx(1) assert var2.grad(var1) == np.array([0])
def test_atanh(): var1 = Var(0.) var2 = ops.arctanh(var1) var2.backward() assert var2.val == pytest.approx(0) assert var2.grad(var1) == np.array([1])
def test_rsub_var(): x = Var([1, 2, 3]) y = 6 - x x.forward() assert np.all(y == Var([5, 4, 3])) assert np.all(y.grad(x) == np.array([-1, -1, -1]))
def test_exp(): var1 = Var(0) var2 = ops.exp(var1) var2.backward() assert var2.val == pytest.approx(1) assert var2.grad(var1) == np.array([1])
def test_radd_var(): x = Var([1, 2, 3]) y = 6 + x x.forward() assert np.all(y == Var([7, 8, 9])) assert np.all(y.grad(x) == np.array([1, 1, 1]))
def test_sub_var_number(): x = Var([1, 2, 3]) y = x - 5 x.forward() assert np.all(y == Var([-4, -3, -2])) assert np.all(y.grad(x) == np.array([1, 1, 1]))
def test_init_var(): var = Var([1, 2, 3]) assert np.all(var.val == [1, 2, 3])
def test_add_var_number(): x = Var([1, 2, 3]) y = x + 5 x.forward() assert np.all(y == Var([6, 7, 8])) assert np.all(y.grad(x) == np.array([1, 1, 1]))
def test_abs(): x = Var([1, 2, -3]) y = abs(x) x.forward() assert np.all(y == Var([1, 2, 3])) assert np.all(y.grad(x) == np.array([1, 1, -1]))
def test_acos(): var1 = Var(0.) var2 = ops.arccos(var1) var2.backward() assert var2.val == pytest.approx(1.570796, abs=1e-2) assert var2.grad(var1) == np.array([-1])
def test_log(): var1 = Var(1.) var2 = ops.log(var1) var2.backward() assert var2.val == pytest.approx(0) assert var2.grad(var1) == np.array([1])
def test_sin(): var1 = Var(np.pi) var2 = ops.sin(var1) var2.backward() assert var2.val == pytest.approx(0) assert var2.grad(var1) == np.array([-1])
def test_logistic(): var1 = Var(0) var2 = ops.logistic(var1) var2.backward() assert var2.val == pytest.approx(0.5) assert var2.grad(var1) == var2.val * (1 - var2.val)
def test_acosh(): var1 = Var(2.) var2 = ops.arccosh(var1) var2.backward() assert var2.val == np.arccosh(2.) assert var2.grad(var1) == np.array([1 / np.sqrt(3)])
def test_sqrt(): var1 = Var(4) var2 = ops.sqrt(var1) var2.backward() assert var2.val == pytest.approx(2) assert var2.grad(var1) == .5 * 1 / var2.val
def test_invalid_arg_raises_error(): var = Var([1, 2, 3]) with pytest.raises(TypeError): var.grad('invalid value')
def test_neg(): var1 = Var(1) var2 = ops.neg(var1) var2.backward() assert var2.val == pytest.approx(-1) assert var2.grad(var1) == -1
def test_rmul_vars(): x = Var([3, 3, 3]) y = 6 * x y.backward() assert np.all(y == Var([18, 18, 18])) assert np.all(y.grad(x) == np.array([6, 6, 6]))
def test_cos(): var1 = Var(np.pi) var2 = ops.cos(var1) var2.backward() assert var2.val == pytest.approx(-1) assert np.array([var2.grad(var1)]) == pytest.approx([0])
def test_div_var_number_fails_with_divide_by_zero(): with pytest.raises((ZeroDivisionError, FloatingPointError)): Var([1, 2, 3]) / 0.
def test_abs(): var1 = Var(-1) var2 = ops.abs(var1) var2.backward() assert var2.val == 1. assert var2.grad(var1) == -1
def test_neg(): x = Var([1, 2, 3]) y = -x y.backward() assert np.all(y == Var([-1, -2, -3])) assert np.all(y.grad(x) == np.array([-1, -1, -1]))
def test_composite_logexp(): x = Var(5) y = ops.log(ops.exp(x)) y.backward() assert np.all(x.val == pytest.approx(y.val)) assert np.all(y.grad(x) == 1)
def test_div_var_non_number(): with pytest.raises(TypeError): Var([1, 2, 3]) / 'string'
def test_idiv_banned(): x = Var([1, 2, 3]) with pytest.raises(TypeError): x /= 3