def testAnalysisLowpassFiltersAreNormalized(self): """Tests that the analysis lowpass filter doubles the input's magnitude.""" for wavelet_type in wavelet.generate_filters(): filters = wavelet.generate_filters(wavelet_type) # The sum of the outer product of the analysis lowpass filter with itself. magnitude = np.sum(filters.analysis_lo[:, np.newaxis] * filters.analysis_lo[np.newaxis, :]) np.testing.assert_allclose(magnitude, 2., atol=1e-10, rtol=1e-10)
def transform_to_mat(self, x): """Transforms a batch of images to a matrix.""" assert len(x.shape) == 4 x = torch.as_tensor(x) if self.color_space == 'YUV': x = util.rgb_to_syuv(x) # If `color_space` == 'RGB', do nothing. # Reshape `x` from # (num_batches, width, height, num_channels) to # (num_batches * num_channels, width, height) _, width, height, num_channels = x.shape x_stack = torch.reshape(x.permute(0, 3, 1, 2), (-1, width, height)) # Turn each channel in `x_stack` into the spatial representation specified # by `representation`. if self.representation in wavelet.generate_filters(): x_stack = wavelet.flatten( wavelet.rescale( wavelet.construct(x_stack, self.wavelet_num_levels, self.representation), self.wavelet_scale_base)) elif self.representation == 'DCT': x_stack = util.image_dct(x_stack) # If `representation` == 'PIXEL', do nothing. # Reshape `x_stack` from # (num_batches * num_channels, width, height) to # (num_batches, num_channels * width * height) x_mat = torch.reshape( torch.reshape(x_stack, (-1, num_channels, width, height)).permute( 0, 2, 3, 1), [-1, width * height * num_channels]) return x_mat
def testAccurateRoundTripWithSmallRandomImages(self): """Tests that collapse(construct(x)) == x for x = [1, k, k], k in [1, 4].""" for wavelet_type in wavelet.generate_filters(): for width in range(0, 5): sz = [1, width, width] num_levels = wavelet.get_max_num_levels(sz) im = np.random.uniform(size=sz) pyr = wavelet.construct(im, num_levels, wavelet_type) recon = wavelet.collapse(pyr, wavelet_type) np.testing.assert_allclose(recon, im, atol=1e-8, rtol=1e-8)
def _collapse_preserves_dtype(self, float_dtype): """Checks that collapse()'s output has the same precision as its input.""" n = 16 x = [] for n in [8, 4, 2]: band = [] for _ in range(3): band.append(float_dtype(np.random.normal(size=(3, n, n)))) x.append(band) x.append(float_dtype(np.random.normal(size=(3, n, n)))) for wavelet_type in wavelet.generate_filters(): y = wavelet.collapse(x, wavelet_type) np.testing.assert_equal(y.detach().numpy().dtype, float_dtype)
def testAccurateRoundTripWithLargeRandomImages(self): """Tests that collapse(construct(x)) == x for large random x's.""" for wavelet_type in wavelet.generate_filters(): for _ in range(4): num_levels = np.int32(np.ceil(4 * np.random.uniform())) sz_clamp = 2**(num_levels - 1) + 1 sz = np.maximum( np.int32( np.ceil(np.array([2, 32, 32]) * np.random.uniform(size=3))), np.array([0, sz_clamp, sz_clamp])) im = np.random.uniform(size=sz) pyr = wavelet.construct(im, num_levels, wavelet_type) recon = wavelet.collapse(pyr, wavelet_type) np.testing.assert_allclose(recon, im, atol=1e-8, rtol=1e-8)
def testDecompositionIsNonRedundant(self): """Test that wavelet construction is not redundant. If the wavelet decompositon is not redundant, then we should be able to 1) Construct a wavelet decomposition 2) Alter a single coefficient in the decomposition 3) Collapse that decomposition into an image and back and the two wavelet decompositions should be the same. """ for wavelet_type in wavelet.generate_filters(): for _ in range(4): # Construct an image and a wavelet decomposition of it. num_levels = np.int32(np.ceil(4 * np.random.uniform())) sz_clamp = 2**(num_levels - 1) + 1 sz = np.maximum( np.int32( np.ceil( np.array([2, 32, 32]) * np.random.uniform(size=3))), np.array([0, sz_clamp, sz_clamp]), ) im = np.random.uniform(size=sz) pyr = wavelet.construct(im, num_levels, wavelet_type) # Pick a coefficient at random in the decomposition to alter. d = np.int32(np.floor(np.random.uniform() * len(pyr))) v = np.random.uniform() if d == (len(pyr) - 1): if np.prod(pyr[d].shape) > 0: c, i, j = np.int32( np.floor( np.array(np.random.uniform(size=3)) * pyr[d].shape)).tolist() pyr[d][c, i, j] = v else: b = np.int32(np.floor(np.random.uniform() * len(pyr[d]))) if np.prod(pyr[d][b].shape) > 0: c, i, j = np.int32( np.floor( np.array(np.random.uniform(size=3)) * pyr[d][b].shape)).tolist() pyr[d][b][c, i, j] = v # Collapse and then reconstruct the wavelet decomposition, and check # that it is unchanged. recon = wavelet.collapse(pyr, wavelet_type) pyr_again = wavelet.construct(recon, num_levels, wavelet_type) self._assert_pyramids_close(pyr, pyr_again, 1e-8)
def testWaveletTransformationIsVolumePreserving(self): """Tests that construct() is volume preserving when size is a power of 2.""" for wavelet_type in wavelet.generate_filters(): sz = (1, 4, 4) num_levels = 2 # Construct the Jacobian of construct(). im = np.float32(np.random.uniform(0., 1., sz)) jacobian = [] vec = lambda x: torch.reshape(x, [-1]) for d in range(im.size): var_im = torch.autograd.Variable(torch.tensor(im), requires_grad=True) coeff = vec( wavelet.flatten( wavelet.construct(var_im, num_levels, wavelet_type)))[d] coeff.backward() jacobian.append(np.reshape(var_im.grad.detach().numpy(), [-1])) jacobian = np.stack(jacobian, 1) # Assert that the determinant of the Jacobian is close to 1. det = np.linalg.det(jacobian) np.testing.assert_allclose(det, 1., atol=1e-5, rtol=1e-5)
def __init__(self, image_size, float_dtype, device, color_space='YUV', representation='CDF9/7', wavelet_num_levels=5, wavelet_scale_base=1, use_students_t=False, **kwargs): """Sets up the adaptive form of the robust loss on a set of images. This function is a wrapper around AdaptiveLossFunction. It requires inputs of a specific shape and size, and constructs internal parameters describing each non-batch dimension. By default, this function uses a CDF9/7 wavelet decomposition in a YUV color space, which often works well. Args: image_size: The size (width, height, num_channels) of the input images. float_dtype: The dtype of the floats used as input. device: The device to use. color_space: The color space that `x` will be transformed into before computing the loss. Must be 'RGB' (in which case no transformation is applied) or 'YUV' (in which case we actually use a volume-preserving scaled YUV colorspace so that log-likelihoods still have meaning, see util.rgb_to_syuv()). Note that changing this argument does not change the assumption that `x` is the set of differences between RGB images, it just changes what color space `x` is converted to from RGB when computing the loss. representation: The spatial image representation that `x` will be transformed into after converting the color space and before computing the loss. If this is a valid type of wavelet according to wavelet.generate_filters() then that is what will be used, but we also support setting this to 'DCT' which applies a 2D DCT to the images, and to 'PIXEL' which applies no transformation to the image, thereby causing the loss to be imposed directly on pixels. wavelet_num_levels: If `representation` is a kind of wavelet, this is the number of levels used when constructing wavelet representations. Otherwise this is ignored. Should probably be set to as large as possible a value that is supported by the input resolution, such as that produced by wavelet.get_max_num_levels(). wavelet_scale_base: If `representation` is a kind of wavelet, this is the base of the scaling used when constructing wavelet representations. Otherwise this is ignored. For image_lossfun() to be volume preserving (a useful property when evaluating generative models) this value must be == 1. If the goal of this loss isn't proper statistical modeling, then modifying this value (say, setting it to 0.5 or 2) may significantly improve performance. use_students_t: If true, use the NLL of Student's T-distribution instead of the adaptive loss. This causes all `alpha_*` inputs to be ignored. **kwargs: Arguments to be passed to the underlying lossfun(). Raises: ValueError: if `color_space` of `representation` are unsupported color spaces or image representations, respectively. """ super(AdaptiveImageLossFunction, self).__init__() color_spaces = ['RGB', 'YUV'] if color_space not in color_spaces: raise ValueError('`color_space` must be in {}, but is {!r}'.format( color_spaces, color_space)) representations = wavelet.generate_filters() + ['DCT', 'PIXEL'] if representation not in representations: raise ValueError( '`representation` must be in {}, but is {!r}'.format( representations, representation)) assert len(image_size) == 3 self.color_space = color_space self.representation = representation self.wavelet_num_levels = wavelet_num_levels self.wavelet_scale_base = wavelet_scale_base self.use_students_t = use_students_t self.image_size = image_size if float_dtype == np.float32: float_dtype = torch.float32 if float_dtype == np.float64: float_dtype = torch.float64 self.float_dtype = float_dtype self.device = device if isinstance(device, int) or\ (isinstance(device, str) and 'cuda' in device) or\ (isinstance(device, torch.device) and device.type == 'cuda'): torch.cuda.set_device(self.device) x_example = torch.zeros([1] + list(self.image_size)).type( self.float_dtype) x_example_mat = self.transform_to_mat(x_example) self.num_dims = x_example_mat.shape[1] if self.use_students_t: self.adaptive_lossfun = StudentsTLossFunction( self.num_dims, self.float_dtype, self.device, **kwargs) else: self.adaptive_lossfun = AdaptiveLossFunction( self.num_dims, self.float_dtype, self.device, **kwargs)
def _construct_preserves_dtype(self, float_dtype): """Checks that construct()'s output has the same precision as its input.""" x = float_dtype(np.random.normal(size=(3, 16, 16))) for wavelet_type in wavelet.generate_filters(): y = wavelet.flatten(wavelet.construct(x, 3, wavelet_type)) np.testing.assert_equal(y.detach().numpy().dtype, float_dtype)