コード例 #1
0
 def __init__(self, biort='near_sym_a', qshift='qshift_a', mode='symmetric',
              magbias=1e-2, combine_colour=False):
     super().__init__()
     self.biort = biort
     self.qshift = biort
     # Have to convert the string to an int as the grad checks don't work
     # with string inputs
     self.mode_str = mode
     self.mode = mode_to_int(mode)
     self.magbias = magbias
     self.combine_colour = combine_colour
     if biort == 'near_sym_b_bp':
         assert qshift == 'qshift_b_bp'
         self.bandpass_diag = True
         h0o, _, h1o, _, h2o, _ = _biort(biort)
         self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False)
         self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False)
         self.h2o = torch.nn.Parameter(prep_filt(h2o, 1), False)
         h0a, h0b, _, _, h1a, h1b, _, _, h2a, h2b, _, _ = _qshift('qshift_b_bp')
         self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False)
         self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False)
         self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False)
         self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False)
         self.h2a = torch.nn.Parameter(prep_filt(h2a, 1), False)
         self.h2b = torch.nn.Parameter(prep_filt(h2b, 1), False)
     else:
         self.bandpass_diag = False
         h0o, _, h1o, _ = _biort(biort)
         self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False)
         self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False)
         h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift)
         self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False)
         self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False)
         self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False)
         self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False)
コード例 #2
0
def test_gradients_inv(biort, qshift, size, J):
    """ Gradient of forward function should be inverse function with filters
    swapped """
    im = np.random.randn(5, 6, *size).astype('float32')
    imt = torch.tensor(im, dtype=torch.float32, device=dev)
    ifm = DTCWTInverse(biort=biort, qshift=qshift).to(dev)
    h0o, g0o, h1o, g1o = _biort(biort)
    h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
    ifm_grad = DTCWTForward(J=J,
                            biort=(g0o[::-1], g1o[::-1]),
                            qshift=(g0a[::-1], g0b[::-1], g1a[::-1],
                                    g1b[::-1])).to(dev)
    yl, yh = ifm_grad(imt)
    g = torch.randn(*imt.shape, device=dev)
    ylv = torch.randn(*yl.shape, requires_grad=True, device=dev)
    yhv = [torch.randn(*h.shape, requires_grad=True, device=dev) for h in yh]
    Y = ifm((ylv, yhv))
    Y.backward(g)

    # Check the lowpass gradient is the same
    ref_lp, ref_bp = ifm_grad(g)
    np.testing.assert_array_almost_equal(ylv.grad.detach().cpu(), ref_lp.cpu())
    # check the bandpasses are the same
    for y, ref in zip(yhv, ref_bp):
        np.testing.assert_array_almost_equal(y.grad.detach().cpu(), ref.cpu())
コード例 #3
0
def test_gradients_fwd(biort, qshift, size, J):
    """ Gradient of forward function should be inverse function with filters
    swapped """
    im = np.random.randn(5, 6, *size).astype('float32')
    imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev)
    xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(dev)
    h0o, g0o, h1o, g1o = _biort(biort)
    h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift)
    xfm_grad = DTCWTInverse(biort=(h0o[::-1], h1o[::-1]),
                            qshift=(h0a[::-1], h0b[::-1], h1a[::-1],
                                    h1b[::-1])).to(dev)
    Yl, Yh = xfm(imt)
    Ylg = torch.randn(*Yl.shape, device=dev)
    Yl.backward(Ylg, retain_graph=True)
    ref = xfm_grad((Ylg, [
        None,
    ] * J))
    np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu())
    for j, y in enumerate(Yh):
        imt.grad.zero_()
        g = torch.randn(*y.shape, device=dev)
        y.backward(g, retain_graph=True)
        hps = [
            None,
        ] * J
        hps[j] = g
        ref = xfm_grad((torch.zeros_like(Yl), hps))
        np.testing.assert_array_almost_equal(imt.grad.detach().cpu(),
                                             ref.cpu())
