コード例 #1
0
 def testAnalysisLowpassFiltersAreNormalized(self):
     """Tests that the analysis lowpass filter doubles the input's magnitude."""
     for wavelet_type in wavelet.generate_filters():
         filters = wavelet.generate_filters(wavelet_type)
         # The sum of the outer product of the analysis lowpass filter with itself.
         magnitude = np.sum(filters.analysis_lo[:, np.newaxis] *
                            filters.analysis_lo[np.newaxis, :])
         np.testing.assert_allclose(magnitude, 2., atol=1e-10, rtol=1e-10)
コード例 #2
0
    def transform_to_mat(self, x):
        """Transforms a batch of images to a matrix."""
        assert len(x.shape) == 4
        x = torch.as_tensor(x)
        if self.color_space == 'YUV':
            x = util.rgb_to_syuv(x)
        # If `color_space` == 'RGB', do nothing.

        # Reshape `x` from
        #   (num_batches, width, height, num_channels) to
        #   (num_batches * num_channels, width, height)
        _, width, height, num_channels = x.shape
        x_stack = torch.reshape(x.permute(0, 3, 1, 2), (-1, width, height))

        # Turn each channel in `x_stack` into the spatial representation specified
        # by `representation`.
        if self.representation in wavelet.generate_filters():
            x_stack = wavelet.flatten(
                wavelet.rescale(
                    wavelet.construct(x_stack, self.wavelet_num_levels,
                                      self.representation),
                    self.wavelet_scale_base))
        elif self.representation == 'DCT':
            x_stack = util.image_dct(x_stack)
        # If `representation` == 'PIXEL', do nothing.

        # Reshape `x_stack` from
        #   (num_batches * num_channels, width, height) to
        #   (num_batches, num_channels * width * height)
        x_mat = torch.reshape(
            torch.reshape(x_stack, (-1, num_channels, width, height)).permute(
                0, 2, 3, 1), [-1, width * height * num_channels])
        return x_mat
コード例 #3
0
    def testAccurateRoundTripWithSmallRandomImages(self):
        """Tests that collapse(construct(x)) == x for x = [1, k, k], k in [1, 4]."""
        for wavelet_type in wavelet.generate_filters():
            for width in range(0, 5):
                sz = [1, width, width]
                num_levels = wavelet.get_max_num_levels(sz)
                im = np.random.uniform(size=sz)

                pyr = wavelet.construct(im, num_levels, wavelet_type)
                recon = wavelet.collapse(pyr, wavelet_type)
                np.testing.assert_allclose(recon, im, atol=1e-8, rtol=1e-8)
コード例 #4
0
 def _collapse_preserves_dtype(self, float_dtype):
     """Checks that collapse()'s output has the same precision as its input."""
     n = 16
     x = []
     for n in [8, 4, 2]:
         band = []
         for _ in range(3):
             band.append(float_dtype(np.random.normal(size=(3, n, n))))
         x.append(band)
     x.append(float_dtype(np.random.normal(size=(3, n, n))))
     for wavelet_type in wavelet.generate_filters():
         y = wavelet.collapse(x, wavelet_type)
         np.testing.assert_equal(y.detach().numpy().dtype, float_dtype)
コード例 #5
0
 def testAccurateRoundTripWithLargeRandomImages(self):
   """Tests that collapse(construct(x)) == x for large random x's."""
   for wavelet_type in wavelet.generate_filters():
     for _ in range(4):
       num_levels = np.int32(np.ceil(4 * np.random.uniform()))
       sz_clamp = 2**(num_levels - 1) + 1
       sz = np.maximum(
           np.int32(
               np.ceil(np.array([2, 32, 32]) * np.random.uniform(size=3))),
           np.array([0, sz_clamp, sz_clamp]))
       im = np.random.uniform(size=sz)
       pyr = wavelet.construct(im, num_levels, wavelet_type)
       recon = wavelet.collapse(pyr, wavelet_type)
       np.testing.assert_allclose(recon, im, atol=1e-8, rtol=1e-8)
