def test_backpropping_non_numeric_gradient_raises(constant: bool, arr: np.ndarray, op_before: bool, op_after): x = Tensor(arr, constant=constant) if op_before: x += 1 x = old_op(x) if op_after: x = x * 2 # if constant tensor, backprop should not be triggered - no exception raised with (pytest.raises(InvalidGradient) if not constant else does_not_raise()): x.backward()
def test_backpropping_non_numeric_gradient_raises(constant: bool, arr: np.ndarray, op_before: bool, op_after: bool): x = Tensor(arr, constant=constant) if op_before: x += 1 x = old_op(x) if op_after: x = x * 2 # if constant tensor, backprop should not be triggered - no exception raised with (pytest.raises(InvalidGradient) if not constant else does_not_raise()) as exec_info: x.backward() if exec_info is not None: err_msg = str(exec_info.value) assert "NoneType" in err_msg
def test_invalid_gradient_raises(constant: bool): x = Tensor(3, constant=constant) * 2 with (pytest.raises(InvalidGradient) if not constant else does_not_raise()): x.backward("bad")