def fwd_j2plus_rot(x, h0a, h1a, h0b, h1b, h2a, h2b, skip_hps, o_dim, mode): """ Level 2 plus forward dtcwt. Have it as a separate function as can be used by the forward pass of the forward transform and the backward pass of the inverse transform. """ if not skip_hps: lo = rowdfilt(x, h0b, h0a, False, mode) hi = rowdfilt(x, h1b, h1a, True, mode) ba = rowdfilt(x, h2b, h2a, True, mode) lh = coldfilt(lo, h1b, h1a, True, mode) hl = coldfilt(hi, h0b, h0a, False, mode) hh = coldfilt(ba, h2b, h2a, True, mode) ll = coldfilt(lo, h0b, h0a, False, mode) del lo, hi, ba highr, highi = highs_to_orientations(lh, hl, hh, o_dim) else: ll = rowdfilt(x, h0b, h0a, False, mode) ll = coldfilt(ll, h0b, h0a, False, mode) highr = None highi = None return ll, highr, highi
def test_gradients(): ha = qshift('qshift_c')[0] hb = qshift('qshift_c')[1] im_t = torch.unsqueeze(torch.tensor(barbara, dtype=torch.float32, requires_grad=True), dim=0) y_t = coldfilt(im_t, prep_filt(ha, 1), prep_filt(hb, 1), np.sum(ha*hb) > 0) dy = np.random.randn(*tuple(y_t.shape)).astype('float32') torch.autograd.grad(y_t, im_t, grad_outputs=torch.tensor(dy))
def test_equal_numpy_qshift1(hp): if hp: ha = qshift('qshift_a')[4] hb = qshift('qshift_a')[5] else: ha = qshift('qshift_a')[0] hb = qshift('qshift_a')[1] ref = ref_coldfilt(barbara, ha, hb) y = coldfilt(barbara_t, prep_filt(ha, 1).to(dev), prep_filt(hb, 1).to(dev), highpass=hp) np.testing.assert_array_almost_equal(y[0].cpu(), ref, decimal=4)
def test_equal_small_in(hp): if hp: ha = qshift('qshift_a')[4] hb = qshift('qshift_a')[5] else: ha = qshift('qshift_a')[0] hb = qshift('qshift_a')[1] im = barbara[:,0:4,0:4] im_t = torch.unsqueeze(torch.tensor(im, dtype=torch.float32), dim=0).to(dev) ref = ref_coldfilt(im, ha, hb) y = coldfilt(im_t, prep_filt(ha, 1).to(dev), prep_filt(hb, 1).to(dev), highpass=hp) np.testing.assert_array_almost_equal(y[0].cpu(), ref, decimal=4)
def test_output_size(): ha = prep_filt((-1, 1), 1).to(dev) hb = prep_filt((1, -1), 1).to(dev) y_op = coldfilt(barbara_t, ha, hb) assert list(y_op.shape)[1:] == bshape_half
def test_good_input_size_non_orthogonal(): ha = prep_filt((1, 1), 1).to(dev) hb = prep_filt((1, -1), 1).to(dev) coldfilt(barbara_t[:,:,:,:511], ha, hb)
def test_good_input_size(): ha = prep_filt((-1, 1), 1).to(dev) hb = prep_filt((1, -1), 1).to(dev) coldfilt(barbara_t[:,:,:,:511], ha, hb)
def test_bad_input_size(): with raises(ValueError): ha = prep_filt((-1, 1), 1).to(dev) hb = prep_filt((1, -1), 1).to(dev) coldfilt(barbara_t[:,:,:511,:], ha, hb)
def test_different_size(): with raises(ValueError): ha = prep_filt((-0.5,-1,2,0.5), 1).to(dev) hb = prep_filt((-1,2,1), 1).to(dev) coldfilt(barbara_t, ha, hb)
def test_odd_filter(): with raises(ValueError): ha = prep_filt((-1,2,-1), 1).to(dev) hb = prep_filt((-1,2,1), 1).to(dev) coldfilt(barbara_t, ha, hb)