Esempio n. 1
0
def test_per_axis_wavelets():
    # tests seperate wavelet for each axis.
    rstate = np.random.RandomState(1234)
    data = rstate.randn(16, 16, 16)
    level = 3

    # wavelet can be a string or wavelet object
    wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')

    coefs = pywt.swtn(data, wavelets, level=level)
    assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14)

    # 1-tuple also okay
    coefs = pywt.swtn(data, wavelets[:1], level=level)
    assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)

    # length of wavelets doesn't match the length of axes
    assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
    assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])

    with warnings.catch_warnings():
        warnings.simplefilter('ignore', FutureWarning)
        # swt2/iswt2 also support per-axis wavelets/modes
        data2 = data[..., 0]
        coefs2 = pywt.swt2(data2, wavelets[:2], level)
        assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14)
Esempio n. 2
0
def test_per_axis_wavelets():
    # tests seperate wavelet for each axis.
    rstate = np.random.RandomState(1234)
    data = rstate.randn(16, 16, 16)
    level = 3

    # wavelet can be a string or wavelet object
    wavelets = (pywt.Wavelet('haar'), 'sym2', 'db4')

    coefs = pywt.swtn(data, wavelets, level=level)
    assert_allclose(pywt.iswtn(coefs, wavelets), data, atol=1e-14)

    # 1-tuple also okay
    coefs = pywt.swtn(data, wavelets[:1], level=level)
    assert_allclose(pywt.iswtn(coefs, wavelets[:1]), data, atol=1e-14)

    # length of wavelets doesn't match the length of axes
    assert_raises(ValueError, pywt.swtn, data, wavelets[:2], level)
    assert_raises(ValueError, pywt.iswtn, coefs, wavelets[:2])

    with warnings.catch_warnings():
        warnings.simplefilter('ignore', FutureWarning)
        # swt2/iswt2 also support per-axis wavelets/modes
        data2 = data[..., 0]
        coefs2 = pywt.swt2(data2, wavelets[:2], level)
        assert_allclose(pywt.iswt2(coefs2, wavelets[:2]), data2, atol=1e-14)
Esempio n. 3
0
    def adj_op(self, coeffs):
        """
        Define the wavelet adjoint operator.
        This method returns the reconstructed image.

        Parameters
        ----------
        coeffs: np.ndarray
            the wavelet coefficients.

        Returns
        -------
        data: np.ndarray((m, n)) or np.ndarray((m, n, p))
            the 2D or 3D reconstructed data.
        """
        self.coeffs = coeffs
        if self.undecimated:
            coeffs_dict = self.unflatten(coeffs, self.coeffs_shape)
            data = pywt.iswtn(coeffs_dict, self.pywt_transform)
        else:
            coeffs_dict = self.unflatten(coeffs, self.coeffs_shape)
            data = pywt.waverecn(coeffs=coeffs_dict,
                                 wavelet=self.pywt_transform,
                                 mode=self.mode)
        return data
Esempio n. 4
0
def test_iswtn_mixed_dtypes():
    # Mixed precision inputs give double precision output
    rstate = np.random.RandomState(0)
    x_real = rstate.randn(8, 8, 8)
    x_complex = x_real + 1j*x_real
    wav = 'sym2'
    for dtype1, dtype2 in [(np.float64, np.float32),
                           (np.float32, np.float64),
                           (np.float16, np.float64),
                           (np.complex128, np.complex64),
                           (np.complex64, np.complex128)]:

        if dtype1 in [np.complex64, np.complex128]:
            x = x_complex
            output_dtype = np.complex128
        else:
            x = x_real
            output_dtype = np.float64

        coeffs = pywt.swtn(x, wav, 2)
        # different precision for the approximation coefficients
        a = coeffs[0].pop('a' * x.ndim)
        a = a.astype(dtype1)
        coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()}
        coeffs[0]['a' * x.ndim] = a
        y = pywt.iswtn(coeffs, wav)
        assert_equal(output_dtype, y.dtype)
        assert_allclose(y, x, rtol=1e-3, atol=1e-3)