コード例 #6
0
    def testDecompositionIsNonRedundant(self):
        """Test that wavelet construction is not redundant.

    If the wavelet decompositon is not redundant, then we should be able to
    1) Construct a wavelet decomposition
    2) Alter a single coefficient in the decomposition
    3) Collapse that decomposition into an image and back
    and the two wavelet decompositions should be the same.
    """
        for wavelet_type in wavelet.generate_filters():
            for _ in range(4):
                # Construct an image and a wavelet decomposition of it.
                num_levels = np.int32(np.ceil(4 * np.random.uniform()))
                sz_clamp = 2**(num_levels - 1) + 1
                sz = np.maximum(
                    np.int32(
                        np.ceil(
                            np.array([2, 32, 32]) *
                            np.random.uniform(size=3))),
                    np.array([0, sz_clamp, sz_clamp]),
                )
                im = np.random.uniform(size=sz)
                pyr = wavelet.construct(im, num_levels, wavelet_type)

            # Pick a coefficient at random in the decomposition to alter.
            d = np.int32(np.floor(np.random.uniform() * len(pyr)))
            v = np.random.uniform()
            if d == (len(pyr) - 1):
                if np.prod(pyr[d].shape) > 0:
                    c, i, j = np.int32(
                        np.floor(
                            np.array(np.random.uniform(size=3)) *
                            pyr[d].shape)).tolist()
                    pyr[d][c, i, j] = v
            else:
                b = np.int32(np.floor(np.random.uniform() * len(pyr[d])))
                if np.prod(pyr[d][b].shape) > 0:
                    c, i, j = np.int32(
                        np.floor(
                            np.array(np.random.uniform(size=3)) *
                            pyr[d][b].shape)).tolist()
                    pyr[d][b][c, i, j] = v

            # Collapse and then reconstruct the wavelet decomposition, and check
            # that it is unchanged.
            recon = wavelet.collapse(pyr, wavelet_type)
            pyr_again = wavelet.construct(recon, num_levels, wavelet_type)
            self._assert_pyramids_close(pyr, pyr_again, 1e-8)
コード例 #7
0
 def testWaveletTransformationIsVolumePreserving(self):
   """Tests that construct() is volume preserving when size is a power of 2."""
   for wavelet_type in wavelet.generate_filters():
     sz = (1, 4, 4)
     num_levels = 2
     # Construct the Jacobian of construct().
     im = np.float32(np.random.uniform(0., 1., sz))
     jacobian = []
     vec = lambda x: torch.reshape(x, [-1])
     for d in range(im.size):
       var_im = torch.autograd.Variable(torch.tensor(im), requires_grad=True)
       coeff = vec(
           wavelet.flatten(
               wavelet.construct(var_im, num_levels, wavelet_type)))[d]
       coeff.backward()
       jacobian.append(np.reshape(var_im.grad.detach().numpy(), [-1]))
     jacobian = np.stack(jacobian, 1)
     # Assert that the determinant of the Jacobian is close to 1.
     det = np.linalg.det(jacobian)
     np.testing.assert_allclose(det, 1., atol=1e-5, rtol=1e-5)
