Example #1
0
def test_dgapl21l1():
    """Test duality gap for L21 + L1 regularization."""
    n_orient = 2
    M, G, active_set = _generate_tf_data()
    n_times = M.shape[1]
    n_sources = G.shape[1]
    tstep, wsize = 4, 32
    n_steps = int(np.ceil(n_times / float(tstep)))
    n_freqs = wsize // 2 + 1
    n_coefs = n_steps * n_freqs
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freqs, n_steps, n_times)

    for l1_ratio in [0.05, 0.1]:
        alpha_max = norm_epsilon_inf(G, M, phi, l1_ratio, n_orient)
        alpha_space = (1. - l1_ratio) * alpha_max
        alpha_time = l1_ratio * alpha_max

        Z = np.zeros([n_sources, n_coefs])
        shape = (-1, n_steps, n_freqs)
        # for alpha = alpha_max, Z = 0 is the solution so the dgap is 0
        gap = dgap_l21l1(M, G, Z, np.ones(n_sources, dtype=bool), alpha_space,
                         alpha_time, phi, phiT, shape, n_orient, -np.inf)[0]

        assert_allclose(0., gap)
        # check that solution for alpha smaller than alpha_max is non 0:
        X_hat_tf, active_set_hat_tf, E, gap = tf_mixed_norm_solver(
            M,
            G,
            alpha_space / 1.01,
            alpha_time / 1.01,
            maxit=200,
            tol=1e-8,
            verbose=True,
            debias=False,
            n_orient=n_orient,
            tstep=tstep,
            wsize=wsize,
            return_gap=True)
        # allow possible small numerical errors (negative gap)
        assert_array_less(-1e-10, gap)
        assert_array_less(gap, 1e-8)
        assert_array_less(1, len(active_set_hat_tf))

        X_hat_tf, active_set_hat_tf, E, gap = tf_mixed_norm_solver(
            M,
            G,
            alpha_space / 5.,
            alpha_time / 5.,
            maxit=200,
            tol=1e-8,
            verbose=True,
            debias=False,
            n_orient=n_orient,
            tstep=tstep,
            wsize=wsize,
            return_gap=True)
        assert_array_less(-1e-10, gap)
        assert_array_less(gap, 1e-8)
        assert_array_less(1, len(active_set_hat_tf))
Example #2
0
def test_norm_epsilon():
    """Test computation of espilon norm on TF coefficients."""
    tstep = np.array([2])
    wsize = np.array([4])
    n_times = 10
    n_steps = np.ceil(n_times / tstep.astype(float)).astype(int)
    n_freqs = wsize // 2 + 1
    n_coefs = n_steps * n_freqs
    phi = _Phi(wsize, tstep, n_coefs, n_times)

    Y = np.zeros(n_steps * n_freqs)
    l1_ratio = 0.03
    assert_allclose(norm_epsilon(Y, l1_ratio, phi), 0.)

    Y[0] = 2.
    assert_allclose(norm_epsilon(Y, l1_ratio, phi), np.max(Y))

    l1_ratio = 1.
    assert_allclose(norm_epsilon(Y, l1_ratio, phi), np.max(Y))
    # dummy value without random:
    Y = np.arange(n_steps * n_freqs).reshape(-1, )
    l1_ratio = 0.0
    assert_allclose(
        norm_epsilon(Y, l1_ratio, phi)**2,
        stft_norm2(Y.reshape(-1, n_freqs[0], n_steps[0])))

    l1_ratio = 0.03
    # test that vanilla epsilon norm = weights equal to 1
    w_time = np.ones(n_coefs[0])
    Y = np.abs(np.random.randn(n_coefs[0]))
    assert_allclose(norm_epsilon(Y, l1_ratio, phi),
                    norm_epsilon(Y, l1_ratio, phi, w_time=w_time))

    # scaling w_time and w_space by the same amount should divide
    # epsilon norm by the same amount
    Y = np.arange(n_coefs) + 1
    mult = 2.
    assert_allclose(
        norm_epsilon(Y, l1_ratio, phi, w_space=1, w_time=np.ones(n_coefs)) /
        mult,
        norm_epsilon(Y,
                     l1_ratio,
                     phi,
                     w_space=mult,
                     w_time=mult * np.ones(n_coefs)))
Example #3
0
def test_dgapl21l1():
    """Test duality gap for L21 + L1 regularization."""
    n_orient = 2
    M, G, active_set = _generate_tf_data()
    n_times = M.shape[1]
    n_sources = G.shape[1]
    tstep, wsize = np.array([4, 2]), np.array([64, 16])
    n_steps = np.ceil(n_times / tstep.astype(float)).astype(int)
    n_freqs = wsize // 2 + 1
    n_coefs = n_steps * n_freqs
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freqs, n_steps, n_times)

    for l1_ratio in [0.05, 0.1]:
        alpha_max = norm_epsilon_inf(G, M, phi, l1_ratio, n_orient)
        alpha_space = (1. - l1_ratio) * alpha_max
        alpha_time = l1_ratio * alpha_max

        Z = np.zeros([n_sources, phi.n_coefs.sum()])
        # for alpha = alpha_max, Z = 0 is the solution so the dgap is 0
        gap = dgap_l21l1(M, G, Z, np.ones(n_sources, dtype=bool),
                         alpha_space, alpha_time, phi, phiT,
                         n_orient, -np.inf)[0]

        assert_allclose(0., gap)
        # check that solution for alpha smaller than alpha_max is non 0:
        X_hat_tf, active_set_hat_tf, E, gap = tf_mixed_norm_solver(
            M, G, alpha_space / 1.01, alpha_time / 1.01, maxit=200, tol=1e-8,
            verbose=True, debias=False, n_orient=n_orient, tstep=tstep,
            wsize=wsize, return_gap=True)
        # allow possible small numerical errors (negative gap)
        assert_array_less(-1e-10, gap)
        assert_array_less(gap, 1e-8)
        assert_array_less(1, len(active_set_hat_tf))

        X_hat_tf, active_set_hat_tf, E, gap = tf_mixed_norm_solver(
            M, G, alpha_space / 5., alpha_time / 5., maxit=200, tol=1e-8,
            verbose=True, debias=False, n_orient=n_orient, tstep=tstep,
            wsize=wsize, return_gap=True)
        assert_array_less(-1e-10, gap)
        assert_array_less(gap, 1e-8)
        assert_array_less(1, len(active_set_hat_tf))
Example #4
0
def test_norm_epsilon():
    """Test computation of espilon norm on TF coefficients."""
    tstep = np.array([2])
    wsize = np.array([4])
    n_times = 10
    n_steps = np.ceil(n_times / tstep.astype(float)).astype(int)
    n_freqs = wsize // 2 + 1
    n_coefs = n_steps * n_freqs
    phi = _Phi(wsize, tstep, n_coefs)
    Y = np.zeros(n_steps * n_freqs)
    l1_ratio = 0.5
    assert_allclose(norm_epsilon(Y, l1_ratio, phi), 0.)

    Y[0] = 2.
    assert_allclose(norm_epsilon(Y, l1_ratio, phi), np.max(Y))

    l1_ratio = 1.
    assert_allclose(norm_epsilon(Y, l1_ratio, phi), np.max(Y))
    # dummy value without random:
    Y = np.arange(n_steps * n_freqs).reshape(-1, )
    l1_ratio = 0.
    assert_allclose(norm_epsilon(Y, l1_ratio, phi) ** 2,
                    stft_norm2(Y.reshape(-1, n_freqs[0], n_steps[0])))
