Ejemplo n.º 1
0
def test_gradients_inv(biort, qshift, size, J):
    """ Gradient of forward function should be inverse function with filters
    swapped """
    im = np.random.randn(5, 6, *size).astype('float32')
    imt = torch.tensor(im, dtype=torch.float32, device=dev)
    ifm = DTCWTInverse(biort=biort, qshift=qshift).to(dev)
    h0o, g0o, h1o, g1o = _biort(biort)
    h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
    ifm_grad = DTCWTForward(J=J,
                            biort=(g0o[::-1], g1o[::-1]),
                            qshift=(g0a[::-1], g0b[::-1], g1a[::-1],
                                    g1b[::-1])).to(dev)
    yl, yh = ifm_grad(imt)
    g = torch.randn(*imt.shape, device=dev)
    ylv = torch.randn(*yl.shape, requires_grad=True, device=dev)
    yhv = [torch.randn(*h.shape, requires_grad=True, device=dev) for h in yh]
    Y = ifm((ylv, yhv))
    Y.backward(g)

    # Check the lowpass gradient is the same
    ref_lp, ref_bp = ifm_grad(g)
    np.testing.assert_array_almost_equal(ylv.grad.detach().cpu(), ref_lp.cpu())
    # check the bandpasses are the same
    for y, ref in zip(yhv, ref_bp):
        np.testing.assert_array_almost_equal(y.grad.detach().cpu(), ref.cpu())
Ejemplo n.º 2
0
def test_inv(J, o_before_c):
    Yl = 100 * np.random.randn(3, 5, 64, 64)
    Yhr = [np.random.randn(3, 5, 6, 2**j, 2**j) for j in range(4 + J, 4, -1)]
    Yhi = [np.random.randn(3, 5, 6, 2**j, 2**j) for j in range(4 + J, 4, -1)]
    Yh1 = [yhr + 1j * yhi for yhr, yhi in zip(Yhr, Yhi)]
    if o_before_c:
        Yh2 = [
            torch.tensor(np.stack((yhr, yhi), axis=-1),
                         dtype=torch.float32,
                         device=dev).transpose(1, 2)
            for yhr, yhi in zip(Yhr, Yhi)
        ]
    else:
        Yh2 = [
            torch.tensor(np.stack((yhr, yhi), axis=-1),
                         dtype=torch.float32,
                         device=dev) for yhr, yhi in zip(Yhr, Yhi)
        ]

    ifm = DTCWTInverse(J=J, o_before_c=o_before_c).to(dev)
    X = ifm((torch.tensor(Yl, dtype=torch.float32, device=dev), Yh2))
    f1 = Transform2d_np()
    x = f1.inverse(Yl, Yh1)

    np.testing.assert_array_almost_equal(X.cpu(), x, decimal=PRECISION_FLOAT)
Ejemplo n.º 3
0
def test_gradients_fwd(biort, qshift, size, J):
    """ Gradient of forward function should be inverse function with filters
    swapped """
    im = np.random.randn(5, 6, *size).astype('float32')
    imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev)
    xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(dev)
    h0o, g0o, h1o, g1o = _biort(biort)
    h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
    xfm_grad = DTCWTInverse(biort=(h0o[::-1], h1o[::-1]),
                            qshift=(h0a[::-1], h0b[::-1], h1a[::-1],
                                    h1b[::-1])).to(dev)
    Yl, Yh = xfm(imt)
    Ylg = torch.randn(*Yl.shape, device=dev)
    Yl.backward(Ylg, retain_graph=True)
    ref = xfm_grad((Ylg, [
        None,
    ] * J))
    np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu())
    for j, y in enumerate(Yh):
        imt.grad.zero_()
        g = torch.randn(*y.shape, device=dev)
        y.backward(g, retain_graph=True)
        hps = [
            None,
        ] * J
        hps[j] = g
        ref = xfm_grad((torch.zeros_like(Yl), hps))
        np.testing.assert_array_almost_equal(imt.grad.detach().cpu(),
                                             ref.cpu())
