Example #1
0
def test_per_axis_wavelets():
    # tests seperate wavelet for each axis.
    rstate = np.random.RandomState(1234)
    data = rstate.randn(16, 16, 16)
    level = 3

    # wavelet can be a string or wavelet object
    wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')

    coefs = pywt.swtn(data, wavelets, level=level)
    assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14)

    # 1-tuple also okay
    coefs = pywt.swtn(data, wavelets[:1], level=level)
    assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)

    # length of wavelets doesn't match the length of axes
    assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
    assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])

    with warnings.catch_warnings():
        warnings.simplefilter('ignore', FutureWarning)
        # swt2/iswt2 also support per-axis wavelets/modes
        data2 = data[..., 0]
        coefs2 = pywt.swt2(data2, wavelets[:2], level)
        assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14)
Example #2
0
def test_per_axis_wavelets():
    # tests seperate wavelet for each axis.
    rstate = np.random.RandomState(1234)
    data = rstate.randn(16, 16, 16)
    level = 3

    # wavelet can be a string or wavelet object
    wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')

    coefs = pywt.swtn(data, wavelets, level=level)
    assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14)

    # 1-tuple also okay
    coefs = pywt.swtn(data, wavelets[:1], level=level)
    assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)

    # length of wavelets doesn't match the length of axes
    assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
    assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])

    with warnings.catch_warnings():
        warnings.simplefilter('ignore', FutureWarning)
        # swt2/iswt2 also support per-axis wavelets/modes
        data2 = data[..., 0]
        coefs2 = pywt.swt2(data2, wavelets[:2], level)
        assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14)
Example #3
0
def _swt3(matrix, axes):  # Stationary Wavelet Transform 3D
    wavelet = 'coif1'
    level = 1
    start_level = 0

    # This function gets a numpy array from the SimpleITK Image "inputImage"
    matrix = numpy.asarray(
        matrix
    )  # The function np.asarray converts "matrix" (which could be also a tuple) into an array.
    if matrix.ndim != 3:
        raise ValueError('Expected 3D data array')

    original_shape = matrix.shape
    # original_shape becomes a tuple (?,?,?) containing the number of rows, columns, and slices of the image
    padding = tuple([(0, 1 if dim % 2 != 0 else 0) for dim in original_shape])
    # padding is necessary because of pywt.swtn (see function Notes)
    data = matrix.copy(
    )  # creates a modifiable copy of "matrix" and we call it "data"
    data = numpy.pad(data, padding,
                     'wrap')  # padding the tuple "padding" previously computed

    for i in range(0, start_level
                   ):  # if start_level = 0 this for loop never gets executed
        dec = pywt.swtn(data, wavelet, level=1, start_level=0, axes=axes)[
            0]  # computes all decompositions and saves them in "dec" dict
        data = dec['a' * len(axes)].copy(
        )  # copies in "data" just the "aaa" decomposition (if len(axes) = 3)

    ret = []  # initialize empty list
    for i in range(start_level, start_level + level):
        dec = pywt.swtn(data, wavelet, level=1, start_level=0, axes=axes)[
            0]  # computes the n-dimensional stationary wavelet transform
        data = dec['a' * len(axes)].copy()

        dec_im = {}  # initialize empty dict
        for decName, decImage in six.iteritems(dec):

            decTemp = decImage.copy()
            decTemp = decTemp[tuple(
                slice(None, -1 if dim % 2 != 0 else None)
                for dim in original_shape)]
            sitkImage = decTemp
            #print(str(decName).replace('a', 'L').replace('d', 'H') + "\n")
            dec_im[str(decName).replace('a', 'L').replace('d',
                                                          'H')] = sitkImage

    rst = np.zeros((original_shape[0], original_shape[1], 4))
    rst[:, :, 0] = dec_im['LL'][:, :, 0]
    rst[:, :, 1] = dec_im['LH'][:, :, 0]
    rst[:, :, 2] = dec_im['HL'][:, :, 0]
    rst[:, :, 3] = dec_im['HH'][:, :, 0]

    return rst  # returns the approximation and the detail (ret) coefficients of the stationary wavelet decomposition
Example #4
0
def _swt3(inputImage, axes, **kwargs):  # Stationary Wavelet Transform 3D
  wavelet = kwargs.get('wavelet', 'coif1')
  level = kwargs.get('level', 1)
  start_level = kwargs.get('start_level', 0)

  matrix = sitk.GetArrayFromImage(inputImage)  # This function gets a numpy array from the SimpleITK Image "inputImage"
  matrix = numpy.asarray(matrix) # The function np.asarray converts "matrix" (which could be also a tuple) into an array.
  if matrix.ndim != 3:
    raise ValueError('Expected 3D data array')

  original_shape = matrix.shape
  # original_shape becomes a tuple (?,?,?) containing the number of rows, columns, and slices of the image
  padding = tuple([(0, 1 if dim % 2 != 0 else 0) for dim in original_shape])
  # padding is necessary because of pywt.swtn (see function Notes)
  data = matrix.copy()  # creates a modifiable copy of "matrix" and we call it "data"
  data = numpy.pad(data, padding, 'wrap')  # padding the tuple "padding" previously computed

  if not isinstance(wavelet, pywt.Wavelet):
    wavelet = pywt.Wavelet(wavelet)

  for i in range(0, start_level):  # if start_level = 0 this for loop never gets executed
    dec = pywt.swtn(data, wavelet, level=1, start_level=0, axes=axes)[0] # computes all decompositions and saves them in "dec" dict
    data = dec['a' * len(axes)].copy()  # copies in "data" just the "aaa" decomposition (if len(axes) = 3)

  ret = []  # initialize empty list
  for i in range(start_level, start_level + level):
    dec = pywt.swtn(data, wavelet, level=1, start_level=0, axes=axes)[0]  # computes the n-dimensional stationary wavelet transform
    data = dec['a' * len(axes)].copy()

    dec_im = {}  # initialize empty dict
    for decName, decImage in six.iteritems(dec):
      # Returning the approximiation is done only for the last loop,
      # and is handled separately below (by building it from `data`)
      # There for, skip it here
      if decName == 'a' * len(axes):
        continue
      decTemp = decImage.copy()
      decTemp = decTemp[tuple(slice(None, -1 if dim % 2 != 0 else None) for dim in original_shape)]
      sitkImage = sitk.GetImageFromArray(decTemp)
      sitkImage.CopyInformation(inputImage)
      dec_im[str(decName).replace('a', 'L').replace('d', 'H')] = sitkImage
      # modifies 'a' with 'L' (Low-pass filter) and 'd' with 'H' (High-pass filter)

    ret.append(dec_im)  # appending all the filtered sitk images (stored in "dec_im") to the "ret" list

  data = data[tuple(slice(None, -1 if dim % 2 != 0 else None) for dim in original_shape)]
  approximation = sitk.GetImageFromArray(data)
  approximation.CopyInformation(inputImage)

  return approximation, ret  # returns the approximation and the detail (ret) coefficients of the stationary wavelet decomposition
