Exemplo n.º 1
0
def test_dgapl21l1():
    """Test duality gap for L21 + L1 regularization."""
    n_orient = 2
    M, G, active_set = _generate_tf_data()
    n_times = M.shape[1]
    n_sources = G.shape[1]
    tstep, wsize = 4, 32
    n_steps = int(np.ceil(n_times / float(tstep)))
    n_freqs = wsize // 2 + 1
    n_coefs = n_steps * n_freqs
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freqs, n_steps, n_times)

    for l1_ratio in [0.05, 0.1]:
        alpha_max = norm_epsilon_inf(G, M, phi, l1_ratio, n_orient)
        alpha_space = (1. - l1_ratio) * alpha_max
        alpha_time = l1_ratio * alpha_max

        Z = np.zeros([n_sources, n_coefs])
        shape = (-1, n_steps, n_freqs)
        # for alpha = alpha_max, Z = 0 is the solution so the dgap is 0
        gap = dgap_l21l1(M, G, Z, np.ones(n_sources, dtype=bool), alpha_space,
                         alpha_time, phi, phiT, shape, n_orient, -np.inf)[0]

        assert_allclose(0., gap)
        # check that solution for alpha smaller than alpha_max is non 0:
        X_hat_tf, active_set_hat_tf, E, gap = tf_mixed_norm_solver(
            M,
            G,
            alpha_space / 1.01,
            alpha_time / 1.01,
            maxit=200,
            tol=1e-8,
            verbose=True,
            debias=False,
            n_orient=n_orient,
            tstep=tstep,
            wsize=wsize,
            return_gap=True)
        # allow possible small numerical errors (negative gap)
        assert_array_less(-1e-10, gap)
        assert_array_less(gap, 1e-8)
        assert_array_less(1, len(active_set_hat_tf))

        X_hat_tf, active_set_hat_tf, E, gap = tf_mixed_norm_solver(
            M,
            G,
            alpha_space / 5.,
            alpha_time / 5.,
            maxit=200,
            tol=1e-8,
            verbose=True,
            debias=False,
            n_orient=n_orient,
            tstep=tstep,
            wsize=wsize,
            return_gap=True)
        assert_array_less(-1e-10, gap)
        assert_array_less(gap, 1e-8)
        assert_array_less(1, len(active_set_hat_tf))
def get_sensor_MSE_on_held_data_individual_G(Z, active_set, test_evoked_list,
                       test_X, G_list, G_ind_test, method, 
                       wsize = 16, tstep = 4):
    '''
    Compute the mean squared error on the held-out sensor data.
    This is done on the original sensor data, before pre-whitening and SSP
    To make sure this is comparable for all three method, I rewrote the computation here. 
    Input: 
        Z, [active_set.sum(), n_coef*p]
        active_set,
        test_evoked_list, evoked list for testing
        test_X, [n_trials, p]
        G_list, 
        G_ind_test
        method = "STFT-R" "MNE-R"
        wsize, tstep
    Output: MSE_sensor    
    '''
    # note that M is not whitened here
    n_channels, n_times = test_evoked_list[0].data.shape
    n_trials = len(test_evoked_list)
    p = test_X.shape[1]
    
    if method == "MNE-R":
        n_times = test_evoked_list[0].data.shape[1]
        n_step = int(np.ceil(n_times/float(tstep)))
        n_freq = wsize//2+1
        n_step = n_times // tstep
        n_coefs = n_freq*n_step
        phiT = _PhiT(tstep, n_freq, n_step, n_times)
        
    
    M = np.zeros([n_channels, n_times, n_trials ])
    for i in range(n_trials):
         M[:,:,i] = test_evoked_list[i].data 
 
    MSE_sensor = 0.0
    for i in range(n_trials):
        tmp_Z = np.reshape(Z, [active_set.sum(), p, n_coefs])
        tmp_Z = np.sum(  np.swapaxes(tmp_Z, 1,2) * test_X[i,:], axis = 2)
        tmp_source = phiT(tmp_Z)
        predicted = G_list[G_ind_test[i]][:,active_set].dot(tmp_source)     
        MSE_sensor +=  np.sum ( (M[:,:,i] - predicted)**2 )
        
    MSE_sensor /= n_trials 
    return MSE_sensor  
Exemplo n.º 3
0
def test_dgapl21l1():
    """Test duality gap for L21 + L1 regularization."""
    n_orient = 2
    M, G, active_set = _generate_tf_data()
    n_times = M.shape[1]
    n_sources = G.shape[1]
    tstep, wsize = np.array([4, 2]), np.array([64, 16])
    n_steps = np.ceil(n_times / tstep.astype(float)).astype(int)
    n_freqs = wsize // 2 + 1
    n_coefs = n_steps * n_freqs
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freqs, n_steps, n_times)

    for l1_ratio in [0.05, 0.1]:
        alpha_max = norm_epsilon_inf(G, M, phi, l1_ratio, n_orient)
        alpha_space = (1. - l1_ratio) * alpha_max
        alpha_time = l1_ratio * alpha_max

        Z = np.zeros([n_sources, phi.n_coefs.sum()])
        # for alpha = alpha_max, Z = 0 is the solution so the dgap is 0
        gap = dgap_l21l1(M, G, Z, np.ones(n_sources, dtype=bool),
                         alpha_space, alpha_time, phi, phiT,
                         n_orient, -np.inf)[0]

        assert_allclose(0., gap)
        # check that solution for alpha smaller than alpha_max is non 0:
        X_hat_tf, active_set_hat_tf, E, gap = tf_mixed_norm_solver(
            M, G, alpha_space / 1.01, alpha_time / 1.01, maxit=200, tol=1e-8,
            verbose=True, debias=False, n_orient=n_orient, tstep=tstep,
            wsize=wsize, return_gap=True)
        # allow possible small numerical errors (negative gap)
        assert_array_less(-1e-10, gap)
        assert_array_less(gap, 1e-8)
        assert_array_less(1, len(active_set_hat_tf))

        X_hat_tf, active_set_hat_tf, E, gap = tf_mixed_norm_solver(
            M, G, alpha_space / 5., alpha_time / 5., maxit=200, tol=1e-8,
            verbose=True, debias=False, n_orient=n_orient, tstep=tstep,
            wsize=wsize, return_gap=True)
        assert_array_less(-1e-10, gap)
        assert_array_less(gap, 1e-8)
        assert_array_less(1, len(active_set_hat_tf))
X = Simu_obj['X']
coef = Simu_obj['coef']
Z_true = Simu_obj['Z_true']


n_channels, n_times = evoked_list[0].data.shape
n_trials = len(evoked_list)

wsize, tstep = 16,4
n_step = int(np.ceil(n_times/float(tstep)))
n_freq = wsize// 2+1
n_coefs = n_step*n_freq
p = 2
n_coefs_all = n_coefs*p
phi = _Phi(wsize, tstep, n_coefs)
phiT = _PhiT(tstep, n_freq, n_step, n_times)

G = fwd['sol']['data']
n_dipoles = G.shape[1]
DipoleGroup = list([label_ind[0], label_ind[1]])
non_label_ind = np.arange(0,n_dipoles,1)
for i in range(len(label_ind)):
    non_label_ind = np.setdiff1d(non_label_ind, label_ind[i])

for i in range(len(non_label_ind)):
    DipoleGroup.append(np.array([non_label_ind[i]]))
    
true_active_set = np.union1d(label_ind[0], label_ind[1])
DipoleGroupWeight = np.ones(len(DipoleGroup))/np.float(len(DipoleGroup))

