Exemple #1
0
 def __init__(self,
              biort='near_sym_a',
              qshift='qshift_a',
              mode='symmetric',
              magbias=1e-2,
              combine_colour=False):
     super().__init__()
     # Have to convert the string to an int as the grad checks don't work
     # with string inputs
     self.combine_colour = combine_colour
     self.MagFn1 = MagFn(b=magbias, c=1)
     self.MagFn2 = MagFn(b=magbias, c=1)
     self.MagFn1c = MagFn(b=magbias, c=1)
     self.MagFn2c = MagFn(b=magbias, c=1)
     self.MagFn3 = MagFn(b=magbias, c=1)
     self.xfm1 = DTCWTForward(J=2,
                              biort=biort,
                              qshift=qshift,
                              o_dim=2,
                              ri_dim=-1,
                              mode=mode)
     self.xfm2 = DTCWTForward(J=1,
                              biort=biort,
                              qshift=qshift,
                              o_dim=2,
                              ri_dim=-1,
                              mode=mode)
     Hr, Hi = filters_rotated()
     self.Hr = nn.Parameter(Hr, requires_grad=False)
     self.Hi = nn.Parameter(Hi, requires_grad=False)
Exemple #2
0
def test_fwd(J, o_before_c):
    X = 100 * np.random.randn(3, 5, 100, 100)
    Xt = torch.tensor(X, dtype=torch.get_default_dtype(), device=dev)
    xfm = DTCWTForward(J=J, o_before_c=o_before_c).to(dev)
    Yl, Yh = xfm(Xt)
    f1 = Transform2d_np()
    yl, yh = f1.forward(X, nlevels=J)

    np.testing.assert_array_almost_equal(Yl.cpu(), yl, decimal=PRECISION_FLOAT)
    for i in range(len(yh)):
        if o_before_c:
            np.testing.assert_array_almost_equal(Yh[i][..., 0].cpu().transpose(
                2, 1),
                                                 yh[i].real,
                                                 decimal=PRECISION_FLOAT)
            np.testing.assert_array_almost_equal(Yh[i][..., 1].cpu().transpose(
                2, 1),
                                                 yh[i].imag,
                                                 decimal=PRECISION_FLOAT)
        else:
            np.testing.assert_array_almost_equal(Yh[i][..., 0].cpu(),
                                                 yh[i].real,
                                                 decimal=PRECISION_FLOAT)
            np.testing.assert_array_almost_equal(Yh[i][..., 1].cpu(),
                                                 yh[i].imag,
                                                 decimal=PRECISION_FLOAT)
Exemple #3
0
def test_fwd_ri_dim(o_dim, ri_dim):
    J = 3
    X = 100 * np.random.randn(3, 5, 100, 100)
    Xt = torch.tensor(X, dtype=torch.get_default_dtype(), device=dev)
    xfm = DTCWTForward(J=J, o_dim=o_dim, ri_dim=ri_dim).to(dev)
    Yl, Yh = xfm(Xt)
    f1 = Transform2d_np()
    yl, yh = f1.forward(X, nlevels=J)

    np.testing.assert_array_almost_equal(Yl.cpu(), yl, decimal=PRECISION_FLOAT)

    if (ri_dim % 6) < o_dim:
        o_dim -= 1

    for i in range(len(yh)):
        ours_r = np.take(Yh[i].cpu().numpy(), 0, ri_dim)
        ours_i = np.take(Yh[i].cpu().numpy(), 1, ri_dim)
        for l in range(6):
            ours = np.take(ours_r, l, o_dim)
            np.testing.assert_array_almost_equal(ours,
                                                 yh[i][:, :, l].real,
                                                 decimal=PRECISION_FLOAT)
            ours = np.take(ours_i, l, o_dim)
            np.testing.assert_array_almost_equal(ours,
                                                 yh[i][:, :, l].imag,
                                                 decimal=PRECISION_FLOAT)
