Exemple #1
0
def test_nonconstant_s0_raises(s0, dropout: float, out_constant: bool):
    T, N, C, D = 5, 1, 3, 2
    X = Tensor(np.random.rand(T, N, C))
    Wz, Wr, Wh = Tensor(np.random.rand(3, D, D))
    Uz, Ur, Uh = Tensor(np.random.rand(3, C, D))
    bz, br, bh = Tensor(np.random.rand(3, D))

    with does_not_raise() if (
        out_constant or s0 is None or isinstance(s0, np.ndarray) or s0.constant
    ) else pytest.raises(ValueError):
        gru(
            X,
            Uz,
            Wz,
            bz,
            Ur,
            Wr,
            br,
            Uh,
            Wh,
            bh,
            s0=s0,
            dropout=dropout,
            constant=out_constant,
        )
def test_backpropping_non_numeric_gradient_raises(constant: bool,
                                                  arr: np.ndarray,
                                                  op_before: bool, op_after):
    x = Tensor(arr, constant=constant)

    if op_before:
        x += 1

    x = old_op(x)

    if op_after:
        x = x * 2

    # if constant tensor, backprop should not be triggered - no exception raised
    with (pytest.raises(InvalidGradient)
          if not constant else does_not_raise()):
        x.backward()
def test_backpropping_non_numeric_gradient_raises(constant: bool,
                                                  arr: np.ndarray,
                                                  op_before: bool,
                                                  op_after: bool):
    x = Tensor(arr, constant=constant)

    if op_before:
        x += 1

    x = old_op(x)

    if op_after:
        x = x * 2

    # if constant tensor, backprop should not be triggered - no exception raised
    with (pytest.raises(InvalidGradient)
          if not constant else does_not_raise()) as exec_info:
        x.backward()

    if exec_info is not None:
        err_msg = str(exec_info.value)
        assert "NoneType" in err_msg
Exemple #4
0
def test_invalid_gradient_raises(constant: bool):
    x = Tensor(3, constant=constant) * 2
    with (pytest.raises(InvalidGradient)
          if not constant else does_not_raise()):
        x.backward("bad")