Beispiel #1
0
    def transform_to_mat(self, x):
        """Transforms a batch of images to a matrix."""
        assert len(x.shape) == 4
        x = torch.as_tensor(x)
        if self.color_space == 'YUV':
            x = util.rgb_to_syuv(x)
        # If `color_space` == 'RGB', do nothing.

        # Reshape `x` from
        #   (num_batches, width, height, num_channels) to
        #   (num_batches * num_channels, width, height)
        _, width, height, num_channels = x.shape
        x_stack = torch.reshape(x.permute(0, 3, 1, 2), (-1, width, height))

        # Turn each channel in `x_stack` into the spatial representation specified
        # by `representation`.
        if self.representation in wavelet.generate_filters():
            x_stack = wavelet.flatten(
                wavelet.rescale(
                    wavelet.construct(x_stack, self.wavelet_num_levels,
                                      self.representation),
                    self.wavelet_scale_base))
        elif self.representation == 'DCT':
            x_stack = util.image_dct(x_stack)
        # If `representation` == 'PIXEL', do nothing.

        # Reshape `x_stack` from
        #   (num_batches * num_channels, width, height) to
        #   (num_batches, num_channels * width * height)
        x_mat = torch.reshape(
            torch.reshape(x_stack, (-1, num_channels, width, height)).permute(
                0, 2, 3, 1), [-1, width * height * num_channels])
        return x_mat
Beispiel #2
0
 def testWaveletTransformationIsVolumePreserving(self):
   """Tests that construct() is volume preserving when size is a power of 2."""
   for wavelet_type in wavelet.generate_filters():
     sz = (1, 4, 4)
     num_levels = 2
     # Construct the Jacobian of construct().
     im = np.float32(np.random.uniform(0., 1., sz))
     jacobian = []
     vec = lambda x: torch.reshape(x, [-1])
     for d in range(im.size):
       var_im = torch.autograd.Variable(torch.tensor(im), requires_grad=True)
       coeff = vec(
           wavelet.flatten(
               wavelet.construct(var_im, num_levels, wavelet_type)))[d]
       coeff.backward()
       jacobian.append(np.reshape(var_im.grad.detach().numpy(), [-1]))
     jacobian = np.stack(jacobian, 1)
     # Assert that the determinant of the Jacobian is close to 1.
     det = np.linalg.det(jacobian)
     np.testing.assert_allclose(det, 1., atol=1e-5, rtol=1e-5)
Beispiel #3
0
 def _construct_preserves_dtype(self, float_dtype):
     """Checks that construct()'s output has the same precision as its input."""
     x = float_dtype(np.random.normal(size=(3, 16, 16)))
     for wavelet_type in wavelet.generate_filters():
         y = wavelet.flatten(wavelet.construct(x, 3, wavelet_type))
         np.testing.assert_equal(y.detach().numpy().dtype, float_dtype)