def testRescaleAndUnrescaleReproducesInput(self): """Tests that rescale(rescale(x, k), 1/k) = x.""" im = np.random.uniform(size=(2, 32, 32)) scale_base = np.exp(np.random.normal()) pyr = wavelet.construct(im, 4, 'LeGall5/3') pyr_rescaled = wavelet.rescale(pyr, scale_base) pyr_recon = wavelet.rescale(pyr_rescaled, 1. / scale_base) self._assert_pyramids_close(pyr, pyr_recon, 1e-8)
def transform_to_mat(self, x): """Transforms a batch of images to a matrix.""" assert len(x.shape) == 4 x = torch.as_tensor(x) if self.color_space == 'YUV': x = util.rgb_to_syuv(x) # If `color_space` == 'RGB', do nothing. # Reshape `x` from # (num_batches, width, height, num_channels) to # (num_batches * num_channels, width, height) _, width, height, num_channels = x.shape x_stack = torch.reshape(x.permute(0, 3, 1, 2), (-1, width, height)) # Turn each channel in `x_stack` into the spatial representation specified # by `representation`. if self.representation in wavelet.generate_filters(): x_stack = wavelet.flatten( wavelet.rescale( wavelet.construct(x_stack, self.wavelet_num_levels, self.representation), self.wavelet_scale_base)) elif self.representation == 'DCT': x_stack = util.image_dct(x_stack) # If `representation` == 'PIXEL', do nothing. # Reshape `x_stack` from # (num_batches * num_channels, width, height) to # (num_batches, num_channels * width * height) x_mat = torch.reshape( torch.reshape(x_stack, (-1, num_channels, width, height)).permute( 0, 2, 3, 1), [-1, width * height * num_channels]) return x_mat
def testRescaleOneHalfIsNormalized(self): """Tests that rescale(construct(k), 0.5)[-1] = k for constant image k.""" for num_levels in range(5): k = np.random.uniform() im = k * np.ones((2, 32, 32)) pyr = wavelet.construct(im, num_levels, 'LeGall5/3') pyr_rescaled = wavelet.rescale(pyr, 0.5) np.testing.assert_allclose(pyr_rescaled[-1], k * np.ones_like(pyr_rescaled[-1]), atol=1e-8, rtol=1e-8)
def testRescaleDoesNotAffectTheFirstLevel(self): """Tests that rescale(x, s)[0] = x[0] for any s.""" im = np.random.uniform(size=(2, 32, 32)) pyr = wavelet.construct(im, 4, 'LeGall5/3') pyr_rescaled = wavelet.rescale(pyr, np.exp(np.random.normal())) self._assert_pyramids_close(pyr[0:1], pyr_rescaled[0:1], 1e-8)
def testRescaleOneIsANoOp(self): """Tests that rescale(x, 1) = x.""" im = np.random.uniform(size=(2, 32, 32)) pyr = wavelet.construct(im, 4, 'LeGall5/3') pyr_rescaled = wavelet.rescale(pyr, 1.) self._assert_pyramids_close(pyr, pyr_rescaled, 1e-8)