示例#1
0
 def test_np_backend(self):
     """Test numpy backend."""
     xp, backend = get_backend('numpy')
     if backend != 'numpy' or xp != LIBRARIES['numpy']:
         raise AssertionError('numpy get_backend fails!')
     np_input = change_backend(self.input, 'numpy')
     if (get_array_module(LIBRARIES['numpy'].ones(1)) != LIBRARIES['numpy']
             or get_array_module(np_input) != LIBRARIES['numpy']):
         raise AssertionError('numpy backend fails!')
示例#2
0
 def test_tf_backend(self):
     """Test tensorflow backend."""
     xp, backend = get_backend('tensorflow')
     if backend != 'tensorflow' or xp != LIBRARIES['tensorflow']:
         raise AssertionError('tensorflow get_backend fails!')
     tf_input = change_backend(self.input, 'tensorflow')
     if (get_array_module(
             LIBRARIES['tensorflow'].ones(1)) != LIBRARIES['tensorflow']
             or get_array_module(tf_input) != LIBRARIES['tensorflow']):
         raise AssertionError('tensorflow backend fails!')
示例#3
0
    def is_restart(self, z_old, x_new, x_old):
        """Check whether the algorithm needs to restart.

        This method implements the checks necessary to tell whether the
        algorithm needs to restart depending on the restarting strategy.
        It also updates the FISTA parameters according to the restarting
        strategy (namely beta and r).

        Parameters
        ----------
        z_old: numpy.ndarray
            Corresponds to y_n in :cite:`liang2018`.
        x_new: numpy.ndarray
            Corresponds to x_{n+1} in :cite:`liang2018`.
        x_old: numpy.ndarray
            Corresponds to x_n in :cite:`liang2018`.

        Returns
        -------
        bool
            Whether the algorithm should restart

        Notes
        -----
        Implements restarting and safeguarding steps in alg 4-5 o
        :cite:`liang2018`

        """
        xp = backend.get_array_module(x_new)

        if self.restart_strategy is None:
            return False

        criterion = xp.vdot(z_old - x_new, x_new - x_old) >= 0

        if criterion:
            if 'adaptive' in self.restart_strategy:
                self.r_lazy *= self.xi_restart
            if self.restart_strategy in {'adaptive-ii', 'adaptive-2'}:
                self._t_now = 1

        if self.restart_strategy == 'greedy':
            cur_delta = xp.linalg.norm(x_new - x_old)
            if self._delta0 is None:
                self._delta0 = self.s_greedy * cur_delta
            else:
                self._safeguard = cur_delta >= self._delta0

        return criterion
