def test_dtcwt2(size, J, no_grad=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) h0a, h0b, _, _, h1a, h1b, _, _ = level1('farras') cols, rows = lowlevel2.prep_filt_quad_afb2d(h0a, h1a, h0b, h1b, device=dev) yh = [] for j in range(3): x, y = lowlevel2.quad_afb2d(x, cols, rows, mode='zero') yh.append(y) return x, yh
def test_dtcwt(size, J, no_grad=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) h0a, h0b, _, _, h1a, h1b, _, _ = level1('farras') filts = lowlevel2.prep_filt_quad_afb2d_nonsep(h0a, h1a, h0a, h1a, h0a, h1a, h0b, h1b, h0b, h1b, h0a, h1a, h0b, h1b, h0b, h1b, device=dev) for j in range(3): yl, yh = lowlevel.afb2d_nonsep(x, filts, mode='zero') x = yl.reshape(yl.shape[0], -1, yl.shape[-2], yl.shape[-1])