Example #5
0
def test_norm_epsilon():
    """Test computation of espilon norm on TF coefficients."""
    tstep = np.array([2])
    wsize = np.array([4])
    n_times = 10
    n_steps = np.ceil(n_times / tstep.astype(float)).astype(int)
    n_freqs = wsize // 2 + 1
    n_coefs = n_steps * n_freqs
    phi = _Phi(wsize, tstep, n_coefs)
    Y = np.zeros(n_steps * n_freqs)
    l1_ratio = 0.5
    assert_allclose(norm_epsilon(Y, l1_ratio, phi), 0.)

    Y[0] = 2.
    assert_allclose(norm_epsilon(Y, l1_ratio, phi), np.max(Y))

    l1_ratio = 1.
    assert_allclose(norm_epsilon(Y, l1_ratio, phi), np.max(Y))
    # dummy value without random:
    Y = np.arange(n_steps * n_freqs).reshape(-1, )
    l1_ratio = 0.
    assert_allclose(norm_epsilon(Y, l1_ratio, phi) ** 2,
                    stft_norm2(Y.reshape(-1, n_freqs[0], n_steps[0])))
snr = Simu_obj['snr']
X = Simu_obj['X']
coef = Simu_obj['coef']
Z_true = Simu_obj['Z_true']


n_channels, n_times = evoked_list[0].data.shape
n_trials = len(evoked_list)

wsize, tstep = 16,4
n_step = int(np.ceil(n_times/float(tstep)))
n_freq = wsize// 2+1
n_coefs = n_step*n_freq
p = 2
n_coefs_all = n_coefs*p
phi = _Phi(wsize, tstep, n_coefs)
phiT = _PhiT(tstep, n_freq, n_step, n_times)

G = fwd['sol']['data']
n_dipoles = G.shape[1]
DipoleGroup = list([label_ind[0], label_ind[1]])
non_label_ind = np.arange(0,n_dipoles,1)
for i in range(len(label_ind)):
    non_label_ind = np.setdiff1d(non_label_ind, label_ind[i])

for i in range(len(non_label_ind)):
    DipoleGroup.append(np.array([non_label_ind[i]]))
    
true_active_set = np.union1d(label_ind[0], label_ind[1])
DipoleGroupWeight = np.ones(len(DipoleGroup))/np.float(len(DipoleGroup))
Example #7
0
def solve_stft_regression_L2_tsparse(M,G_list, G_ind, X, Z0, 
                                    active_set_z0, active_t_ind_z0,
                                    coef_non_zero_mat,
                                wsize=16, tstep = 4, delta = 0,
                                maxit=200, tol = 1e-3,lipschitz_constant = None,
                                Flag_backtrack = True, eta = 1.5, L0 = 1.0,
                                Flag_verbose = False):                             
    """
    Use the accelerated gradient descent (exactly FISTA without non-smooth penalty)
        to find the solution given an active set
        min 1/2||R||_F^2 + delta ||Z||_F^2
    Input:
       M, [n_channels, n_times, n_trials] array of the sensor data
       G_list, a list of [n_channels, n_dipoles] forward gain matrix
       G_ind, [n_trial], marks the index of G for this run
       X, [n_trials, p],the design matrix, it must include an all 1 colume
       Z0, [n_active_dipoles, p*n_freq * n_active_step]
       acitve_set_z0, [n_dipoles,] a boolean array, indicating the active set of dipoles
       active_t_ind_z0, [n_step, ] a boolean array, indicating the active set of time points,
           the union of all frequencies, columns of X, and dipoles
       coef_non_zero_mat,[n_active_dipoles,n_coef*pq]  boolean matrix,
           since some active_set_z and active_t_ind_z is a super set of the active set,
           but we assume that it is the UNION of each coefficient of X and all trials
       wsize, window size of the STFT
       tstep, length of the time step
       delta, the regularization parameter
       maxit, maximum number of iteration allowed
       tol, tolerance of the objective function
       lipschitz_constant, the lipschitz constant
       
       No Flag trial by trial is allowed in this version
    """
    n_sensors, n_times, q= M.shape
    n_dipoles = G_list[0].shape[1]
    p = X.shape[1]
    n_step = int(np.ceil(n_times/float(tstep)))
    n_freq = wsize// 2+1
    n_coefs = n_step*n_freq

    #coef_non_zero_mat_full = np.tile(coef_non_zero_mat,[1,pq])
    coef_non_zero_mat_full = coef_non_zero_mat
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freq, n_step, n_times)
    sparse_phi = sparse_Phi(wsize, tstep)
    sparse_phiT = sparse_PhiT(tstep, n_freq, n_step, n_times)
    
    if (active_t_ind_z0.sum()*n_freq*p != Z0.shape[1] or n_dipoles!= len(active_set_z0)):
        print active_t_ind_z0.sum()*n_freq*p, Z0.shape[1],n_dipoles,len(active_set_z0)
        raise ValueError("wrong number of dipoles or coefs")

    if lipschitz_constant is None and not Flag_backtrack: 
        lipschitz_constant = 1.1* (get_lipschitz_const(M,G_list[0],X,phi,phiT,
                                    Flag_trial_by_trial = False,n_coefs = n_coefs,
                                    tol = 1e-3) +2.0*delta)
        print "lipschitz_constant = %e" % lipschitz_constant
    if Flag_backtrack:
        L = L0 
    else:
        L = lipschitz_constant
                                       
    #initialization
    active_t_ind_z = active_t_ind_z0.copy()
    active_set_z = active_set_z0.copy()
    Z = Z0.copy()
    Z[coef_non_zero_mat_full==0] = 0
    Y = Z0.copy()    
    Y[coef_non_zero_mat_full==0] = 0
    n_active_dipole = active_set_z.sum()
    # number of coeficients s
    n_coefs_z = n_freq*active_t_ind_z.sum()
    n_coefs_all_active = p*n_coefs_z
    #==== the main loop =====  
    tau, tau0 = 1.0, 1.0
    obj = np.inf
    old_obj = np.inf
    
    # prepare the M_list from the M
    n_run = len(np.unique(G_ind))
    M_list = list()
    X_list = list()
    for run_id in range(n_run):
        M_list.append(M[:,:,G_ind == run_id])
        X_list.append(X[G_ind == run_id, :])
           
    # greadient part one is fixed   -G^T( \sum_r M(r)) PhiT
    gradient_y0 = get_gradient0_tsparse(M_list, G_list, X_list, p, n_run,
                  n_active_dipole, active_set_z, n_times,
                  n_coefs_z, n_coefs_all_active, 
                  active_t_ind_z,
                  sparse_phi, sparse_phiT)
        
    #  iterations, we only need to compute the second part of gradient 
        #  +G^T G(\sum_r X_k^(r) \sum_k Z_k  X_k^r) PhiT Phi
    for i in range(maxit):
        Z0 = Z.copy()
        gradient_y1 = get_gradient1_tsparse(M_list, G_list, X_list, Y, p, n_run,
                  n_active_dipole, active_set_z, n_times,
                  n_coefs_z, n_coefs_all_active, 
                  active_t_ind_z,
                  sparse_phi, sparse_phiT)
        gradient_y = gradient_y0 + gradient_y1
        # compare the gradient(tested)