コード例 #4
0
    def __init__(self,
                 biort='near_sym_a',
                 qshift='qshift_a',
                 J=3,
                 o_dim=2,
                 ri_dim=-1):
        super().__init__()
        self.biort = biort
        self.qshift = qshift
        self.o_dim = o_dim
        self.ri_dim = ri_dim
        self.J = J
        if isinstance(biort, str):
            _, g0o, _, g1o = _biort(biort)
            self.g0o = torch.nn.Parameter(prep_filt(g0o, 1), False)
            self.g1o = torch.nn.Parameter(prep_filt(g1o, 1), False)
        else:
            self.g0o = torch.nn.Parameter(prep_filt(biort[0], 1), False)
            self.g1o = torch.nn.Parameter(prep_filt(biort[1], 1), False)
        if isinstance(qshift, str):
            _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift)
            self.g0a = torch.nn.Parameter(prep_filt(g0a, 1), False)
            self.g0b = torch.nn.Parameter(prep_filt(g0b, 1), False)
            self.g1a = torch.nn.Parameter(prep_filt(g1a, 1), False)
            self.g1b = torch.nn.Parameter(prep_filt(g1b, 1), False)
        else:
            self.g0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False)
            self.g0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False)
            self.g1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False)
            self.g1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False)

        # Create the function to do the DTCWT
        self.dtcwt_func = getattr(tf, 'ifm{J}'.format(J=J))
コード例 #5
0
    def __init__(self, C=None, biort='near_sym_a', qshift='qshift_a', J=3,
                 o_before_c=False):
        super().__init__()
        if C is not None:
            warnings.warn('C parameter is deprecated. do not need to pass it')
        self.biort = biort
        self.qshift = qshift
        self.o_before_c = o_before_c
        self.J = J
        if isinstance(biort, str):
            _, g0o, _, g1o = _biort(biort)
            self.g0o = torch.nn.Parameter(prep_filt(g0o, 1), False)
            self.g1o = torch.nn.Parameter(prep_filt(g1o, 1), False)
        else:
            self.g0o = torch.nn.Parameter(prep_filt(biort[0], 1), False)
            self.g1o = torch.nn.Parameter(prep_filt(biort[1], 1), False)
        if isinstance(qshift, str):
            _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift)
            self.g0a = torch.nn.Parameter(prep_filt(g0a, 1), False)
            self.g0b = torch.nn.Parameter(prep_filt(g0b, 1), False)
            self.g1a = torch.nn.Parameter(prep_filt(g1a, 1), False)
            self.g1b = torch.nn.Parameter(prep_filt(g1b, 1), False)
        else:
            self.g0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False)
            self.g0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False)
            self.g1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False)
            self.g1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False)

        # Create the function to do the DTCWT
        self.dtcwt_func = getattr(tf, 'ifm{J}'.format(J=J))
コード例 #6
0
 def __init__(self,
              biort='near_sym_a',
              qshift='qshift_a',
              o_dim=2,
              ri_dim=-1,
              mode='symmetric'):
     super().__init__()
     self.biort = biort
     self.qshift = qshift
     self.o_dim = o_dim
     self.ri_dim = ri_dim
     self.mode = mode
     if isinstance(biort, str):
         _, g0o, _, g1o = _biort(biort)
         self.register_buffer('g0o', prep_filt(g0o, 1))
         self.register_buffer('g1o', prep_filt(g1o, 1))
     else:
         self.register_buffer('g0o', prep_filt(biort[0], 1))
         self.register_buffer('g1o', prep_filt(biort[1], 1))
     if isinstance(qshift, str):
         _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift)
         self.register_buffer('g0a', prep_filt(g0a, 1))
         self.register_buffer('g0b', prep_filt(g0b, 1))
         self.register_buffer('g1a', prep_filt(g1a, 1))
         self.register_buffer('g1b', prep_filt(g1b, 1))
     else:
         self.register_buffer('g0a', prep_filt(qshift[0], 1))
         self.register_buffer('g0b', prep_filt(qshift[1], 1))
         self.register_buffer('g1a', prep_filt(qshift[2], 1))
         self.register_buffer('g1b', prep_filt(qshift[3], 1))