Exemple #4
0
def test_gradients_fwd(biort, qshift, size, J):
    """ Gradient of forward function should be inverse function with filters
    swapped """
    im = np.random.randn(5, 6, *size).astype('float32')
    imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev)
    xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(dev)
    h0o, g0o, h1o, g1o = _biort(biort)
    h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
    xfm_grad = DTCWTInverse(biort=(h0o[::-1], h1o[::-1]),
                            qshift=(h0a[::-1], h0b[::-1], h1a[::-1],
                                    h1b[::-1])).to(dev)
    Yl, Yh = xfm(imt)
    Ylg = torch.randn(*Yl.shape, device=dev)
    Yl.backward(Ylg, retain_graph=True)
    ref = xfm_grad((Ylg, [
        None,
    ] * J))
    np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu())
    for j, y in enumerate(Yh):
        imt.grad.zero_()
        g = torch.randn(*y.shape, device=dev)
        y.backward(g, retain_graph=True)
        hps = [
            None,
        ] * J
        hps[j] = g
        ref = xfm_grad((torch.zeros_like(Yl), hps))
        np.testing.assert_array_almost_equal(imt.grad.detach().cpu(),
                                             ref.cpu())
Exemple #5
0
def test_gradients_inv(biort, qshift, size, J):
    """ Gradient of forward function should be inverse function with filters
    swapped """
    im = np.random.randn(5, 6, *size).astype('float32')
    imt = torch.tensor(im, dtype=torch.float32, device=dev)
    ifm = DTCWTInverse(biort=biort, qshift=qshift).to(dev)
    h0o, g0o, h1o, g1o = _biort(biort)
    h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
    ifm_grad = DTCWTForward(J=J,
                            biort=(g0o[::-1], g1o[::-1]),
                            qshift=(g0a[::-1], g0b[::-1], g1a[::-1],
                                    g1b[::-1])).to(dev)
    yl, yh = ifm_grad(imt)
    g = torch.randn(*imt.shape, device=dev)
    ylv = torch.randn(*yl.shape, requires_grad=True, device=dev)
    yhv = [torch.randn(*h.shape, requires_grad=True, device=dev) for h in yh]
    Y = ifm((ylv, yhv))
    Y.backward(g)

    # Check the lowpass gradient is the same
    ref_lp, ref_bp = ifm_grad(g)
    np.testing.assert_array_almost_equal(ylv.grad.detach().cpu(), ref_lp.cpu())
    # check the bandpasses are the same
    for y, ref in zip(yhv, ref_bp):
        np.testing.assert_array_almost_equal(y.grad.detach().cpu(), ref.cpu())
Exemple #6
0
def test_fwd_skip_hps(J, o_before_c):
    X = 100 * np.random.randn(3, 5, 100, 100)
    # Randomly turn on/off the highpass outputs
    hps = np.random.binomial(size=J, n=1, p=0.5).astype('bool')
    xfm = DTCWTForward(J=J, skip_hps=hps, o_before_c=o_before_c).to(dev)
    Yl, Yh = xfm(torch.tensor(X, dtype=torch.float32, device=dev))
    f1 = Transform2d_np()
    yl, yh = f1.forward(X, nlevels=J)

    np.testing.assert_array_almost_equal(Yl.cpu(), yl, decimal=PRECISION_FLOAT)
    for j in range(J):
        if hps[j]:
            assert Yh[j].shape == torch.Size([0])
        else:
            if o_before_c:
                np.testing.assert_array_almost_equal(Yh[j][...,
                                                           0].cpu().transpose(
                                                               2, 1),
                                                     yh[j].real,
                                                     decimal=PRECISION_FLOAT)
                np.testing.assert_array_almost_equal(Yh[j][...,
                                                           1].cpu().transpose(
                                                               2, 1),
                                                     yh[j].imag,
                                                     decimal=PRECISION_FLOAT)
            else:
                np.testing.assert_array_almost_equal(Yh[j][..., 0].cpu(),
                                                     yh[j].real,
                                                     decimal=PRECISION_FLOAT)
                np.testing.assert_array_almost_equal(Yh[j][..., 1].cpu(),
                                                     yh[j].imag,
                                                     decimal=PRECISION_FLOAT)
