Beispiel #1
0
def inv_j1(ll, highr, highi, g0, g1, o_dim, h_dim, w_dim, mode):
    """ Level1 inverse dtcwt.

    Have it as a separate function as can be used by the forward pass of the
    inverse transform and the backward pass of the forward transform.
    """
    if highr is None or highr.shape == torch.Size([]):
        y = rowfilter(colfilter(ll, g0), g0)
    else:
        # Get the double sampled bandpass coefficients
        lh, hl, hh = orientations_to_highs(highr, highi, o_dim)

        if ll is None or ll.shape == torch.Size([]):
            # Interpolate
            hi = colfilter(hh, g1, mode) + colfilter(hl, g0, mode)
            lo = colfilter(lh, g1, mode)
            del lh, hh, hl
        else:
            # Possibly cut back some rows to make the ll match the highs
            r, c = ll.shape[2:]
            r1, c1 = highr.shape[h_dim], highr.shape[w_dim]
            if r != r1 * 2:
                ll = ll[:, :, 1:-1]
            if c != c1 * 2:
                ll = ll[:, :, :, 1:-1]
            # Interpolate
            hi = colfilter(hh, g1, mode) + colfilter(hl, g0, mode)
            lo = colfilter(lh, g1, mode) + colfilter(ll, g0, mode)
            del lh, hl, hh

        y = rowfilter(hi, g1, mode) + rowfilter(lo, g0, mode)

    return y
Beispiel #2
0
def fwd_j1_rot(x, h0, h1, h2, skip_hps, o_dim, mode):
    """ Level 1 forward dtcwt.

    Have it as a separate function as can be used by
    the forward pass of the forward transform and the backward pass of the
    inverse transform.
    """
    # Level 1 forward (biorthogonal analysis filters)
    if not skip_hps:
        lo = rowfilter(x, h0, mode)
        hi = rowfilter(x, h1, mode)
        ba = rowfilter(x, h2, mode)

        lh = colfilter(lo, h1, mode)
        hl = colfilter(hi, h0, mode)
        hh = colfilter(ba, h2, mode)
        ll = colfilter(lo, h0, mode)

        del lo, hi, ba
        highr, highi = highs_to_orientations(lh, hl, hh, o_dim)
    else:
        ll = rowfilter(x, h0, mode)
        ll = colfilter(ll, h0, mode)
        highr = x.new_zeros([])
        highi = x.new_zeros([])
    return ll, highr, highi
Beispiel #3
0
def test_equal_numpy_biort2():
    h = _biort('near_sym_b')[0]
    im = barbara[:, 52:407, 30:401]
    im_t = torch.unsqueeze(torch.tensor(im, dtype=torch.float32),
                           dim=0).to(dev)
    ref = ref_rowfilter(im, h)
    y = rowfilter(im_t, prep_filt(h, 1).to(dev))
    np.testing.assert_array_almost_equal(y[0].cpu(), ref, decimal=4)
Beispiel #4
0
def test_equal_small_in():
    h = _qshift('qshift_b')[0]
    im = barbara[:, 0:4, 0:4]
    im_t = torch.unsqueeze(torch.tensor(im, dtype=torch.float32),
                           dim=0).to(dev)
    ref = ref_rowfilter(im, h)
    y = rowfilter(im_t, prep_filt(h, 1).to(dev))
    np.testing.assert_array_almost_equal(y[0].cpu(), ref, decimal=4)
Beispiel #5
0
def test_gradients():
    h = _biort('near_sym_b')[0]
    im_t = torch.unsqueeze(torch.tensor(barbara,
                                        dtype=torch.float32,
                                        requires_grad=True),
                           dim=0).to(dev)
    y_t = rowfilter(im_t, prep_filt(h, 1).to(dev))
    dy = np.random.randn(*tuple(y_t.shape)).astype('float32')
    torch.autograd.grad(y_t, im_t, grad_outputs=torch.tensor(dy))
Beispiel #6
0
def test_equal_numpy_biort1():
    h = _biort('near_sym_b')[0]
    ref = ref_rowfilter(barbara, h)
    y = rowfilter(barbara_t, prep_filt(h, 1).to(dev))
    np.testing.assert_array_almost_equal(y[0].cpu(), ref, decimal=4)
Beispiel #7
0
def test_even_size_batch():
    zero_t = torch.zeros([1, *barbara.shape], dtype=torch.float32).to(dev)
    h = [-1, 1]
    y = rowfilter(zero_t, prep_filt(h, 1).to(dev))
    assert list(y.shape)[1:] == bshape_extracol
    assert not np.any(y.cpu().numpy()[:] != 0.0)
Beispiel #8
0
def test_biort():
    h = _biort('antonini')[0]
    y_op = rowfilter(barbara_t, prep_filt(h, 1).to(dev))
    assert list(y_op.shape)[1:] == bshape
Beispiel #9
0
def test_qshift():
    h = _qshift('qshift_a')[0]
    x = barbara_t
    y_op = rowfilter(x, prep_filt(h, 1).to(dev))
    assert list(y_op.shape)[1:] == bshape_extracol
Beispiel #10
0
def test_even_size():
    h = [-1, -1]
    y_op = rowfilter(barbara_t, prep_filt(h, 1).to(dev))
    assert list(y_op.shape)[1:] == bshape_extracol
Beispiel #11
0
def test_odd_size():
    h = [-1, 2, -1]
    y_op = rowfilter(barbara_t, prep_filt(h, 1).to(dev))
    assert list(y_op.shape)[1:] == bshape
Beispiel #12
0
def test_equal_numpy_qshift1():
    h = _qshift('qshift_c')[0]
    ref = ref_rowfilter(barbara, h)
    y = rowfilter(barbara_t, prep_filt(h, 1).to(dev))
    np.testing.assert_array_almost_equal(y[0].cpu(), ref, decimal=4)