Esempio n. 5
0
def test_iswtn_mixed_dtypes():
    # Mixed precision inputs give double precision output
    rstate = np.random.RandomState(0)
    x_real = rstate.randn(8, 8, 8)
    x_complex = x_real + 1j*x_real
    wav = 'sym2'
    for dtype1, dtype2 in [(np.float64, np.float32),
                           (np.float32, np.float64),
                           (np.float16, np.float64),
                           (np.complex128, np.complex64),
                           (np.complex64, np.complex128)]:

        if dtype1 in [np.complex64, np.complex128]:
            x = x_complex
            output_dtype = np.complex128
        else:
            x = x_real
            output_dtype = np.float64

        coeffs = pywt.swtn(x, wav, 2)
        # different precision for the approximation coefficients
        a = coeffs[0].pop('a' * x.ndim)
        a = a.astype(dtype1)
        coeffs[0] = {k: c.astype(dtype2) for k, c in coeffs[0].items()}
        coeffs[0]['a' * x.ndim] = a
        y = pywt.iswtn(coeffs, wav)
        assert_equal(output_dtype, y.dtype)
        assert_allclose(y, x, rtol=1e-3, atol=1e-3)
def main(out_file, bins, signal_events, background_events,
         max_transient_events, time_steps, cmap):
    '''
    Use a toy model to create a transient appearing in the FoV of another source.
    A steady background is subtracted and denoised using wavelets.
    This script then creates an animated gif of the whoe shebang saved under the
    OUT_FILE argument.
    '''

    bins = [bins, bins]
    cube_steady = simulate_steady_source(
        num_slices=time_steps,
        source_count=signal_events,
        background_count=background_events,
        bins=bins,
    )

    def time_dependency():
        return transient_gaussian(time_steps=time_steps,
                                  max_events=max_transient_events)

    cube_with_transient = simulate_steady_source_with_transient(
        time_dependency,
        source_count=signal_events,
        background_count=background_events,
        bins=bins)

    # remove mean measured noise from current cube
    cube = cube_with_transient - cube_steady.mean(axis=0)
    coeffs = pywt.swtn(
        data=cube,
        wavelet='bior1.3',
        level=2,
    )

    # remove noisy coefficents.
    ct = thresholding_3d(coeffs, k=30)
    cube_smoothed = pywt.iswtn(coeffs=ct, wavelet='bior1.3')

    # some Criterion which could be used to trigger this.
    trans_factor = cube_smoothed.max(axis=1).max(axis=1)

    p = TransientPlotter(
        cube_with_transient,
        cube_smoothed,
        trans_factor,
        cmap=cmap,
    )

    print('Plotting animation. (Be patient)')
    anim = animation.FuncAnimation(
        p.fig,
        p.step,
        frames=time_steps,
        interval=15,
        blit=True,
    )

    anim.save(out_file, writer='imagemagick', fps=25)
Esempio n. 7
0
def test_swtn_iswtn_unique_shape_per_axis():
    # test case for gh-460
    _shape = (1, 48, 32)  # unique shape per axis
    wav = 'sym2'
    max_level = 3
    rstate = np.random.RandomState(0)
    for shape in permutations(_shape):
        # transform only along the non-singleton axes
        axes = [ax for ax, s in enumerate(shape) if s != 1]
        x = rstate.standard_normal(shape)
        c = pywt.swtn(x, wav, max_level, axes=axes)
        r = pywt.iswtn(c, wav, axes=axes)
        assert_allclose(x, r, rtol=1e-10, atol=1e-10)
Esempio n. 8
0
def test_swtn_iswtn_unique_shape_per_axis():
    # test case for gh-460
    _shape = (1, 48, 32)  # unique shape per axis
    wav = 'sym2'
    max_level = 3
    rstate = np.random.RandomState(0)
    for shape in permutations(_shape):
        # transform only along the non-singleton axes
        axes = [ax for ax, s in enumerate(shape) if s != 1]
        x = rstate.standard_normal(shape)
        c = pywt.swtn(x, wav, max_level, axes=axes)
        r = pywt.iswtn(c, wav, axes=axes)
        assert_allclose(x, r, rtol=1e-10, atol=1e-10)