Exemple #7
0
def forward(size, no_grad, J, no_hp=False, dev='cuda'):
    x = torch.randn(*size, requires_grad=(not no_grad)).to(dev)
    xfm = DTCWTForward(J=J, skip_hps=no_hp, o_dim=1).to(dev)
    Yl, Yh = xfm(x)
    if not no_grad:
        Yl.backward(torch.ones_like(Yl))
    return Yl.mean(), [y.mean() for y in Yh]
Exemple #8
0
def forward(size, no_grad, J, no_hp=False, dev='cuda'):
    x = torch.randn(*size, requires_grad=(not no_grad)).to(dev)
    xfm = DTCWTForward(J=J, skip_hps=no_hp, o_dim=1, mode='symmetric').to(dev)
    for _ in range(5):
        Yl, Yh = xfm(x)
        if not no_grad:
            Yl.backward(torch.ones_like(Yl))
    return Yl, Yh
Exemple #9
0
def end_to_end(size, no_grad, J, no_hp=False, dev='cuda'):
    x = torch.randn(*size, requires_grad=(not no_grad)).to(dev)
    xfm = DTCWTForward(J=J, skip_hps=no_hp).to(dev)
    ifm = DTCWTInverse(J=J).to(dev)
    Yl, Yh = xfm(x)
    Y = ifm((Yl, Yh))
    if not no_grad:
        Y.backward(torch.ones_like(Y))
    return Y
Exemple #10
0
def test_bwd_include_scale(scales):
    X = 100 * np.random.randn(3, 5, 100, 100)
    # Randomly turn on/off the highpass outputs
    J = len(scales)
    xfm = DTCWTForward(J=J, include_scale=scales).to(dev)
    Ys, Yh = xfm(
        torch.tensor(X, dtype=torch.float32, requires_grad=True, device=dev))
    f1 = Transform2d_np()
    yl, yh, ys = f1.forward(X, nlevels=J, include_scale=True)

    for ys in Ys:
        ys.backward(torch.ones_like(ys), retain_graph=True)
 def __init__(self, stride=2):
     super().__init__()
     self.xfm = DTCWTForward(J=1, o_dim=1, ri_dim=2)
     self.mag = MagReshape(b=1e-5, o_dim=1, ri_dim=2)
     if stride == 2:
         self.lp = nn.AvgPool2d(2)
         self.bp = lambda x: x
         self.avg = lambda x: 2 * func.avg_pool2d(x, 2)
     elif stride == 1:
         self.lp = lambda x: x
         self.bp = lambda x: 0.5 * func.interpolate(
             x, scale_factor=2, mode='bilinear', align_corners=False)
     else:
         raise ValueError("Can only do 1 or 2 stride")
Exemple #12
0
    def __init__(self,
                 C,
                 F,
                 lp_size=3,
                 bp_sizes=(1, ),
                 biort='near_sym_a',
                 qshift='qshift_a',
                 xfm=True,
                 ifm=True,
                 wd=0,
                 wd1=None,
                 lp_nl=None,
                 bp_nl=None,
                 lp_nl_kwargs={},
                 bp_nl_kwargs={}):
        super().__init__()
        self.C = C
        self.F = F
        # If any of the mixing for a scale is 0, don't calculate the dtcwt at
        # that scale
        skip_hps = [True if s == 0 else False for s in bp_sizes]
        self.J = len(bp_sizes)
        self.wd = wd
        self.wd1 = wd1

        if xfm:
            self.XFM = DTCWTForward(biort=biort,
                                    qshift=qshift,
                                    J=self.J,
                                    skip_hps=skip_hps,
                                    o_dim=2,
                                    ri_dim=-1)
        else:
            self.XFM = lambda x: x

        self.GainLayer = WaveGainLayer(C, F, lp_size, bp_sizes, wd=wd, wd1=wd1)

        if not isinstance(bp_nl, (list, tuple)):
            bp_nl = [
                bp_nl,
            ] * self.J
        self.NL = WaveNonLinearity(F, lp_nl, bp_nl, lp_nl_kwargs, bp_nl_kwargs)

        if ifm:
            self.IFM = DTCWTInverse(biort=biort,
                                    qshift=qshift,
                                    o_dim=2,
                                    ri_dim=-1)
        else:
            self.IFM = lambda x: x