Example #5
0
def test_swtn_axes():
    atol = 1e-14
    current_wavelet = pywt.Wavelet('db2')
    input_length_power = int(
        np.ceil(np.log2(max(current_wavelet.dec_len,
                            current_wavelet.rec_len))))
    input_length = 2**(input_length_power)
    X = np.arange(input_length**2).reshape(input_length, input_length)
    coeffs = pywt.swtn(X, current_wavelet, level=1, axes=None)[0]
    # opposite order
    coeffs2 = pywt.swtn(X, current_wavelet, level=1, axes=(1, 0))[0]
    assert_allclose(coeffs['aa'], coeffs2['aa'], atol=atol)
    assert_allclose(coeffs['ad'], coeffs2['da'], atol=atol)
    assert_allclose(coeffs['da'], coeffs2['ad'], atol=atol)
    assert_allclose(coeffs['dd'], coeffs2['dd'], atol=atol)

    # 0-level transform
    empty = pywt.swtn(X, current_wavelet, level=0)
    assert_equal(empty, [])

    # duplicate axes not allowed
    assert_raises(ValueError, pywt.swtn, X, current_wavelet, 1, axes=(0, 0))

    # data.ndim = 0
    assert_raises(ValueError, pywt.swtn, np.asarray([]), current_wavelet, 1)

    # start_level too large
    assert_raises(ValueError,
                  pywt.swtn,
                  X,
                  current_wavelet,
                  level=1,
                  start_level=2)

    # level < 1 in swt_axis call
    assert_raises(ValueError,
                  swt_axis,
                  X,
                  current_wavelet,
                  level=0,
                  start_level=0)
    # odd-sized data not allowed
    assert_raises(ValueError,
                  swt_axis,
                  X[:-1, :],
                  current_wavelet,
                  level=0,
                  start_level=0,
                  axis=0)
Example #6
0
def test_iswtn_mixed_dtypes():
    # Mixed precision inputs give double precision output
    rstate = np.random.RandomState(0)
    x_real = rstate.randn(8, 8, 8)
    x_complex = x_real + 1j*x_real
    wav = 'sym2'
    for dtype1, dtype2 in [(np.float64, np.float32),
                           (np.float32, np.float64),
                           (np.float16, np.float64),
                           (np.complex128, np.complex64),
                           (np.complex64, np.complex128)]:

        if dtype1 in [np.complex64, np.complex128]:
            x = x_complex
            output_dtype = np.complex128
        else:
            x = x_real
            output_dtype = np.float64

        coeffs = pywt.swtn(x, wav, 2)
        # different precision for the approximation coefficients
        a = coeffs[0].pop('a' * x.ndim)
        a = a.astype(dtype1)
        coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()}
        coeffs[0]['a' * x.ndim] = a
        y = pywt.iswtn(coeffs, wav)
        assert_equal(output_dtype, y.dtype)
        assert_allclose(y, x, rtol=1e-3, atol=1e-3)
Example #7
0
 def setup(self, D, n, wavelet, dtype):
     try:
         from pywt import iswtn
     except ImportError:
         raise NotImplementedError("iswtn not available")
     super(IswtnTimeSuite, self).setup(D, n, wavelet, dtype)
     self.data = pywt.swtn(self.data, wavelet, self.level)
Example #8
0
def wavelet_features(img, mask):
    # pywt.swt(img,wavelet,level=1)
    s = img.shape

    img2 = np.pad(img,
                  pad_width=[(0, s[0] % 2), (0, s[1] % 2), (0, s[2] % 2)],
                  mode="constant",
                  constant_values=0)
    mask2 = np.pad(mask,
                   pad_width=[(0, s[0] % 2), (0, s[1] % 2), (0, s[2] % 2)],
                   mode="constant",
                   constant_values=0)

    wavelet_coefs = pywt.swtn(img2, 'coif1', level=1)[0]

    features = {}

    for key in wavelet_coefs:
        features1 = group1_features(wavelet_coefs[key])
        features31 = gray_level_cooccurrence_features(wavelet_coefs[key],
                                                      mask2)
        features32 = gray_level_runlength_features(wavelet_coefs[key], mask2)

        for name in features1:
            features[key + "_" + name] = features1[name]
        for name in features31:
            features[key + "_" + name] = features31[name]
        for name in features32:
            features[key + "_" + name] = features32[name]

    return features