Esempio n. 9
0
 def initialize_wl_operators(self):
     if self.use_decimated:
         H = lambda x: pywt.wavedecn(x, wavelet=self.wl_type, axes=self.axes, level=self.decomp_lvl)
         Ht = lambda x: pywt.waverecn(x, wavelet=self.wl_type, axes=self.axes)
     else:
         if use_swtn:
             H = lambda x: pywt.swtn(x, wavelet=self.wl_type, axes=self.axes, level=self.decomp_lvl)
             Ht = lambda x: pywt.iswtn(x, wavelet=self.wl_type, axes=self.axes)
         else:
             H = lambda x: pywt.swt2(np.squeeze(x), wavelet=self.wl_type, axes=self.axes, level=self.decomp_lvl)
             #                Ht = lambda x : pywt.iswt2(x, wavelet=self.wl_type)
             Ht = lambda x: pywt.iswt2(x, wavelet=self.wl_type)[np.newaxis, ...]
     return (H, Ht)
Esempio n. 10
0
def test_swtn_iswtn_integration(wavelets=None):
    # This function performs a round-trip swtn/iswtn transform for various
    # possible combinations of:
    #   1.) 1 out of 2 axes of a 2D array
    #   2.) 2 out of 3 axes of a 3D array
    #
    # To keep test time down, only wavelets of length <= 8 are run.
    #
    # This test does not validate swtn or iswtn individually, but only
    # confirms that iswtn yields an (almost) perfect reconstruction of swtn.
    max_level = 3
    if wavelets is None:
        wavelets = pywt.wavelist(kind='discrete')
        if 'dmey' in wavelets:
            # The 'dmey' wavelet is a special case - disregard it for now
            wavelets.remove('dmey')
    for ndim_transform in range(1, 3):
        ndim = ndim_transform + 1
        for axes in combinations(range(ndim), ndim_transform):
            for current_wavelet_str in wavelets:
                wav = pywt.Wavelet(current_wavelet_str)
                if wav.dec_len > 8:
                    continue  # avoid excessive test duration
                input_length_power = int(
                    np.ceil(np.log2(max(wav.dec_len, wav.rec_len))))
                N = 2**(input_length_power + max_level - 1)
                X = np.arange(N**ndim).reshape((N, ) * ndim)

                for norm in [True, False]:
                    if norm and not wav.orthogonal:
                        # non-orthogonal wavelets to avoid warnings
                        continue
                    for trim_approx in [True, False]:
                        coeffs = pywt.swtn(X,
                                           wav,
                                           max_level,
                                           axes=axes,
                                           trim_approx=trim_approx,
                                           norm=norm)
                        coeffs_copy = deepcopy(coeffs)
                        Y = pywt.iswtn(coeffs, wav, axes=axes, norm=norm)
                        assert_allclose(Y, X, rtol=1e-5, atol=1e-5)

                # verify the inverse transform didn't modify any coeffs
                for c, c2 in zip(coeffs, coeffs_copy):
                    for k, v in c.items():
                        assert_array_equal(c2[k], v)
Esempio n. 11
0
    def _synthesis(self, analysis_data, analysis_header):
        """ Reconstruct a real signal from the wavelet coefficients using pywt.

        Parameters
        ----------
        analysis_data: list of nd-array
            the wavelet coefficients array.
        analysis_header: dict
            the wavelet decomposition parameters.

        Returns
        -------
        data: nd-array
            the reconstructed data array.
        """
        coeffs = self._organize_pywt(analysis_data, analysis_header)
        if self.is_decimated:
            data = pywt.waverecn(coeffs, self.trf, mode=self.padding_mode,
                                 axes=self.axes)
        else:
            data = pywt.iswtn(coeffs, self.trf, axes=self.axes)
        return data
    def denoise_and_compare_cubes(self, steady_cube, cube_with_transient):
        cube = cube_with_transient - steady_cube.mean(axis=0)
        coeffs = pywt.swtn(
            data=cube,
            wavelet='bior1.3',
            level=2,
        )

        # remove noisy coefficents.
        ct = thresholding_3d(coeffs, k=30)
        cube_smoothed = pywt.iswtn(coeffs=ct, wavelet='bior1.3')

        # some Criterion which could be used to trigger this.
        trans_factor = cube_smoothed.max(axis=1).max(axis=1)

        # return trans_factor

        p = TransientPlotter(
            cube_with_transient,
            cube_smoothed,
            trans_factor,
            cmap='viridis',
        )

        print('Plotting animation. (Be patient)')
        anim = animation.FuncAnimation(
            p.fig,
            p.step,
            frames=len(cube),
            interval=15,
            blit=True,
        )

        anim.save('build/anim_{}.gif'.format(self.window.popleft()[0]),
                  writer='imagemagick',
                  fps=25)

        return trans_factor
