def obj_fun(x): return 0.5 * np.linalg.norm(b - A.dot(x)) ** 2 / A.shape[0] + 0.5 * l2_reg * x.dot(x) def grad(x): return - A.T.dot(b - A.dot(x)) / A.shape[0] + l2_reg * x f, ax = plt.subplots(2, 3, sharey=False) all_alphas = [1e-6, 1e-3, 1e-1] xlim = [0.02, 0.02, 0.1] for i, alpha in enumerate(all_alphas): max_iter = 5000 trace_three = Trace(lambda x: obj_fun(x) + alpha * TV(x)) out_tos = three_split( obj_fun, grad, prox_tv1d_rows, prox_tv1d_cols, np.zeros(n_features), alpha=alpha, beta=alpha, g_prox_args=(n_rows, n_cols), h_prox_args=(n_rows, n_cols), callback=trace_three, max_iter=max_iter, tol=1e-16) trace_gd = Trace(lambda x: obj_fun(x) + alpha * TV(x)) out_gd = proximal_gradient( obj_fun, grad, prox_tv2d, np.zeros(n_features), alpha=alpha, g_prox_args=(n_rows, n_cols, 1000, 1e-1), max_iter=max_iter, callback=trace_gd) ax[0, i].set_title(r'$\lambda=%s$' % alpha) ax[0, i].imshow(out_tos.x.reshape((n_rows, n_cols)), interpolation='nearest', cmap=plt.cm.Blues) ax[0, i].set_xticks(()) ax[0, i].set_yticks(()) fmin = min(np.min(trace_three.values), np.min(trace_gd.values))