Example #9
0
def test_iswtn_mixed_dtypes():
    # Mixed precision inputs give double precision output
    rstate = np.random.RandomState(0)
    x_real = rstate.randn(8, 8, 8)
    x_complex = x_real + 1j*x_real
    wav = 'sym2'
    for dtype1, dtype2 in [(np.float64, np.float32),
                           (np.float32, np.float64),
                           (np.float16, np.float64),
                           (np.complex128, np.complex64),
                           (np.complex64, np.complex128)]:

        if dtype1 in [np.complex64, np.complex128]:
            x = x_complex
            output_dtype = np.complex128
        else:
            x = x_real
            output_dtype = np.float64

        coeffs = pywt.swtn(x, wav, 2)
        # different precision for the approximation coefficients
        a = coeffs[0].pop('a' * x.ndim)
        a = a.astype(dtype1)
        coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()}
        coeffs[0]['a' * x.ndim] = a
        y = pywt.iswtn(coeffs, wav)
        assert_equal(output_dtype, y.dtype)
        assert_allclose(y, x, rtol=1e-3, atol=1e-3)
Example #10
0
    def stationary_wavelet_transform(self, waveform, wavelet, level):

        # Calculate waveform length
        waveform_length = len(waveform)

        # Calculate minimum waveform length for SWT of certain decomposition level
        waveform_length_updated = self.calculate_decomposition_level(
            waveform_length, level)

        # Add necessary padding to waveform
        waveform_padded, pad_before, pad_after = self.add_padding(
            waveform, waveform_length_updated)

        # Compute stationary wavelet transform
        swt = pywt.swtn(waveform_padded,
                        wavelet=wavelet,
                        level=level,
                        start_level=0)

        # Loop through decomposition levels and remove padding
        for lev in range(len(swt)):

            # Approximation
            swt[lev]['a'] = swt[lev]['a'][pad_before:len(waveform_padded) -
                                          pad_after]

            # Detail
            swt[lev]['d'] = swt[lev]['d'][pad_before:len(waveform_padded) -
                                          pad_after]

        return swt
Example #11
0
    def _analysis(self, data, **kwargs):
        """ Decompose a real signal using pywt.

        Parameters
        ----------
        data: nd-array
            a real array to be decomposed.

        Returns
        -------
        analysis_data: nd_array
            the decomposition coefficients.
        analysis_header: dict
            the decomposition associated information.
        """
        if self.is_decimated:
            coeffs = pywt.wavedecn(data,
                                   self.trf,
                                   mode=self.padding_mode,
                                   level=self.nb_scale,
                                   axes=self.axes)
        else:
            coeffs = pywt.swtn(data,
                               self.trf,
                               level=self.nb_scale,
                               axes=self.axes)
        analysis_data, analysis_header = self._organize_pysap(coeffs)
        self.nb_band_per_scale = [
            len(scale_info) for scale_info in analysis_header
        ]

        return analysis_data, analysis_header
Example #12
0
    def op(self, data):
        """
        Define the wavelet operator.
        This method returns the input data convolved with the wavelet filter.

        Parameters
        ----------
        data: np.ndarray(m', n') or np.ndarray(m', n', p')
            input 2D or 3D data array.

        Returns
        -------
        coeffs: np.ndarray
            the wavelet coefficients.
        """
        if self.undecimated:
            coeffs_dict = pywt.swtn(data,
                                    self.pywt_transform,
                                    level=self.nb_scale)
            coeffs, self.coeffs_shape = self.flatten(coeffs_dict)
            return coeffs
        else:
            coeffs_dict = pywt.wavedecn(data,
                                        self.pywt_transform,
                                        level=self.nb_scale,
                                        mode=self.mode)
            self.coeffs, self.coeffs_shape = self.flatten(coeffs_dict)
            return self.coeffs
def main(out_file, bins, signal_events, background_events,
         max_transient_events, time_steps, cmap):
    '''
    Use a toy model to create a transient appearing in the FoV of another source.
    A steady background is subtracted and denoised using wavelets.
    This script then creates an animated gif of the whoe shebang saved under the
    OUT_FILE argument.
    '''

    bins = [bins, bins]
    cube_steady = simulate_steady_source(
        num_slices=time_steps,
        source_count=signal_events,
        background_count=background_events,
        bins=bins,
    )

    def time_dependency():
        return transient_gaussian(time_steps=time_steps,
                                  max_events=max_transient_events)

    cube_with_transient = simulate_steady_source_with_transient(
        time_dependency,
        source_count=signal_events,
        background_count=background_events,
        bins=bins)

    # remove mean measured noise from current cube
    cube = cube_with_transient - cube_steady.mean(axis=0)
    coeffs = pywt.swtn(
        data=cube,
        wavelet='bior1.3',
        level=2,
    )

    # remove noisy coefficents.
    ct = thresholding_3d(coeffs, k=30)
    cube_smoothed = pywt.iswtn(coeffs=ct, wavelet='bior1.3')

    # some Criterion which could be used to trigger this.
    trans_factor = cube_smoothed.max(axis=1).max(axis=1)

    p = TransientPlotter(
        cube_with_transient,
        cube_smoothed,
        trans_factor,
        cmap=cmap,
    )

    print('Plotting animation. (Be patient)')
    anim = animation.FuncAnimation(
        p.fig,
        p.step,
        frames=time_steps,
        interval=15,
        blit=True,
    )

    anim.save(out_file, writer='imagemagick', fps=25)
