def test_np_backend(self): """Test numpy backend.""" xp, backend = get_backend('numpy') if backend != 'numpy' or xp != LIBRARIES['numpy']: raise AssertionError('numpy get_backend fails!') np_input = change_backend(self.input, 'numpy') if (get_array_module(LIBRARIES['numpy'].ones(1)) != LIBRARIES['numpy'] or get_array_module(np_input) != LIBRARIES['numpy']): raise AssertionError('numpy backend fails!')
def test_tf_backend(self): """Test tensorflow backend.""" xp, backend = get_backend('tensorflow') if backend != 'tensorflow' or xp != LIBRARIES['tensorflow']: raise AssertionError('tensorflow get_backend fails!') tf_input = change_backend(self.input, 'tensorflow') if (get_array_module( LIBRARIES['tensorflow'].ones(1)) != LIBRARIES['tensorflow'] or get_array_module(tf_input) != LIBRARIES['tensorflow']): raise AssertionError('tensorflow backend fails!')
def is_restart(self, z_old, x_new, x_old): """Check whether the algorithm needs to restart. This method implements the checks necessary to tell whether the algorithm needs to restart depending on the restarting strategy. It also updates the FISTA parameters according to the restarting strategy (namely beta and r). Parameters ---------- z_old: numpy.ndarray Corresponds to y_n in :cite:`liang2018`. x_new: numpy.ndarray Corresponds to x_{n+1} in :cite:`liang2018`. x_old: numpy.ndarray Corresponds to x_n in :cite:`liang2018`. Returns ------- bool Whether the algorithm should restart Notes ----- Implements restarting and safeguarding steps in alg 4-5 o :cite:`liang2018` """ xp = backend.get_array_module(x_new) if self.restart_strategy is None: return False criterion = xp.vdot(z_old - x_new, x_new - x_old) >= 0 if criterion: if 'adaptive' in self.restart_strategy: self.r_lazy *= self.xi_restart if self.restart_strategy in {'adaptive-ii', 'adaptive-2'}: self._t_now = 1 if self.restart_strategy == 'greedy': cur_delta = xp.linalg.norm(x_new - x_old) if self._delta0 is None: self._delta0 = self.s_greedy * cur_delta else: self._safeguard = cur_delta >= self._delta0 return criterion
def _check_cost(self): """Check cost function. This method tests the cost function for convergence in the specified interval of iterations using the last :math:`n` (``test_range``) cost values. Returns ------- bool Result of the convergence test """ # Add current cost value to the test list self._test_list.append(self.cost) xp = get_array_module(self.cost) # Check if enough cost values have been collected if len(self._test_list) == self._test_range: # The mean of the first half of the test list t1 = xp.mean( xp.array(self._test_list[len(self._test_list) // 2:]), axis=0, ) # The mean of the second half of the test list t2 = xp.mean( xp.array(self._test_list[:len(self._test_list) // 2]), axis=0, ) # Calculate the change across the test list if xp.around(t1, decimals=16): cost_diff = (xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1)) else: cost_diff = 0 # Reset the test list self._test_list = [] if self._verbose: print(' - CONVERGENCE TEST - ') print(' - CHANGE IN COST:', cost_diff) print('') # Check for convergence return cost_diff <= self._tolerance return False
def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0): """Get spectral radius. This method calculates the spectral radius Parameters ---------- tolerance : float, optional Tolerance threshold for convergence (default is ``1e-6``) max_iter : int, optional Maximum number of iterations (default is ``20``) extra_factor : float, optional Extra multiplicative factor for calculating the spectral radius (default is ``1.0``) """ # Set (or reset) values of x. x_old = self._set_initial_x() # Iterate until the L2 norm of x converges. for i_elem in range(max_iter): xp = get_array_module(x_old) x_old_norm = xp.linalg.norm(x_old) x_new = self._operator(x_old) / x_old_norm x_new_norm = xp.linalg.norm(x_new) if (xp.abs(x_new_norm - x_old_norm) < tolerance): message = (' - Power Method converged after {0} iterations!') if self._verbose: print(message.format(i_elem + 1)) break elif i_elem == max_iter - 1 and self._verbose: message = ( ' - Power Method did not converge after {0} iterations!') print(message.format(max_iter)) xp.copyto(x_old, x_new) self.spec_rad = x_new_norm * extra_factor self.inv_spec_rad = 1.0 / self.spec_rad
def thresh(input_data, threshold, threshold_type='hard'): r"""Threshold data. This method perfoms hard or soft thresholding on the input data. Parameters ---------- input_data : numpy.ndarray, list or tuple Input data array threshold : float or numpy.ndarray Threshold level(s) threshold_type : {'hard', 'soft'} Type of noise to be added (default is ``'hard'``) Returns ------- numpy.ndarray Thresholded data Raises ------ ValueError If ``threshold_type`` is not ``'hard'`` or ``'soft'`` Notes ----- Implements one of the following two equations: * Hard Threshold .. math:: \mathrm{HT}_\lambda(x) = \begin{cases} x & \text{if } |x|\geq\lambda \\ 0 & \text{otherwise} \end{cases} * Soft Threshold .. math:: \mathrm{ST}_\lambda(x) = \begin{cases} x-\lambda\text{sign}(x) & \text{if } |x|\geq\lambda \\ 0 & \text{otherwise} \end{cases} Examples -------- >>> import numpy as np >>> from modopt.signal.noise import thresh >>> np.random.seed(1) >>> x = np.random.randint(-9, 9, 10) >>> x array([-4, 2, 3, -1, 0, 2, -4, 6, -9, 7]) >>> thresh(x, 4) array([-4, 0, 0, 0, 0, 0, -4, 6, -9, 7]) >>> import numpy as np >>> from modopt.signal.noise import thresh >>> np.random.seed(1) >>> x = np.random.ranf((3, 3)) >>> x array([[4.17022005e-01, 7.20324493e-01, 1.14374817e-04], [3.02332573e-01, 1.46755891e-01, 9.23385948e-02], [1.86260211e-01, 3.45560727e-01, 3.96767474e-01]]) >>> thresh(x, 0.2, threshold_type='soft') array([[0.217022 , 0.52032449, 0. ], [0.10233257, 0. , 0. ], [0. , 0.14556073, 0.19676747]]) """ xp = get_array_module(input_data) input_data = xp.array(input_data) if threshold_type not in {'hard', 'soft'}: raise ValueError( 'Invalid threshold type. Options are "hard" or "soft"', ) if threshold_type == 'soft': denominator = xp.maximum(xp.finfo(np.float64).eps, xp.abs(input_data)) max_value = xp.maximum((1.0 - threshold / denominator), 0) return xp.around(max_value * input_data, decimals=15) return input_data * (xp.abs(input_data) >= threshold)