Exemplo n.º 5
0
def solve_stft_regression_L2_tsparse(M,G_list, G_ind, X, Z0, 
                                    active_set_z0, active_t_ind_z0,
                                    coef_non_zero_mat,
                                wsize=16, tstep = 4, delta = 0,
                                maxit=200, tol = 1e-3,lipschitz_constant = None,
                                Flag_backtrack = True, eta = 1.5, L0 = 1.0,
                                Flag_verbose = False):                             
    """
    Use the accelerated gradient descent (exactly FISTA without non-smooth penalty)
        to find the solution given an active set
        min 1/2||R||_F^2 + delta ||Z||_F^2
    Input:
       M, [n_channels, n_times, n_trials] array of the sensor data
       G_list, a list of [n_channels, n_dipoles] forward gain matrix
       G_ind, [n_trial], marks the index of G for this run
       X, [n_trials, p],the design matrix, it must include an all 1 colume
       Z0, [n_active_dipoles, p*n_freq * n_active_step]
       acitve_set_z0, [n_dipoles,] a boolean array, indicating the active set of dipoles
       active_t_ind_z0, [n_step, ] a boolean array, indicating the active set of time points,
           the union of all frequencies, columns of X, and dipoles
       coef_non_zero_mat,[n_active_dipoles,n_coef*pq]  boolean matrix,
           since some active_set_z and active_t_ind_z is a super set of the active set,
           but we assume that it is the UNION of each coefficient of X and all trials
       wsize, window size of the STFT
       tstep, length of the time step
       delta, the regularization parameter
       maxit, maximum number of iteration allowed
       tol, tolerance of the objective function
       lipschitz_constant, the lipschitz constant
       
       No Flag trial by trial is allowed in this version
    """
    n_sensors, n_times, q= M.shape
    n_dipoles = G_list[0].shape[1]
    p = X.shape[1]
    n_step = int(np.ceil(n_times/float(tstep)))
    n_freq = wsize// 2+1
    n_coefs = n_step*n_freq

    #coef_non_zero_mat_full = np.tile(coef_non_zero_mat,[1,pq])
    coef_non_zero_mat_full = coef_non_zero_mat
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freq, n_step, n_times)
    sparse_phi = sparse_Phi(wsize, tstep)
    sparse_phiT = sparse_PhiT(tstep, n_freq, n_step, n_times)
    
    if (active_t_ind_z0.sum()*n_freq*p != Z0.shape[1] or n_dipoles!= len(active_set_z0)):
        print active_t_ind_z0.sum()*n_freq*p, Z0.shape[1],n_dipoles,len(active_set_z0)
        raise ValueError("wrong number of dipoles or coefs")

    if lipschitz_constant is None and not Flag_backtrack: 
        lipschitz_constant = 1.1* (get_lipschitz_const(M,G_list[0],X,phi,phiT,
                                    Flag_trial_by_trial = False,n_coefs = n_coefs,
                                    tol = 1e-3) +2.0*delta)
        print "lipschitz_constant = %e" % lipschitz_constant
    if Flag_backtrack:
        L = L0 
    else:
        L = lipschitz_constant
                                       
    #initialization
    active_t_ind_z = active_t_ind_z0.copy()
    active_set_z = active_set_z0.copy()
    Z = Z0.copy()
    Z[coef_non_zero_mat_full==0] = 0
    Y = Z0.copy()    
    Y[coef_non_zero_mat_full==0] = 0
    n_active_dipole = active_set_z.sum()
    # number of coeficients s
    n_coefs_z = n_freq*active_t_ind_z.sum()
    n_coefs_all_active = p*n_coefs_z
    #==== the main loop =====  
    tau, tau0 = 1.0, 1.0
    obj = np.inf
    old_obj = np.inf
    
    # prepare the M_list from the M
    n_run = len(np.unique(G_ind))
    M_list = list()
    X_list = list()
    for run_id in range(n_run):
        M_list.append(M[:,:,G_ind == run_id])
        X_list.append(X[G_ind == run_id, :])
           
    # greadient part one is fixed   -G^T( \sum_r M(r)) PhiT
    gradient_y0 = get_gradient0_tsparse(M_list, G_list, X_list, p, n_run,
                  n_active_dipole, active_set_z, n_times,
                  n_coefs_z, n_coefs_all_active, 
                  active_t_ind_z,
                  sparse_phi, sparse_phiT)
        
    #  iterations, we only need to compute the second part of gradient 
        #  +G^T G(\sum_r X_k^(r) \sum_k Z_k  X_k^r) PhiT Phi
    for i in range(maxit):
        Z0 = Z.copy()
        gradient_y1 = get_gradient1_tsparse(M_list, G_list, X_list, Y, p, n_run,
                  n_active_dipole, active_set_z, n_times,
                  n_coefs_z, n_coefs_all_active, 
                  active_t_ind_z,
                  sparse_phi, sparse_phiT)
        gradient_y = gradient_y0 + gradient_y1
        # compare the gradient(tested)
#        if False: 
#            gradient_y2 = np.zeros([n_active_dipole, n_coefs_all_active], dtype =np.complex)
#            R_all_sq = np.zeros([n_sensors, n_times])
#            for r in range(q):
#                # tmp_coef = y(0) + \sum_k y(k)* X(r,k)
#                tmp_coef = np.zeros([n_active_dipole,n_coefs_z], dtype = np.complex)
#                for k in range(p):
#                    tmp_coef += Y[:,k*n_coefs_z:(k+1)*n_coefs_z]*X[r,k]
#                # current residual for this trial            
#                tmpR = np.real(M[:,:,r] - G_list[G_ind[r]][:,active_set_z].dot(sparse_phiT(tmp_coef, active_t_ind_z)))           
#                R_all_sq += np.abs(tmpR)**2
#                tmpA = G_list[G_ind[r]][:,active_set_z].T.dot(sparse_phi(tmpR, active_t_ind_z))
#                for k in range(p):
#                    gradient_y2[:,k*n_coefs_z:(k+1)*n_coefs_z] += - X[r,k]*tmpA 
#            print "obj = %e" %(0.5* R_all_sq.sum() + delta*np.sum(np.abs(Y)**2))
#            #plt.plot(np.real(gradient_y2.ravel()), np.real(gradient_y.ravel()), '.')
#            #plt.plot(np.real(gradient_y2.ravel()), np.real(gradient_y2.ravel()), 'r')
#            print np.linalg.norm(gradient_y2-gradient_y)/np.linalg.norm(gradient_y2)
#            #gradient_y = gradient_y2.copy()
        # the L2 pentalty
        gradient_y += 2*delta*Y
        # compute the variable  y- 1/L gradient y
        gradient_y[coef_non_zero_mat_full==0] =0
        
        # follow the criterion in FISTA f(Y - gradient Y/L) > f(Y)+(0.5/L**2 -1/L ) ||gradient Y||^2
        # this is different from the general case f(Y - gradient Y/L) > f(Y)-0.5/L ||gradient Y||^2
        Z = Y- (gradient_y/L)
        objz = f_l2(Z, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT, delta)

        if Flag_backtrack:
            objy = f_l2( Y, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT,delta)
            # Ryan's slides
            # https://www.cs.cmu.edu/~ggordon/10725-F12/slides/05-gd-revisited.pdf
            # page 10
            diff_bt = objz-objy+ (0.5/L)* np.sum( np.abs(gradient_y)**2)
            while diff_bt > 0:
                L = L*eta
                Z = Y-(gradient_y/L)
                objz = f_l2( Z, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT, delta)
                diff_bt = objz-objy + (0.5/L)* np.sum( np.abs(gradient_y)**2)
                
        ## ==== FISTA step, update the variables =====
        tau0 = tau;
        tau = 0.5*(1+ np.sqrt(4*tau**2+1))
        diff = Z-Z0
        Y = Z + (tau0 - 1.0)/tau* diff
        ## ===== compute objective function, check stopping criteria ====
        old_obj = obj
        obj = np.linalg.norm(gradient_y)
        diff_obj = old_obj-obj
        if Flag_verbose:
            print "\n iteration %d" % i
            print "sum sq gradient = %f" %obj
            print "diff_obj = %e" % (diff_obj)
            print "diff = %e" %(np.abs(diff).sum()/np.sum(abs(Y))) 
            print "obj = %f" % objz
        stop = np.abs(diff).sum()/np.sum(abs(Y)) < tol 
        if stop:
            print "convergence reached!"
            break           
    Z = Y.copy() 
    Z[coef_non_zero_mat_full ==0] =0
    return Z, objz
