def test_fwd(J, o_before_c): X = 100 * np.random.randn(3, 5, 100, 100) Xt = torch.tensor(X, dtype=torch.get_default_dtype(), device=dev) xfm = DTCWTForward(J=J, o_before_c=o_before_c).to(dev) Yl, Yh = xfm(Xt) f1 = Transform2d_np() yl, yh = f1.forward(X, nlevels=J) np.testing.assert_array_almost_equal(Yl.cpu(), yl, decimal=PRECISION_FLOAT) for i in range(len(yh)): if o_before_c: np.testing.assert_array_almost_equal(Yh[i][..., 0].cpu().transpose( 2, 1), yh[i].real, decimal=PRECISION_FLOAT) np.testing.assert_array_almost_equal(Yh[i][..., 1].cpu().transpose( 2, 1), yh[i].imag, decimal=PRECISION_FLOAT) else: np.testing.assert_array_almost_equal(Yh[i][..., 0].cpu(), yh[i].real, decimal=PRECISION_FLOAT) np.testing.assert_array_almost_equal(Yh[i][..., 1].cpu(), yh[i].imag, decimal=PRECISION_FLOAT)
def test_fwd_ri_dim(o_dim, ri_dim): J = 3 X = 100 * np.random.randn(3, 5, 100, 100) Xt = torch.tensor(X, dtype=torch.get_default_dtype(), device=dev) xfm = DTCWTForward(J=J, o_dim=o_dim, ri_dim=ri_dim).to(dev) Yl, Yh = xfm(Xt) f1 = Transform2d_np() yl, yh = f1.forward(X, nlevels=J) np.testing.assert_array_almost_equal(Yl.cpu(), yl, decimal=PRECISION_FLOAT) if (ri_dim % 6) < o_dim: o_dim -= 1 for i in range(len(yh)): ours_r = np.take(Yh[i].cpu().numpy(), 0, ri_dim) ours_i = np.take(Yh[i].cpu().numpy(), 1, ri_dim) for l in range(6): ours = np.take(ours_r, l, o_dim) np.testing.assert_array_almost_equal(ours, yh[i][:, :, l].real, decimal=PRECISION_FLOAT) ours = np.take(ours_i, l, o_dim) np.testing.assert_array_almost_equal(ours, yh[i][:, :, l].imag, decimal=PRECISION_FLOAT)
def test_inv_skip_hps(J, o_dim): hps = np.random.binomial(size=J, n=1,p=0.5).astype('bool') Yl = 100*np.random.randn(3, 5, 64, 64) Yhr = [[np.random.randn(3, 5, 2**j, 2**j) for l in range(6)] for j in range(4+J,4,-1)] Yhi = [[np.random.randn(3, 5, 2**j, 2**j) for l in range(6)] for j in range(4+J,4,-1)] Yh1 = [np.stack(r, axis=2) + 1j*np.stack(i, axis=2) for r, i in zip(Yhr, Yhi)] Yh2 = [np.stack((np.stack(r, axis=o_dim), np.stack(i, axis=o_dim)), axis=-1) for r, i in zip(Yhr, Yhi)] Yh2 = [torch.tensor(yh, dtype=torch.float32, device=dev) for yh in Yh2] for j in range(J): if hps[j]: Yh2[j] = torch.tensor([]) Yh1[j] = np.zeros_like(Yh1[j]) ifm = DTCWTInverse(J=J, o_dim=o_dim).to(dev) X = ifm((torch.tensor(Yl, dtype=torch.float32, requires_grad=True, device=dev), Yh2)) # Also test giving None instead of an empty tensor for j in range(J): if hps[j]: Yh2[j] = None X2 = ifm((torch.tensor(Yl, dtype=torch.float32, device=dev), Yh2)) f1 = Transform2d_np() x = f1.inverse(Yl, Yh1) np.testing.assert_array_almost_equal( X.detach().cpu(), x, decimal=PRECISION_FLOAT) np.testing.assert_array_almost_equal( X2.cpu(), x, decimal=PRECISION_FLOAT) # Test gradients are ok X.backward(torch.ones_like(X))
def test_inv(J, o_before_c): Yl = 100 * np.random.randn(3, 5, 64, 64) Yhr = [np.random.randn(3, 5, 6, 2**j, 2**j) for j in range(4 + J, 4, -1)] Yhi = [np.random.randn(3, 5, 6, 2**j, 2**j) for j in range(4 + J, 4, -1)] Yh1 = [yhr + 1j * yhi for yhr, yhi in zip(Yhr, Yhi)] if o_before_c: Yh2 = [ torch.tensor(np.stack((yhr, yhi), axis=-1), dtype=torch.float32, device=dev).transpose(1, 2) for yhr, yhi in zip(Yhr, Yhi) ] else: Yh2 = [ torch.tensor(np.stack((yhr, yhi), axis=-1), dtype=torch.float32, device=dev) for yhr, yhi in zip(Yhr, Yhi) ] ifm = DTCWTInverse(J=J, o_before_c=o_before_c).to(dev) X = ifm((torch.tensor(Yl, dtype=torch.float32, device=dev), Yh2)) f1 = Transform2d_np() x = f1.inverse(Yl, Yh1) np.testing.assert_array_almost_equal(X.cpu(), x, decimal=PRECISION_FLOAT)
def test_fwd_skip_hps(J, o_before_c): X = 100 * np.random.randn(3, 5, 100, 100) # Randomly turn on/off the highpass outputs hps = np.random.binomial(size=J, n=1, p=0.5).astype('bool') xfm = DTCWTForward(J=J, skip_hps=hps, o_before_c=o_before_c).to(dev) Yl, Yh = xfm(torch.tensor(X, dtype=torch.float32, device=dev)) f1 = Transform2d_np() yl, yh = f1.forward(X, nlevels=J) np.testing.assert_array_almost_equal(Yl.cpu(), yl, decimal=PRECISION_FLOAT) for j in range(J): if hps[j]: assert Yh[j].shape == torch.Size([0]) else: if o_before_c: np.testing.assert_array_almost_equal(Yh[j][..., 0].cpu().transpose( 2, 1), yh[j].real, decimal=PRECISION_FLOAT) np.testing.assert_array_almost_equal(Yh[j][..., 1].cpu().transpose( 2, 1), yh[j].imag, decimal=PRECISION_FLOAT) else: np.testing.assert_array_almost_equal(Yh[j][..., 0].cpu(), yh[j].real, decimal=PRECISION_FLOAT) np.testing.assert_array_almost_equal(Yh[j][..., 1].cpu(), yh[j].imag, decimal=PRECISION_FLOAT)
def test_bwd_include_scale(scales): X = 100 * np.random.randn(3, 5, 100, 100) # Randomly turn on/off the highpass outputs J = len(scales) xfm = DTCWTForward(J=J, include_scale=scales).to(dev) Ys, Yh = xfm( torch.tensor(X, dtype=torch.float32, requires_grad=True, device=dev)) f1 = Transform2d_np() yl, yh, ys = f1.forward(X, nlevels=J, include_scale=True) for ys in Ys: ys.backward(torch.ones_like(ys), retain_graph=True)
def test_inv(J, o_dim): Yl = 100*np.random.randn(3, 5, 64, 64) Yhr = [[np.random.randn(3, 5, 2**j, 2**j) for l in range(6)] for j in range(4+J,4,-1)] Yhi = [[np.random.randn(3, 5, 2**j, 2**j) for l in range(6)] for j in range(4+J,4,-1)] Yh1 = [np.stack(r, axis=2) + 1j*np.stack(i, axis=2) for r, i in zip(Yhr, Yhi)] Yh2 = [np.stack((np.stack(r, axis=o_dim), np.stack(i, axis=o_dim)), axis=-1) for r, i in zip(Yhr, Yhi)] Yh2 = [torch.tensor(yh, dtype=torch.float32, device=dev) for yh in Yh2] ifm = DTCWTInverse(J=J, o_dim=o_dim).to(dev) X = ifm((torch.tensor(Yl, dtype=torch.float32, device=dev), Yh2)) f1 = Transform2d_np() x = f1.inverse(Yl, Yh1) np.testing.assert_array_almost_equal( X.cpu(), x, decimal=PRECISION_FLOAT)
def test_fwd_include_scale(scales): X = 100*np.random.randn(3, 5, 100, 100) # Randomly turn on/off the highpass outputs J = len(scales) xfm = DTCWTForward(J=J, include_scale=scales).to(dev) Ys, Yh = xfm(torch.tensor(X, dtype=torch.float32, device=dev)) f1 = Transform2d_np() yl, yh, ys = f1.forward(X, nlevels=J, include_scale=True) for j in range(J): if not scales[j]: assert Ys[j].shape == torch.Size([0]) else: np.testing.assert_array_almost_equal( Ys[j].cpu(), ys[j], decimal=PRECISION_FLOAT)
def test_end2end(biort, qshift, size, J): im = np.random.randn(5,6,*size).astype('float32') imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev) xfm = DTCWTForward(J=J, biort=biort, qshift=qshift).to(dev) Yl, Yh = xfm(imt) ifm = DTCWTInverse(J=J, biort=biort, qshift=qshift).to(dev) y = ifm((Yl, Yh)) # Compare with numpy results f_np = Transform2d_np(biort=biort, qshift=qshift) yl, yh = f_np.forward(im, nlevels=J) y2 = f_np.inverse(yl, yh) np.testing.assert_array_almost_equal(y.detach().cpu(), y2, decimal=PRECISION_FLOAT) # Test gradients are ok y.backward(torch.ones_like(y))
def test_fwd_double(J, o_dim): with set_double_precision(): X = 100*np.random.randn(3, 5, 100, 100) Xt = torch.tensor(X, dtype=torch.get_default_dtype(), device=dev) xfm = DTCWTForward(J=J, o_dim=o_dim).to(dev) Yl, Yh = xfm(Xt) assert Yl.dtype == torch.float64 f1 = Transform2d_np() yl, yh = f1.forward(X, nlevels=J) np.testing.assert_array_almost_equal( Yl.cpu(), yl, decimal=PRECISION_DOUBLE) for i in range(len(yh)): for l in range(6): ours_r = np.take(Yh[i][...,0].cpu().numpy(), l, o_dim) ours_i = np.take(Yh[i][...,1].cpu().numpy(), l, o_dim) np.testing.assert_array_almost_equal( ours_r, yh[i][:,:,l].real, decimal=PRECISION_DOUBLE) np.testing.assert_array_almost_equal( ours_i, yh[i][:,:,l].imag, decimal=PRECISION_DOUBLE)
def test_inv_ri_dim(ri_dim): Yl = 100*np.random.randn(3, 5, 64, 64) J = 3 Yhr = [np.random.randn(3, 5, 6, 2**j, 2**j) for j in range(4+J,4,-1)] Yhi = [np.random.randn(3, 5, 6, 2**j, 2**j) for j in range(4+J,4,-1)] Yh1 = [yhr + 1j*yhi for yhr, yhi in zip(Yhr, Yhi)] Yh2 = [torch.tensor(np.stack((yhr, yhi), axis=ri_dim), dtype=torch.float32, device=dev) for yhr, yhi in zip(Yhr, Yhi)] if (ri_dim % 6) <= 2: o_dim = 3 else: o_dim = 2 ifm = DTCWTInverse(J=J, o_dim=o_dim, ri_dim=ri_dim).to(dev) X = ifm((torch.tensor(Yl, dtype=torch.float32, device=dev), Yh2)) f1 = Transform2d_np() x = f1.inverse(Yl, Yh1) np.testing.assert_array_almost_equal( X.cpu(), x, decimal=PRECISION_FLOAT)