Esempio n. 13
0
def test_swtn_iswtn_integration(wavelets=None):
    # This function performs a round-trip swtn/iswtn transform for various
    # possible combinations of:
    #   1.) 1 out of 2 axes of a 2D array
    #   2.) 2 out of 3 axes of a 3D array
    #
    # To keep test time down, only wavelets of length <= 8 are run.
    #
    # This test does not validate swtn or iswtn individually, but only
    # confirms that iswtn yields an (almost) perfect reconstruction of swtn.
    max_level = 3
    if wavelets is None:
        wavelets = pywt.wavelist(kind='discrete')
        if 'dmey' in wavelets:
            # The 'dmey' wavelet is a special case - disregard it for now
            wavelets.remove('dmey')
    for ndim_transform in range(1, 3):
        ndim = ndim_transform + 1
        for axes in combinations(range(ndim), ndim_transform):
            for current_wavelet_str in wavelets:
                wav = pywt.Wavelet(current_wavelet_str)
                if wav.dec_len > 8:
                    continue  # avoid excessive test duration
                input_length_power = int(np.ceil(np.log2(max(
                    wav.dec_len,
                    wav.rec_len))))
                N = 2**(input_length_power + max_level - 1)
                X = np.arange(N**ndim).reshape((N, )*ndim)

                coeffs = pywt.swtn(X, wav, max_level, axes=axes)
                coeffs_copy = deepcopy(coeffs)
                Y = pywt.iswtn(coeffs, wav, axes=axes)
                assert_allclose(Y, X, rtol=1e-5, atol=1e-5)

                # verify the inverse transform didn't modify any coeffs
                for c, c2 in zip(coeffs, coeffs_copy):
                    for k, v in c.items():
                        assert_array_equal(c2[k], v)
Esempio n. 14
0
    def inverse_swt(self, y):
        """Perform the inverse wavelet transform.

        :param x: Data to anti-transform.
        :type x: list

        :return: Anti-transformed data.
        :rtype: `numpy.array_like`
        """
        x = pywt.iswtn(y,
                       wavelet=self.wavelet,
                       axes=self.axes,
                       norm=self.normalized)
        if self.pad_on_demand is not None and np.any(self.pad_axes):
            for ax in np.nonzero(self.pad_axes)[0]:
                pad_l = np.ceil(self.pad_axes[ax] / 2).astype(np.intp)
                pad_h = np.floor(self.pad_axes[ax] / 2).astype(np.intp)
                slices = [slice(None)] * len(x.shape)
                slices[self.axes[ax]] = slice(pad_l,
                                              x.shape[self.axes[ax]] - pad_h,
                                              1)
                x = x[tuple(slices)]
        return x
Esempio n. 15
0
def fuse_stationary_wavelets(first,
                             second,
                             *,
                             levels=None,
                             pca=False,
                             wavelet=None):
    if levels is None:
        levels = 6
    if wavelet is None:
        wavelet = 'sym4'

    pad, unpad = swt_pad_funcs(first.shape, levels)
    first = pad(first)
    second = pad(second)

    first = pywt.swtn(first,
                      wavelet,
                      level=levels,
                      axes=(0, 1),
                      norm=True,
                      trim_approx=True)
    second = pywt.swtn(second,
                       wavelet,
                       level=levels,
                       axes=(0, 1),
                       norm=True,
                       trim_approx=True)

    first[0] = (first[0] + second[0]) / 2
    for first_cs, second_cs in zip(first[1:], second[1:]):
        mask = _coeff_strength(second_cs.values(), pca) > _coeff_strength(
            first_cs.values(), pca)
        for k, v in first_cs.items():
            v[mask, ...] = second_cs[k][mask, ...]
    del second
    first = pywt.iswtn(first, wavelet, axes=(0, 1), norm=True)
    return unpad(first)
