def softmax(a, b, alpha=1, normalize=0):
    """The softmaximum of softmax(a,b) = log(e^a + a^b).
    normalize should be zero if a or b could be negative and can be 1.0 (more accurate)
    if a and b are strictly positive.
    Also called \alpha-quasimax: 
            J. Cook.  Basic properties of the soft maximum.  
            Working Paper Series 70, UT MD Anderson CancerCenter Department of Biostatistics, 
            2011. http://biostats.bepress.com/mdandersonbiostat/paper7
    """
    return np.log(np.exp(a * alpha) + np.exp(b * alpha) - normalize) / alpha
def vector_sort(matrices, X, key, alpha=1):
    """
    Sort a matrix X, applying a differentiable function "key" to each vector
    while sorting. Uses softmax to weight components of the matrix.
    
    For example, selecting the nth element of each vector by 
    multiplying with a one-hot vector.
    
    Parameters:
    ------------
        matrices:   the nxn bitonic sort matrices created by bitonic_matrices
        X:          an [n,d] matrix of elements
        key:        a function taking a d-element vector and returning a scalar
        alpha=1.0:  smoothing to apply; smaller alpha=smoother, less accurate sorting,
                    larger=harder max, increased numerical instability
        
    Returns:
    ----------
        X_sorted: [n,d] matrix (approximately) sorted accoring to 
        
    """
    for l, r, map_l, map_r in matrices:

        x = key(X)
        # compute weighting on the scalar function
        a, b = l @ x, r @ x
        a_weight = np.exp(a * alpha) / (np.exp(a * alpha) + np.exp(b * alpha))
        b_weight = 1 - a_weight
        # apply weighting to the full vectors
        aX = l @ X
        bX = r @ X
        w_max = (a_weight * aX.T + b_weight * bX.T).T
        w_min = (b_weight * aX.T + a_weight * bX.T).T
        # recombine into the full vector
        X = (map_l @ w_max) + (map_r @ w_min)
    return X
def comparison_sort(matrices, x, compare_fn, alpha=1, scale=250):
    """
    Sort a tensor X, applying a differentiable comparison function "compare_fn" 
    while sorting. Uses softmax to weight components of the matrix.
           
    Parameters:
    ------------
        matrices:   the nxn bitonic sort matrices created by bitonic_matrices
        X:          an [n,...] tensor of elements
        compare_fn: a differentiable comparison function compare_fn(a,b)
                    taking a pair of [n//2,...] tensors and returning a signed [n//2] vector.
        alpha=1.0:  smoothing to apply; smaller alpha=smoother, less accurate sorting,
                    larger=harder max, increased numerical instability
        scale=250:  scaling applied to output of compare_fn. Default is useful for 
                    comparison functions returning values in the range ~[-1, 1]
        
    Returns:
    ----------
        X_sorted: [n,...] tensor (approximately) sorted accoring to compare_fn
        
    """
    for l, r, map_l, map_r in matrices:
        score = compare_fn((x.T @ l.T).T, (x.T @ r.T).T)
        a, b = score * scale, score * -scale
        a_weight = np.exp(a * alpha) / (np.exp(a * alpha) + np.exp(b * alpha))
        b_weight = 1 - a_weight
        # apply weighting to the full vectors
        aX = x.T @ l.T
        bX = x.T @ r.T
        w_max = (a_weight * aX + b_weight * bX)
        w_min = (b_weight * aX + a_weight * bX)
        # recombine into the full vector
        x = (w_max @ map_l.T) + (w_min @ map_r.T)
        x = x.T

    return x
def softmax_smooth(a, b, smooth=0):
    """The smoothed softmaximum of softmax(a,b) = log(e^a + a^b).
    With smooth=0.0, is softmax; with smooth=1.0, averages a and b"""
    t = smooth / 2.0
    return np.log(np.exp((1 - t) * a + b * t) +
                  np.exp((1 - t) * b + t * a)) - np.log(1 + smooth)
def smoothmax(a, b, alpha=1):
    return (a * np.exp(a * alpha) +
            b * np.exp(b * alpha)) / (np.exp(a * alpha) + np.exp(b * alpha))
def order_matrix(original, sortd, sigma=0.1):
    """Apply a simple RBF kernel to the difference between original and sortd,
    with the kernel width set by sigma. Normalise each row to sum to 1.0."""
    diff = ((original).reshape(-1, 1) - sortd.reshape(1, -1))**2
    rbf = np.exp(-(diff) / (2 * sigma**2))
    return (rbf.T / np.sum(rbf, axis=1)).T
def softmax(a, b, alpha=1, normalize=0):
    """The softmaximum of softmax(a,b) = log(e^a + a^b).
    normalize should be zero if a or b could be negative and can be 1.0 (more accurate)
    if a and b are strictly positive.
    """
    return np.log(np.exp(a * alpha) + np.exp(b * alpha) - normalize) / alpha