Ejemplo n.º 1
0
def check_band_energies(coeff_1, coeff_2, rtol=1e-4, atol=1e-4):
    '''
    function that checks if the energy in each band of two pyramids are the same.
    We set an absolute and relative tolerance and the function checks for each band if
    abs(coeff_1-coeff_2) <= atol + rtol*abs(coeff_1)
    Args:
    coeff_1: first dictionary of torch tensors corresponding to each band
    coeff_2: second dictionary of torch tensors corresponding to each band
    '''

    for i in range(len(coeff_1.items())):
        k1 = list(coeff_1.keys())[i]
        k2 = list(coeff_2.keys())[i]
        band_1 = coeff_1[k1]
        band_2 = coeff_2[k2]
        if band_1.shape[-1] == 2:
            band_1 = torch_complex_to_numpy(band_1)
            band_2 = torch_complex_to_numpy(band_2)
        else:
            band_1 = to_numpy(band_1)
            band_2 = to_numpy(band_2)
        band_1 = band_1.squeeze()
        band_2 = band_2.squeeze()

        np.testing.assert_allclose(np.sum(np.abs(band_1)**2),
                                   np.sum(np.abs(band_2)**2),
                                   rtol=rtol,
                                   atol=atol)
Ejemplo n.º 2
0
 def test_pyr_to_tensor(self, img, spyr, scales, rtol=1e-12, atol=1e-12):
     pyr_tensor = spyr.forward(img, scales=scales)
     pyr_coeff_dict = spyr.convert_tensor_to_pyr(pyr_tensor)
     for i in range(len(pyr_coeff_dict.keys())):
         k1 = list(pyr_coeff_dict.keys())[i]
         k2 = list(spyr.pyr_coeffs.keys())[i]
         np.testing.assert_allclose(to_numpy(pyr_coeff_dict[k1]),
                                    to_numpy(spyr.pyr_coeffs[k2]),
                                    rtol=rtol,
                                    atol=atol)
Ejemplo n.º 3
0
 def test_partial_recon(self, img, spyr):
     spyr.forward(img)
     # need to add 1 because our heights are 0-indexed (i.e., the lowest
     # height has k[0]==0)
     height = max(
         [k[0]
          for k in spyr.pyr_coeffs.keys() if isinstance(k[0], int)]) + 1
     pt_spyr = pt.pyramids.SteerablePyramidFreq(to_numpy(img.squeeze()),
                                                height=height,
                                                order=spyr.order,
                                                is_complex=spyr.is_complex)
     recon_levels = [[0], [1, 3], [1, 3, 4]]
     recon_bands = [[1], [1, 3]]
     for levels, bands in product(['all'] + recon_levels,
                                  ['all'] + recon_bands):
         po_recon = to_numpy(spyr.recon_pyr(levels, bands).squeeze())
         pt_recon = pt_spyr.recon_pyr(levels, bands)
         np.testing.assert_allclose(po_recon,
                                    pt_recon,
                                    rtol=1e-4,
                                    atol=1e-4)
Ejemplo n.º 4
0
 def test_torch_vs_numpy_pyr(self, img, spyr):
     torch_spc = spyr.forward(img)
     # need to add 1 because our heights are 0-indexed (i.e., the lowest
     # height has k[0]==0)
     height = max(
         [k[0]
          for k in spyr.pyr_coeffs.keys() if isinstance(k[0], int)]) + 1
     pyrtools_sp = pt.pyramids.SteerablePyramidFreq(
         to_numpy(img.squeeze()),
         height=height,
         order=spyr.order,
         is_complex=spyr.is_complex)
     pyrtools_spc = pyrtools_sp.pyr_coeffs
     check_pyr_coeffs(pyrtools_spc, torch_spc)
Ejemplo n.º 5
0
def check_parseval(im, coeff, rtol=1e-4, atol=0):
    '''
    function that checks if the pyramid is parseval, i.e. energy of coeffs is
    the same as the energy in the original image.
    Args:
    input image: image stimulus as torch.Tensor
    coeff: dictionary of torch tensors corresponding to each band
    '''
    total_band_energy = 0
    im_energy = np.sum(to_numpy(im)**2)
    for k, v in coeff.items():
        band = coeff[k]
        if band.shape[-1] == 2:
            band = torch_complex_to_numpy(band)
        else:
            band = to_numpy(band)
        band = band.squeeze()

        total_band_energy += np.sum(np.abs(band)**2)

    np.testing.assert_allclose(total_band_energy,
                               im_energy,
                               rtol=rtol,
                               atol=atol)
Ejemplo n.º 6
0
 def test_recon_match_pyrtools(self, img, spyr, rtol=1e-6, atol=1e-6):
     # this should fail if and only if test_complete_recon does, but
     # may as well include it just in case
     spyr.forward(img)
     # need to add 1 because our heights are 0-indexed (i.e., the lowest
     # height has k[0]==0)
     height = max(
         [k[0]
          for k in spyr.pyr_coeffs.keys() if isinstance(k[0], int)]) + 1
     pt_pyr = pt.pyramids.SteerablePyramidFreq(to_numpy(img.squeeze()),
                                               height=height,
                                               order=spyr.order,
                                               is_complex=spyr.is_complex)
     po_recon = po.to_numpy(spyr.recon_pyr().squeeze())
     pt_recon = pt_pyr.recon_pyr()
     np.testing.assert_allclose(po_recon, pt_recon, rtol=rtol, atol=atol)
Ejemplo n.º 7
0
def check_pyr_coeffs(coeff_np, coeff_torch, rtol=1e-3, atol=1e-3):
    '''
    function that checks if two sets of pyramid coefficients (one numpy  and one torch) are the same
    We set an absolute and relative tolerance and the following function checks if
    abs(coeff1-coeff2) <= atol + rtol*abs(coeff1)
    Inputs:
    coeff1: numpy pyramid coefficients
    coeff2: torch pyramid coefficients
    Both coeffs must obviously have the same number of scales, orientations etc.
    '''

    for k in coeff_np.keys():
        coeff_np_k = coeff_np[k]
        coeff_torch_k = coeff_torch[k]
        if coeff_torch_k.shape[-1] == 2:
            coeff_torch_k = torch_complex_to_numpy(coeff_torch_k)
        else:
            coeff_torch_k = to_numpy(coeff_torch_k)
        coeff_torch_k = coeff_torch_k.squeeze()
        np.testing.assert_allclose(coeff_torch_k,
                                   coeff_np_k,
                                   rtol=rtol,
                                   atol=atol)
Ejemplo n.º 8
0
 def test_complete_recon(self, img, spyr):
     spyr.forward(img)
     recon = to_numpy(spyr.recon_pyr())
     np.testing.assert_allclose(recon, to_numpy(img), rtol=1e-4, atol=1e-4)