# we know that || (muI + AtA)^-1 ||_2 is greater than about 1/mu
# --> || [(muI + AtA)^-1]y ||_2 > (1/mu)*||y||_2 = alpha
# --> so want to perhaps shrink-threshold values < alpha/Ny ??
# so set tau in the range of (k*alpha/n) * mu = k/n
# (for ||y||=1 and A.shape = (m,n))

mu = 20. 
tau = float(mu) / n

phi1 = lambda x: np.sum(np.abs(x))
phi_map1 = salsa.l1_proximity_map(tau, mu)
BtB_solve = trn.BtBpI_solver(mu=mu)

# quadratic + l1 regularize SALSA
x1, u1 = salsa.qreg_salsa(
    B, Bt, BtB_solve, tst_col, phi1, phi_map1, mu, n_iter=400,
    x0=x0, verbose=True, save_sol=True, rtol=1e-6
    )

# strict basis pursuit SALSA
# set mu such that the shrinkage threshold is about
# the ratio of class columns to total columns -- i.e., the level
# of a uniform distribution of coefficients.

BBt_solve = trn.BBt_solver()
mu = 100.0
phi_map2 = salsa.l1_proximity_map(1, mu)
x2, u2 = salsa.bp_salsa(
    B, Bt, BBt_solve, tst_col, phi1, phi_map2, n_iter=400, x0=x0,
    verbose=True, save_sol=True, rtol=1e-6
    )
# C-SALSA
# --> || [(muI + AtA)^-1]y ||_2 > (1/mu)*||y||_2 = alpha
# --> so want to perhaps shrink-threshold values < alpha/Ny ??
# so set tau in the range of (k*alpha/n) * mu = k/n
# (for ||y||=1 and A.shape = (m,n))

# want to enforce accuracy pretty heavily
mu = 1.
tau = float(mu) / n
tau = 1/10.
phi1 = lambda x: np.sum(np.abs(x))
phi_map1 = salsa.l1_proximity_map(tau, mu)
AtA_solve = trn.AtApI_solver(mu=mu)

# quadratic + l1 regularize SALSA
x1, u1 = salsa.qreg_salsa(
    A, At, AtA_solve, tst_col, phi1, phi_map1, mu,
    x0=mmse_x, rtol=1e-4, save_sol=True, n_iter=200
    )

# C-SALSA
# let phi_map be the same as above
# make a new solver for (BtB + I)
AtApI_solve = trn.AtApI_solver(mu=1)
mu2 = float(n) / len(cls_cols)
mu2 = 15.
phi_map2 = salsa.l1_proximity_map(1, n)
eps = 5e-3
x2, u2 = salsa.c_salsa(
    A, At, AtApI_solve, tst_col, eps, phi1, phi_map2,
    x0=mmse_x, rtol=1e-4, n_iter=400, save_sol=True
    )