Exemple #13
0
def test_fwd_include_scale(scales):
    X = 100*np.random.randn(3, 5, 100, 100)
    # Randomly turn on/off the highpass outputs
    J = len(scales)
    xfm = DTCWTForward(J=J, include_scale=scales).to(dev)
    Ys, Yh = xfm(torch.tensor(X, dtype=torch.float32, device=dev))
    f1 = Transform2d_np()
    yl, yh, ys = f1.forward(X, nlevels=J, include_scale=True)

    for j in range(J):
        if not scales[j]:
            assert Ys[j].shape == torch.Size([0])
        else:
            np.testing.assert_array_almost_equal(
                Ys[j].cpu(), ys[j], decimal=PRECISION_FLOAT)
Exemple #14
0
def test_end2end(biort, qshift, size, J):
    im = np.random.randn(5,6,*size).astype('float32')
    imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev)
    xfm = DTCWTForward(J=J, biort=biort, qshift=qshift).to(dev)
    Yl, Yh = xfm(imt)
    ifm = DTCWTInverse(J=J, biort=biort, qshift=qshift).to(dev)
    y = ifm((Yl, Yh))

    # Compare with numpy results
    f_np = Transform2d_np(biort=biort, qshift=qshift)
    yl, yh = f_np.forward(im, nlevels=J)
    y2 = f_np.inverse(yl, yh)

    np.testing.assert_array_almost_equal(y.detach().cpu(), y2, decimal=PRECISION_FLOAT)

    # Test gradients are ok
    y.backward(torch.ones_like(y))
Exemple #15
0
 def __init__(self,
              wt_type='DTCWT',
              biort='near_sym_b',
              qshift='qshift_b',
              J=5,
              wave='db3',
              mode='zero',
              device='cuda',
              requires_grad=True):
     super().__init__()
     if wt_type == 'DTCWT':
         self.xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(device)
         self.ifm = DTCWTInverse(biort=biort, qshift=qshift).to(device)
     elif wt_type == 'DWT':
         self.xfm = DWTForward(wave=wave, J=J, mode=mode).to(device)
         self.ifm = DWTInverse(wave=wave, mode=mode).to(device)
     else:
         raise ValueError('no such type of wavelet transform is supported')
     self.J = J
     self.wt_type = wt_type
Exemple #16
0
def test_fwd_double(J, o_dim):
    with set_double_precision():
        X = 100*np.random.randn(3, 5, 100, 100)
        Xt = torch.tensor(X, dtype=torch.get_default_dtype(), device=dev)
        xfm = DTCWTForward(J=J, o_dim=o_dim).to(dev)
        Yl, Yh = xfm(Xt)
    assert Yl.dtype == torch.float64
    f1 = Transform2d_np()
    yl, yh = f1.forward(X, nlevels=J)

    np.testing.assert_array_almost_equal(
        Yl.cpu(), yl, decimal=PRECISION_DOUBLE)
    for i in range(len(yh)):
        for l in range(6):
            ours_r = np.take(Yh[i][...,0].cpu().numpy(), l, o_dim)
            ours_i = np.take(Yh[i][...,1].cpu().numpy(), l, o_dim)
            np.testing.assert_array_almost_equal(
                ours_r, yh[i][:,:,l].real, decimal=PRECISION_DOUBLE)
            np.testing.assert_array_almost_equal(
                ours_i, yh[i][:,:,l].imag, decimal=PRECISION_DOUBLE)
