Ejemplo n.º 1
0
def test_print_op_bprop():
    """
    Ensure bprop of PrintOp is correct (passes through exactly the delta)
    """
    x = ng.placeholder(ng.make_axes([ng.make_axis(length=10)]))

    # randomly initialize
    x_value = rng.uniform(-1, 1, x.axes)

    check_derivative(ng.PrintOp(x), x, 0.001, x_value, atol=1e-3, rtol=1e-3)
Ejemplo n.º 2
0
def test_print_op_fprop(capfd):
    """
    Ensure fprop of PrintOp makes no change to input, and also prints to
    stdout.
    """

    x = ng.placeholder(ng.make_axes([ng.make_axis(length=1)]))

    # hardcode value so there are is no rounding to worry about in str
    # comparison in  final assert
    x_value = np.array([1])

    output = ng.PrintOp(x, 'prefix')
    with executor(output, x) as ex:
        result = ex(x_value)

    ng.testing.assert_allclose(result, x_value)

    out, err = capfd.readouterr()
    assert str(x_value[0]) in out
    assert 'prefix' in out