def transform_to_mat(self, x): """Transforms a batch of images to a matrix.""" assert len(x.shape) == 4 x = torch.as_tensor(x) if self.color_space == 'YUV': x = util.rgb_to_syuv(x) # If `color_space` == 'RGB', do nothing. # Reshape `x` from # (num_batches, width, height, num_channels) to # (num_batches * num_channels, width, height) _, width, height, num_channels = x.shape x_stack = torch.reshape(x.permute(0, 3, 1, 2), (-1, width, height)) # Turn each channel in `x_stack` into the spatial representation specified # by `representation`. if self.representation in wavelet.generate_filters(): x_stack = wavelet.flatten( wavelet.rescale( wavelet.construct(x_stack, self.wavelet_num_levels, self.representation), self.wavelet_scale_base)) elif self.representation == 'DCT': x_stack = util.image_dct(x_stack) # If `representation` == 'PIXEL', do nothing. # Reshape `x_stack` from # (num_batches * num_channels, width, height) to # (num_batches, num_channels * width * height) x_mat = torch.reshape( torch.reshape(x_stack, (-1, num_channels, width, height)).permute( 0, 2, 3, 1), [-1, width * height * num_channels]) return x_mat
def testWaveletTransformationIsVolumePreserving(self): """Tests that construct() is volume preserving when size is a power of 2.""" for wavelet_type in wavelet.generate_filters(): sz = (1, 4, 4) num_levels = 2 # Construct the Jacobian of construct(). im = np.float32(np.random.uniform(0., 1., sz)) jacobian = [] vec = lambda x: torch.reshape(x, [-1]) for d in range(im.size): var_im = torch.autograd.Variable(torch.tensor(im), requires_grad=True) coeff = vec( wavelet.flatten( wavelet.construct(var_im, num_levels, wavelet_type)))[d] coeff.backward() jacobian.append(np.reshape(var_im.grad.detach().numpy(), [-1])) jacobian = np.stack(jacobian, 1) # Assert that the determinant of the Jacobian is close to 1. det = np.linalg.det(jacobian) np.testing.assert_allclose(det, 1., atol=1e-5, rtol=1e-5)
def _construct_preserves_dtype(self, float_dtype): """Checks that construct()'s output has the same precision as its input.""" x = float_dtype(np.random.normal(size=(3, 16, 16))) for wavelet_type in wavelet.generate_filters(): y = wavelet.flatten(wavelet.construct(x, 3, wavelet_type)) np.testing.assert_equal(y.detach().numpy().dtype, float_dtype)