Example #14
0
def _swt3(inputImage, wavelet='coif1', level=1, start_level=0, axes=(2, 1, 0)):
    matrix = sitk.GetArrayFromImage(inputImage)
    matrix = numpy.asarray(matrix)
    if matrix.ndim != 3:
        raise ValueError('Expected 3D data array')

    original_shape = matrix.shape
    padding = tuple([(0, 1 if dim % 2 != 0 else 0) for dim in original_shape])
    data = matrix.copy()
    data = numpy.pad(data, padding, 'wrap')

    if not isinstance(wavelet, pywt.Wavelet):
        wavelet = pywt.Wavelet(wavelet)

    for i in range(0, start_level):
        dec = pywt.swtn(data, wavelet, level=1, start_level=0, axes=axes)[0]
        data = dec['a' * len(axes)].copy()

    ret = []
    for i in range(start_level, start_level + level):
        dec = pywt.swtn(data, wavelet, level=1, start_level=0, axes=axes)[0]
        data = dec['a' * len(axes)].copy()

        dec_im = {}
        for decName, decImage in six.iteritems(dec):
            decTemp = decImage.copy()
            decTemp = decTemp[[
                slice(None, -1 if dim % 2 != 0 else None)
                for dim in original_shape
            ]]
            sitkImage = sitk.GetImageFromArray(decTemp)
            sitkImage.CopyInformation(inputImage)
            dec_im[str(decName).replace('a', 'L').replace('d',
                                                          'H')] = sitkImage

        ret.append(dec_im)

    data = data[[
        slice(None, -1 if dim % 2 != 0 else None) for dim in original_shape
    ]]
    approximation = sitk.GetImageFromArray(data)
    approximation.CopyInformation(inputImage)

    return approximation, ret
Example #15
0
def test_swtn_iswtn_unique_shape_per_axis():
    # test case for gh-460
    _shape = (1, 48, 32)  # unique shape per axis
    wav = 'sym2'
    max_level = 3
    rstate = np.random.RandomState(0)
    for shape in permutations(_shape):
        # transform only along the non-singleton axes
        axes = [ax for ax, s in enumerate(shape) if s != 1]
        x = rstate.standard_normal(shape)
        c = pywt.swtn(x, wav, max_level, axes=axes)
        r = pywt.iswtn(c, wav, axes=axes)
        assert_allclose(x, r, rtol=1e-10, atol=1e-10)
Example #16
0
 def initialize_wl_operators(self):
     if self.use_decimated:
         H = lambda x: pywt.wavedecn(x, wavelet=self.wl_type, axes=self.axes, level=self.decomp_lvl)
         Ht = lambda x: pywt.waverecn(x, wavelet=self.wl_type, axes=self.axes)
     else:
         if use_swtn:
             H = lambda x: pywt.swtn(x, wavelet=self.wl_type, axes=self.axes, level=self.decomp_lvl)
             Ht = lambda x: pywt.iswtn(x, wavelet=self.wl_type, axes=self.axes)
         else:
             H = lambda x: pywt.swt2(np.squeeze(x), wavelet=self.wl_type, axes=self.axes, level=self.decomp_lvl)
             #                Ht = lambda x : pywt.iswt2(x, wavelet=self.wl_type)
             Ht = lambda x: pywt.iswt2(x, wavelet=self.wl_type)[np.newaxis, ...]
     return (H, Ht)
Example #17
0
def test_swtn_iswtn_unique_shape_per_axis():
    # test case for gh-460
    _shape = (1, 48, 32)  # unique shape per axis
    wav = 'sym2'
    max_level = 3
    rstate = np.random.RandomState(0)
    for shape in permutations(_shape):
        # transform only along the non-singleton axes
        axes = [ax for ax, s in enumerate(shape) if s != 1]
        x = rstate.standard_normal(shape)
        c = pywt.swtn(x, wav, max_level, axes=axes)
        r = pywt.iswtn(c, wav, axes=axes)
        assert_allclose(x, r, rtol=1e-10, atol=1e-10)
Example #18
0
def test_iswtn_errors():
    x = np.arange(8**3).reshape(8, 8, 8)
    max_level = 2
    axes = (0, 1)
    w = pywt.Wavelet('db1')
    coeffs = pywt.swtn(x, w, max_level, axes=axes)

    # more axes than dimensions transformed
    assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 1, 2))
    # duplicate axes not allowed
    assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 0))
    # mismatched coefficient size
    coeffs[0]['da'] = coeffs[0]['da'][:-1, :]
    assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)
Example #19
0
def test_iswtn_errors():
    x = np.arange(8**3).reshape(8, 8, 8)
    max_level = 2
    axes = (0, 1)
    w = pywt.Wavelet('db1')
    coeffs = pywt.swtn(x, w, max_level, axes=axes)

    # more axes than dimensions transformed
    assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 1, 2))
    # duplicate axes not allowed
    assert_raises(ValueError, pywt.iswtn, coeffs, w, axes=(0, 0))
    # mismatched coefficient size
    coeffs[0]['da'] = coeffs[0]['da'][:-1, :]
    assert_raises(RuntimeError, pywt.iswtn, coeffs, w, axes=axes)