Exemplo n.º 6
0
def select_delta_stft_regression_cv(M,G_list, G_ind,X,Z00,
                                    active_set_z0, active_t_ind_z0,
                                    coef_non_zero_mat,
                                    delta_seq,cv_partition_ind,
                                    wsize=16, tstep = 4, 
                                    maxit=200, tol = 1e-3,
                                    Flag_backtrack = True, L0 = 1.0, eta = 1.5,
                                    Flag_verbose = False): 
    ''' Find the best L2 regularization parameter delta by cross validation
        Note that here, in training, the trial by trial paramter is estimated, 
        but in testing, only the regression coefficients were used. 

    Input:
        M, [n_channels, n_times, n_trials] array of the sensor data
        G_list, a list of [n_channels, n_dipoles] forward gain matrix
        G_ind, [n_trial], marks the index of G for this run
        X, [n_trials, p],the design matrix, it must include an all 1 colume
        Z00, [n_active_dipoles, p*n_freq * n_active_step]
        e.g 
            # initial value
                Z00 = (np.random.randn(n_true_dipoles, n_coefs_all_active) \
                + np.random.randn(n_true_dipoles, n_coefs_all_active)*1j)*1E-15  
                n_coefs_all_active = active_t_ind_z0.sum()*n_freq*pq
                
        acitve_set_z0, [n_dipoles,] a boolean array, indicating the active set of dipoles
        active_t_ind_z0, [n_step, ] a boolean array, indicating the active set of time points,
            the union of all frequencies, columns of X, and dipoles
        coef_non_zero_mat,[n_active_dipoles,n_coefs*pq]  boolean matrix,
            since some active_set_z and active_t_ind_z is a super set of the active set,
            but we assume that it is the UNION of each coefficient of X and all trials
        delta_seq, np.array, a sequence of delta to be tested
        cv_parition_ind, an integer array of which cross-validation group
                    each trial is in
        wsize, window size of the STFT
        tstep, length of the time step
        maxit, maximum number of iteration
        tol, tolerance,
    Output:
       delta_star, the best delta  
       cv_MSE, the cv MSE for all elements in delta_seq
    '''
    n_sensors, n_times, n_trials = M.shape
    n_dipoles = G_list[0].shape[1]
    if len(active_set_z0) != n_dipoles:
        raise ValueError("the number of dipoles does not match") 
    p = X.shape[1]
    n_step = int(np.ceil(n_times/float(tstep)))
    n_freq = wsize// 2+1
    n_coefs = n_step*n_freq
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freq, n_step, n_times)
    n_fold = len(np.unique(cv_partition_ind))
    n_delta = len(delta_seq)
    # n_true_dipoles = np.sum(active_set_z0)
    lipschitz_constant0 = get_lipschitz_const(M,G_list[0],X,phi,phiT,n_coefs,
                          tol = 1e-3,Flag_trial_by_trial = False)
    cv_MSE = np.zeros([n_fold, n_delta])
    
    for j in range(n_fold):
        # partition
        test_trials = np.nonzero(cv_partition_ind == j)[0]
        train_trials = np.nonzero(cv_partition_ind != j)[0]
        Z0 = Z00
        tmp_coef_non_zero_mat =  coef_non_zero_mat
     
        Mtrain = M[:,:,train_trials]
        Xtrain = X[train_trials,:]
        Mtest = M[:,:,test_trials]
        Xtest = X[test_trials,:] 
        G_ind_train = G_ind[train_trials]
        G_ind_test = G_ind[test_trials]
        for i in range(n_delta):
            tmp_delta = delta_seq[i]
            # lipschitz constant
            L = (lipschitz_constant0+ 2*tmp_delta)*1.1
            # training  
            Z, _ = solve_stft_regression_L2_tsparse (Mtrain,G_list,
                                                     G_ind_train, Xtrain, Z0,
                        active_set_z0, active_t_ind_z0, tmp_coef_non_zero_mat,
                        wsize=wsize, tstep =tstep,delta = tmp_delta,
                        maxit=maxit,tol = tol,lipschitz_constant =L,
                        Flag_backtrack = Flag_backtrack, L0 = L0, eta = eta,
                        Flag_verbose = Flag_verbose)
            # only take the regression coefficients out
            Z_star = Z[:,0:p*active_t_ind_z0.sum()*n_freq]
            Z0 = Z.copy()
            # testing
            tmp_val, _,_,_ = get_MSE_stft_regresion_tsparse(Mtest,G_list,
                                           G_ind_test,Xtest,
                                 Z_star, active_set_z0, active_t_ind_z0,
                                 wsize=wsize, tstep = tstep)
            cv_MSE[j,i] =  tmp_val
            # debug
            #import matplotlib.pyplot as plt
            #plt.figure()
            #plt.plot(np.real(Z_star).T)
            #plt.title(tmp_delta)
            
    cv_MSE = np.mean(cv_MSE, axis = 0)
    best_ind = np.argmin(cv_MSE)
    delta_star = delta_seq[best_ind]
    return delta_star, cv_MSE
          