Esempio n. 16
0
 def time_iswtn(self, D, n, wavelet, dtype):
     pywt.iswtn(self.data, wavelet)
Esempio n. 17
0
def fuse_focal_stack_kmax(images,
                          *,
                          k=None,
                          levels=None,
                          wavelet=None,
                          pca=None,
                          in_memory=None,
                          sharpness_sigma=None):
    if k is None:
        k = 3
    if levels is None:
        levels = 3
    if wavelet is None:
        wavelet = 'sym4'
    if pca is None:
        pca = True
    if in_memory is None:
        in_memory = False
    if sharpness_sigma is None:
        sharpness_sigma = 3

    k = min(k, len(images))

    pad, unpad = swt_pad_funcs(images[0].shape, levels)
    sharpnesses = temporary_array_list(
        (sharpness(pad(image), sharpness_sigma, pca=pca) for image in images),
        in_memory=in_memory)
    kmax = kmax_sharpnesses(sharpnesses, k)

    bases = temporary_array_list()
    ads = [temporary_array_list() for _ in range(levels)]
    das = [temporary_array_list() for _ in range(levels)]
    dds = [temporary_array_list() for _ in range(levels)]

    for image in images:
        image = pad(image)
        shape = image.shape
        coeffs = pywt.swtn(image,
                           wavelet,
                           level=levels,
                           axes=(0, 1),
                           norm=True,
                           trim_approx=True)
        bases.append(coeffs[0])
        for level in range(levels):
            cl = coeffs[level + 1]
            ads[level].append(cl['ad'])
            das[level].append(cl['da'])
            dds[level].append(cl['dd'])

    base = np.empty_like(bases[0])
    ad = [np.empty_like(ads[level][0]) for level in range(levels)]
    da = [np.empty_like(das[level][0]) for level in range(levels)]
    dd = [np.empty_like(dds[level][0]) for level in range(levels)]
    first = np.full(shape[:2], True)

    for i, s in enumerate(sharpnesses):
        mask = np.any(s[:, :, np.newaxis] == kmax, axis=2)
        mask_and_first = mask & first
        base[mask_and_first, ...] = bases[i][mask_and_first, ...]
        for level in range(levels):
            ad[level][mask_and_first, ...] = ads[level][i][mask_and_first, ...]
            da[level][mask_and_first, ...] = das[level][i][mask_and_first, ...]
            dd[level][mask_and_first, ...] = dds[level][i][mask_and_first, ...]
        del mask_and_first
        mask_and_not_first = mask & ~first
        base[mask_and_not_first, ...] += bases[i][mask_and_not_first, ...]
        for level in range(levels):
            cmask = (
                (reduce_color_dimension(ads[level][i]**2) +
                 reduce_color_dimension(das[level][i]**2) +
                 reduce_color_dimension(dds[level][i]**2)) >
                (reduce_color_dimension(ad[level]**2) +
                 reduce_color_dimension(da[level]**2) +
                 reduce_color_dimension(dd[level]**2))) & mask_and_not_first
            ad[level][cmask, ...] = ads[level][i][cmask, ...]
            da[level][cmask, ...] = das[level][i][cmask, ...]
            dd[level][cmask, ...] = dds[level][i][cmask, ...]
        first[mask] = False

    base /= k
    coeffs = [base] + [
        dict(ad=ad[level], da=da[level], dd=dd[level])
        for level in range(levels)
    ]
    return unpad(pywt.iswtn(coeffs, wavelet, axes=(0, 1), norm=True))