Example #20
0
def fuse_stationary_wavelets(first,
                             second,
                             *,
                             levels=None,
                             pca=False,
                             wavelet=None):
    if levels is None:
        levels = 6
    if wavelet is None:
        wavelet = 'sym4'

    pad, unpad = swt_pad_funcs(first.shape, levels)
    first = pad(first)
    second = pad(second)

    first = pywt.swtn(first,
                      wavelet,
                      level=levels,
                      axes=(0, 1),
                      norm=True,
                      trim_approx=True)
    second = pywt.swtn(second,
                       wavelet,
                       level=levels,
                       axes=(0, 1),
                       norm=True,
                       trim_approx=True)

    first[0] = (first[0] + second[0]) / 2
    for first_cs, second_cs in zip(first[1:], second[1:]):
        mask = _coeff_strength(second_cs.values(), pca) > _coeff_strength(
            first_cs.values(), pca)
        for k, v in first_cs.items():
            v[mask, ...] = second_cs[k][mask, ...]
    del second
    first = pywt.iswtn(first, wavelet, axes=(0, 1), norm=True)
    return unpad(first)
Example #21
0
def test_swtn_axes():
    atol = 1e-14
    current_wavelet = pywt.Wavelet('db2')
    input_length_power = int(np.ceil(np.log2(max(
        current_wavelet.dec_len,
        current_wavelet.rec_len))))
    input_length = 2**(input_length_power)
    X = np.arange(input_length**2).reshape(input_length, input_length)
    coeffs = pywt.swtn(X, current_wavelet, level=1, axes=None)[0]
    # opposite order
    coeffs2 = pywt.swtn(X, current_wavelet, level=1, axes=(1, 0))[0]
    assert_allclose(coeffs['aa'], coeffs2['aa'], atol=atol)
    assert_allclose(coeffs['ad'], coeffs2['da'], atol=atol)
    assert_allclose(coeffs['da'], coeffs2['ad'], atol=atol)
    assert_allclose(coeffs['dd'], coeffs2['dd'], atol=atol)

    # 0-level transform
    empty = pywt.swtn(X, current_wavelet, level=0)
    assert_equal(empty, [])

    # duplicate axes not allowed
    assert_raises(ValueError, pywt.swtn, X, current_wavelet, 1, axes=(0, 0))

    # data.ndim = 0
    assert_raises(ValueError, pywt.swtn, np.asarray([]), current_wavelet, 1)

    # start_level too large
    assert_raises(ValueError, pywt.swtn, X, current_wavelet,
                  level=1, start_level=2)

    # level < 1 in swt_axis call
    assert_raises(ValueError, swt_axis, X, current_wavelet, level=0,
                  start_level=0)
    # odd-sized data not allowed
    assert_raises(ValueError, swt_axis, X[:-1, :], current_wavelet, level=0,
                  start_level=0, axis=0)
Example #22
0
def test_swtn_iswtn_integration(wavelets=None):
    # This function performs a round-trip swtn/iswtn transform for various
    # possible combinations of:
    #   1.) 1 out of 2 axes of a 2D array
    #   2.) 2 out of 3 axes of a 3D array
    #
    # To keep test time down, only wavelets of length <= 8 are run.
    #
    # This test does not validate swtn or iswtn individually, but only
    # confirms that iswtn yields an (almost) perfect reconstruction of swtn.
    max_level = 3
    if wavelets is None:
        wavelets = pywt.wavelist(kind='discrete')
        if 'dmey' in wavelets:
            # The 'dmey' wavelet is a special case - disregard it for now
            wavelets.remove('dmey')
    for ndim_transform in range(1, 3):
        ndim = ndim_transform + 1
        for axes in combinations(range(ndim), ndim_transform):
            for current_wavelet_str in wavelets:
                wav = pywt.Wavelet(current_wavelet_str)
                if wav.dec_len > 8:
                    continue  # avoid excessive test duration
                input_length_power = int(
                    np.ceil(np.log2(max(wav.dec_len, wav.rec_len))))
                N = 2**(input_length_power + max_level - 1)
                X = np.arange(N**ndim).reshape((N, ) * ndim)

                for norm in [True, False]:
                    if norm and not wav.orthogonal:
                        # non-orthogonal wavelets to avoid warnings
                        continue
                    for trim_approx in [True, False]:
                        coeffs = pywt.swtn(X,
                                           wav,
                                           max_level,
                                           axes=axes,
                                           trim_approx=trim_approx,
                                           norm=norm)
                        coeffs_copy = deepcopy(coeffs)
                        Y = pywt.iswtn(coeffs, wav, axes=axes, norm=norm)
                        assert_allclose(Y, X, rtol=1e-5, atol=1e-5)

                # verify the inverse transform didn't modify any coeffs
                for c, c2 in zip(coeffs, coeffs_copy):
                    for k, v in c.items():
                        assert_array_equal(c2[k], v)
Example #23
0
def test_swtn_variance_and_energy_preservation():
    """Verify that the nD SWT partitions variance among the coefficients."""
    # When norm is True and the wavelet is orthogonal, the sum of the
    # variances of the coefficients should equal the variance of the signal.
    wav = 'db2'
    rstate = np.random.RandomState(5)
    x = rstate.randn(64, 64)
    coeffs = pywt.swtn(x, wav, level=4, trim_approx=True, norm=True)
    coeff_list = [coeffs[0].ravel()]
    for d in coeffs[1:]:
        for k, v in d.items():
            coeff_list.append(v.ravel())
    variances = [np.var(v) for v in coeff_list]
    assert_allclose(np.sum(variances), np.var(x))

    # also verify L2-norm energy preservation property
    assert_allclose(np.linalg.norm(x),
                    np.linalg.norm(np.concatenate(coeff_list)))

    # non-orthogonal wavelet with norm=True raises a warning
    assert_warns(UserWarning, pywt.swtn, x, 'bior2.2', level=4, norm=True)