#        if False: 
#            gradient_y2 = np.zeros([n_active_dipole, n_coefs_all_active], dtype =np.complex)
#            R_all_sq = np.zeros([n_sensors, n_times])
#            for r in range(q):
#                # tmp_coef = y(0) + \sum_k y(k)* X(r,k)
#                tmp_coef = np.zeros([n_active_dipole,n_coefs_z], dtype = np.complex)
#                for k in range(p):
#                    tmp_coef += Y[:,k*n_coefs_z:(k+1)*n_coefs_z]*X[r,k]
#                # current residual for this trial            
#                tmpR = np.real(M[:,:,r] - G_list[G_ind[r]][:,active_set_z].dot(sparse_phiT(tmp_coef, active_t_ind_z)))           
#                R_all_sq += np.abs(tmpR)**2
#                tmpA = G_list[G_ind[r]][:,active_set_z].T.dot(sparse_phi(tmpR, active_t_ind_z))
#                for k in range(p):
#                    gradient_y2[:,k*n_coefs_z:(k+1)*n_coefs_z] += - X[r,k]*tmpA 
#            print "obj = %e" %(0.5* R_all_sq.sum() + delta*np.sum(np.abs(Y)**2))
#            #plt.plot(np.real(gradient_y2.ravel()), np.real(gradient_y.ravel()), '.')
#            #plt.plot(np.real(gradient_y2.ravel()), np.real(gradient_y2.ravel()), 'r')
#            print np.linalg.norm(gradient_y2-gradient_y)/np.linalg.norm(gradient_y2)
#            #gradient_y = gradient_y2.copy()
        # the L2 pentalty
        gradient_y += 2*delta*Y
        # compute the variable  y- 1/L gradient y
        gradient_y[coef_non_zero_mat_full==0] =0
        
        # follow the criterion in FISTA f(Y - gradient Y/L) > f(Y)+(0.5/L**2 -1/L ) ||gradient Y||^2
        # this is different from the general case f(Y - gradient Y/L) > f(Y)-0.5/L ||gradient Y||^2
        Z = Y- (gradient_y/L)
        objz = f_l2(Z, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT, delta)

        if Flag_backtrack:
            objy = f_l2( Y, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT,delta)
            # Ryan's slides
            # https://www.cs.cmu.edu/~ggordon/10725-F12/slides/05-gd-revisited.pdf
            # page 10
            diff_bt = objz-objy+ (0.5/L)* np.sum( np.abs(gradient_y)**2)
            while diff_bt > 0:
                L = L*eta
                Z = Y-(gradient_y/L)
                objz = f_l2( Z, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT, delta)
                diff_bt = objz-objy + (0.5/L)* np.sum( np.abs(gradient_y)**2)
                
        ## ==== FISTA step, update the variables =====
        tau0 = tau;
        tau = 0.5*(1+ np.sqrt(4*tau**2+1))
        diff = Z-Z0
        Y = Z + (tau0 - 1.0)/tau* diff
        ## ===== compute objective function, check stopping criteria ====
        old_obj = obj
        obj = np.linalg.norm(gradient_y)
        diff_obj = old_obj-obj
        if Flag_verbose:
            print "\n iteration %d" % i
            print "sum sq gradient = %f" %obj
            print "diff_obj = %e" % (diff_obj)
            print "diff = %e" %(np.abs(diff).sum()/np.sum(abs(Y))) 
            print "obj = %f" % objz
        stop = np.abs(diff).sum()/np.sum(abs(Y)) < tol 
        if stop:
            print "convergence reached!"
            break           
    Z = Y.copy() 
    Z[coef_non_zero_mat_full ==0] =0
    return Z, objz
Example #8
0
def select_delta_stft_regression_cv(M,G_list, G_ind,X,Z00,
                                    active_set_z0, active_t_ind_z0,
                                    coef_non_zero_mat,
                                    delta_seq,cv_partition_ind,
                                    wsize=16, tstep = 4, 
                                    maxit=200, tol = 1e-3,
                                    Flag_backtrack = True, L0 = 1.0, eta = 1.5,
                                    Flag_verbose = False): 
    ''' Find the best L2 regularization parameter delta by cross validation
        Note that here, in training, the trial by trial paramter is estimated, 
        but in testing, only the regression coefficients were used. 

    Input:
        M, [n_channels, n_times, n_trials] array of the sensor data
        G_list, a list of [n_channels, n_dipoles] forward gain matrix
        G_ind, [n_trial], marks the index of G for this run
        X, [n_trials, p],the design matrix, it must include an all 1 colume
        Z00, [n_active_dipoles, p*n_freq * n_active_step]
        e.g 
            # initial value
                Z00 = (np.random.randn(n_true_dipoles, n_coefs_all_active) \
                + np.random.randn(n_true_dipoles, n_coefs_all_active)*1j)*1E-15  
                n_coefs_all_active = active_t_ind_z0.sum()*n_freq*pq
                
        acitve_set_z0, [n_dipoles,] a boolean array, indicating the active set of dipoles
        active_t_ind_z0, [n_step, ] a boolean array, indicating the active set of time points,
            the union of all frequencies, columns of X, and dipoles
        coef_non_zero_mat,[n_active_dipoles,n_coefs*pq]  boolean matrix,
            since some active_set_z and active_t_ind_z is a super set of the active set,
            but we assume that it is the UNION of each coefficient of X and all trials
        delta_seq, np.array, a sequence of delta to be tested
        cv_parition_ind, an integer array of which cross-validation group
                    each trial is in
        wsize, window size of the STFT
        tstep, length of the time step
        maxit, maximum number of iteration
        tol, tolerance,
    Output:
       delta_star, the best delta  
       cv_MSE, the cv MSE for all elements in delta_seq
    '''
    n_sensors, n_times, n_trials = M.shape
    n_dipoles = G_list[0].shape[1]
    if len(active_set_z0) != n_dipoles:
        raise ValueError("the number of dipoles does not match") 
    p = X.shape[1]
    n_step = int(np.ceil(n_times/float(tstep)))
    n_freq = wsize// 2+1
    n_coefs = n_step*n_freq
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freq, n_step, n_times)
    n_fold = len(np.unique(cv_partition_ind))
    n_delta = len(delta_seq)
    # n_true_dipoles = np.sum(active_set_z0)
    lipschitz_constant0 = get_lipschitz_const(M,G_list[0],X,phi,phiT,n_coefs,
                          tol = 1e-3,Flag_trial_by_trial = False)
    cv_MSE = np.zeros([n_fold, n_delta])
    
    for j in range(n_fold):
        # partition
        test_trials = np.nonzero(cv_partition_ind == j)[0]
        train_trials = np.nonzero(cv_partition_ind != j)[0]
        Z0 = Z00
        tmp_coef_non_zero_mat =  coef_non_zero_mat
     
        Mtrain = M[:,:,train_trials]
        Xtrain = X[train_trials,:]
        Mtest = M[:,:,test_trials]
        Xtest = X[test_trials,:] 
        G_ind_train = G_ind[train_trials]
        G_ind_test = G_ind[test_trials]
        for i in range(n_delta):
            tmp_delta = delta_seq[i]
            # lipschitz constant
            L = (lipschitz_constant0+ 2*tmp_delta)*1.1
            # training  
            Z, _ = solve_stft_regression_L2_tsparse (Mtrain,G_list,
                                                     G_ind_train, Xtrain, Z0,
                        active_set_z0, active_t_ind_z0, tmp_coef_non_zero_mat,
                        wsize=wsize, tstep =tstep,delta = tmp_delta,
                        maxit=maxit,tol = tol,lipschitz_constant =L,
                        Flag_backtrack = Flag_backtrack, L0 = L0, eta = eta,
                        Flag_verbose = Flag_verbose)
            # only take the regression coefficients out
            Z_star = Z[:,0:p*active_t_ind_z0.sum()*n_freq]
            Z0 = Z.copy()
            # testing
            tmp_val, _,_,_ = get_MSE_stft_regresion_tsparse(Mtest,G_list,
                                           G_ind_test,Xtest,
                                 Z_star, active_set_z0, active_t_ind_z0,
                                 wsize=wsize, tstep = tstep)
            cv_MSE[j,i] =  tmp_val
            # debug
            #import matplotlib.pyplot as plt
            #plt.figure()
            #plt.plot(np.real(Z_star).T)
            #plt.title(tmp_delta)
            
    cv_MSE = np.mean(cv_MSE, axis = 0)
    best_ind = np.argmin(cv_MSE)
    delta_star = delta_seq[best_ind]
    return delta_star, cv_MSE
          