示例#4
0
文件: cost.py 项目: sfarrens/ModOpt
    def _check_cost(self):
        """Check cost function.

        This method tests the cost function for convergence in the specified
        interval of iterations using the last :math:`n` (``test_range``) cost
        values.

        Returns
        -------
        bool
            Result of the convergence test

        """
        # Add current cost value to the test list
        self._test_list.append(self.cost)

        xp = get_array_module(self.cost)

        # Check if enough cost values have been collected
        if len(self._test_list) == self._test_range:

            # The mean of the first half of the test list
            t1 = xp.mean(
                xp.array(self._test_list[len(self._test_list) // 2:]),
                axis=0,
            )
            # The mean of the second half of the test list
            t2 = xp.mean(
                xp.array(self._test_list[:len(self._test_list) // 2]),
                axis=0,
            )
            # Calculate the change across the test list
            if xp.around(t1, decimals=16):
                cost_diff = (xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1))
            else:
                cost_diff = 0

            # Reset the test list
            self._test_list = []

            if self._verbose:
                print(' - CONVERGENCE TEST - ')
                print(' - CHANGE IN COST:', cost_diff)
                print('')

            # Check for convergence
            return cost_diff <= self._tolerance

        return False
示例#5
0
    def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0):
        """Get spectral radius.

        This method calculates the spectral radius

        Parameters
        ----------
        tolerance : float, optional
            Tolerance threshold for convergence (default is ``1e-6``)
        max_iter : int, optional
            Maximum number of iterations (default is ``20``)
        extra_factor : float, optional
            Extra multiplicative factor for calculating the spectral radius
            (default is ``1.0``)

        """
        # Set (or reset) values of x.
        x_old = self._set_initial_x()

        # Iterate until the L2 norm of x converges.
        for i_elem in range(max_iter):

            xp = get_array_module(x_old)

            x_old_norm = xp.linalg.norm(x_old)

            x_new = self._operator(x_old) / x_old_norm

            x_new_norm = xp.linalg.norm(x_new)

            if (xp.abs(x_new_norm - x_old_norm) < tolerance):
                message = (' - Power Method converged after {0} iterations!')
                if self._verbose:
                    print(message.format(i_elem + 1))
                break

            elif i_elem == max_iter - 1 and self._verbose:
                message = (
                    ' - Power Method did not converge after {0} iterations!')
                print(message.format(max_iter))

            xp.copyto(x_old, x_new)

        self.spec_rad = x_new_norm * extra_factor
        self.inv_spec_rad = 1.0 / self.spec_rad
示例#6
0
def thresh(input_data, threshold, threshold_type='hard'):
    r"""Threshold data.

    This method perfoms hard or soft thresholding on the input data.

    Parameters
    ----------
    input_data : numpy.ndarray, list or tuple
        Input data array
    threshold : float or numpy.ndarray
        Threshold level(s)
    threshold_type : {'hard', 'soft'}
        Type of noise to be added (default is ``'hard'``)

    Returns
    -------
    numpy.ndarray
        Thresholded data

    Raises
    ------
    ValueError
        If ``threshold_type`` is not ``'hard'`` or ``'soft'``

    Notes
    -----
    Implements one of the following two equations:

    * Hard Threshold
        .. math::
            \mathrm{HT}_\lambda(x) =
            \begin{cases}
            x & \text{if } |x|\geq\lambda \\
            0 & \text{otherwise}
            \end{cases}

    * Soft Threshold
        .. math::
            \mathrm{ST}_\lambda(x) =
            \begin{cases}
            x-\lambda\text{sign}(x) & \text{if } |x|\geq\lambda \\
            0 & \text{otherwise}
            \end{cases}

    Examples
    --------
    >>> import numpy as np
    >>> from modopt.signal.noise import thresh
    >>> np.random.seed(1)
    >>> x = np.random.randint(-9, 9, 10)
    >>> x
    array([-4,  2,  3, -1,  0,  2, -4,  6, -9,  7])
    >>> thresh(x, 4)
    array([-4,  0,  0,  0,  0,  0, -4,  6, -9,  7])

    >>> import numpy as np
    >>> from modopt.signal.noise import thresh
    >>> np.random.seed(1)
    >>> x = np.random.ranf((3, 3))
    >>> x
    array([[4.17022005e-01, 7.20324493e-01, 1.14374817e-04],
           [3.02332573e-01, 1.46755891e-01, 9.23385948e-02],
           [1.86260211e-01, 3.45560727e-01, 3.96767474e-01]])
    >>> thresh(x, 0.2, threshold_type='soft')
    array([[0.217022  , 0.52032449, 0.        ],
           [0.10233257, 0.        , 0.        ],
           [0.        , 0.14556073, 0.19676747]])

    """
    xp = get_array_module(input_data)

    input_data = xp.array(input_data)

    if threshold_type not in {'hard', 'soft'}:
        raise ValueError(
            'Invalid threshold type. Options are "hard" or "soft"',
        )

    if threshold_type == 'soft':
        denominator = xp.maximum(xp.finfo(np.float64).eps, xp.abs(input_data))
        max_value = xp.maximum((1.0 - threshold / denominator), 0)

        return xp.around(max_value * input_data, decimals=15)

    return input_data * (xp.abs(input_data) >= threshold)