Ejemplo n.º 4
0
def test_inv_skip_hps(J, o_dim):
    hps = np.random.binomial(size=J, n=1,p=0.5).astype('bool')
    Yl = 100*np.random.randn(3, 5, 64, 64)
    Yhr = [[np.random.randn(3, 5, 2**j, 2**j) for l in range(6)] for j in range(4+J,4,-1)]
    Yhi = [[np.random.randn(3, 5, 2**j, 2**j) for l in range(6)] for j in range(4+J,4,-1)]
    Yh1 = [np.stack(r, axis=2) + 1j*np.stack(i, axis=2) for r, i in zip(Yhr, Yhi)]
    Yh2 = [np.stack((np.stack(r, axis=o_dim), np.stack(i, axis=o_dim)), axis=-1)
           for r, i in zip(Yhr, Yhi)]
    Yh2 = [torch.tensor(yh, dtype=torch.float32, device=dev) for yh in Yh2]
    for j in range(J):
        if hps[j]:
            Yh2[j] = torch.tensor([])
            Yh1[j] = np.zeros_like(Yh1[j])

    ifm = DTCWTInverse(J=J, o_dim=o_dim).to(dev)
    X = ifm((torch.tensor(Yl, dtype=torch.float32, requires_grad=True,
                          device=dev), Yh2))
    # Also test giving None instead of an empty tensor
    for j in range(J):
        if hps[j]:
            Yh2[j] = None
    X2 = ifm((torch.tensor(Yl, dtype=torch.float32, device=dev), Yh2))
    f1 = Transform2d_np()
    x = f1.inverse(Yl, Yh1)

    np.testing.assert_array_almost_equal(
        X.detach().cpu(), x, decimal=PRECISION_FLOAT)
    np.testing.assert_array_almost_equal(
        X2.cpu(), x, decimal=PRECISION_FLOAT)

    # Test gradients are ok
    X.backward(torch.ones_like(X))
Ejemplo n.º 5
0
def end_to_end(size, no_grad, J, no_hp=False, dev='cuda'):
    x = torch.randn(*size, requires_grad=(not no_grad)).to(dev)
    xfm = DTCWTForward(J=J, skip_hps=no_hp).to(dev)
    ifm = DTCWTInverse(J=J).to(dev)
    Yl, Yh = xfm(x)
    Y = ifm((Yl, Yh))
    if not no_grad:
        Y.backward(torch.ones_like(Y))
    return Y
Ejemplo n.º 6
0
def inverse(size, no_grad, J, no_hp=False, dev='cuda'):
    yl = torch.randn(size[0], size[1], size[2] >> (J-1), size[3] >> (J-1),
                     requires_grad=(not no_grad)).to(dev)
    yh = [torch.randn(size[0], size[1], 6, size[2] >> j, size[3] >> j, 2,
                      requires_grad=(not no_grad)).to(dev)
          for j in range(1,J+1)]
    ifm = DTCWTInverse(J=J).to(dev)
    Y = ifm((yl, yh))
    if not no_grad:
        Y.backward(torch.ones_like(Y))
    return Y
Ejemplo n.º 7
0
 def __init__(self, C, F, k=4, stride=1, J=1, wd=0, wd1=None, right=True):
     super().__init__()
     self.wd = wd
     if wd1 is None:
         self.wd1 = wd
     else:
         self.wd1 = wd1
     self.k = k
     self.C = C
     self.F = F
     self.J = J
     self.right = right
     self.ifm = DTCWTInverse()
     self._init()
Ejemplo n.º 8
0
    def __init__(self,
                 C,
                 F,
                 lp_size=3,
                 bp_sizes=(1, ),
                 biort='near_sym_a',
                 qshift='qshift_a',
                 xfm=True,
                 ifm=True,
                 wd=0,
                 wd1=None,
                 lp_nl=None,
                 bp_nl=None,
                 lp_nl_kwargs={},
                 bp_nl_kwargs={}):
        super().__init__()
        self.C = C
        self.F = F
        # If any of the mixing for a scale is 0, don't calculate the dtcwt at
        # that scale
        skip_hps = [True if s == 0 else False for s in bp_sizes]
        self.J = len(bp_sizes)
        self.wd = wd
        self.wd1 = wd1

        if xfm:
            self.XFM = DTCWTForward(biort=biort,
                                    qshift=qshift,
                                    J=self.J,
                                    skip_hps=skip_hps,
                                    o_dim=2,
                                    ri_dim=-1)
        else:
            self.XFM = lambda x: x

        self.GainLayer = WaveGainLayer(C, F, lp_size, bp_sizes, wd=wd, wd1=wd1)

        if not isinstance(bp_nl, (list, tuple)):
            bp_nl = [
                bp_nl,
            ] * self.J
        self.NL = WaveNonLinearity(F, lp_nl, bp_nl, lp_nl_kwargs, bp_nl_kwargs)

        if ifm:
            self.IFM = DTCWTInverse(biort=biort,
                                    qshift=qshift,
                                    o_dim=2,
                                    ri_dim=-1)
        else:
            self.IFM = lambda x: x