Exemple #17
0
    def _init(self):
        # To get the right coeff sizes, perform a forward dtcwt on a kernel size
        # you ultimately want after reconstruction.
        x = torch.zeros(self.F, self.C, self.k, self.k)
        torch.nn.init.xavier_uniform_(x)
        xfm = DTCWTForward(J=self.J)
        yl, yh = xfm(x)
        self.downsample = False

        if self.k == 4:
            if self.J == 1:
                self.downsample = True
                yl = func.avg_pool2d(yl, 2)
            self.u_lp = nn.Parameter(torch.zeros_like(yl))
            self.uj = nn.Parameter(torch.zeros_like(yh[-1]))
            self.u_lp.data = yl.data
            self.uj.data = yh[-1].data
            if self.right:
                self.pad = (1, 2, 1, 2)
            else:
                self.pad = (2, 1, 2, 1)
        elif self.k == 8:
            if self.J == 1:
                self.downsample = True
                yl = func.avg_pool2d(yl, 2)
            self.u_lp = nn.Parameter(torch.zeros_like(yl))
            self.uj = nn.Parameter(torch.zeros_like(yh[-1]))
            self.u_lp.data = yl.data
            self.uj.data = yh[-1].data
            if self.right:
                self.pad = (3, 4, 3, 4)
            else:
                self.pad = (4, 3, 4, 3)
        else:
            raise NotImplementedError
        return yl, yh
Exemple #18
0
def test_odd_rows_and_cols():
    xfm = DTCWTForward(J=3).to(dev)
    Yl, Yh = xfm(barbara_t[:, :, :509, :509])
Exemple #19
0
def test_specific_wavelet():
    xfm = DTCWTForward(J=3, biort='antonini', qshift='qshift_06').to(dev)
    Yl, Yh = xfm(barbara_t)
    assert len(Yl.shape) == 4
    assert len(Yh) == 3
    assert Yh[0].shape[-1] == 2
Exemple #20
0
def test_simple():
    xfm = DTCWTForward(J=3).to(dev)
    Yl, Yh = xfm(barbara_t)
    assert len(Yl.shape) == 4
    assert len(Yh) == 3
    assert Yh[0].shape[-1] == 2
