Esempio n. 1
0
def simplex_prox(tensor, parameter):
    """
    Projects the input tensor on the simplex of radius parameter.

    Parameters
    ----------
    tensor : ndarray
    parameter : float

    Returns
    -------
    ndarray

    References
    ----------
    .. [1]: Held, Michael, Philip Wolfe, and Harlan P. Crowder.
            "Validation of subgradient optimization."
            Mathematical programming 6.1 (1974): 62-88.
    """
    _, col = tl.shape(tensor)
    tensor = tl.clip(tensor, 0, tl.max(tensor))
    tensor_sort = tl.sort(tensor, axis=0, descending=True)

    to_change = tl.sum(tl.where(
        tensor_sort > (tl.cumsum(tensor_sort, axis=0) - parameter), 1.0, 0.0),
                       axis=0)
    difference = tl.zeros(col)
    for i in range(col):
        if to_change[i] > 0:
            difference = tl.index_update(
                difference, tl.index[i],
                tl.cumsum(tensor_sort, axis=0)[int(to_change[i] - 1), i])
    difference = (difference - parameter) / to_change
    return tl.clip(tensor - difference, a_min=0)
Esempio n. 2
0
def sparsify_tensor(tensor, card):
    """Zeros out all elements in the `tensor` except `card` elements with maximum absolute values. 
    
    Parameters
    ----------
    tensor : ndarray
    card : int
        Desired number of non-zero elements in the `tensor`
        
    Returns
    -------
    ndarray of shape tensor.shape
    """
    if card >= np.prod(tensor.shape):
        return tensor
    bound = tl.sort(tl.abs(tensor), axis = None)[-card]
    
    return tl.where(tl.abs(tensor) < bound, tl.zeros(tensor.shape, **tl.context(tensor)), tensor)