Example #24
0
    def direct_swt(self, x):
        """Perform the direct wavelet transform.

        :param x: Data to transform.
        :type x: `numpy.array_like`

        :return: Transformed data.
        :rtype: list
        """
        if self.pad_on_demand is not None and np.any(self.pad_axes):
            for ax in np.nonzero(self.pad_axes)[0]:
                pad_l = np.ceil(self.pad_axes[ax] / 2).astype(np.intp)
                pad_h = np.floor(self.pad_axes[ax] / 2).astype(np.intp)
                pad_width = [(0, 0)] * len(x.shape)
                pad_width[self.axes[ax]] = (pad_l, pad_h)
                x = np.pad(x, pad_width, mode=self.pad_on_demand)
        return pywt.swtn(x,
                         wavelet=self.wavelet,
                         axes=self.axes,
                         norm=self.normalized,
                         level=self.level,
                         trim_approx=True)
Example #25
0
def test_swtn_iswtn_integration(wavelets=None):
    # This function performs a round-trip swtn/iswtn transform for various
    # possible combinations of:
    #   1.) 1 out of 2 axes of a 2D array
    #   2.) 2 out of 3 axes of a 3D array
    #
    # To keep test time down, only wavelets of length <= 8 are run.
    #
    # This test does not validate swtn or iswtn individually, but only
    # confirms that iswtn yields an (almost) perfect reconstruction of swtn.
    max_level = 3
    if wavelets is None:
        wavelets = pywt.wavelist(kind='discrete')
        if 'dmey' in wavelets:
            # The 'dmey' wavelet is a special case - disregard it for now
            wavelets.remove('dmey')
    for ndim_transform in range(1, 3):
        ndim = ndim_transform + 1
        for axes in combinations(range(ndim), ndim_transform):
            for current_wavelet_str in wavelets:
                wav = pywt.Wavelet(current_wavelet_str)
                if wav.dec_len > 8:
                    continue  # avoid excessive test duration
                input_length_power = int(np.ceil(np.log2(max(
                    wav.dec_len,
                    wav.rec_len))))
                N = 2**(input_length_power + max_level - 1)
                X = np.arange(N**ndim).reshape((N, )*ndim)

                coeffs = pywt.swtn(X, wav, max_level, axes=axes)
                coeffs_copy = deepcopy(coeffs)
                Y = pywt.iswtn(coeffs, wav, axes=axes)
                assert_allclose(Y, X, rtol=1e-5, atol=1e-5)

                # verify the inverse transform didn't modify any coeffs
                for c, c2 in zip(coeffs, coeffs_copy):
                    for k, v in c.items():
                        assert_array_equal(c2[k], v)
    def denoise_and_compare_cubes(self, steady_cube, cube_with_transient):
        cube = cube_with_transient - steady_cube.mean(axis=0)
        coeffs = pywt.swtn(
            data=cube,
            wavelet='bior1.3',
            level=2,
        )

        # remove noisy coefficents.
        ct = thresholding_3d(coeffs, k=30)
        cube_smoothed = pywt.iswtn(coeffs=ct, wavelet='bior1.3')

        # some Criterion which could be used to trigger this.
        trans_factor = cube_smoothed.max(axis=1).max(axis=1)

        # return trans_factor

        p = TransientPlotter(
            cube_with_transient,
            cube_smoothed,
            trans_factor,
            cmap='viridis',
        )

        print('Plotting animation. (Be patient)')
        anim = animation.FuncAnimation(
            p.fig,
            p.step,
            frames=len(cube),
            interval=15,
            blit=True,
        )

        anim.save('build/anim_{}.gif'.format(self.window.popleft()[0]),
                  writer='imagemagick',
                  fps=25)

        return trans_factor
imgPath = 'C:\\Users\\rites\\Desktop\\ProjectSubmit\\Input\\001_F.png'

print('Path Imported')

I = Image.open(imgPath)

print('Image Imported')

b = I.convert('YCbCr')

print('Image Converted to YCbCr')

y, cb, cr = b.split()

[one, two, three] = pywt.swtn(y, 'db1', 3)

print('SWT applied Successfully')

[LL, LH, HL, HH] = three['aa'], three['ad'], three['da'], three['dd']

row, column = LL.shape

print('Got LL(approximation) Band')

C = []

#Block = np.zeros(8*8).reshape(8,8)
#for i in range (1,(row-7)*(column-7)+1):
#    C.append(Block)
Example #28
0
 def time_swtn(self, D, n, wavelet, dtype):
     pywt.swtn(self.data, wavelet, self.level)
