Beispiel #1
0
def test_backpropping_non_numeric_gradient_raises(constant: bool,
                                                  arr: np.ndarray,
                                                  op_before: bool, op_after):
    x = Tensor(arr, constant=constant)

    if op_before:
        x += 1

    x = old_op(x)

    if op_after:
        x = x * 2

    # if constant tensor, backprop should not be triggered - no exception raised
    with (pytest.raises(InvalidGradient)
          if not constant else does_not_raise()):
        x.backward()
def test_backpropping_non_numeric_gradient_raises(constant: bool,
                                                  arr: np.ndarray,
                                                  op_before: bool,
                                                  op_after: bool):
    x = Tensor(arr, constant=constant)

    if op_before:
        x += 1

    x = old_op(x)

    if op_after:
        x = x * 2

    # if constant tensor, backprop should not be triggered - no exception raised
    with (pytest.raises(InvalidGradient)
          if not constant else does_not_raise()) as exec_info:
        x.backward()

    if exec_info is not None:
        err_msg = str(exec_info.value)
        assert "NoneType" in err_msg
Beispiel #3
0
def test_invalid_gradient_raises(constant: bool):
    x = Tensor(3, constant=constant) * 2
    with (pytest.raises(InvalidGradient)
          if not constant else does_not_raise()):
        x.backward("bad")