def test_numpy_batch_ch(data_format): if data_format == "nhwc": X = np.random.randn(5,100,100,4) else: X = np.random.randn(5,4,100,100) f = Transform2d() p = f.forward_channels(X, data_format=data_format, include_scale=True) f1 = Transform2d_np() for i in range(5): for j in range(4): if data_format == "nhwc": p1 = f1.forward(X[i,:,:,j], include_scale=True) np.testing.assert_array_almost_equal( p.lowpass[i,:,:,j], p1.lowpass, decimal=PRECISION_DECIMAL) for x,y in zip(p.highpasses, p1.highpasses): np.testing.assert_array_almost_equal( x[i,:,:,j], y, decimal=PRECISION_DECIMAL) for x,y in zip(p.scales, p1.scales): np.testing.assert_array_almost_equal( x[i,:,:,j], y, decimal=PRECISION_DECIMAL) else: p1 = f1.forward(X[i,j], include_scale=True) np.testing.assert_array_almost_equal( p.lowpass[i,j], p1.lowpass, decimal=PRECISION_DECIMAL) for x,y in zip(p.highpasses, p1.highpasses): np.testing.assert_array_almost_equal( x[i,j], y, decimal=PRECISION_DECIMAL) for x,y in zip(p.scales, p1.scales): np.testing.assert_array_almost_equal( x[i,j], y, decimal=PRECISION_DECIMAL)
def test_results_match_invmask(biort,qshift,gain_mask): im = mandrill f_np = Transform2d_np(biort=biort, qshift=qshift) p_np = f_np.forward(im, nlevels=4, include_scale=True) X_np = f_np.inverse(p_np, gain_mask) f_tf = Transform2d(biort=biort, qshift=qshift) p_tf = f_tf.forward(im, nlevels=4, include_scale=True) X_tf = f_tf.inverse(p_tf, gain_mask) np.testing.assert_array_almost_equal( X_np, X_tf, decimal=PRECISION_DECIMAL)
def test_results_match_endtoend(test_input, biort, qshift): im = test_input f_np = Transform2d_np(biort=biort, qshift=qshift) p_np = f_np.forward(im, nlevels=4, include_scale=True) X_np = f_np.inverse(p_np) in_p = tf.placeholder(tf.float32, [im.shape[0], im.shape[1]]) f_tf = Transform2d(biort=biort, qshift=qshift) p_tf = f_tf.forward(in_p, nlevels=4, include_scale=True) X = f_tf.inverse(p_tf) with tf.Session() as sess: X_tf = sess.run(X, feed_dict={in_p: im}) np.testing.assert_array_almost_equal( X_np, X_tf, decimal=PRECISION_DECIMAL)
def test_results_match_inverse(test_input,biort,qshift): im = test_input f_np = Transform2d_np(biort=biort, qshift=qshift) p_np = f_np.forward(im, nlevels=4, include_scale=True) X_np = f_np.inverse(p_np) # Use a zero input and the fwd transform to get the shape of # the pyramid easily f_tf = Transform2d(biort=biort, qshift=qshift) p_tf = f_tf.forward(im, nlevels=4, include_scale=True) # Create ops for the inverse transform X_tf = f_tf.inverse(p_tf) np.testing.assert_array_almost_equal( X_np, X_tf, decimal=PRECISION_DECIMAL)
def test_results_match(test_input, biort, qshift): """ Compare forward transform with numpy forward transform for mandrill image """ im = test_input f_np = Transform2d_np(biort=biort,qshift=qshift) p_np = f_np.forward(im, include_scale=True) f_tf = Transform2d(biort=biort,qshift=qshift) p_tf = f_tf.forward(im, include_scale=True) np.testing.assert_array_almost_equal( p_np.lowpass, p_tf.lowpass, decimal=PRECISION_DECIMAL) [np.testing.assert_array_almost_equal( h_np, h_tf, decimal=PRECISION_DECIMAL) for h_np, h_tf in zip(p_np.highpasses, p_tf.highpasses)] [np.testing.assert_array_almost_equal( s_np, s_tf, decimal=PRECISION_DECIMAL) for s_np, s_tf in zip(p_np.scales, p_tf.scales)]
def test_numpy_in(): X = np.random.randn(100,100) f = Transform2d() p = f.forward(X) f1 = Transform2d_np() p1 = f1.forward(X) np.testing.assert_array_almost_equal( p.lowpass, p1.lowpass, decimal=PRECISION_DECIMAL) for x,y in zip(p.highpasses, p1.highpasses): np.testing.assert_array_almost_equal(x,y,decimal=PRECISION_DECIMAL) X = np.random.randn(100,100) p = f.forward(X, include_scale=True) p1 = f1.forward(X, include_scale=True) np.testing.assert_array_almost_equal( p.lowpass, p1.lowpass, decimal=PRECISION_DECIMAL) for x,y in zip(p.highpasses, p1.highpasses): np.testing.assert_array_almost_equal(x,y,decimal=PRECISION_DECIMAL) for x,y in zip(p.scales, p1.scales): np.testing.assert_array_almost_equal(x,y,decimal=PRECISION_DECIMAL)
def __init__(self, biort='near_sym_a', qshift='qshift_a'): self.xfm = Transform2d_np(biort, qshift)