def simplex_prox(tensor, parameter): """ Projects the input tensor on the simplex of radius parameter. Parameters ---------- tensor : ndarray parameter : float Returns ------- ndarray References ---------- .. [1]: Held, Michael, Philip Wolfe, and Harlan P. Crowder. "Validation of subgradient optimization." Mathematical programming 6.1 (1974): 62-88. """ _, col = tl.shape(tensor) tensor = tl.clip(tensor, 0, tl.max(tensor)) tensor_sort = tl.sort(tensor, axis=0, descending=True) to_change = tl.sum(tl.where( tensor_sort > (tl.cumsum(tensor_sort, axis=0) - parameter), 1.0, 0.0), axis=0) difference = tl.zeros(col) for i in range(col): if to_change[i] > 0: difference = tl.index_update( difference, tl.index[i], tl.cumsum(tensor_sort, axis=0)[int(to_change[i] - 1), i]) difference = (difference - parameter) / to_change return tl.clip(tensor - difference, a_min=0)
def sparsify_tensor(tensor, card): """Zeros out all elements in the `tensor` except `card` elements with maximum absolute values. Parameters ---------- tensor : ndarray card : int Desired number of non-zero elements in the `tensor` Returns ------- ndarray of shape tensor.shape """ if card >= np.prod(tensor.shape): return tensor bound = tl.sort(tl.abs(tensor), axis = None)[-card] return tl.where(tl.abs(tensor) < bound, tl.zeros(tensor.shape, **tl.context(tensor)), tensor)