def test_print_op_bprop(): """ Ensure bprop of PrintOp is correct (passes through exactly the delta) """ x = ng.placeholder(ng.make_axes([ng.make_axis(length=10)])) # randomly initialize x_value = rng.uniform(-1, 1, x.axes) check_derivative(ng.PrintOp(x), x, 0.001, x_value, atol=1e-3, rtol=1e-3)
def test_print_op_fprop(capfd): """ Ensure fprop of PrintOp makes no change to input, and also prints to stdout. """ x = ng.placeholder(ng.make_axes([ng.make_axis(length=1)])) # hardcode value so there are is no rounding to worry about in str # comparison in final assert x_value = np.array([1]) output = ng.PrintOp(x, 'prefix') with executor(output, x) as ex: result = ex(x_value) ng.testing.assert_allclose(result, x_value) out, err = capfd.readouterr() assert str(x_value[0]) in out assert 'prefix' in out