Ejemplo n.º 9
0
def test_inv(J, o_dim):
    Yl = 100*np.random.randn(3, 5, 64, 64)
    Yhr = [[np.random.randn(3, 5, 2**j, 2**j) for l in range(6)] for j in range(4+J,4,-1)]
    Yhi = [[np.random.randn(3, 5, 2**j, 2**j) for l in range(6)] for j in range(4+J,4,-1)]
    Yh1 = [np.stack(r, axis=2) + 1j*np.stack(i, axis=2) for r, i in zip(Yhr, Yhi)]
    Yh2 = [np.stack((np.stack(r, axis=o_dim), np.stack(i, axis=o_dim)), axis=-1)
           for r, i in zip(Yhr, Yhi)]
    Yh2 = [torch.tensor(yh, dtype=torch.float32, device=dev) for yh in Yh2]
    ifm = DTCWTInverse(J=J, o_dim=o_dim).to(dev)
    X = ifm((torch.tensor(Yl, dtype=torch.float32, device=dev), Yh2))
    f1 = Transform2d_np()
    x = f1.inverse(Yl, Yh1)

    np.testing.assert_array_almost_equal(
        X.cpu(), x, decimal=PRECISION_FLOAT)
Ejemplo n.º 10
0
def test_end2end(biort, qshift, size, J):
    im = np.random.randn(5,6,*size).astype('float32')
    imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev)
    xfm = DTCWTForward(J=J, biort=biort, qshift=qshift).to(dev)
    Yl, Yh = xfm(imt)
    ifm = DTCWTInverse(J=J, biort=biort, qshift=qshift).to(dev)
    y = ifm((Yl, Yh))

    # Compare with numpy results
    f_np = Transform2d_np(biort=biort, qshift=qshift)
    yl, yh = f_np.forward(im, nlevels=J)
    y2 = f_np.inverse(yl, yh)

    np.testing.assert_array_almost_equal(y.detach().cpu(), y2, decimal=PRECISION_FLOAT)

    # Test gradients are ok
    y.backward(torch.ones_like(y))
Ejemplo n.º 11
0
 def __init__(self,
              wt_type='DTCWT',
              biort='near_sym_b',
              qshift='qshift_b',
              J=5,
              wave='db3',
              mode='zero',
              device='cuda',
              requires_grad=True):
     super().__init__()
     if wt_type == 'DTCWT':
         self.xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(device)
         self.ifm = DTCWTInverse(biort=biort, qshift=qshift).to(device)
     elif wt_type == 'DWT':
         self.xfm = DWTForward(wave=wave, J=J, mode=mode).to(device)
         self.ifm = DWTInverse(wave=wave, mode=mode).to(device)
     else:
         raise ValueError('no such type of wavelet transform is supported')
     self.J = J
     self.wt_type = wt_type
Ejemplo n.º 12
0
def test_inv_ri_dim(ri_dim):
    Yl = 100*np.random.randn(3, 5, 64, 64)
    J = 3
    Yhr = [np.random.randn(3, 5, 6, 2**j, 2**j) for j in range(4+J,4,-1)]
    Yhi = [np.random.randn(3, 5, 6, 2**j, 2**j) for j in range(4+J,4,-1)]
    Yh1 = [yhr + 1j*yhi for yhr, yhi in zip(Yhr, Yhi)]
    Yh2 = [torch.tensor(np.stack((yhr, yhi), axis=ri_dim),
                        dtype=torch.float32, device=dev)
           for yhr, yhi in zip(Yhr, Yhi)]

    if (ri_dim % 6) <= 2:
        o_dim = 3
    else:
        o_dim = 2

    ifm = DTCWTInverse(J=J, o_dim=o_dim, ri_dim=ri_dim).to(dev)
    X = ifm((torch.tensor(Yl, dtype=torch.float32, device=dev), Yh2))
    f1 = Transform2d_np()
    x = f1.inverse(Yl, Yh1)

    np.testing.assert_array_almost_equal(
        X.cpu(), x, decimal=PRECISION_FLOAT)
Ejemplo n.º 13
0
 def __init__(self):
     super().__init__()
     self.xfm = DTCWTForward(J=3, C=3)
     self.ifm = DTCWTInverse(J=3, C=3)
     self.sparsify = SparsifyWaveCoeffs2(3, 3)