コード例 #8
0
    def __init__(self,
                 image_size,
                 float_dtype,
                 device,
                 color_space='YUV',
                 representation='CDF9/7',
                 wavelet_num_levels=5,
                 wavelet_scale_base=1,
                 use_students_t=False,
                 **kwargs):
        """Sets up the adaptive form of the robust loss on a set of images.

    This function is a wrapper around AdaptiveLossFunction. It requires inputs
    of a specific shape and size, and constructs internal parameters describing
    each non-batch dimension. By default, this function uses a CDF9/7 wavelet
    decomposition in a YUV color space, which often works well.

    Args:
      image_size: The size (width, height, num_channels) of the input images.
      float_dtype: The dtype of the floats used as input.
      device: The device to use.
      color_space: The color space that `x` will be transformed into before
        computing the loss. Must be 'RGB' (in which case no transformation is
        applied) or 'YUV' (in which case we actually use a volume-preserving
        scaled YUV colorspace so that log-likelihoods still have meaning, see
        util.rgb_to_syuv()). Note that changing this argument does not change
        the assumption that `x` is the set of differences between RGB images, it
        just changes what color space `x` is converted to from RGB when
        computing the loss.
      representation: The spatial image representation that `x` will be
        transformed into after converting the color space and before computing
        the loss. If this is a valid type of wavelet according to
        wavelet.generate_filters() then that is what will be used, but we also
        support setting this to 'DCT' which applies a 2D DCT to the images, and
        to 'PIXEL' which applies no transformation to the image, thereby causing
        the loss to be imposed directly on pixels.
      wavelet_num_levels: If `representation` is a kind of wavelet, this is the
        number of levels used when constructing wavelet representations.
        Otherwise this is ignored. Should probably be set to as large as
        possible a value that is supported by the input resolution, such as that
        produced by wavelet.get_max_num_levels().
      wavelet_scale_base: If `representation` is a kind of wavelet, this is the
        base of the scaling used when constructing wavelet representations.
        Otherwise this is ignored. For image_lossfun() to be volume preserving
        (a useful property when evaluating generative models) this value must be
        == 1. If the goal of this loss isn't proper statistical modeling, then
        modifying this value (say, setting it to 0.5 or 2) may significantly
        improve performance.
      use_students_t: If true, use the NLL of Student's T-distribution instead
        of the adaptive loss. This causes all `alpha_*` inputs to be ignored.
      **kwargs: Arguments to be passed to the underlying lossfun().

    Raises:
      ValueError: if `color_space` of `representation` are unsupported color
        spaces or image representations, respectively.
    """
        super(AdaptiveImageLossFunction, self).__init__()

        color_spaces = ['RGB', 'YUV']
        if color_space not in color_spaces:
            raise ValueError('`color_space` must be in {}, but is {!r}'.format(
                color_spaces, color_space))
        representations = wavelet.generate_filters() + ['DCT', 'PIXEL']
        if representation not in representations:
            raise ValueError(
                '`representation` must be in {}, but is {!r}'.format(
                    representations, representation))
        assert len(image_size) == 3

        self.color_space = color_space
        self.representation = representation
        self.wavelet_num_levels = wavelet_num_levels
        self.wavelet_scale_base = wavelet_scale_base
        self.use_students_t = use_students_t
        self.image_size = image_size

        if float_dtype == np.float32:
            float_dtype = torch.float32
        if float_dtype == np.float64:
            float_dtype = torch.float64
        self.float_dtype = float_dtype
        self.device = device
        if isinstance(device, int) or\
           (isinstance(device, str) and 'cuda' in device) or\
           (isinstance(device, torch.device) and device.type == 'cuda'):
            torch.cuda.set_device(self.device)

        x_example = torch.zeros([1] + list(self.image_size)).type(
            self.float_dtype)
        x_example_mat = self.transform_to_mat(x_example)
        self.num_dims = x_example_mat.shape[1]

        if self.use_students_t:
            self.adaptive_lossfun = StudentsTLossFunction(
                self.num_dims, self.float_dtype, self.device, **kwargs)
        else:
            self.adaptive_lossfun = AdaptiveLossFunction(
                self.num_dims, self.float_dtype, self.device, **kwargs)
コード例 #9
0
 def _construct_preserves_dtype(self, float_dtype):
     """Checks that construct()'s output has the same precision as its input."""
     x = float_dtype(np.random.normal(size=(3, 16, 16)))
     for wavelet_type in wavelet.generate_filters():
         y = wavelet.flatten(wavelet.construct(x, 3, wavelet_type))
         np.testing.assert_equal(y.detach().numpy().dtype, float_dtype)