Esempio n. 18
0
def compute_depth_map(
    depth_cues,
    iterations=500,
    lambda_tv=2.0,
    lambda_d2=0.05,
    lambda_wl=None,
    use_defocus=1.0,
    use_correspondence=1.0,
    use_xcorrelation=0.0,
):
    """Computes a depth map from the given depth cues.

    This depth map is based on the procedure from:

    M. W. Tao, et al., "Depth from combining defocus and correspondence using
    light-field cameras," in Proceedings of the IEEE International Conference on
    Computer Vision, 2013, pp. 673–680.

    :param depth_cues: The depth cues
    :type depth_cues: dict
    :param iterations: Number of iterations, defaults to 500
    :type iterations: int, optional
    :param lambda_tv: Lambda value of the TV term, defaults to 2.0
    :type lambda_tv: float, optional
    :param lambda_d2: Lambda value of the smoothing term, defaults to 0.05
    :type lambda_d2: float, optional
    :param lambda_wl: Lambda value of the wavelet term, defaults to None
    :type lambda_wl: float, optional
    :param use_defocus: Weight of defocus cues, defaults to 1.0
    :type use_defocus: float, optional
    :param use_correspondence: Weight of corresponence cues, defaults to 1.0
    :type use_correspondence: float, optional
    :param use_xcorrelation: Weight of the cross-correlation cues, defaults to 0.0
    :type use_xcorrelation: float, optional

    :raises ValueError: In case of requested wavelet regularization but not available

    :returns: The depth map
    :rtype: `numpy.array_like`
    """
    if not (lambda_wl is None or (has_pywt and use_swtn)):
        raise ValueError("Wavelet regularization requested but not available")

    use_defocus = np.fmax(use_defocus, 0.0)
    use_defocus = np.fmin(use_defocus, 1.0)
    use_correspondence = np.fmax(use_correspondence, 0.0)
    use_correspondence = np.fmin(use_correspondence, 1.0)
    use_xcorrelation = np.fmax(use_xcorrelation, 0.0)
    use_xcorrelation = np.fmin(use_xcorrelation, 1.0)

    W_d = depth_cues["confidence_defocus"]
    a_d = depth_cues["depth_defocus"]

    W_c = depth_cues["confidence_correspondence"]
    a_c = depth_cues["depth_correspondence"]

    W_x = depth_cues["confidence_xcorrelation"]
    a_x = depth_cues["depth_xcorrelation"]

    if use_defocus > 0 and (W_d.size == 0 or a_d.size == 0):
        use_defocus = 0
        warnings.warn("Defocusing parameters were not passed, disabling their use")

    if use_correspondence > 0 and (W_c.size == 0 or a_c.size == 0):
        use_correspondence = 0
        warnings.warn("Correspondence parameters were not passed, disabling their use")

    if use_xcorrelation > 0 and (W_x.size == 0 or a_x.size == 0):
        use_xcorrelation = 0
        warnings.warn("Cross-correlation parameters were not passed, disabling their use")

    if use_defocus:
        img_size = a_d.shape
        data_type = a_d.dtype
    elif use_correspondence:
        img_size = a_c.shape
        data_type = a_c.dtype
    elif use_xcorrelation:
        img_size = a_x.shape
        data_type = a_x.dtype
    else:
        raise ValueError("Cannot proceed if at least one of Defocus, Correspondence, and Cross-correlation cues can be used")

    if lambda_wl is not None and has_pywt is False:
        lambda_wl = None
        print("WARNING - wavelets selected but not available")

    depth = np.zeros(img_size, dtype=data_type)
    depth_it = depth

    q_g = np.zeros(np.concatenate(((2,), img_size)), dtype=data_type)
    tau = 4 * lambda_tv
    if lambda_d2 is not None:
        q_l = np.zeros(img_size, dtype=data_type)
        tau += 8 * lambda_d2
    if use_defocus > 0:
        q_d = np.zeros(img_size, dtype=data_type)
        tau += W_d
    if use_correspondence > 0:
        q_c = np.zeros(img_size, dtype=data_type)
        tau += W_c
    if use_xcorrelation > 0:
        q_x = np.zeros(img_size, dtype=data_type)
        tau += W_x
    if lambda_wl is not None:
        wl_type = "sym4"
        wl_lvl = np.fmin(pywt.dwtn_max_level(img_size, wl_type), 2)
        print("Wavelets selected! Wl type: %s, Wl lvl %d" % (wl_type, wl_lvl))
        q_wl = pywt.swtn(depth, wl_type, wl_lvl)
        tau += lambda_wl * (2 ** wl_lvl)
        sigma_wl = 1 / (2 ** np.arange(wl_lvl, 0, -1))
    tau = 1 / tau

    for ii in range(iterations):
        (d0, d1) = _gradient2(depth_it)
        d_2 = np.stack((d0, d1)) / 2
        q_g += d_2
        grad_l2_norm = np.fmax(1, np.sqrt(np.sum(q_g ** 2, axis=0)))
        q_g /= grad_l2_norm

        update = -lambda_tv * _divergence2(q_g[0, :, :], q_g[1, :, :])
        if lambda_d2 is not None:
            l_dep = _laplacian2(depth_it)
            q_l += l_dep / 8
            q_l /= np.fmax(1, np.abs(q_l))

            update += lambda_d2 * _laplacian2(q_l)

        if use_defocus > 0:
            q_d += depth_it - a_d
            q_d /= np.fmax(1, np.abs(q_d))

            update += use_defocus * W_d * q_d

        if use_correspondence > 0:
            q_c += depth_it - a_c
            q_c /= np.fmax(1, np.abs(q_c))

            update += use_correspondence * W_c * q_c

        if use_xcorrelation > 0:
            q_x += depth_it - a_x
            q_x /= np.fmax(1, np.abs(q_x))

            update += use_xcorrelation * W_x * q_x

        if lambda_wl is not None:
            d = pywt.swtn(depth_it, wl_type, wl_lvl)
            for ii_l in range(wl_lvl):
                for k in q_wl[ii_l].keys():
                    q_wl[ii_l][k] += d[ii_l][k] * sigma_wl[ii_l]
                    q_wl[ii_l][k] /= np.fmax(1, np.abs(q_wl[ii_l][k]))
            update += lambda_wl * pywt.iswtn(q_wl, wl_type)

        depth_new = depth - update * tau
        depth_it = depth_new + (depth_new - depth)
        depth = depth_new

    return depth