Ejemplo n.º 14
0
def main():
    xfm = DTCWTForward(J=1)
    #  xfm.h0o = xfm.h0a
    #  xfm.h1o = xfm.h1a
    ifm = DTCWTInverse()
    #  ifm.g0o = ifm.g0a
    #  ifm.g1o = ifm.g1a
    b1 = (ifm.g0o.data.numpy().ravel()[::-1],
          ifm.g1o.data.numpy().ravel()[::-1])
    xfm2 = DTCWTForward(J=1, biort=b1)
    b1 = (np.copy(xfm.h0o.data.numpy().ravel()[::-1]),
          np.copy(xfm.h1o.data.numpy().ravel()[::-1]))
    ifm2 = DTCWTInverse(biort=b1)
    #  xfm2 = xfm
    #  ifm2 = ifm
    wd = 1e-2
    N = 8
    pad = (N - 1) // 2
    #  U = np.exp(-1j*2*np.pi*index/N)
    #  Us = 1/N * np.conj(U)

    w = np.random.randn(8, 5, N, N).astype('float32')
    W = torch.randn(8, 5, N, N, requires_grad=True)
    W_hat_lp, (W_hat_bp, ) = xfm(W.data)
    W_hat_lp.requires_grad = True
    W_hat_bp.requires_grad = True

    optim1 = torch.optim.SGD([
        W,
    ], lr=0.1, momentum=0.0, weight_decay=0)
    optim2 = torch.optim.SGD([W_hat_lp, W_hat_bp],
                             lr=0.1,
                             momentum=0.0,
                             weight_decay=0)
    #  optim1 = torch.optim.Adam([W,], lr=0.01)
    #  optim2 = CplxAdam([W_hat,], lr=0.01)
    #  optim1 = torch.optim.Adagrad([W,], lr=0.1)
    #  optim2 = torch.optim.Adagrad([W_hat,], lr=0.1)
    loss1 = torch.nn.MSELoss()
    loss2 = torch.nn.MSELoss()
    for i in range(10):
        print('Testing step {}'.format(i))
        optim1.zero_grad()
        optim2.zero_grad()

        W2 = ifm((W_hat_lp, (W_hat_bp, )))
        W2.retain_grad()
        np.testing.assert_array_almost_equal(W2.detach().numpy(),
                                             W.detach().numpy(),
                                             decimal=4)

        x = torch.randn(10, 5, 32, 32)
        y1 = F.conv2d(x, W, padding=pad)
        y2 = F.conv2d(x, W2, padding=pad)
        np.testing.assert_array_almost_equal(y1.detach().numpy(),
                                             y2.detach().numpy(),
                                             decimal=4)
        y_target = torch.randn(10, 8, 31, 31)

        output1 = loss1(y1, y_target)
        output2 = loss2(y2, y_target)
        output1.backward()
        output2.backward()

        # Check the gradients are the same before regularization
        np.testing.assert_array_almost_equal(W2.grad.numpy(),
                                             W.grad.numpy(),
                                             decimal=4)

        reg1 = wd * reg_loss(W, 'l2')
        reg2 = wd * reg_loss(W_hat_lp, 'l2') + wd * reg_loss(W_hat_bp, 'l2')

        # Do some sanity checks on gradients
        np.testing.assert_array_almost_equal(W.grad.data.numpy(),
                                             W2.grad.data.numpy(),
                                             decimal=4)

        # DTCWT = a, DTCWT_grad = b, pixel = c, pixel_grad = d
        a_lp = W_hat_lp.data.clone()
        a_bp = W_hat_bp.data.clone()
        da_lp = W_hat_lp.grad.data.clone()
        da_bp = W_hat_bp.grad.data.clone()
        b = W.data.clone()
        db = W.grad.data.clone()

        # Check that c -> a and d -> b
        b_lp, (b_bp, ) = xfm(b)
        db_lp, (db_bp, ) = xfm2(db)
        np.testing.assert_array_almost_equal(a_lp, b_lp, decimal=4)
        np.testing.assert_array_almost_equal(a_bp, b_bp, decimal=4)
        np.testing.assert_array_almost_equal(da_lp, db_lp, decimal=4)
        np.testing.assert_array_almost_equal(da_bp, db_bp, decimal=4)

        # Check that a -> c and b -> d
        a = ifm((a_lp, (a_bp, )))
        da = ifm2((da_lp, (da_bp, )))
        np.testing.assert_array_almost_equal(a, b, decimal=4)
        np.testing.assert_array_almost_equal(da, db, decimal=4)

        # Add in regularization
        #  reg1.backward()
        #  reg2.backward()
        optim1.step()
        optim2.step()

    print('Done! They matched')