def fwd_j2plus_rot(x, h0a, h1a, h0b, h1b, h2a, h2b, skip_hps, o_dim, mode): """ Level 2 plus forward dtcwt. Have it as a separate function as can be used by the forward pass of the forward transform and the backward pass of the inverse transform. """ if not skip_hps: lo = rowdfilt(x, h0b, h0a, False, mode) hi = rowdfilt(x, h1b, h1a, True, mode) ba = rowdfilt(x, h2b, h2a, True, mode) lh = coldfilt(lo, h1b, h1a, True, mode) hl = coldfilt(hi, h0b, h0a, False, mode) hh = coldfilt(ba, h2b, h2a, True, mode) ll = coldfilt(lo, h0b, h0a, False, mode) del lo, hi, ba highr, highi = highs_to_orientations(lh, hl, hh, o_dim) else: ll = rowdfilt(x, h0b, h0a, False, mode) ll = coldfilt(ll, h0b, h0a, False, mode) highr = None highi = None return ll, highr, highi
def test_equal_numpy_qshift1(hp): if hp: ha = qshift('qshift_a')[4] hb = qshift('qshift_a')[5] else: ha = qshift('qshift_a')[0] hb = qshift('qshift_a')[1] ref = ref_rowdfilt(barbara, ha, hb) y = rowdfilt(barbara_t, prep_filt(ha, 1).to(dev), prep_filt(hb, 1).to(dev), highpass=hp) np.testing.assert_array_almost_equal(y[0].cpu(), ref, decimal=4)
def test_equal_small_in(hp): if hp: ha = qshift('qshift_b')[4] hb = qshift('qshift_b')[5] else: ha = qshift('qshift_b')[0] hb = qshift('qshift_b')[1] im = barbara[:, 0:4, 0:4] im_t = torch.unsqueeze(torch.tensor(im, dtype=torch.float32), dim=0).to(dev) ref = ref_rowdfilt(im, ha, hb) y = rowdfilt(im_t, prep_filt(ha, 1).to(dev), prep_filt(hb, 1).to(dev), highpass=hp) np.testing.assert_array_almost_equal(y[0].cpu(), ref, decimal=4)
def test_gradients(hp): if hp: ha = qshift('qshift_b')[4] hb = qshift('qshift_b')[5] else: ha = qshift('qshift_b')[0] hb = qshift('qshift_b')[1] im_t = torch.unsqueeze(torch.tensor(barbara, dtype=torch.float32, requires_grad=True), dim=0).to(dev) y_t = rowdfilt(im_t, prep_filt(ha, 1).to(dev), prep_filt(hb, 1).to(dev), highpass=hp) np.random.randn(*tuple(y_t.shape)).astype('float32')
def test_output_size(): ha = prep_filt((-1, 1), 1).to(dev) hb = prep_filt((1, -1), 1).to(dev) y_op = rowdfilt(barbara_t, ha, hb) assert list(y_op.shape[1:]) == bshape_half
def test_good_input_size_non_orthogonal(): ha = prep_filt((1, 1), 1).to(dev) hb = prep_filt((1, -1), 1).to(dev) rowdfilt(barbara_t[:, :, :511, :], ha, hb)
def test_bad_input_size(): with raises(ValueError): ha = prep_filt((-1, 1), 1).to(dev) hb = prep_filt((1, -1), 1).to(dev) rowdfilt(barbara_t[:, :, :, :511], ha, hb)
def test_different_size(): with raises(ValueError): ha = prep_filt((-0.5, -1, 2, 0.5), 1).to(dev) hb = prep_filt((-1, 2, 1), 1).to(dev) rowdfilt(barbara_t, ha, hb)
def test_odd_filter(): with raises(ValueError): ha = prep_filt((-1, 2, -1), 1).to(dev) hb = prep_filt((-1, 2, 1), 1).to(dev) rowdfilt(barbara_t, ha, hb)