コード例 #7
0
ファイル: transform2d.py プロジェクト: kevinbro96/SID
 def __init__(self, biort='near_sym_a', qshift='qshift_a', o_dim=2,
              ri_dim=-1, mode='symmetric'):
     super().__init__()
     self.biort = biort
     self.qshift = qshift
     self.o_dim = o_dim
     self.ri_dim = ri_dim
     self.mode = mode
     if isinstance(biort, str):
         _, g0o, _, g1o = _biort(biort)
         self.g0o = torch.nn.Parameter(prep_filt(g0o, 1), False)
         self.g1o = torch.nn.Parameter(prep_filt(g1o, 1), False)
     else:
         self.g0o = torch.nn.Parameter(prep_filt(biort[0], 1), False)
         self.g1o = torch.nn.Parameter(prep_filt(biort[1], 1), False)
     if isinstance(qshift, str):
         _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift)
         self.g0a = torch.nn.Parameter(prep_filt(g0a, 1), False)
         self.g0b = torch.nn.Parameter(prep_filt(g0b, 1), False)
         self.g1a = torch.nn.Parameter(prep_filt(g1a, 1), False)
         self.g1b = torch.nn.Parameter(prep_filt(g1b, 1), False)
     else:
         self.g0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False)
         self.g0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False)
         self.g1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False)
         self.g1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False)
コード例 #8
0
ファイル: transform2d.py プロジェクト: csinva/local-vae
    def __init__(self,
                 biort='near_sym_a',
                 qshift='qshift_a',
                 J=3,
                 skip_hps=False,
                 include_scale=False,
                 o_dim=2,
                 ri_dim=-1,
                 mode='symmetric'):
        super().__init__()
        if o_dim == ri_dim:
            raise ValueError("Orientations and real/imaginary parts must be "
                             "in different dimensions.")

        self.biort = biort
        self.qshift = qshift
        self.J = J
        self.o_dim = o_dim
        self.ri_dim = ri_dim
        self.mode = mode
        if isinstance(biort, str):
            h0o, _, h1o, _ = _biort(biort)
            self.h0o = prep_filt(h0o, 1)
            self.h1o = prep_filt(h1o, 1)
        else:
            self.h0o = prep_filt(biort[0], 1)
            self.h1o = prep_filt(biort[1], 1)
        if isinstance(qshift, str):
            h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift)
            self.h0a = prep_filt(h0a, 1)
            self.h0b = prep_filt(h0b, 1)
            self.h1a = prep_filt(h1a, 1)
            self.h1b = prep_filt(h1b, 1)
        else:
            self.h0a = prep_filt(qshift[0], 1)
            self.h0b = prep_filt(qshift[1], 1)
            self.h1a = prep_filt(qshift[2], 1)
            self.h1b = prep_filt(qshift[3], 1)

        self.h0o = nn.Parameter(self.h0o, requires_grad=True)
        self.h1o = nn.Parameter(self.h1o, requires_grad=True)
        self.h0a = nn.Parameter(self.h0a, requires_grad=True)
        self.h0b = nn.Parameter(self.h0b, requires_grad=True)
        self.h1a = nn.Parameter(self.h1a, requires_grad=True)
        self.h1b = nn.Parameter(self.h1b, requires_grad=True)

        # Get the function to do the DTCWT
        if isinstance(skip_hps, (list, tuple, ndarray)):
            self.skip_hps = skip_hps
        else:
            self.skip_hps = [
                skip_hps,
            ] * self.J
        if isinstance(include_scale, (list, tuple, ndarray)):
            self.include_scale = include_scale
        else:
            self.include_scale = [
                include_scale,
            ] * self.J
コード例 #9
0
def test_equal_numpy_biort2():
    h = _biort('near_sym_b')[0]
    im = barbara[:, 52:407, 30:401]
    im_t = torch.unsqueeze(torch.tensor(im, dtype=torch.float32),
                           dim=0).to(dev)
    ref = ref_rowfilter(im, h)
    y = rowfilter(im_t, prep_filt(h, 1).to(dev))
    np.testing.assert_array_almost_equal(y[0].cpu(), ref, decimal=4)
コード例 #10
0
def test_gradients():
    h = _biort('near_sym_b')[0]
    im_t = torch.unsqueeze(torch.tensor(barbara,
                                        dtype=torch.float32,
                                        requires_grad=True),
                           dim=0).to(dev)
    y_t = rowfilter(im_t, prep_filt(h, 1).to(dev))
    dy = np.random.randn(*tuple(y_t.shape)).astype('float32')
    torch.autograd.grad(y_t, im_t, grad_outputs=torch.tensor(dy))