Example #29
0
def fuse_focal_stack_kmax(images,
                          *,
                          k=None,
                          levels=None,
                          wavelet=None,
                          pca=None,
                          in_memory=None,
                          sharpness_sigma=None):
    if k is None:
        k = 3
    if levels is None:
        levels = 3
    if wavelet is None:
        wavelet = 'sym4'
    if pca is None:
        pca = True
    if in_memory is None:
        in_memory = False
    if sharpness_sigma is None:
        sharpness_sigma = 3

    k = min(k, len(images))

    pad, unpad = swt_pad_funcs(images[0].shape, levels)
    sharpnesses = temporary_array_list(
        (sharpness(pad(image), sharpness_sigma, pca=pca) for image in images),
        in_memory=in_memory)
    kmax = kmax_sharpnesses(sharpnesses, k)

    bases = temporary_array_list()
    ads = [temporary_array_list() for _ in range(levels)]
    das = [temporary_array_list() for _ in range(levels)]
    dds = [temporary_array_list() for _ in range(levels)]

    for image in images:
        image = pad(image)
        shape = image.shape
        coeffs = pywt.swtn(image,
                           wavelet,
                           level=levels,
                           axes=(0, 1),
                           norm=True,
                           trim_approx=True)
        bases.append(coeffs[0])
        for level in range(levels):
            cl = coeffs[level + 1]
            ads[level].append(cl['ad'])
            das[level].append(cl['da'])
            dds[level].append(cl['dd'])

    base = np.empty_like(bases[0])
    ad = [np.empty_like(ads[level][0]) for level in range(levels)]
    da = [np.empty_like(das[level][0]) for level in range(levels)]
    dd = [np.empty_like(dds[level][0]) for level in range(levels)]
    first = np.full(shape[:2], True)

    for i, s in enumerate(sharpnesses):
        mask = np.any(s[:, :, np.newaxis] == kmax, axis=2)
        mask_and_first = mask & first
        base[mask_and_first, ...] = bases[i][mask_and_first, ...]
        for level in range(levels):
            ad[level][mask_and_first, ...] = ads[level][i][mask_and_first, ...]
            da[level][mask_and_first, ...] = das[level][i][mask_and_first, ...]
            dd[level][mask_and_first, ...] = dds[level][i][mask_and_first, ...]
        del mask_and_first
        mask_and_not_first = mask & ~first
        base[mask_and_not_first, ...] += bases[i][mask_and_not_first, ...]
        for level in range(levels):
            cmask = (
                (reduce_color_dimension(ads[level][i]**2) +
                 reduce_color_dimension(das[level][i]**2) +
                 reduce_color_dimension(dds[level][i]**2)) >
                (reduce_color_dimension(ad[level]**2) +
                 reduce_color_dimension(da[level]**2) +
                 reduce_color_dimension(dd[level]**2))) & mask_and_not_first
            ad[level][cmask, ...] = ads[level][i][cmask, ...]
            da[level][cmask, ...] = das[level][i][cmask, ...]
            dd[level][cmask, ...] = dds[level][i][cmask, ...]
        first[mask] = False

    base /= k
    coeffs = [base] + [
        dict(ad=ad[level], da=da[level], dd=dd[level])
        for level in range(levels)
    ]
    return unpad(pywt.iswtn(coeffs, wavelet, axes=(0, 1), norm=True))
def main(gamma_file, proton_file, output_file):
    bins = [80, 80]
    bin_range = [[62.5, 78.5], [-12.4, 12.4]]

    df_gammas = loadFile(gamma_file)
    df_protons = loadFile(proton_file)

    print('Read {} gammas and {} protons'.format(len(df_gammas),
                                                 len(df_protons)))
    factor = (10E5 * len(df_gammas)) / len(df_protons)

    print(factor)

    df_background = df_protons[df_protons['prediction:signal:mean'] > 0.87]
    df_signal = df_gammas[df_gammas['prediction:signal:mean'] > 0.87]

    print('Read {} signal events and {} background events'.format(
        len(df_signal), len(df_background)))
    ratio = len(df_background) / len(df_signal)
    expected_background = int(ratio * len(df_background) * factor)
    print('Upsampling background to get {} events'.format(expected_background))
    df_background = df_protons.sample(expected_background, replace=True)

    cube_background = create_cube(df_background.sample(frac=0.5),
                                  bins=bins,
                                  bin_range=bin_range)
    cube_gammas = create_cube(df_gammas.sample(frac=0.5),
                              bins=bins,
                              bin_range=bin_range)

    cube_steady = cube_background + cube_gammas

    cube_background = create_cube(df_background.sample(frac=0.5),
                                  bins=bins,
                                  bin_range=bin_range)
    cube_gammas = create_cube(df_gammas.sample(frac=0.5),
                              bins=bins,
                              bin_range=bin_range)
    cube_bright_gammas = create_cube(df_gammas, bins=bins, bin_range=bin_range)

    cube_with_transient = np.vstack(
        (cube_background + cube_gammas, cube_background + cube_bright_gammas))

    # remove mean measured noise from current cube
    cube = cube_with_transient - cube_steady.mean(axis=0)
    coeffs = pywt.swtn(
        data=cube,
        wavelet='bior1.3',
        level=2,
    )

    # remove noisy coefficents.
    ct = thresholding_3d(coeffs, k=30)
    cube_smoothed = pywt.iswtn(coeffs=ct, wavelet='bior1.3')

    # some Criterion which could be used to trigger this.
    trans_factor = cube_smoothed.max(axis=1).max(axis=1)

    p = TransientPlotter(
        cube_with_transient,
        cube_smoothed,
        trans_factor,
        cmap='viridis',
    )

    print('Plotting animation. (Be patient)')
    anim = animation.FuncAnimation(
        p.fig,
        p.step,
        frames=len(cube),
        interval=15,
        blit=True,
    )

    anim.save('anim.gif', writer='imagemagick', fps=25)