Exemple #21
0
def main():
    xfm = DTCWTForward(J=1)
    #  xfm.h0o = xfm.h0a
    #  xfm.h1o = xfm.h1a
    ifm = DTCWTInverse()
    #  ifm.g0o = ifm.g0a
    #  ifm.g1o = ifm.g1a
    b1 = (ifm.g0o.data.numpy().ravel()[::-1],
          ifm.g1o.data.numpy().ravel()[::-1])
    xfm2 = DTCWTForward(J=1, biort=b1)
    b1 = (np.copy(xfm.h0o.data.numpy().ravel()[::-1]),
          np.copy(xfm.h1o.data.numpy().ravel()[::-1]))
    ifm2 = DTCWTInverse(biort=b1)
    #  xfm2 = xfm
    #  ifm2 = ifm
    wd = 1e-2
    N = 8
    pad = (N - 1) // 2
    #  U = np.exp(-1j*2*np.pi*index/N)
    #  Us = 1/N * np.conj(U)

    w = np.random.randn(8, 5, N, N).astype('float32')
    W = torch.randn(8, 5, N, N, requires_grad=True)
    W_hat_lp, (W_hat_bp, ) = xfm(W.data)
    W_hat_lp.requires_grad = True
    W_hat_bp.requires_grad = True

    optim1 = torch.optim.SGD([
        W,
    ], lr=0.1, momentum=0.0, weight_decay=0)
    optim2 = torch.optim.SGD([W_hat_lp, W_hat_bp],
                             lr=0.1,
                             momentum=0.0,
                             weight_decay=0)
    #  optim1 = torch.optim.Adam([W,], lr=0.01)
    #  optim2 = CplxAdam([W_hat,], lr=0.01)
    #  optim1 = torch.optim.Adagrad([W,], lr=0.1)
    #  optim2 = torch.optim.Adagrad([W_hat,], lr=0.1)
    loss1 = torch.nn.MSELoss()
    loss2 = torch.nn.MSELoss()
    for i in range(10):
        print('Testing step {}'.format(i))
        optim1.zero_grad()
        optim2.zero_grad()

        W2 = ifm((W_hat_lp, (W_hat_bp, )))
        W2.retain_grad()
        np.testing.assert_array_almost_equal(W2.detach().numpy(),
                                             W.detach().numpy(),
                                             decimal=4)

        x = torch.randn(10, 5, 32, 32)
        y1 = F.conv2d(x, W, padding=pad)
        y2 = F.conv2d(x, W2, padding=pad)
        np.testing.assert_array_almost_equal(y1.detach().numpy(),
                                             y2.detach().numpy(),
                                             decimal=4)
        y_target = torch.randn(10, 8, 31, 31)

        output1 = loss1(y1, y_target)
        output2 = loss2(y2, y_target)
        output1.backward()
        output2.backward()

        # Check the gradients are the same before regularization
        np.testing.assert_array_almost_equal(W2.grad.numpy(),
                                             W.grad.numpy(),
                                             decimal=4)

        reg1 = wd * reg_loss(W, 'l2')
        reg2 = wd * reg_loss(W_hat_lp, 'l2') + wd * reg_loss(W_hat_bp, 'l2')

        # Do some sanity checks on gradients
        np.testing.assert_array_almost_equal(W.grad.data.numpy(),
                                             W2.grad.data.numpy(),
                                             decimal=4)

        # DTCWT = a, DTCWT_grad = b, pixel = c, pixel_grad = d
        a_lp = W_hat_lp.data.clone()
        a_bp = W_hat_bp.data.clone()
        da_lp = W_hat_lp.grad.data.clone()
        da_bp = W_hat_bp.grad.data.clone()
        b = W.data.clone()
        db = W.grad.data.clone()

        # Check that c -> a and d -> b
        b_lp, (b_bp, ) = xfm(b)
        db_lp, (db_bp, ) = xfm2(db)
        np.testing.assert_array_almost_equal(a_lp, b_lp, decimal=4)
        np.testing.assert_array_almost_equal(a_bp, b_bp, decimal=4)
        np.testing.assert_array_almost_equal(da_lp, db_lp, decimal=4)
        np.testing.assert_array_almost_equal(da_bp, db_bp, decimal=4)

        # Check that a -> c and b -> d
        a = ifm((a_lp, (a_bp, )))
        da = ifm2((da_lp, (da_bp, )))
        np.testing.assert_array_almost_equal(a, b, decimal=4)
        np.testing.assert_array_almost_equal(da, db, decimal=4)

        # Add in regularization
        #  reg1.backward()
        #  reg2.backward()
        optim1.step()
        optim2.step()

    print('Done! They matched')
Exemple #22
0
 def __init__(self):
     super().__init__()
     self.xfm = DTCWTForward(J=3, C=3)
     self.ifm = DTCWTInverse(J=3, C=3)
     self.sparsify = SparsifyWaveCoeffs2(3, 3)
Exemple #23
0
def forward(size, no_grad, J, no_hp=False, dev='cuda'):
    x = torch.randn(*size, requires_grad=(not no_grad)).to(dev)
    xfm = DTCWTForward(J=J, skip_hps=no_hp).to(dev)
    Yl, Yh = xfm(x)
    if not no_grad:
        Yl.backward(torch.ones_like(Yl))