def main(gamma_file, proton_file, output_file):
    bins = [80, 80]
    bin_range = [[62.5, 78.5], [-12.4, 12.4]]

    df_gammas = loadFile(gamma_file)
    df_protons = loadFile(proton_file)

    print('Read {} gammas and {} protons'.format(len(df_gammas),
                                                 len(df_protons)))
    factor = (10E5 * len(df_gammas)) / len(df_protons)

    print(factor)

    df_background = df_protons[df_protons['prediction:signal:mean'] > 0.87]
    df_signal = df_gammas[df_gammas['prediction:signal:mean'] > 0.87]

    print('Read {} signal events and {} background events'.format(
        len(df_signal), len(df_background)))
    ratio = len(df_background) / len(df_signal)
    expected_background = int(ratio * len(df_background) * factor)
    print('Upsampling background to get {} events'.format(expected_background))
    df_background = df_protons.sample(expected_background, replace=True)

    cube_background = create_cube(df_background.sample(frac=0.5),
                                  bins=bins,
                                  bin_range=bin_range)
    cube_gammas = create_cube(df_gammas.sample(frac=0.5),
                              bins=bins,
                              bin_range=bin_range)

    cube_steady = cube_background + cube_gammas

    cube_background = create_cube(df_background.sample(frac=0.5),
                                  bins=bins,
                                  bin_range=bin_range)
    cube_gammas = create_cube(df_gammas.sample(frac=0.5),
                              bins=bins,
                              bin_range=bin_range)
    cube_bright_gammas = create_cube(df_gammas, bins=bins, bin_range=bin_range)

    cube_with_transient = np.vstack(
        (cube_background + cube_gammas, cube_background + cube_bright_gammas))

    # remove mean measured noise from current cube
    cube = cube_with_transient - cube_steady.mean(axis=0)
    coeffs = pywt.swtn(
        data=cube,
        wavelet='bior1.3',
        level=2,
    )

    # remove noisy coefficents.
    ct = thresholding_3d(coeffs, k=30)
    cube_smoothed = pywt.iswtn(coeffs=ct, wavelet='bior1.3')

    # some Criterion which could be used to trigger this.
    trans_factor = cube_smoothed.max(axis=1).max(axis=1)

    p = TransientPlotter(
        cube_with_transient,
        cube_smoothed,
        trans_factor,
        cmap='viridis',
    )

    print('Plotting animation. (Be patient)')
    anim = animation.FuncAnimation(
        p.fig,
        p.step,
        frames=len(cube),
        interval=15,
        blit=True,
    )

    anim.save('anim.gif', writer='imagemagick', fps=25)