Example #31
0
def compute_depth_map(
    depth_cues,
    iterations=500,
    lambda_tv=2.0,
    lambda_d2=0.05,
    lambda_wl=None,
    use_defocus=1.0,
    use_correspondence=1.0,
    use_xcorrelation=0.0,
):
    """Computes a depth map from the given depth cues.

    This depth map is based on the procedure from:

    M. W. Tao, et al., "Depth from combining defocus and correspondence using
    light-field cameras," in Proceedings of the IEEE International Conference on
    Computer Vision, 2013, pp. 673–680.

    :param depth_cues: The depth cues
    :type depth_cues: dict
    :param iterations: Number of iterations, defaults to 500
    :type iterations: int, optional
    :param lambda_tv: Lambda value of the TV term, defaults to 2.0
    :type lambda_tv: float, optional
    :param lambda_d2: Lambda value of the smoothing term, defaults to 0.05
    :type lambda_d2: float, optional
    :param lambda_wl: Lambda value of the wavelet term, defaults to None
    :type lambda_wl: float, optional
    :param use_defocus: Weight of defocus cues, defaults to 1.0
    :type use_defocus: float, optional
    :param use_correspondence: Weight of corresponence cues, defaults to 1.0
    :type use_correspondence: float, optional
    :param use_xcorrelation: Weight of the cross-correlation cues, defaults to 0.0
    :type use_xcorrelation: float, optional

    :raises ValueError: In case of requested wavelet regularization but not available

    :returns: The depth map
    :rtype: `numpy.array_like`
    """
    if not (lambda_wl is None or (has_pywt and use_swtn)):
        raise ValueError("Wavelet regularization requested but not available")

    use_defocus = np.fmax(use_defocus, 0.0)
    use_defocus = np.fmin(use_defocus, 1.0)
    use_correspondence = np.fmax(use_correspondence, 0.0)
    use_correspondence = np.fmin(use_correspondence, 1.0)
    use_xcorrelation = np.fmax(use_xcorrelation, 0.0)
    use_xcorrelation = np.fmin(use_xcorrelation, 1.0)

    W_d = depth_cues["confidence_defocus"]
    a_d = depth_cues["depth_defocus"]

    W_c = depth_cues["confidence_correspondence"]
    a_c = depth_cues["depth_correspondence"]

    W_x = depth_cues["confidence_xcorrelation"]
    a_x = depth_cues["depth_xcorrelation"]

    if use_defocus > 0 and (W_d.size == 0 or a_d.size == 0):
        use_defocus = 0
        warnings.warn("Defocusing parameters were not passed, disabling their use")

    if use_correspondence > 0 and (W_c.size == 0 or a_c.size == 0):
        use_correspondence = 0
        warnings.warn("Correspondence parameters were not passed, disabling their use")

    if use_xcorrelation > 0 and (W_x.size == 0 or a_x.size == 0):
        use_xcorrelation = 0
        warnings.warn("Cross-correlation parameters were not passed, disabling their use")

    if use_defocus:
        img_size = a_d.shape
        data_type = a_d.dtype
    elif use_correspondence:
        img_size = a_c.shape
        data_type = a_c.dtype
    elif use_xcorrelation:
        img_size = a_x.shape
        data_type = a_x.dtype
    else:
        raise ValueError("Cannot proceed if at least one of Defocus, Correspondence, and Cross-correlation cues can be used")

    if lambda_wl is not None and has_pywt is False:
        lambda_wl = None
        print("WARNING - wavelets selected but not available")

    depth = np.zeros(img_size, dtype=data_type)
    depth_it = depth

    q_g = np.zeros(np.concatenate(((2,), img_size)), dtype=data_type)
    tau = 4 * lambda_tv
    if lambda_d2 is not None:
        q_l = np.zeros(img_size, dtype=data_type)
        tau += 8 * lambda_d2
    if use_defocus > 0:
        q_d = np.zeros(img_size, dtype=data_type)
        tau += W_d
    if use_correspondence > 0:
        q_c = np.zeros(img_size, dtype=data_type)
        tau += W_c
    if use_xcorrelation > 0:
        q_x = np.zeros(img_size, dtype=data_type)
        tau += W_x
    if lambda_wl is not None:
        wl_type = "sym4"
        wl_lvl = np.fmin(pywt.dwtn_max_level(img_size, wl_type), 2)
        print("Wavelets selected! Wl type: %s, Wl lvl %d" % (wl_type, wl_lvl))
        q_wl = pywt.swtn(depth, wl_type, wl_lvl)
        tau += lambda_wl * (2 ** wl_lvl)
        sigma_wl = 1 / (2 ** np.arange(wl_lvl, 0, -1))
    tau = 1 / tau

    for ii in range(iterations):
        (d0, d1) = _gradient2(depth_it)
        d_2 = np.stack((d0, d1)) / 2
        q_g += d_2
        grad_l2_norm = np.fmax(1, np.sqrt(np.sum(q_g ** 2, axis=0)))
        q_g /= grad_l2_norm

        update = -lambda_tv * _divergence2(q_g[0, :, :], q_g[1, :, :])
        if lambda_d2 is not None:
            l_dep = _laplacian2(depth_it)
            q_l += l_dep / 8
            q_l /= np.fmax(1, np.abs(q_l))

            update += lambda_d2 * _laplacian2(q_l)

        if use_defocus > 0:
            q_d += depth_it - a_d
            q_d /= np.fmax(1, np.abs(q_d))

            update += use_defocus * W_d * q_d

        if use_correspondence > 0:
            q_c += depth_it - a_c
            q_c /= np.fmax(1, np.abs(q_c))

            update += use_correspondence * W_c * q_c

        if use_xcorrelation > 0:
            q_x += depth_it - a_x
            q_x /= np.fmax(1, np.abs(q_x))

            update += use_xcorrelation * W_x * q_x

        if lambda_wl is not None:
            d = pywt.swtn(depth_it, wl_type, wl_lvl)
            for ii_l in range(wl_lvl):
                for k in q_wl[ii_l].keys():
                    q_wl[ii_l][k] += d[ii_l][k] * sigma_wl[ii_l]
                    q_wl[ii_l][k] /= np.fmax(1, np.abs(q_wl[ii_l][k]))
            update += lambda_wl * pywt.iswtn(q_wl, wl_type)

        depth_new = depth - update * tau
        depth_it = depth_new + (depth_new - depth)
        depth = depth_new

    return depth