Exemplo n.º 7
0
def solve_stft_regression_tree_group(M,G_list, G_ind,X,
                                alpha,beta, gamma, 
                                DipoleGroup,DipoleGroupWeight,
                                Z_ini, active_set_z_ini, 
                                n_orient=1, wsize=16, tstep = 4,
                                maxit=200, tol = 1e-3,lipschitz_constant = None,
                                Flag_backtrack = True, eta = 1.5, L0 = 1.0,
                                Flag_verbose = False):                             
    """    
    Input:
       M, [n_sensors,n_times,n_trials] array of the sensor data
       G_list,  a list of [n_sensors, n_dipoles] forward gain matrix
       G_ind, [n_trials], index of run number (which G to use)
       X, [n_trials, p]the design matrix, it must include an all 1 colume
       alpha, tuning parameter for the regularization
       beta, the tuning paramter of balance between single frequency- time basis and dipoles
       gamma, the penalty on the absolute value of entries in Z
       DipoleGroup, grouping of dipoles,
                    the dipoles in the same ROI are in the same group,
                    the dipoles outside ROIs form one-dipole groups
       DipleGroupWeight, weights of each dipole group (ROI)
       Z_ini, [n_dipoles, n_coefs*p], initial value of Z, the ravel order is [n_dioles, p, n_freqs, n_step]
       active_set_z_ini, [n_dipoles,]  boolean, active_set of dipoles
       n_orient, number of orientations for each dipole
                 note that if n_orient == 3, the DipoleGroup is still on the columns of G, 
                 but grouped in the 3-set way. 
                 The initial values and initial active set of Z should also be correspondingly correct. 
       wsize, number of frequence
       tstep, step in time of the stft
       maxit, maximum number of iteration allowed
       tol, tolerance of the objective function
       lipschitz_constant, the lipschitz constant,
       Flag_backtrack, if True, use backtracking instead of constant stepsize ( the lipschitz constant )
       eta, L0, the shrinking parameters and initial 1/stepsize
       Flag_verbose, if true, print the objective values and difference, else not
    Output:
        Z, the solustion, only rows in the active set, 
             note that it is not guaranteed that all rows from the same group will be in the final solutions,
             if all of the coefficients for that row is zero, it is dropped too. 
        active_set_z,  a boolean vector, active dipoles (rows)
        active_t_ind_z, a boolean vector, active time steps, (time steps)
        obj, the objective function   
    """
    # check the active_set_structure and group structure for n_orient ==3
    if n_orient == 3:
        active_set_mat = active_set_z_ini.copy()
        active_set_mat = active_set_mat.reshape([-1,n_orient])
        any_ind = np.any(active_set_mat,axis =1)
        all_ind = np.all(active_set_mat,axis =1)
        if np.sum(np.abs(any_ind-all_ind)) >0:
            raise ValueError("wrong active set for n_orient = 3")
        # DipoleGroup must also satisfy the structure
        for l in range(len(DipoleGroup)):
            if np.remainder(len(DipoleGroup[l]),n_orient)!=0:
                raise ValueError("wrong group")
            tmp_mat = np.reshape(DipoleGroup[l],[-1,n_orient],order = 'C')
            if np.sum(np.abs(tmp_mat[:,2] - tmp_mat[:,1] - 1)) != 0 \
                    or np.sum(np.abs(tmp_mat[:,1] - tmp_mat[:,0] - 1)) != 0:
                raise ValueError("wrong group")
                
    n_sensors, n_times, q = M.shape
    n_dipoles = G_list[0].shape[1]
    p = X.shape[1]
    n_step = int(np.ceil(n_times/float(tstep)))
    n_freq = wsize// 2+1
    n_coefs = n_step*n_freq

    # create the sparse and non sparse version of the STFT, iSTFT
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freq, n_step, n_times)
    # initialization    
    if lipschitz_constant is None and not Flag_backtrack: 
        lipschitz_constant = 1.1* get_lipschitz_const(M,G_list[0],X,phi,phiT,
                                    Flag_trial_by_trial = False,n_coefs = n_coefs,
                                    tol = 1e-3)
        print "lipschitz_constant = %e" % lipschitz_constant
        
    if Flag_backtrack:
        L = L0   
    else:
        L = lipschitz_constant
    # indices for the active set is only for rows
    Z = Z_ini.copy()
    Y = Z_ini.copy()
    active_set_z = active_set_z_ini.copy()
    active_set_y = active_set_z_ini.copy() 
    
    if Z.shape[0] != active_set_z.sum() or Z.shape[1] != n_coefs*p:
        raise ValueError('Z0 shape does not match active sets')
    #==== the main loop =====  
    tau, tau0 = 1.0, 1.0
    obj = np.inf
    old_obj = np.inf
    
    # prepare the M_list from the M
    n_run = len(np.unique(G_ind))
    M_list = list()
    X_list = list()
    for run_id in range(n_run):
        M_list.append(M[:,:,G_ind == run_id])
        X_list.append(X[G_ind == run_id, :])
           
    # greadient part one is fixed   -G^T( \sum_r M(r)) PhiT
    # this should be on all dipoles
    gradient_y0 = get_gradient0(M_list, G_list, X_list, p, n_run,
                  n_dipoles, np.ones(n_dipoles, dtype = np.bool), n_times,
                  n_coefs, n_coefs*p, phi, phiT)
    # only keep on full matrix
    #gradient_y = np.zeros([n_dipoles, n_coefs*p],dtype =np.complex)
    for i in range(maxit):
        Z0 = Z.copy()
        active_set_z0 = active_set_z.copy()
        # this part can be only on the active set
        # but gradient_y1 should be a full matrix
        gradient_y1 = get_gradient1_L21(M_list, G_list, X_list, Y, p, n_run,
                  np.int(active_set_y.sum()), active_set_y, n_times,
                  n_coefs, n_coefs*p, phi, phiT)
        gradient_y = gradient_y0 + gradient_y1
  
        # active_set_z0, active rows/dipoles for z0
        # active_t_ind_z0, active columns/time_points for z0
        
        tmp_Y_L_gradient_Y = -gradient_y/L
        tmp_Y_L_gradient_Y[active_set_y,:] += Y
        ## ==== ISTA step, get the proximal operator ====
        ## Input must be a full matrix, so the active set is set to full
        Z, active_set_z = prox_tree_hard_coded_full_matrix(tmp_Y_L_gradient_Y, n_coefs, p, 
                                 alpha/L, beta/L, gamma/L,
                                 DipoleGroup, DipoleGroupWeight, n_orient)
        # check if it is zero solution                       
        if not np.any(active_set_z):
            print "active_set = 0"
            return 0,0,0,0                 
        objz = f( Z, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT)
        # compute Z-Y
        diff_z, active_set_diff_z = _add_z([Z,-Y],  np.vstack([active_set_z, active_set_y]))
        
        if Flag_backtrack:
            objy = f( Y, active_set_y, M, G_list, G_ind, X, n_coefs, q, p, phiT)
            # compute the criterion for back track: f(z)-f(y) -grad_y.dot(z-y) +0.5*(z-y)**2
            # note gradient_y is a full matrix
            diff_bt = objz-objy- np.sum( np.real(gradient_y[active_set_diff_z])* np.real(diff_z))\
                      - np.sum(np.imag(gradient_y[active_set_diff_z])* np.imag(diff_z)) \
                      -0.5*L*np.sum( np.abs(diff_z)**2)   
            while diff_bt > 0:
                L = L*eta
                tmp_Y_L_gradient_Y = -gradient_y/L
                tmp_Y_L_gradient_Y[active_set_y,:] += Y
                Z, active_set_z = prox_tree_hard_coded_full_matrix(tmp_Y_L_gradient_Y, n_coefs, p, 
                                         alpha/L, beta/L, gamma/L,
                                         DipoleGroup, DipoleGroupWeight, n_orient)
                if not np.any(active_set_z):
                    print "active_set = 0"
                    return 0,0,0                        
                objz = f( Z, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT)
                # Z-Y
                diff_z, active_set_diff_z = _add_z([Z,-Y],  np.vstack([active_set_z, active_set_y]))
                # the criterion for back track: 
                diff_bt = objz-objy- np.sum( np.real(gradient_y[active_set_diff_z])* np.real(diff_z))\
                      - np.sum(np.imag(gradient_y[active_set_diff_z])* np.imag(diff_z)) \
                      -0.5*L*np.sum( np.abs(diff_z)**2)   
                               
        ## ==== FISTA step, update the variables =====
        tau0 = tau;
        tau = 0.5*(1+ np.sqrt(4*tau**2+1))
        
        if np.any(active_set_diff_z) == 0:
            print "active_set = 0"
            return 0,0,0
        else:
            # y <- z + (tau0-1)/tau (z-z0) =(tau+tau0-1)/tau z - (tao0-1)/tau z0
            Y, active_set_y = _add_z( [(tau+tau0-1)/tau*Z, -(tau0-1)/tau*Z0],
                                       np.vstack([active_set_z, active_set_z0]))
            
        ## ===== compute objective function, check stopping criteria ====
        old_obj = obj
        full_Z = np.zeros([n_dipoles, n_coefs*p], dtype = np.complex)
        full_Z[active_set_z, :] = Z
        
        obj = objz \
            + get_tree_norm_hard_coded(full_Z,n_coefs, p, 
              alpha, beta, gamma, DipoleGroup, DipoleGroupWeight, n_orient)
        diff_obj = old_obj-obj
        relative_diff = np.sum(np.abs(diff_z))/np.sum(np.abs(Z))
        if Flag_verbose: 
            print "\n iteration %d" % i
            print "diff_obj = %e" % (diff_obj/obj)
            print "obj = %e" %obj
            print "diff = %e" %relative_diff
        stop = ( relative_diff < tol  and np.abs(diff_obj/obj) < tol)        
        if stop:
            print "convergence reached!"
            break    
    Z = Y.copy() 
    active_set_z = active_set_y.copy()
    return Z, active_set_z, obj