Example #9
0
def solve_stft_regression_tree_group(M,G_list, G_ind,X,
                                alpha,beta, gamma, 
                                DipoleGroup,DipoleGroupWeight,
                                Z_ini, active_set_z_ini, 
                                n_orient=1, wsize=16, tstep = 4,
                                maxit=200, tol = 1e-3,lipschitz_constant = None,
                                Flag_backtrack = True, eta = 1.5, L0 = 1.0,
                                Flag_verbose = False):                             
    """    
    Input:
       M, [n_sensors,n_times,n_trials] array of the sensor data
       G_list,  a list of [n_sensors, n_dipoles] forward gain matrix
       G_ind, [n_trials], index of run number (which G to use)
       X, [n_trials, p]the design matrix, it must include an all 1 colume
       alpha, tuning parameter for the regularization
       beta, the tuning paramter of balance between single frequency- time basis and dipoles
       gamma, the penalty on the absolute value of entries in Z
       DipoleGroup, grouping of dipoles,
                    the dipoles in the same ROI are in the same group,
                    the dipoles outside ROIs form one-dipole groups
       DipleGroupWeight, weights of each dipole group (ROI)
       Z_ini, [n_dipoles, n_coefs*p], initial value of Z, the ravel order is [n_dioles, p, n_freqs, n_step]
       active_set_z_ini, [n_dipoles,]  boolean, active_set of dipoles
       n_orient, number of orientations for each dipole
                 note that if n_orient == 3, the DipoleGroup is still on the columns of G, 
                 but grouped in the 3-set way. 
                 The initial values and initial active set of Z should also be correspondingly correct. 
       wsize, number of frequence
       tstep, step in time of the stft
       maxit, maximum number of iteration allowed
       tol, tolerance of the objective function
       lipschitz_constant, the lipschitz constant,
       Flag_backtrack, if True, use backtracking instead of constant stepsize ( the lipschitz constant )
       eta, L0, the shrinking parameters and initial 1/stepsize
       Flag_verbose, if true, print the objective values and difference, else not
    Output:
        Z, the solustion, only rows in the active set, 
             note that it is not guaranteed that all rows from the same group will be in the final solutions,
             if all of the coefficients for that row is zero, it is dropped too. 
        active_set_z,  a boolean vector, active dipoles (rows)
        active_t_ind_z, a boolean vector, active time steps, (time steps)
        obj, the objective function   
    """
    # check the active_set_structure and group structure for n_orient ==3
    if n_orient == 3:
        active_set_mat = active_set_z_ini.copy()
        active_set_mat = active_set_mat.reshape([-1,n_orient])
        any_ind = np.any(active_set_mat,axis =1)
        all_ind = np.all(active_set_mat,axis =1)
        if np.sum(np.abs(any_ind-all_ind)) >0:
            raise ValueError("wrong active set for n_orient = 3")
        # DipoleGroup must also satisfy the structure
        for l in range(len(DipoleGroup)):
            if np.remainder(len(DipoleGroup[l]),n_orient)!=0:
                raise ValueError("wrong group")
            tmp_mat = np.reshape(DipoleGroup[l],[-1,n_orient],order = 'C')
            if np.sum(np.abs(tmp_mat[:,2] - tmp_mat[:,1] - 1)) != 0 \
                    or np.sum(np.abs(tmp_mat[:,1] - tmp_mat[:,0] - 1)) != 0:
                raise ValueError("wrong group")
                
    n_sensors, n_times, q = M.shape
    n_dipoles = G_list[0].shape[1]
    p = X.shape[1]
    n_step = int(np.ceil(n_times/float(tstep)))
    n_freq = wsize// 2+1
    n_coefs = n_step*n_freq

    # create the sparse and non sparse version of the STFT, iSTFT
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freq, n_step, n_times)
    # initialization    
    if lipschitz_constant is None and not Flag_backtrack: 
        lipschitz_constant = 1.1* get_lipschitz_const(M,G_list[0],X,phi,phiT,
                                    Flag_trial_by_trial = False,n_coefs = n_coefs,
                                    tol = 1e-3)
        print "lipschitz_constant = %e" % lipschitz_constant
        
    if Flag_backtrack:
        L = L0   
    else:
        L = lipschitz_constant
    # indices for the active set is only for rows
    Z = Z_ini.copy()
    Y = Z_ini.copy()
    active_set_z = active_set_z_ini.copy()
    active_set_y = active_set_z_ini.copy() 
    
    if Z.shape[0] != active_set_z.sum() or Z.shape[1] != n_coefs*p:
        raise ValueError('Z0 shape does not match active sets')
    #==== the main loop =====  
    tau, tau0 = 1.0, 1.0
    obj = np.inf
    old_obj = np.inf
    
    # prepare the M_list from the M
    n_run = len(np.unique(G_ind))
    M_list = list()
    X_list = list()
    for run_id in range(n_run):
        M_list.append(M[:,:,G_ind == run_id])
        X_list.append(X[G_ind == run_id, :])
           
    # greadient part one is fixed   -G^T( \sum_r M(r)) PhiT
    # this should be on all dipoles
    gradient_y0 = get_gradient0(M_list, G_list, X_list, p, n_run,
                  n_dipoles, np.ones(n_dipoles, dtype = np.bool), n_times,
                  n_coefs, n_coefs*p, phi, phiT)
    # only keep on full matrix
    #gradient_y = np.zeros([n_dipoles, n_coefs*p],dtype =np.complex)
    for i in range(maxit):
        Z0 = Z.copy()
        active_set_z0 = active_set_z.copy()
        # this part can be only on the active set
        # but gradient_y1 should be a full matrix
        gradient_y1 = get_gradient1_L21(M_list, G_list, X_list, Y, p, n_run,
                  np.int(active_set_y.sum()), active_set_y, n_times,
                  n_coefs, n_coefs*p, phi, phiT)
        gradient_y = gradient_y0 + gradient_y1
  
        # active_set_z0, active rows/dipoles for z0
        # active_t_ind_z0, active columns/time_points for z0
        
        tmp_Y_L_gradient_Y = -gradient_y/L
        tmp_Y_L_gradient_Y[active_set_y,:] += Y
        ## ==== ISTA step, get the proximal operator ====
        ## Input must be a full matrix, so the active set is set to full
        Z, active_set_z = prox_tree_hard_coded_full_matrix(tmp_Y_L_gradient_Y, n_coefs, p, 
                                 alpha/L, beta/L, gamma/L,
                                 DipoleGroup, DipoleGroupWeight, n_orient)
        # check if it is zero solution                       
        if not np.any(active_set_z):
            print "active_set = 0"
            return 0,0,0,0                 
        objz = f( Z, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT)
        # compute Z-Y
        diff_z, active_set_diff_z = _add_z([Z,-Y],  np.vstack([active_set_z, active_set_y]))
        
        if Flag_backtrack:
            objy = f( Y, active_set_y, M, G_list, G_ind, X, n_coefs, q, p, phiT)
            # compute the criterion for back track: f(z)-f(y) -grad_y.dot(z-y) +0.5*(z-y)**2
            # note gradient_y is a full matrix
            diff_bt = objz-objy- np.sum( np.real(gradient_y[active_set_diff_z])* np.real(diff_z))\
                      - np.sum(np.imag(gradient_y[active_set_diff_z])* np.imag(diff_z)) \
                      -0.5*L*np.sum( np.abs(diff_z)**2)   
            while diff_bt > 0:
                L = L*eta
                tmp_Y_L_gradient_Y = -gradient_y/L
                tmp_Y_L_gradient_Y[active_set_y,:] += Y
                Z, active_set_z = prox_tree_hard_coded_full_matrix(tmp_Y_L_gradient_Y, n_coefs, p, 
                                         alpha/L, beta/L, gamma/L,
                                         DipoleGroup, DipoleGroupWeight, n_orient)
                if not np.any(active_set_z):
                    print "active_set = 0"
                    return 0,0,0                        
                objz = f( Z, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT)
                # Z-Y
                diff_z, active_set_diff_z = _add_z([Z,-Y],  np.vstack([active_set_z, active_set_y]))
                # the criterion for back track: 
                diff_bt = objz-objy- np.sum( np.real(gradient_y[active_set_diff_z])* np.real(diff_z))\
                      - np.sum(np.imag(gradient_y[active_set_diff_z])* np.imag(diff_z)) \
                      -0.5*L*np.sum( np.abs(diff_z)**2)   
                               
        ## ==== FISTA step, update the variables =====
        tau0 = tau;
        tau = 0.5*(1+ np.sqrt(4*tau**2+1))
        
        if np.any(active_set_diff_z) == 0:
            print "active_set = 0"
            return 0,0,0
        else:
            # y <- z + (tau0-1)/tau (z-z0) =(tau+tau0-1)/tau z - (tao0-1)/tau z0
            Y, active_set_y = _add_z( [(tau+tau0-1)/tau*Z, -(tau0-1)/tau*Z0],
                                       np.vstack([active_set_z, active_set_z0]))
            
        ## ===== compute objective function, check stopping criteria ====
        old_obj = obj
        full_Z = np.zeros([n_dipoles, n_coefs*p], dtype = np.complex)
        full_Z[active_set_z, :] = Z
        
        obj = objz \
            + get_tree_norm_hard_coded(full_Z,n_coefs, p, 
              alpha, beta, gamma, DipoleGroup, DipoleGroupWeight, n_orient)
        diff_obj = old_obj-obj
        relative_diff = np.sum(np.abs(diff_z))/np.sum(np.abs(Z))
        if Flag_verbose: 
            print "\n iteration %d" % i
            print "diff_obj = %e" % (diff_obj/obj)
            print "obj = %e" %obj
            print "diff = %e" %relative_diff
        stop = ( relative_diff < tol  and np.abs(diff_obj/obj) < tol)        
        if stop:
            print "convergence reached!"
            break    
    Z = Y.copy() 
    active_set_z = active_set_y.copy()
    return Z, active_set_z, obj