コード例 #11
0
    def __init__(self,
                 biort='near_sym_a',
                 qshift='qshift_a',
                 J=3,
                 skip_hps=False,
                 include_scale=False,
                 downsample=False,
                 o_dim=2,
                 ri_dim=-1):
        super().__init__()
        if o_dim == ri_dim:
            raise ValueError("Orientations and real/imaginary parts must be "
                             "in different dimensions.")

        self.biort = biort
        self.qshift = qshift
        self.J = J
        self.downsample = downsample
        self.o_dim = o_dim
        self.ri_dim = ri_dim
        if isinstance(biort, str):
            h0o, _, h1o, _ = _biort(biort)
            self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False)
            self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False)
        else:
            self.h0o = torch.nn.Parameter(prep_filt(biort[0], 1), False)
            self.h1o = torch.nn.Parameter(prep_filt(biort[1], 1), False)
        if isinstance(qshift, str):
            h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift)
            self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False)
            self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False)
            self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False)
            self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False)
        else:
            self.h0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False)
            self.h0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False)
            self.h1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False)
            self.h1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False)

        # Get the function to do the DTCWT
        if isinstance(skip_hps, (list, tuple, ndarray)):
            self.skip_hps = skip_hps
        else:
            self.skip_hps = [
                skip_hps,
            ] * self.J
        if isinstance(include_scale, (list, tuple, ndarray)):
            self.include_scale = include_scale
        else:
            self.include_scale = [
                include_scale,
            ] * self.J
        if True in self.include_scale:
            self.dtcwt_func = getattr(tf, 'xfm{J}scale'.format(J=J))
        else:
            self.dtcwt_func = getattr(tf, 'xfm{J}'.format(J=J))
コード例 #12
0
 def __init__(self, biort='near_sym_a', mode='symmetric', magbias=1e-2):
     super().__init__()
     self.biort = biort
     # Have to convert the string to an int as the grad checks don't work
     # with string inputs
     self.mode_str = mode
     self.mode = mode_to_int(mode)
     self.magbias = magbias
     h0o, _, h1o, _ = _biort(biort)
     self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False)
     self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False)
     self.lp_pool = nn.AvgPool2d(2)
コード例 #13
0
    def __init__(self, C=None, biort='near_sym_a', qshift='qshift_a',
                 J=3, skip_hps=False, o_before_c=False, include_scale=False,
                 downsample=False):
        super().__init__()
        if C is not None:
            warnings.warn('C parameter is deprecated. do not need to pass it '
                          'anymore.')

        self.biort = biort
        self.qshift = qshift
        self.o_before_c = o_before_c
        self.J = J
        self.downsample = downsample
        if isinstance(biort, str):
            h0o, _, h1o, _ = _biort(biort)
            self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False)
            self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False)
        else:
            self.h0o = torch.nn.Parameter(prep_filt(biort[0], 1), False)
            self.h1o = torch.nn.Parameter(prep_filt(biort[1], 1), False)
        if isinstance(qshift, str):
            h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift)
            self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False)
            self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False)
            self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False)
            self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False)
        else:
            self.h0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False)
            self.h0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False)
            self.h1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False)
            self.h1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False)

        # Get the function to do the DTCWT
        if isinstance(skip_hps, (list, tuple, ndarray)):
            self.skip_hps = skip_hps
        else:
            self.skip_hps = [skip_hps,] * self.J
        if isinstance(include_scale, (list, tuple, ndarray)):
            self.include_scale = include_scale
        else:
            self.include_scale = [include_scale,] * self.J
        if True in self.include_scale:
            self.dtcwt_func = getattr(tf, 'xfm{J}scale'.format(J=J))
        else:
            self.dtcwt_func = getattr(tf, 'xfm{J}'.format(J=J))
コード例 #14
0
def test_equal_numpy_biort1():
    h = _biort('near_sym_b')[0]
    ref = ref_rowfilter(barbara, h)
    y = rowfilter(barbara_t, prep_filt(h, 1).to(dev))
    np.testing.assert_array_almost_equal(y[0].cpu(), ref, decimal=4)
コード例 #15
0
def test_biort():
    h = _biort('antonini')[0]
    y_op = rowfilter(barbara_t, prep_filt(h, 1).to(dev))
    assert list(y_op.shape)[1:] == bshape