Exemplo n.º 8
0
def compute_dual_gap(M, G_list, G_ind, X, Z, active_set, 
              alpha, beta, gamma,
              DipoleGroup, DipoleGroupWeight, n_orient,
              wsize = 16, tstep = 4):
    """ 
    Compute the duality gap, and check the feasibility of the dual function.
    Input:
        M, G_list, G_ind, X, Z, active_set, all the full primal.
        alpha, beta, gamma, penalty paramters
        n_orient, number of orientations
    Output:
        dict(feasibility_dist = feasibility_dist,
                gradient = gradient,
                feasibility_dist_DipoleGroup = feasibility_dist_DipoleGroup )       
        feasibility_dist, is the major criterion, if it is small enougth, then accept the results
        gradient, the gradient
    """
    # initialization, only update when seeing greater values
    if gamma == 0:
        raise ValueError( "the third level must be penalized!!")
    
    n_dipoles = G_list[0].shape[1]
    n_sensors, n_times, q= M.shape
    n_step = int(np.ceil(n_times/float(tstep)))
    n_freq = wsize// 2+1
    n_coefs = n_step*n_freq
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freq, n_step, n_times)   
    n_trials, p = X.shape
               
    
    n_run = len(np.unique(G_ind))
    M_list = list()
    X_list = list()
    for run_id in range(n_run):
        M_list.append(M[:,:,G_ind == run_id])
        X_list.append(X[G_ind == run_id, :])
    
    # compute the gradient to check feasibility
    gradient_y0 = get_gradient0(M_list, G_list, X_list, p, n_run,
                  n_dipoles, np.ones(n_dipoles, dtype = np.bool), n_times,
                  n_coefs, n_coefs*p, phi, phiT)
    gradient_y1 = get_gradient1_L21(M_list, G_list, X_list, Z, p, n_run,
              np.int(active_set.sum()), active_set, n_times,
              n_coefs, n_coefs*p, phi,phiT)
    gradient = gradient_y0 + gradient_y1
  
    # sanity check
    # for each dipole, get the maximum abs value among the real and imag parts
    #max_grad = np.max(np.vstack([np.max(np.abs(np.real(gradient)), axis = 1), 
    #               np.max(np.abs(np.imag(gradient)),axis = 1)]), axis = 0)
                   
    alpha_weight = np.zeros(n_dipoles)
    for i in range(len(DipoleGroup)):
        alpha_weight[DipoleGroup[i]] = alpha*DipoleGroupWeight[i]

    # if max_grad is greater than alpha+beta+gamma, 
    # then the dual variable can not be feasible
 
    #if np.any(max_grad > alpha_weight+ beta+gamma):
    #     feasibility_dist = np.inf
    #else:
    #    
    # feasibility check. b = A^T u + \sum_g D_g^T v_g
    # where g is the non zero groups 
    b = gradient.copy()
    active_set_ind = np.nonzero(active_set)[0]
    Z_full = np.zeros([n_dipoles, Z.shape[1]],dtype = np.complex)
    Z_full[active_set,:] = Z
    
    # add g in the alpha level, dipole groups 
    # a bool vector showing whether the group is in the active set
    DipoleGroup_active = np.zeros(len(DipoleGroup), dtype = np.bool)
    for i in range(len(DipoleGroup)):
        if np.intersect1d(DipoleGroup[i],active_set_ind).size >0:
            DipoleGroup_active[i] = True
            # l2 norm of the group
            l2_norm_alpha = np.sqrt(np.sum( ( np.abs(Z_full[DipoleGroup[i],:]) )**2) )
            # add the sum of the gradient
            if l2_norm_alpha == 0:
                raise ValueError("all zero in an active group!")
            b[DipoleGroup[i],:]+= Z_full[DipoleGroup[i],:]/l2_norm_alpha * alpha_weight[i]
    
    # add g in the beta level, same dipole, same stft coef
    if n_orient == 1:
        Z_reshape = np.reshape(Z, [active_set.sum(), p, -1])
        # active_set.sum() x  n_coefs
        l2_norm_beta = np.sqrt(np.sum( (np.abs(Z_reshape))**2, axis = 1))
        # active_set.sum() x  n_coefs * p
        l2_norm_beta_large = np.tile(l2_norm_beta, [1, p])                
       
    else: # n_orient == 3
        Z_reshape = np.reshape(Z, [active_set.sum()//3,3, p, -1])
        l2_norm_beta = np.sqrt( np.sum(   np.sum( ( np.abs(Z_reshape) )**2, axis = 2) , axis = 1 ) )
        l2_norm_beta_large = np.reshape(  np.tile(l2_norm_beta,[ 1, 3*p]) , Z.shape)

    tmp_add_to_b = np.zeros([active_set.sum(), Z.shape[1]], dtype = np.complex)
    nonzero_beta = l2_norm_beta_large > 0
    tmp_add_to_b[nonzero_beta] = Z[nonzero_beta]/l2_norm_beta_large[nonzero_beta]*beta
    b[active_set,:] += tmp_add_to_b
    
    # add g in the gamma level, each element of the matrix
    if n_orient == 1:
        l2_norm_gamma = np.abs(Z)   
    else: # n_orient == 3:
        Z_reshape = np.reshape(Z, [active_set.sum()//3,3,-1])
        l2_norm_gamma = np.sqrt ( np.sum( ( np.abs(Z_reshape) )**2, axis = 1 ) )
        l2_norm_gamma = np.reshape( np.tile(l2_norm_gamma, [1,3* Z.shape[1]]), Z.shape)
        
    nonzero_gamma = l2_norm_gamma > 0
    tmp_add_to_b = np.zeros([active_set.sum(), Z.shape[1]], dtype = np.complex)
    tmp_add_to_b[nonzero_gamma] = Z[nonzero_gamma]/l2_norm_gamma[nonzero_gamma]*gamma
    b[active_set,:] += tmp_add_to_b 
    if np.any(np.isnan(b)):
            raise ValueError("nan found in b!")
    # use coordinate descent to solve the feasibility problem
    nonzero_Z = np.abs(Z) >0 
    feasibility_result = get_feasibility(b, active_set, DipoleGroup, 
                                 alpha_weight, beta, gamma,
                                 n_coefs, p,  nonzero_Z, 
                                 DipoleGroup_active) 
    feasibility_dist =  feasibility_result['feasibility_dist']
    # the dist in each dipole, squared
    feasibility_dist_in_alpha_level = feasibility_result['feasibility_dist_in_alpha_level'] 
    feasibility_dist_DipoleGroup = np.zeros(len(DipoleGroup))
    for i in range(len(DipoleGroup)):
        feasibility_dist_DipoleGroup[i] = \
              np.sqrt( np.sum(feasibility_dist_in_alpha_level[DipoleGroup[i]]**2))                        
    
         
    return dict(feasibility_dist = feasibility_dist,
                gradient = gradient, b = b,
                feasibility_dist_DipoleGroup = feasibility_dist_DipoleGroup)   
Exemplo n.º 9
0
def solve_stft_regression_tree_group_tsparse(M,G_list, G_ind,X,
                                alpha,beta, gamma, 
                                DipoleGroup,DipoleGroupWeight,
                                Z_ini, active_set_z_ini, 
                                n_orient=1, wsize=16, tstep = 4,
                                maxit=200, tol = 1e-3,lipschitz_constant = None,
                                Flag_backtrack = True, eta = 1.5, L0 = 1.0,
                                Flag_verbose = False):                             
    """    
    Input:
       M, [n_sensors,n_times,n_trials] array of the sensor data
       G_list,  a list of [n_sensors, n_dipoles] forward gain matrix, for different runs if applicable
       G_ind, [n_trials], index of run number (which G to use)
       X, [n_trials, p]the design matrix, it must include an all 1 colume
       alpha, tuning parameter for the regularization
       beta, the tuning paramter, for the frequency- time basis
       gamma, the penalty on the absolute value of entries in Z
       DipoleGroup, grouping of dipoles,
                    the dipoles in the same ROI are in the same group,
                    the dipoles outside ROIs form one-dipole groups
       DipleGroupWeight, weights of each dipole group (ROI)
       Z_ini, [n_dipoles, n_coefs*p], initial value of Z, the ravel order is [n_dioles, p, n_freqs, n_step]
       active_set_z_ini, [n_dipoles,]  boolean, active_set of dipoles
       n_orient, number of orientations for each dipole
                 note that if n_orient == 3, the DipoleGroup is still on the columns of G, 
                 but grouped in the 3-set way. 
                 The initial values and initial active set of Z should also be correspondingly correct. 
       wsize, number of frequence
       tstep, step in time of the stft
       maxit, maximum number of iteration allowed
       tol, tolerance of the objective function
       lipschitz_constant, the lipschitz constant,
       Flag_backtrack, if True, use backtracking instead of constant stepsize ( the lipschitz constant )
       eta, L0, the shrinking parameters and initial 1/stepsize
       Flag_verbose, if true, print the objective values and difference, else not
    Output:
        Z, the solustion, only rows in the active set, 
             note that it is not guaranteed that all rows from the same group will be in the final solutions,
             if all of the coefficients for that row is zero, it is dropped too. 
        active_set_z,  a boolean vector, active dipoles (rows)
        active_t_ind_z, a boolean vector, active time steps, (time steps)
        obj, the objective function   
    """
    # check the active_set_structure and group structure for n_orient ==3
    if n_orient == 3:
        active_set_mat = active_set_z_ini.copy()
        active_set_mat = active_set_mat.reshape([-1,n_orient])
        any_ind = np.any(active_set_mat,axis =1)
        all_ind = np.all(active_set_mat,axis =1)
        if np.sum(np.abs(any_ind-all_ind)) >0:
            raise ValueError("wrong active set for n_orient = 3")
        # DipoleGroup must also satisfy the structure
        for l in range(len(DipoleGroup)):
            if np.remainder(len(DipoleGroup[l]),n_orient)!=0:
                raise ValueError("wrong group")
            tmp_mat = np.reshape(DipoleGroup[l],[-1,n_orient],order = 'C')
            if np.sum(np.abs(tmp_mat[:,2] - tmp_mat[:,1] - 1)) != 0 \
                    or np.sum(np.abs(tmp_mat[:,1] - tmp_mat[:,0] - 1)) != 0:
                raise ValueError("wrong group")
                
    n_sensors, n_times, q = M.shape
    n_dipoles = G_list[0].shape[1]
    p = X.shape[1]
    n_step = int(np.ceil(n_times/float(tstep)))
    n_freq = wsize// 2+1
    n_coefs = n_step*n_freq

    # create the sparse and non sparse version of the STFT, iSTFT
    phi = _Phi(wsize, tstep, n_coefs)
    phiT = _PhiT(tstep, n_freq, n_step, n_times)
    sparse_phi = sparse_Phi(wsize, tstep)
    sparse_phiT = sparse_PhiT(tstep, n_freq, n_step, n_times)
    
    # initialization    
    if lipschitz_constant is None and not Flag_backtrack: 
        lipschitz_constant = 1.1* get_lipschitz_const(M,G_list[0],X,phi,phiT,
                                    Flag_trial_by_trial = False,n_coefs = n_coefs,
                                    tol = 1e-3)
        print "lipschitz_constant = %e" % lipschitz_constant
        
    if Flag_backtrack:
        L = L0   
    else:
        L = lipschitz_constant
    # indices for the active set is only for rows
    Z = Z_ini.copy()
    Y = Z_ini.copy()
    active_set_z = active_set_z_ini.copy()
    active_set_y = active_set_z_ini.copy() 
    
    # my code has sparse time steps. (sparse_phi and sparse_phiT)
    # for consistency, I still kept the active_t_ind variables, 
    # but make sure they are all true, all time steps are used. 
    active_t_ind_full = np.ones(n_step, dtype = np.bool)
    if Z.shape[0] != active_set_z.sum() or Z.shape[1] != n_coefs*p:
        raise ValueError('Z0 shape does not match active sets')
    #==== the main loop =====  
    tau, tau0 = 1.0, 1.0
    obj = np.inf
    old_obj = np.inf
    
    # prepare the M_list from the M
    n_run = len(np.unique(G_ind))
    M_list = list()
    X_list = list()
    for run_id in range(n_run):
        M_list.append(M[:,:,G_ind == run_id])
        X_list.append(X[G_ind == run_id, :])
           
    # greadient part one is fixed   -G^T( \sum_r M(r)) PhiT
    # this should be on all dipoles
    gradient_y0 = get_gradient0_tsparse(M_list, G_list, X_list, p, n_run,
                  n_dipoles, np.ones(n_dipoles, dtype = np.bool), n_times,
                  n_coefs, n_coefs*p, active_t_ind_full,
                  sparse_phi, sparse_phiT)
    # only keep on full matrix
    #gradient_y = np.zeros([n_dipoles, n_coefs*p],dtype =np.complex)
    for i in range(maxit):
        Z0 = Z.copy()
        active_set_z0 = active_set_z.copy()
        # this part can be only on the active set
        # but gradient_y1 should be a full matrix
        gradient_y1 = get_gradient1_L21_tsparse(M_list, G_list, X_list, Y, p, n_run,
                  np.int(active_set_y.sum()), active_set_y, n_times,
                  n_coefs, n_coefs*p, active_t_ind_full,
                  sparse_phi, sparse_phiT)
        gradient_y = gradient_y0 + gradient_y1
  
        # active_set_z0, active rows/dipoles for z0
        # active_t_ind_z0, active columns/time_points for z0
        
        # verify the gradient with two different computations
#        if False:
#            gradient_y20 = np.zeros([n_dipoles, n_coefs*p], dtype = np.complex)
#            gradient_y21 = np.zeros([n_dipoles, n_coefs*p], dtype = np.complex)
#            #GTG_active = G.T.dot(G[:,active_set_y])
#            Y_reshape = np.reshape(Y[:,0:p*n_coefs],[active_set_y.sum(), p,n_coefs])
#            Y_reshape = Y_reshape.swapaxes(1,2)
#            for run_id in range(n_run):
#                GTG_active = G_list[run_id].T.dot(G_list[run_id][:,active_set_y])
#                for k in range(p):
#                    # first term of the gradient
#                    M_sum = np.sum(M_list[run_id]*X_list[run_id][:,k],axis = 2)
#                    # G^T(\sum_r X_k(r) M(r))\Phi
#                    GTM_sumPhi = G_list[run_id].T.dot(sparse_phi(M_sum,active_t_ind_full))
#                    # second term of the gradient
#                    sum_X_coef = np.zeros([active_set_y.sum(), n_coefs], dtype = np.complex)
#                    # the first level is n_coefs, second level is p
#                    for r in range(X_list[run_id].shape[0]):
#                        sum_X_coef += np.sum(Y_reshape*X_list[run_id][r,:],axis = 2)*X_list[run_id][r,k]
#                    gradient_y20[:,k*n_coefs:(k+1)*n_coefs] += -GTM_sumPhi 
#                    gradient_y21 [:,k*n_coefs:(k+1)*n_coefs] += sparse_phi(sparse_phiT(GTG_active.dot(sum_X_coef),
#                                                          active_t_ind_full),active_t_ind_full)
#            #print (np.linalg.norm(gradient_y0-gradient_y20)/np.linalg.norm(gradient_y20)) 
#            #print (np.linalg.norm(gradient_y1-gradient_y21)/np.linalg.norm(gradient_y21)) 
#            gradient_y2 = gradient_y20+gradient_y21                                              
#            print "diff gradient=%e" % (np.linalg.norm(gradient_y-gradient_y2)/np.linalg.norm(gradient_y2)) 
#            #plt.plot(np.real(gradient_y2).ravel(),np.real(gradient_y).ravel(), '.' )
#            gradient_y = gradient_y2.copy()
        
        tmp_Y_L_gradient_Y = -gradient_y/L
        tmp_Y_L_gradient_Y[active_set_y,:] += Y
        ## ==== ISTA step, get the proximal operator ====
        ## Input must be a full matrix, so the active set is set to full
        Z, active_set_z = prox_tree_hard_coded_full_matrix(tmp_Y_L_gradient_Y, n_coefs, p, 
                                 alpha/L, beta/L, gamma/L,
                                 DipoleGroup, DipoleGroupWeight, n_orient)
        # check if it is zero solution                       
        if not np.any(active_set_z):
            print "active_set = 0"
            return 0,0,0                
        objz = f( Z, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT)
        # compute Z-Y
        diff_z, active_set_diff_z = _add_z([Z,-Y],  np.vstack([active_set_z, active_set_y]))
        
        if Flag_backtrack:
            objy = f( Y, active_set_y, M, G_list, G_ind, X, n_coefs, q, p, phiT)
            # compute the criterion for back track: f(z)-f(y) -grad_y.dot(z-y) -0.5*(z-y)**2
            # note gradient_y is a full matrix
            # note diff_z is complex
            # http://www.seas.ucla.edu/~vandenbe/236C/lectures/fgrad.pdf 7-17+ FISTA paper Beck
            diff_bt = objz-objy- np.sum( np.real(gradient_y[active_set_diff_z])* np.real(diff_z))\
                      - np.sum(np.imag(gradient_y[active_set_diff_z])* np.imag(diff_z)) \
                      -0.5*L*np.sum( np.abs(diff_z)**2)
            while diff_bt > 0:
                L = L*eta
                tmp_Y_L_gradient_Y = -gradient_y/L
                tmp_Y_L_gradient_Y[active_set_y,:] += Y
                Z, active_set_z = prox_tree_hard_coded_full_matrix(tmp_Y_L_gradient_Y, n_coefs, p, 
                                         alpha/L, beta/L, gamma/L,
                                         DipoleGroup, DipoleGroupWeight, n_orient)
                if not np.any(active_set_z):
                    print "active_set = 0"
                    return 0,0,0                        
                objz = f( Z, active_set_z, M, G_list, G_ind, X, n_coefs, q, p, phiT)
                # Z-Y
                diff_z, active_set_diff_z = _add_z([Z,-Y],  np.vstack([active_set_z, active_set_y]))
                # the criterion for back track: 
                diff_bt = objz-objy- np.sum( np.real(gradient_y[active_set_diff_z])* np.real(diff_z))\
                      - np.sum(np.imag(gradient_y[active_set_diff_z])* np.imag(diff_z)) \
                      -0.5*L*np.sum( np.abs(diff_z)**2)                
        
        ## ==== FISTA step, update the variables =====
        tau0 = tau;
        tau = 0.5*(1+ np.sqrt(4*tau**2+1))
        
        if np.any(active_set_diff_z) == 0:
            print "active_set = 0"
            return 0,0,0
        else:
            # y <- z + (tau0-1)/tau (z-z0) =(tau+tau0-1)/tau z - (tao0-1)/tau z0
            Y, active_set_y = _add_z( [(tau+tau0-1)/tau*Z, -(tau0-1)/tau*Z0],
                                       np.vstack([active_set_z, active_set_z0]))
            
            # commented after verification
            #ind_y = np.nonzero(active_set_y)[0]
            #Z_larger = np.zeros([active_set_y.sum(), n_coefs*p], dtype = np.complex)
            #tmp_intersection = np.all(np.vstack([active_set_y, active_set_z]),axis=0)
            #tmp_ind_inter = np.nonzero(tmp_intersection)[0]
            #tmp_ind = [k0 for k0 in range(len(ind_y)) if ind_y[k0] in tmp_ind_inter]
            #Z_larger[tmp_ind,:] = Z
            # expand Z0 
            #Z0_larger = np.zeros([active_set_y.sum(), n_coefs*p], dtype = np.complex)
            #tmp_intersection0 = np.all(np.vstack([active_set_y, active_set_z0]),axis=0)
            #tmp_ind_inter0 = np.nonzero(tmp_intersection0)[0]
            #tmp_ind0 = [k0 for k0 in range(len(ind_y)) if ind_y[k0] in tmp_ind_inter0]        
            #Z0_larger[tmp_ind0,:] = Z0
            ## update y
            #diff = Z_larger-Z0_larger
            #Y1 = Z_larger + (tau0 - 1.0)/tau *diff 
            #print "Y1-Y=%f" %np.linalg.norm(Y1-Y)
            
        ## ===== compute objective function, check stopping criteria ====
        old_obj = obj
        full_Z = np.zeros([n_dipoles, n_coefs*p], dtype = np.complex)
        full_Z[active_set_z, :] = Z

        # residuals in each trial 
        # if we use the trial by trial model, also update 
        # the gradient of the trial-by-trial coefficients 
        # debug, can be commented soon.                         
        #R_all_sq = 0           
        #for r in range(q):
        #    tmp_coef = np.zeros([active_set_z.sum(),n_coefs], dtype = np.complex)
        #    for k in range(p):
        #        tmp_coef += Z[:,k*n_coefs:(k+1)*n_coefs]*X[r,k]
        #    tmpR = M[:,:,r] - phiT(G_list[G_ind[r]][:,active_set_z].dot(tmp_coef))
        #    R_all_sq += np.sum(tmpR**2)    
        #print "diff in f(z) = %f" %(objz - 0.5* R_all_sq.sum() )
        obj = objz \
            + get_tree_norm_hard_coded(full_Z,n_coefs, p, 
              alpha, beta, gamma, DipoleGroup, DipoleGroupWeight, n_orient)
        diff_obj = old_obj-obj
        relative_diff = np.sum(np.abs(diff_z))/np.sum(np.abs(Z))
        if Flag_verbose: 
            print "\n iteration %d" % i
            print "diff_obj = %e" % (diff_obj/obj)
            print "obj = %e" %obj
            print "diff = %e" %relative_diff
        stop = ( relative_diff < tol  and np.abs(diff_obj/obj) < tol)        
        if stop:
            print "convergence reached!"
            break    
    Z = Y.copy() 
    active_set_z = active_set_y.copy()
    return Z, active_set_z, obj
def compare_with_truth_individual_G(Z, active_set, source_data, 
                       Z_true, true_stc_data_list,  true_ind, 
                       ROI_ind, n_coefs, X, label_ind,
                       method,  
                       Flag_reconstruct_source = False,
                       tstep = 4, wsize = 16, Flag_sparse = True):
    '''
    Given a stft coefficient solution, compare with the truth
    Input:
        Z, [active_set.sum(), n_coefs*p]
        active_set, [n_dipole,] boolean
        source_data, if None, use Z to estimate it, if given,
                     it should be [active_set.sum(),n_times, n_trials]
                     The source_data outputted by the model may not be the same as predicted by Z,
                     for mne, it is the the raw mne solution before reconstruction, 
                     for stft_reg, it could be the trial by trial model
        Z_true, [n_dipoles, n_coefs*p]
        true_stc_data_list, list of the true stc data
        true_ind, indices of truly active dipoles
        ROI_ind, indices of dipoles that we care about
        n_coefs,
        X, [n_trials, p]
        label_ind, list of indices in each ROI
        ROI_curve,  PCA of the ROI curve
        ROI_curve_var_percent, percent of variance the components explain
        method = "stft_reg", "mne" or "mne_stft"
        Flag_reconstruct_source = False,
            if True, use Z to construct the source data
        
    Ouput:
        result = dict(coef_error_ROI = coef_error_ROI, coef_error = coef_error,
                  curve_error_ROI = curve_error_ROI, curve_error = curve_error,
                  MSE_source = MSE_source, MSE_source_ROI = MSE_source_ROI)
    ''' 
    
    # ==================== the error in source space =======================
    # if source_data is None, do not compute the MSEs
    n_dipoles = active_set.shape[0]
    [n_trials, p] = X.shape
    MSE_source_ROI, MSE_source = 0.0,0.0
    MSE_source_abs_ROI, MSE_source_abs = 0.0,0.0
    if source_data is not None:
        n_times = source_data.shape[1]
        # compute the distance to truth with the given source_data
        for i in range(n_trials):
            # estimated data, full and within ROI
            tmp_stc_full = np.zeros([n_dipoles, n_times])
            tmp_stc_full[active_set, :] = source_data[:,:,i]
            tmp_stc_ROI = tmp_stc_full[ROI_ind,:]
            # true data
            if Flag_sparse: # the true data contains only dipoles in true_ind
                tmp_true_stc_data = np.zeros([n_dipoles,n_times])
                tmp_true_stc_data[true_ind,:] = true_stc_data_list[i]
            else:  # the true data contains all dipoles
                tmp_true_stc_data = true_stc_data_list[i]
                
            tmp_true_stc_data_ROI = tmp_true_stc_data[ROI_ind,:]
            
            MSE_source_ROI += np.sum(( tmp_stc_ROI - tmp_true_stc_data_ROI )**2)
            MSE_source += np.sum(( tmp_stc_full - tmp_true_stc_data )**2)
            MSE_source_abs_ROI += np.sum(( np.abs(tmp_stc_ROI) - np.abs(tmp_true_stc_data_ROI) )**2)
            MSE_source_abs += np.sum(( np.abs(tmp_stc_full) - np.abs(tmp_true_stc_data) )**2)
            
        MSE_source_ROI /= n_trials
        MSE_source /= n_trials
        MSE_source_abs_ROI /= n_trials
        MSE_source_abs /= n_trials
        
    # ==================== the reconstruction error in source space =======================
    MSE_source_ROI_recon, MSE_source_recon = 0.0,0.0
    MSE_source_abs_ROI_recon, MSE_source_abs_recon = 0.0,0.0
    
    if Flag_reconstruct_source:
        n_times = true_stc_data_list[0].shape[1]
        n_freq = wsize//2+1
        n_step = n_times // tstep
        phiT = _PhiT(tstep, n_freq, n_step, n_times)
        for i in range(n_trials):
            if method == "mne_stft" or method == "stft_reg":
                tmpZ = np.reshape(Z,[active_set.sum(), p, n_coefs])
                tmpZ = np.swapaxes(tmpZ, 1,2)
                tmpZ = np.sum(tmpZ*X[i,:],axis = 2)
                tmp_source_data = phiT(tmpZ)
            else:
                # method = mne, Z [n_sources, n_times, p]
                tmp_source_data = np.sum(Z*X[i,:],axis = 2)
                
            tmp_stc_full = np.zeros([n_dipoles, n_times])
            tmp_stc_full[active_set, :] = tmp_source_data
            tmp_stc_ROI = tmp_stc_full[ROI_ind,:]
            
            # the true data
            if Flag_sparse:
                tmp_true_stc_data = np.zeros([n_dipoles,n_times])
                tmp_true_stc_data[true_ind,:] = true_stc_data_list[i]
            else:
                tmp_true_stc_data = true_stc_data_list[i]
        
            true_stc_data_ROI = tmp_true_stc_data[ROI_ind,:]
            MSE_source_ROI_recon += np.sum((tmp_stc_ROI - true_stc_data_ROI)**2)
            MSE_source_recon += np.sum((tmp_stc_full - tmp_true_stc_data)**2)
            MSE_source_abs_ROI_recon += np.sum( (np.abs(tmp_stc_ROI) - np.abs(true_stc_data_ROI) )**2)
            MSE_source_abs_recon += np.sum(( np.abs(tmp_stc_full) - np.abs(tmp_true_stc_data))**2)
        
        MSE_source_ROI_recon /= n_trials
        MSE_source_recon /= n_trials
        MSE_source_abs_ROI_recon /= n_trials
        MSE_source_abs_recon /= n_trials
    
    # ======================other errors ===================================    
    
    # ============errors of coefficients================================
    Z_full = np.zeros([n_dipoles,n_coefs*p], dtype = np.complex)
    Z_full[active_set,:] = Z
    Z_ROI = Z_full[ROI_ind,:]
    
    # Z_true is sparse, only at the ROIs
    Z_true_full = np.zeros([n_dipoles,n_coefs*p], dtype = np.complex)
    Z_true_full[true_ind,:] = Z_true
    Z_true_ROI = Z_true_full[ROI_ind,:]
    
    coef_error_ROI = np.sqrt( np.sum( np.abs( Z_ROI - Z_true_ROI ) **2) )
    coef_error = np.sqrt( np.sum( np.abs( Z_full - Z_true_full ) **2) )
    
    # measure the correlation 
    result = dict(coef_error_ROI = coef_error_ROI,coef_error = coef_error,
                  MSE_source = MSE_source, MSE_source_ROI = MSE_source_ROI,
                  MSE_source_ROI_recon = MSE_source_ROI_recon, 
                  MSE_source_recon = MSE_source_recon,
                  MSE_source_abs_ROI_recon = MSE_source_abs_ROI_recon,
                  MSE_source_abs_recon = MSE_source_abs_recon,
                  MSE_source_abs_ROI = MSE_source_abs_ROI,
                  MSE_source_abs = MSE_source_abs)
    return result