active_set = active_set.astype(np.int)
n_valid_source = len(active_set)

# parameters for STFT-transform : 
# tstep_phi: time steps of STFT-R
# wsize_phi: window size of the STFT-R              
tstep_phi = 4
wsize_phi = 16  
# number of time steps
n_step = int(np.ceil(n_times/float(tstep_phi)))
# number of frequencies
n_freq = wsize_phi// 2+1
# n_coefs, total number of time-frequency components
n_tfs = n_step*n_freq   
#STFT and inverse STFT function from MNE-python
phi = _Phi(wsize_phi, tstep_phi, n_tfs)
phiT = _PhiT(tstep_phi, n_freq, n_step, n_times)                          
                         
#Z_true: regression coeffients in the time-frequency domain
coef_per_ROI = np.ones([n_ROI,n_freq,n_step, p]) 
Z_true = np.tile(phi(stc0_data),[1,p])
for i in range(len(labels)):
    Z_true[label_ind[i]] *= np.reshape(np.transpose(coef_per_ROI[i],[2,0,1]),[1,-1]) 

# create the sensor data and source data for each trial        
evoked_list = list()
stc_data_list = list()
label_all = labels[0]
for i in range(1,len(labels)):
    label_all += labels[i]
    
Example #11
0
def compute_dual_gap(M, G_list, G_ind, X, Z, active_set, 
              alpha, beta, gamma,
              DipoleGroup, DipoleGroupWeight, n_orient,
              wsize = 16, tstep = 4):
    """ 
    Compute the duality gap, and check the feasibility of the dual function.
    Input:
        M, G_list, G_ind, X, Z, active_set, all the full primal.
        alpha, beta, gamma, penalty paramters
        n_orient, number of orientations
    Output:
        dict(feasibility_dist = feasibility_dist,
                gradient = gradient,
                feasibility_dist_DipoleGroup = feasibility_dist_DipoleGroup )       
        feasibility_dist, is the major criterion, if it is small enougth, then accept the results
        gradient, the gradient
    """
    # initialization, only update when seeing greater values
    if gamma == 0:
        raise ValueError( "the third level must be penalized!!")
    
    n_dipoles = G_list[0].shape[1]
    n_sensors, n_times, q= M.shape
    n_step = int(np.ceil(n_times/float(tstep)))
    n_freq = wsize// 2+1
    n_coefs = n_step*n_freq
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freq, n_step, n_times)   
    n_trials, p = X.shape
               
    
    n_run = len(np.unique(G_ind))
    M_list = list()
    X_list = list()
    for run_id in range(n_run):
        M_list.append(M[:,:,G_ind == run_id])
        X_list.append(X[G_ind == run_id, :])
    
    # compute the gradient to check feasibility
    gradient_y0 = get_gradient0(M_list, G_list, X_list, p, n_run,
                  n_dipoles, np.ones(n_dipoles, dtype = np.bool), n_times,
                  n_coefs, n_coefs*p, phi, phiT)
    gradient_y1 = get_gradient1_L21(M_list, G_list, X_list, Z, p, n_run,
              np.int(active_set.sum()), active_set, n_times,
              n_coefs, n_coefs*p, phi,phiT)
    gradient = gradient_y0 + gradient_y1
  
    # sanity check
    # for each dipole, get the maximum abs value among the real and imag parts
    #max_grad = np.max(np.vstack([np.max(np.abs(np.real(gradient)), axis = 1), 
    #               np.max(np.abs(np.imag(gradient)),axis = 1)]), axis = 0)
                   
    alpha_weight = np.zeros(n_dipoles)
    for i in range(len(DipoleGroup)):
        alpha_weight[DipoleGroup[i]] = alpha*DipoleGroupWeight[i]

    # if max_grad is greater than alpha+beta+gamma, 
    # then the dual variable can not be feasible
 
    #if np.any(max_grad > alpha_weight+ beta+gamma):
    #     feasibility_dist = np.inf
    #else:
    #    
    # feasibility check. b = A^T u + \sum_g D_g^T v_g
    # where g is the non zero groups 
    b = gradient.copy()
    active_set_ind = np.nonzero(active_set)[0]
    Z_full = np.zeros([n_dipoles, Z.shape[1]],dtype = np.complex)
    Z_full[active_set,:] = Z
    
    # add g in the alpha level, dipole groups 
    # a bool vector showing whether the group is in the active set
    DipoleGroup_active = np.zeros(len(DipoleGroup), dtype = np.bool)
    for i in range(len(DipoleGroup)):
        if np.intersect1d(DipoleGroup[i],active_set_ind).size >0:
            DipoleGroup_active[i] = True
            # l2 norm of the group
            l2_norm_alpha = np.sqrt(np.sum( ( np.abs(Z_full[DipoleGroup[i],:]) )**2) )
            # add the sum of the gradient
            if l2_norm_alpha == 0:
                raise ValueError("all zero in an active group!")
            b[DipoleGroup[i],:]+= Z_full[DipoleGroup[i],:]/l2_norm_alpha * alpha_weight[i]
    
    # add g in the beta level, same dipole, same stft coef
    if n_orient == 1:
        Z_reshape = np.reshape(Z, [active_set.sum(), p, -1])
        # active_set.sum() x  n_coefs
        l2_norm_beta = np.sqrt(np.sum( (np.abs(Z_reshape))**2, axis = 1))
        # active_set.sum() x  n_coefs * p
        l2_norm_beta_large = np.tile(l2_norm_beta, [1, p])                
       
    else: # n_orient == 3
        Z_reshape = np.reshape(Z, [active_set.sum()//3,3, p, -1])
        l2_norm_beta = np.sqrt( np.sum(   np.sum( ( np.abs(Z_reshape) )**2, axis = 2) , axis = 1 ) )
        l2_norm_beta_large = np.reshape(  np.tile(l2_norm_beta,[ 1, 3*p]) , Z.shape)

    tmp_add_to_b = np.zeros([active_set.sum(), Z.shape[1]], dtype = np.complex)
    nonzero_beta = l2_norm_beta_large > 0
    tmp_add_to_b[nonzero_beta] = Z[nonzero_beta]/l2_norm_beta_large[nonzero_beta]*beta
    b[active_set,:] += tmp_add_to_b
    
    # add g in the gamma level, each element of the matrix
    if n_orient == 1:
        l2_norm_gamma = np.abs(Z)   
    else: # n_orient == 3:
        Z_reshape = np.reshape(Z, [active_set.sum()//3,3,-1])
        l2_norm_gamma = np.sqrt ( np.sum( ( np.abs(Z_reshape) )**2, axis = 1 ) )
        l2_norm_gamma = np.reshape( np.tile(l2_norm_gamma, [1,3* Z.shape[1]]), Z.shape)
        
    nonzero_gamma = l2_norm_gamma > 0
    tmp_add_to_b = np.zeros([active_set.sum(), Z.shape[1]], dtype = np.complex)
    tmp_add_to_b[nonzero_gamma] = Z[nonzero_gamma]/l2_norm_gamma[nonzero_gamma]*gamma
    b[active_set,:] += tmp_add_to_b 
    if np.any(np.isnan(b)):
            raise ValueError("nan found in b!")
    # use coordinate descent to solve the feasibility problem
    nonzero_Z = np.abs(Z) >0 
    feasibility_result = get_feasibility(b, active_set, DipoleGroup, 
                                 alpha_weight, beta, gamma,
                                 n_coefs, p,  nonzero_Z, 
                                 DipoleGroup_active) 
    feasibility_dist =  feasibility_result['feasibility_dist']
    # the dist in each dipole, squared
    feasibility_dist_in_alpha_level = feasibility_result['feasibility_dist_in_alpha_level'] 
    feasibility_dist_DipoleGroup = np.zeros(len(DipoleGroup))
    for i in range(len(DipoleGroup)):
        feasibility_dist_DipoleGroup[i] = \
              np.sqrt( np.sum(feasibility_dist_in_alpha_level[DipoleGroup[i]]**2))                        
    
         
    return dict(feasibility_dist = feasibility_dist,
                gradient = gradient, b = b,
                feasibility_dist_DipoleGroup = feasibility_dist_DipoleGroup)   
Example #12
0
def solve_stft_regression_tree_group_tsparse(M,G_list, G_ind,X,
                                alpha,beta, gamma, 
                                DipoleGroup,DipoleGroupWeight,
                                Z_ini, active_set_z_ini, 
                                n_orient=1, wsize=16, tstep = 4,
                                maxit=200, tol = 1e-3,lipschitz_constant = None,
                                Flag_backtrack = True, eta = 1.5, L0 = 1.0,
                                Flag_verbose = False):                             
    """    
    Input:
       M, [n_sensors,n_times,n_trials] array of the sensor data
       G_list,  a list of [n_sensors, n_dipoles] forward gain matrix, for different runs if applicable
       G_ind, [n_trials], index of run number (which G to use)
       X, [n_trials, p]the design matrix, it must include an all 1 colume
       alpha, tuning parameter for the regularization
       beta, the tuning paramter, for the frequency- time basis
       gamma, the penalty on the absolute value of entries in Z
       DipoleGroup, grouping of dipoles,
                    the dipoles in the same ROI are in the same group,
                    the dipoles outside ROIs form one-dipole groups
       DipleGroupWeight, weights of each dipole group (ROI)
       Z_ini, [n_dipoles, n_coefs*p], initial value of Z, the ravel order is [n_dioles, p, n_freqs, n_step]
       active_set_z_ini, [n_dipoles,]  boolean, active_set of dipoles
       n_orient, number of orientations for each dipole
                 note that if n_orient == 3, the DipoleGroup is still on the columns of G, 
                 but grouped in the 3-set way. 
                 The initial values and initial active set of Z should also be correspondingly correct. 
       wsize, number of frequence
       tstep, step in time of the stft
       maxit, maximum number of iteration allowed
       tol, tolerance of the objective function
       lipschitz_constant, the lipschitz constant,
       Flag_backtrack, if True, use backtracking instead of constant stepsize ( the lipschitz constant )
       eta, L0, the shrinking parameters and initial 1/stepsize
       Flag_verbose, if true, print the objective values and difference, else not
    Output:
        Z, the solustion, only rows in the active set, 
             note that it is not guaranteed that all rows from the same group will be in the final solutions,
             if all of the coefficients for that row is zero, it is dropped too. 
        active_set_z,  a boolean vector, active dipoles (rows)
        active_t_ind_z, a boolean vector, active time steps, (time steps)
        obj, the objective function   
    """
    # check the active_set_structure and group structure for n_orient ==3
    if n_orient == 3:
        active_set_mat = active_set_z_ini.copy()
        active_set_mat = active_set_mat.reshape([-1,n_orient])
        any_ind = np.any(active_set_mat,axis =1)
        all_ind = np.all(active_set_mat,axis =1)
        if np.sum(np.abs(any_ind-all_ind)) >0:
            raise ValueError("wrong active set for n_orient = 3")
        # DipoleGroup must also satisfy the structure
        for l in range(len(DipoleGroup)):
            if np.remainder(len(DipoleGroup[l]),n_orient)!=0:
                raise ValueError("wrong group")
            tmp_mat = np.reshape(DipoleGroup[l],[-1,n_orient],order = 'C')
            if np.sum(np.abs(tmp_mat[:,2] - tmp_mat[:,1] - 1)) != 0 \
                    or np.sum(np.abs(tmp_mat[:,1] - tmp_mat[:,0] - 1)) != 0:
                raise ValueError("wrong group")
                
    n_sensors, n_times, q = M.shape
    n_dipoles = G_list[0].shape[1]
    p = X.shape[1]
    n_step = int(np.ceil(n_times/float(tstep)))
    n_freq = wsize// 2+1
    n_coefs = n_step*n_freq

    # create the sparse and non sparse version of the STFT, iSTFT
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freq, n_step, n_times)
    sparse_phi = sparse_Phi(wsize, tstep)
    sparse_phiT = sparse_PhiT(tstep, n_freq, n_step, n_times)
    
    # initialization    
    if lipschitz_constant is None and not Flag_backtrack: 
        lipschitz_constant = 1.1* get_lipschitz_const(M,G_list[0],X,phi,phiT,
                                    Flag_trial_by_trial = False,n_coefs = n_coefs,
                                    tol = 1e-3)
        print "lipschitz_constant = %e" % lipschitz_constant
        
    if Flag_backtrack:
        L = L0   
    else:
        L = lipschitz_constant
    # indices for the active set is only for rows
    Z = Z_ini.copy()
    Y = Z_ini.copy()
    active_set_z = active_set_z_ini.copy()
    active_set_y = active_set_z_ini.copy() 
    
    # my code has sparse time steps. (sparse_phi and sparse_phiT)
    # for consistency, I still kept the active_t_ind variables, 
    # but make sure they are all true, all time steps are used. 
    active_t_ind_full = np.ones(n_step, dtype = np.bool)
    if Z.shape[0] != active_set_z.sum() or Z.shape[1] != n_coefs*p:
        raise ValueError('Z0 shape does not match active sets')
    #==== the main loop =====  
    tau, tau0 = 1.0, 1.0
    obj = np.inf
    old_obj = np.inf
    
    # prepare the M_list from the M
    n_run = len(np.unique(G_ind))
    M_list = list()
    X_list = list()
    for run_id in range(n_run):
        M_list.append(M[:,:,G_ind == run_id])
        X_list.append(X[G_ind == run_id, :])
           
    # greadient part one is fixed   -G^T( \sum_r M(r)) PhiT
    # this should be on all dipoles
    gradient_y0 = get_gradient0_tsparse(M_list, G_list, X_list, p, n_run,
                  n_dipoles, np.ones(n_dipoles, dtype = np.bool), n_times,
                  n_coefs, n_coefs*p, active_t_ind_full,
                  sparse_phi, sparse_phiT)
    # only keep on full matrix
    #gradient_y = np.zeros([n_dipoles, n_coefs*p],dtype =np.complex)
    for i in range(maxit):
        Z0 = Z.copy()
        active_set_z0 = active_set_z.copy()
        # this part can be only on the active set
        # but gradient_y1 should be a full matrix
        gradient_y1 = get_gradient1_L21_tsparse(M_list, G_list, X_list, Y, p, n_run,
                  np.int(active_set_y.sum()), active_set_y, n_times,
                  n_coefs, n_coefs*p, active_t_ind_full,
                  sparse_phi, sparse_phiT)
        gradient_y = gradient_y0 + gradient_y1
  
        # active_set_z0, active rows/dipoles for z0
        # active_t_ind_z0, active columns/time_points for z0
        
        # verify the gradient with two different computations
#        if False:
#            gradient_y20 = np.zeros([n_dipoles, n_coefs*p], dtype = np.complex)
#            gradient_y21 = np.zeros([n_dipoles, n_coefs*p], dtype = np.complex)
#            #GTG_active = G.T.dot(G[:,active_set_y])
#            Y_reshape = np.reshape(Y[:,0:p*n_coefs],[active_set_y.sum(), p,n_coefs])
#            Y_reshape = Y_reshape.swapaxes(1,2)
#            for run_id in range(n_run):
#                GTG_active = G_list[run_id].T.dot(G_list[run_id][:,active_set_y])
#                for k in range(p):
#                    # first term of the gradient
#                    M_sum = np.sum(M_list[run_id]*X_list[run_id][:,k],axis = 2)
#                    # G^T(\sum_r X_k(r) M(r))\Phi
#                    GTM_sumPhi = G_list[run_id].T.dot(sparse_phi(M_sum,active_t_ind_full))
#                    # second term of the gradient
#                    sum_X_coef = np.zeros([active_set_y.sum(), n_coefs], dtype = np.complex)
#                    # the first level is n_coefs, second level is p
#                    for r in range(X_list[run_id].shape[0]):
#                        sum_X_coef += np.sum(Y_reshape*X_list[run_id][r,:],axis = 2)*X_list[run_id][r,k]
#                    gradient_y20[:,k*n_coefs:(k+1)*n_coefs] += -GTM_sumPhi 
#                    gradient_y21 [:,k*n_coefs:(k+1)*n_coefs] += sparse_phi(sparse_phiT(GTG_active.dot(sum_X_coef),
#                                                          active_t_ind_full),active_t_ind_full)
#            #print (np.linalg.norm(gradient_y0-gradient_y20)/np.linalg.norm(gradient_y20)) 
#            #print (np.linalg.norm(gradient_y1-gradient_y21)/np.linalg.norm(gradient_y21)) 
#            gradient_y2 = gradient_y20+gradient_y21                                              
#            print "diff gradient=%e" % (np.linalg.norm(gradient_y-gradient_y2)/np.linalg.norm(gradient_y2)) 
#            #plt.plot(np.real(gradient_y2).ravel(),np.real(gradient_y).ravel(), '.' )
#            gradient_y = gradient_y2.copy()
        
        tmp_Y_L_gradient_Y = -gradient_y/L
        tmp_Y_L_gradient_Y[active_set_y,:] += Y
        ## ==== ISTA step, get the proximal operator ====
        ## Input must be a full matrix, so the active set is set to full
        Z, active_set_z = prox_tree_hard_coded_full_matrix(tmp_Y_L_gradient_Y, n_coefs, p, 
                                 alpha/L, beta/L, gamma/L,
                                 DipoleGroup, DipoleGroupWeight, n_orient)
        # check if it is zero solution                       
        if not np.any(active_set_z):
            print "active_set = 0"
            return 0,0,0                
        objz = f( Z, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT)
        # compute Z-Y
        diff_z, active_set_diff_z = _add_z([Z,-Y],  np.vstack([active_set_z, active_set_y]))
        
        if Flag_backtrack:
            objy = f( Y, active_set_y, M, G_list, G_ind, X, n_coefs, q, p, phiT)
            # compute the criterion for back track: f(z)-f(y) -grad_y.dot(z-y) -0.5*(z-y)**2
            # note gradient_y is a full matrix
            # note diff_z is complex
            # http://www.seas.ucla.edu/~vandenbe/236C/lectures/fgrad.pdf 7-17+ FISTA paper Beck
            diff_bt = objz-objy- np.sum( np.real(gradient_y[active_set_diff_z])* np.real(diff_z))\
                      - np.sum(np.imag(gradient_y[active_set_diff_z])* np.imag(diff_z)) \
                      -0.5*L*np.sum( np.abs(diff_z)**2)
            while diff_bt > 0:
                L = L*eta
                tmp_Y_L_gradient_Y = -gradient_y/L
                tmp_Y_L_gradient_Y[active_set_y,:] += Y
                Z, active_set_z = prox_tree_hard_coded_full_matrix(tmp_Y_L_gradient_Y, n_coefs, p, 
                                         alpha/L, beta/L, gamma/L,
                                         DipoleGroup, DipoleGroupWeight, n_orient)
                if not np.any(active_set_z):
                    print "active_set = 0"
                    return 0,0,0                        
                objz = f( Z, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT)
                # Z-Y
                diff_z, active_set_diff_z = _add_z([Z,-Y],  np.vstack([active_set_z, active_set_y]))
                # the criterion for back track: 
                diff_bt = objz-objy- np.sum( np.real(gradient_y[active_set_diff_z])* np.real(diff_z))\
                      - np.sum(np.imag(gradient_y[active_set_diff_z])* np.imag(diff_z)) \
                      -0.5*L*np.sum( np.abs(diff_z)**2)                
        
        ## ==== FISTA step, update the variables =====
        tau0 = tau;
        tau = 0.5*(1+ np.sqrt(4*tau**2+1))
        
        if np.any(active_set_diff_z) == 0:
            print "active_set = 0"
            return 0,0,0
        else:
            # y <- z + (tau0-1)/tau (z-z0) =(tau+tau0-1)/tau z - (tao0-1)/tau z0
            Y, active_set_y = _add_z( [(tau+tau0-1)/tau*Z, -(tau0-1)/tau*Z0],
                                       np.vstack([active_set_z, active_set_z0]))
            
            # commented after verification
            #ind_y = np.nonzero(active_set_y)[0]
            #Z_larger = np.zeros([active_set_y.sum(), n_coefs*p], dtype = np.complex)
            #tmp_intersection = np.all(np.vstack([active_set_y, active_set_z]),axis=0)
            #tmp_ind_inter = np.nonzero(tmp_intersection)[0]
            #tmp_ind = [k0 for k0 in range(len(ind_y)) if ind_y[k0] in tmp_ind_inter]
            #Z_larger[tmp_ind,:] = Z
            # expand Z0 
            #Z0_larger = np.zeros([active_set_y.sum(), n_coefs*p], dtype = np.complex)
            #tmp_intersection0 = np.all(np.vstack([active_set_y, active_set_z0]),axis=0)
            #tmp_ind_inter0 = np.nonzero(tmp_intersection0)[0]
            #tmp_ind0 = [k0 for k0 in range(len(ind_y)) if ind_y[k0] in tmp_ind_inter0]        
            #Z0_larger[tmp_ind0,:] = Z0
            ## update y
            #diff = Z_larger-Z0_larger
            #Y1 = Z_larger + (tau0 - 1.0)/tau *diff 
            #print "Y1-Y=%f" %np.linalg.norm(Y1-Y)
            
        ## ===== compute objective function, check stopping criteria ====
        old_obj = obj
        full_Z = np.zeros([n_dipoles, n_coefs*p], dtype = np.complex)
        full_Z[active_set_z, :] = Z

        # residuals in each trial 
        # if we use the trial by trial model, also update 
        # the gradient of the trial-by-trial coefficients 
        # debug, can be commented soon.                         
        #R_all_sq = 0           
        #for r in range(q):
        #    tmp_coef = np.zeros([active_set_z.sum(),n_coefs], dtype = np.complex)
        #    for k in range(p):
        #        tmp_coef += Z[:,k*n_coefs:(k+1)*n_coefs]*X[r,k]
        #    tmpR = M[:,:,r] - phiT(G_list[G_ind[r]][:,active_set_z].dot(tmp_coef))
        #    R_all_sq += np.sum(tmpR**2)    
        #print "diff in f(z) = %f" %(objz - 0.5* R_all_sq.sum() )
        obj = objz \
            + get_tree_norm_hard_coded(full_Z,n_coefs, p, 
              alpha, beta, gamma, DipoleGroup, DipoleGroupWeight, n_orient)
        diff_obj = old_obj-obj
        relative_diff = np.sum(np.abs(diff_z))/np.sum(np.abs(Z))
        if Flag_verbose: 
            print "\n iteration %d" % i
            print "diff_obj = %e" % (diff_obj/obj)
            print "obj = %e" %obj
            print "diff = %e" %relative_diff
        stop = ( relative_diff < tol  and np.abs(diff_obj/obj) < tol)        
        if stop:
            print "convergence reached!"
            break    
    Z = Y.copy() 
    active_set_z = active_set_y.copy()
    return Z, active_set_z, obj