def scan_freq(subj,block,phase_elec,amp_elec,surrogate_analysis):
    import scipy.io as sio
    from numpy import arange,array,append,zeros,pi,angle,logical_and,mean,roll,save #for efficiency
    from eegfilt import eegfilt
    from scipy.signal import hilbert
    from math import log
    import pickle
    import os.path
    from random import randint

    #data_path = '/home/jcase/data/' + subj + block + '/' + subj + '_' + block + '_data.mat'
    #subglo_path = '/home/jcase/data/subj_globals.mat'
    #MI_output_path = '/home/jcase/data/' + subj + block + '/MI/e' + str(phase_elec) + '_e' + str(amp_elec)

    data_path = '/Users/johncase/Documents/UCSF Data/' + subj + '/' + subj + block + '/data/' + subj + '_' + block + '_data.mat'
    subglo_path = '/Users/johncase/Documents/UCSF Data/subj_globals.mat'
    MI_output_path = '/Users/johncase/Documents/UCSF Data/' + subj + '/' + subj + block + '/analysis/MI/e' + str(phase_elec) + '_e' + str(amp_elec)

    

    #Load ECOG Data
    ecog_data = sio.loadmat(data_path)['ecogData']
    amp_raw_data = ecog_data[amp_elec-1,:]
    pha_raw_data = ecog_data[phase_elec-1,:]
    del ecog_data

    #Load subject globals
    all_subj_data = sio.loadmat(subglo_path,struct_as_record=False, squeeze_me=True)['subj_globals']
    subj_data = getattr(getattr(all_subj_data,subj),block)
    srate = subj_data.srate
    per_chan_bad_epochs = subj_data.per_chan_bad_epochs
    allstimtimes = subj_data.allstimtimes
 
    #Surrogate Runs
    kruns = 10000

    #Phase-providing frequency
    #fp = arange(1,15.1,0.1)
    #fp_bandwidth = arange(0.5,5.1,0.1)

    fp = arange(1,21,1)
    fp_bandwidth = arange(1,11,1)

    #fp = arange(2,3)
    #fp_bandwidth = arange(2,3)

    #Amplitude-providing frequency
    fa = array([70,150])

    #Define phase bins
    n_bins = 20
    bin_size = 2*pi/n_bins
    bins = arange(-pi,pi-bin_size,bin_size)

    #Define time_window (roughly entire block, exclude artifacts samples later)
    t_0 = int(round(allstimtimes[0,0]*srate))
    t_end = int(round((allstimtimes[-1,1]+3) *srate))
    t_win = arange(t_0,t_end)

    #Determine samples with artifacts
    bad_samp = array([])
    if per_chan_bad_epochs[phase_elec-1].size == 2:
        bad_samp = append(bad_samp,arange(srate*per_chan_bad_epochs[phase_elec-1][0],srate*per_chan_bad_epochs[phase_elec-1][1]))
    else:
        for epoch in per_chan_bad_epochs[phase_elec-1]:
            bad_samp = append(bad_samp,arange(srate*epoch[0],srate*epoch[1]))

    if not phase_elec == amp_elec:
        if per_chan_bad_epochs[amp_elec-1].size == 2:
            bad_samp = append(bad_samp,arange(srate*per_chan_bad_epochs[amp_elec-1][0],srate*per_chan_bad_epochs[amp_elec-1][1]))
        else:
            for epoch in per_chan_bad_epochs[amp_elec-1]:
                bad_samp = append(bad_samp,arange(srate*epoch[0],srate*epoch[1]))

    #good_samps
    good_samps = list(set(t_win)-set(bad_samp))

    #Do high-gamma filtering
    pow,filtwt = eegfilt(amp_raw_data,srate,fa[0],[])
    pow,filtwt = eegfilt(pow[0],srate,[],fa[1])
    pow = abs(hilbert(pow[0][0:len(pow[0])-1]))
    pow = pow[good_samps] #exclude bad` samples

    #Calculate MI for each phase-providing central-frequencies / bandwidths
    MI = zeros([len(fp),len(fp_bandwidth)])
    MI_surrogate = zeros([len(fp),len(fp_bandwidth),kruns])

    for iFreq,freq in enumerate(fp):
        for iBand,band in enumerate(fp_bandwidth):

            if freq-(band/2) < 0.5:
                MI[iFreq,iBand] = 0
                continue

            print('freq = ' + str(freq) + ', bw = ' + str(band))

            #Do phase-providing phase filtering

            pha = zeros([1,len(amp_raw_data)])
            pha = eegfilt(pha_raw_data,srate,[],freq+(band/2))[0][0]
            pha = eegfilt(pha,srate,freq-(band/2),[])[0][0]
            pha = angle(hilbert(pha[0:len(pha)-1]))
            pha = pha[good_samps] #exclude bad samples

            #Calculate mean amplitude within each phase bin to yield a
            #distribution of amplitude(phase)
            bin_dist = zeros([len(bins)])
            for iBin in range(len(bins)):
                ind = logical_and(pha>=bins[iBin],pha<bins[iBin]+bin_size)
                bin_dist[iBin] = mean(pow[ind])

            #Normalize distribution to yield pseudo "probability density function" (PDF)
            bin_dist = bin_dist / sum(bin_dist)

            #Calculate Shannon entropy of PDF
            h_p = 0
            for iBin,mybin in enumerate(bin_dist):
                h_p = h_p - mybin * log(mybin)

            #MI = (Kullback-Leibler distance between h_p and uniform
            #distribution) / (Entropy of uniform distribution) (see
            #http://jn.physiology.org/content/104/2/1195)
            MI[iFreq,iBand] = (log(len(bins)) - h_p) / log(len(bins))

            if surrogate_analysis == 1:


                for iRun in range(kruns):

                    #if iRun%10 == 0:
                    #    print '{}\r'.format('Run ' + str(iRun+10)),

                    shift_factor = randint(0,len(pow))
                    pow_shifted = roll(pow,shift_factor)

                    #Calculate mean amplitude within each phase bin to yield a
                    #distribution of amplitude(phase)
                    bin_dist = zeros([len(bins)])
                    for iBin in range(len(bins)):
                        ind = logical_and(pha>=bins[iBin],pha<bins[iBin]+bin_size)
                        bin_dist[iBin] = mean(pow_shifted[ind])

                    #Normalize distribution to yield pseudo "probability density function" (PDF)
                    bin_dist = bin_dist / sum(bin_dist)

                    #Calculate Shannon entropy of PDF
                    h_p = 0
                    for iBin,mybin in enumerate(bin_dist):
                        h_p = h_p - mybin * log(mybin)

                    MI_surrogate[iFreq,iBand,iRun] = (log(len(bins)) - h_p) / log(len(bins))


    save(open(MI_output_path,'wb'),MI)
    if surrogate_analysis == 1:
        save(open(MI_output_path+'_surrogate','wb'),MI_surrogate)
def scan_freq(subj,block,phase_elec,amp_elec):
    import scipy.io as sio
    import numpy as np
    from eegfilt import eegfilt
    import scipy.signal as sig
    import math
    import pickle
    import os.path
    import timeit


    data_path = '/home/jcase/data/' + subj + block + '/' + subj + '_' + block + '_data.mat'
    subglo_path = '/home/jcase/data/subj_globals.mat'
    MI_output_path = '/home/jcase/data/' + subj + block + '/MI/e' + str(phase_elec) + '_e' + str(amp_elec)

    #Load ECOG Data
    ecog_data = sio.loadmat(data_path)['ecogData']
    amp_raw_data = ecog_data[amp_elec-1,:]
    pha_raw_data = ecog_data[phase_elec-1,:]
    del ecog_data

    #Load subject globals
    all_subj_data = sio.loadmat(subglo_path,struct_as_record=False, squeeze_me=True)['subj_globals']
    subj_data = getattr(getattr(all_subj_data,subj),block)
    srate = subj_data.srate
    per_chan_bad_epochs = subj_data.per_chan_bad_epochs
    allstimtimes = subj_data.allstimtimes

    #Phase-providing frequency
    fp = np.arange(1,15.1,0.1)
    #fp = np.arange(1,15.1,0.1)
    fp_bandwidth = np.arange(0.5,5.1,0.1)

    #Amplitude-providing frequency
    fa = np.array([70,150])

    #Define phase bins
    n_bins = 20
    bin_size = 2*np.pi/n_bins
    bins = np.arange(-np.pi,np.pi-bin_size,bin_size)

    #Define time_window (roughly entire block, exclude artifacts samples later)
    t_0 = int(round(allstimtimes[0,0]*srate))
    t_end = int(round((allstimtimes[-1,1]+3) *srate))
    t_win = np.arange(t_0,t_end)

    #Determine samples with artifacts
    bad_samp = np.array([])
    if per_chan_bad_epochs[phase_elec-1].size == 2:
        bad_samp = np.append(bad_samp,np.arange(srate*per_chan_bad_epochs[phase_elec-1][0],srate*per_chan_bad_epochs[phase_elec-1][1]))
    else:
        for epoch in per_chan_bad_epochs[phase_elec-1]:
            bad_samp = np.append(bad_samp,np.arange(srate*epoch[0],srate*epoch[1]))

    if not phase_elec == amp_elec:
        if per_chan_bad_epochs[amp_elec-1].size == 2:
            bad_samp = np.append(bad_samp,np.arange(srate*per_chan_bad_epochs[amp_elec-1][0],srate*per_chan_bad_epochs[amp_elec-1][1]))
        else:
            for epoch in per_chan_bad_epochs[amp_elec-1]:
                bad_samp = np.append(bad_samp,np.arange(srate*epoch[0],srate*epoch[1]))

    #good_samps
    good_samps = list(set(t_win)-set(bad_samp))

    #Do high-gamma filtering
    pow,filtwt = eegfilt(amp_raw_data,srate,fa[0],[])
    pow,filtwt = eegfilt(pow[0],srate,[],fa[1])
    pow = abs(sig.hilbert(pow[0][0:len(pow[0])-1]))
    pow = pow[good_samps] #exclude bad samples

    #Calculate MI for each phase-providing central-frequencies / bandwidths
    MI = np.zeros([len(fp),len(fp_bandwidth)])

    for iFreq,freq in enumerate(fp):
        for iBand,band in enumerate(fp_bandwidth):

            if freq-(band/2) < 0.5:
                MI[iFreq,iBand] = 0
                continue

            print('freq = ' + str(freq) + ', bw = ' + str(band))

            #Do phase-providing phase filtering

            pha = np.zeros([1,len(amp_raw_data)])
            pha = eegfilt(pha_raw_data,srate,[],freq+(band/2))[0][0]
            pha = eegfilt(pha,srate,freq-(band/2),[])[0][0]
            pha = np.angle(sig.hilbert(pha[0:len(pha)-1]))
            pha = pha[good_samps] #exclude bad samples

            #Calculate mean amplitude within each phase bin to yield a
            #distribution of amplitude(phase)
            bin_dist = np.zeros([len(bins)])
            for iBin in range(len(bins)):
                ind = np.logical_and(pha>=bins[iBin],pha<bins[iBin]+bin_size)
                bin_dist[iBin] = np.mean(pow[ind])

            #Normalize distribution to yield pseudo "probability density function" (PDF)
            bin_dist = bin_dist / sum(bin_dist)

            #Calculate Shannon entropy of PDF
            h_p = 0
            for iBin,mybin in enumerate(bin_dist):
                h_p = h_p - mybin * math.log(mybin)

            #MI = (Kullback-Leibler distance between h_p and uniform
            #distribution) / (Entropy of uniform distribution) (see
            #http://jn.physiology.org/content/104/2/1195)
            MI[iFreq,iBand] = (math.log(len(bins)) - h_p) / math.log(len(bins))




    np.save(open(MI_output_path,'wb'),MI)
def CFC_learning_per_stimID(subj,block,phase_elec,amp_elec,method,surrogate_analysis):
    import scipy.io as sio
    from numpy import concatenate,exp,arange,intersect1d,array,append,zeros,pi,angle,logical_and,mean,roll,save,where,empty,delete,in1d,random,rint #for efficiency
    from eegfilt import eegfilt
    from scipy.signal import hilbert
    from math import log
    import pickle
    import os
    from random import randint
    from collections import defaultdict

    #stim_method: 0 for all stimuli (60 vs 60 trials)
    #             1 for 1st offered (30 vs 30 trials)
    #             2 for 2nd offered (30 vs 30 trials)
    #             3 for feedback    (30 vs 30 trials)

    #s_methods = {0:'all',1:'first',2:'second',3:'feedback'}
    
   # subglo_path = '/home/jcase/data/subj_globals.mat'
   # data_path = '/home/jcase/data/' + subj + block + '/' + subj + '_' + block + '_data.mat'
   # MI_output_path = '/home/jcase/data/' + subj + block + '/' + method + '/' + s_methods[stim_method] + '_e' + str(phase_elec) + '_e' + str(amp_elec)

    subglo_path = '/Users/johncase/Documents/UCSF Data/subj_globals.mat'
    data_path = '/Users/johncase/Documents/UCSF Data/' + subj + '/' + subj + block + '/data/' + subj + '_' + block + '_data.mat'
    MI_output_path = '/Users/johncase/Documents/UCSF Data/' + subj + '/' + subj + block + '/analysis/dPAC_per_stimulus/e' + str(phase_elec) + '_e' + str(amp_elec)

    amp_raw_data = {}
    pha_raw_data = {}
    per_chan_bad_epochs = {}
    stimID = {}

    #Trial windows
    bl = 0
    ps = 1 #in seconds

    #Surrogate Runs
    kruns = 200

    #Amplitude-providing frequency
    fa = array([70,150])

    #Define phase bins
    n_bins = 18
    bin_size = 2*pi/n_bins
    bins = arange(-pi,pi-bin_size,bin_size)

    #Phase-providing frequency
    fp = arange(1,21,1)
    fp_bandwidth = arange(1,11,1)

    fp_good = arange(4,18)
    fp_bandwidth_good = arange(6,10)


    #Determine samples with artifacts
    bad_samp = defaultdict(list)
    allstimtimes = []
    pow = {}

    ecog_data = sio.loadmat(data_path)['ecogData']
    amp_raw_data = ecog_data[amp_elec-1,:]
    pha_raw_data = ecog_data[phase_elec-1,:]
    del ecog_data

    #Load subject globals
    all_subj_data = sio.loadmat(subglo_path,struct_as_record=False, squeeze_me=True)['subj_globals']
    subj_data = getattr(getattr(all_subj_data,subj),block)
    srate = subj_data.srate
    stimID = subj_data.stimID
    stimtimes = subj_data.allstimtimes

    per_chan_bad_epochs = subj_data.per_chan_bad_epochs

    offered = subj_data.offered
    samp1_onsets = subj_data.sample1times[:,0]
    samp2_onsets = subj_data.sample2times[:,0]
    feedback_onsets = subj_data.feedbacktimes[:,0]

    #Identify bad samples
    if per_chan_bad_epochs[phase_elec-1].size == 2:
        bad_samp = append(bad_samp,arange(srate*per_chan_bad_epochs[phase_elec-1][0],srate*per_chan_bad_epochs[phase_elec-1][1]))
    else:
        for epoch in per_chan_bad_epochs[phase_elec-1]:
            bad_samp = append(bad_samp,arange(srate*epoch[0],srate*epoch[1]))

    if not phase_elec == amp_elec:
        if per_chan_bad_epochs[amp_elec-1].size == 2:
            bad_samp = append(bad_samp,arange(srate*per_chan_bad_epochs[amp_elec-1][0],srate*per_chan_bad_epochs[amp_elec-1][1]))
        else:
            for epoch in per_chan_bad_epochs[amp_elec-1]:
                bad_samp = append(bad_samp,arange(srate*epoch[0],srate*epoch[1]))


    #Do high-gamma filtering
    pow,filtwt = eegfilt(amp_raw_data,srate,[],fa[1])
    pow,filtwt = eegfilt(pow[0],srate,fa[0],[])
    pow = abs(hilbert(pow[0][0:len(pow[0])-1]))


    for iStim in range(1,4):

        MI_diff_surrogate = zeros((len(fp),len(fp_bandwidth),kruns))
        MI_block = zeros((len(fp),len(fp_bandwidth),2))
        MI_diff = zeros((len(fp),len(fp_bandwidth)))

        stim_ind = where(stimID==iStim)[0]
        all_onsets = stimtimes[stim_ind,0]

        #make each onset a tuple containing the block name, phase, and amp data
        for onset in all_onsets:

            halfway = all_onsets[len(all_onsets)/2]
            if onset < halfway:
                blk = 0
            else:
                blk = 1

            trl = arange(rint(onset*srate),rint((onset+ps)*srate)).astype(int)
            #if onset does not overlap with an artifact
            if not any(intersect1d(trl,bad_samp)):
                allstimtimes.extend([(blk,trl,pow[trl])])




        #Calculate MI for each phase-providing central-frequencies / bandwidths
        for iFreq,freq in enumerate(fp):
            for iBand,band in enumerate(fp_bandwidth):

                if freq-(band/2) < 0.5:
                    MI_diff[iFreq,iBand] = 0
                    continue

                if freq not in fp_good or band not in fp_bandwidth_good:
                    MI_diff[iFreq,iBand] = 0
                    continue

                print('freq = ' + str(freq) + ', bw = ' + str(band))

                pha = {}

                stim_n = zeros(2)

                #Do phase-providing phase filtering
                pha = zeros([1,len(amp_raw_data)])
                pha = eegfilt(pha_raw_data,srate,[],freq+(band/2))[0][0]
                pha = eegfilt(pha,srate,freq-(band/2),[])[0][0]
                pha = angle(hilbert(pha[0:len(pha)-1]))

                for blk in range(0,2):
                    #add phase information
                    for iTrial,trial in enumerate(allstimtimes):
                        if trial[0] == blk:
                            allstimtimes[iTrial] = allstimtimes[iTrial][:3] + (pha[trial[1]],)

                    #count number of trials in each block (for surrogate analysis)
                    stim_n[blk] = int(len([trial for trial in allstimtimes if trial[0] == blk]))

                for blk in range(0,2):
                    #retrieve power from allstimtimes
                    pow_istim = [trial[2] for trial in allstimtimes if trial[0] == blk]
                    pow_istim = array([item for sublist in pow_istim for item in sublist])

                    #retrieve phase from allstimtimes
                    pha_istim = [trial[3] for trial in allstimtimes if trial[0] == blk]
                    pha_istim = array([item for sublist in pha_istim for item in sublist])

                    if method == 'MI':

                        #Calculate mean amplitude within each phase bin to yield a
                        #distribution of amplitude(phase)
                        bin_dist = zeros([len(bins)])
                        for iBin in range(len(bins)):
                            ind = logical_and(pha_istim>=bins[iBin],pha_istim<bins[iBin]+bin_size)
                            bin_dist[iBin] = mean(pow_istim[ind])

                        #Normalize distribution to yield pseudo "probability density function" (PDF)
                        bin_dist = bin_dist / sum(bin_dist)

                        #Calculate Shannon entropy of PDF
                        h_p = 0
                        for iBin,mybin in enumerate(bin_dist):
                            h_p = h_p - mybin * log(mybin)

                        #MI = (Kullback-Leibler distance between h_p and uniform
                        #distribution) / (Entropy of uniform distribution) (see
                        #http://jn.physiology.org/content/104/2/1195)
                        MI_block[iFreq,iBand,iBlock] = (log(len(bins)) - h_p) / log(len(bins))

                    elif method == 'dPAC':

                        MI_block[iFreq,iBand,blk] = abs(mean(pow_istim*(exp(1j*pha_istim) - mean(exp(1j*pha_istim)))))


                # difference statistic = Block 2 - Block 1
                MI_diff[iFreq,iBand] = MI_block[iFreq,iBand,1]-MI_block[iFreq,iBand,0]

                if surrogate_analysis == 1:

                    phase_trl_onsets_all = [trial[3] for trial in allstimtimes]
                    amp_trl_onsets_all = [trial[2] for trial in allstimtimes]

                    #phases do no shuffle between runs, so predetermine them
                    pha_istims = [None]*2
                    cnt = 0
                    for iBlock in range(2):
                        phase_trl_onsets = phase_trl_onsets_all[cnt:cnt+int(stim_n[iBlock])]
                        cnt = cnt + int(stim_n[iBlock])
                        pha_istims[iBlock] = array([item for sublist in phase_trl_onsets for item in sublist])



                    for iRun in range(kruns):


                        MI_block_surrogate = zeros(2)

                        random.shuffle(amp_trl_onsets_all)

                        cnt = 0
                        for iBlock in range(2):

                            amp_trl_onsets = amp_trl_onsets_all[cnt:cnt+int(stim_n[iBlock])]
                            cnt = cnt + int(stim_n[iBlock])

                            if method == 'MI':
                                bin_dist = zeros([len(bins)])
                                for iBin in range(len(bins)):

                                    bin_power_list = array([])

                                    for iTrial,trl in enumerate(phase_trl_onsets):

                                        #find samples within phase bin (indices relative to stimulus onset (i.e. 0-400))
                                        ind = where(logical_and(trl>=bins[iBin],trl<bins[iBin]+bin_size))[0]

                                        #grow list of power during phase bin (power from random trial but with same post-stim indices)
                                        if any(ind):
                                            bin_power_list = append(bin_power_list,amp_trl_onsets[iTrial][ind])

                                    bin_dist[iBin] = mean(bin_power_list)

                                #Normalize distribution to yield pseudo "probability density function" (PDF)
                                bin_dist = bin_dist / sum(bin_dist)

                                #Calculate Shannon entropy of PDF
                                h_p = 0
                                for iBin,mybin in enumerate(bin_dist):
                                    h_p = h_p - mybin * log(mybin)

                                MI_block_surrogate[iBlock] = (log(len(bins)) - h_p) / log(len(bins))

                            elif method == 'dPAC':
                                pow_istim = array([item for sublist in amp_trl_onsets for item in sublist])
                                MI_block_surrogate[iBlock] = abs(mean(pow_istim*(exp(1j*pha_istims[iBlock]) - mean(exp(1j*pha_istims[iBlock])))))



                        MI_diff_surrogate[iFreq,iBand,iRun] = MI_block_surrogate[1]-MI_block_surrogate[0]

        save(open(MI_output_path+'_block_diff_surrogate_'+str(iStim),'wb'),MI_diff_surrogate)
        save(open(MI_output_path+'_block_diff_'+str(iStim),'wb'),MI_diff)
        save(open(MI_output_path+'_block_'+str(iStim),'wb'),MI_block)
def CFC_mean_vec_pre_post(subj,blocks,phase_elec,amp_elec,surrogate_analysis):
    import scipy.io as sio
    from numpy import arange,intersect1d,array,append,zeros,pi,angle,logical_and,mean,roll,save,where,empty,delete,in1d,random,rint,exp #for efficiency
    from eegfilt import eegfilt
    from scipy.signal import hilbert
    from math import log
    import pickle
    import os
    from random import randint
    from collections import defaultdict
    from itertools import chain

    #subglo_path = '/home/jcase/data/subj_globals.mat'
    #MI_output_path = '/home/jcase/data/' + subj + ''.join(blocks) + '/MI/e' + str(phase_elec) + '_e' + str(amp_elec)

    subglo_path = '/Users/johncase/Documents/UCSF Data/subj_globals.mat'
    MI_output_path = '/Users/johncase/Documents/UCSF Data/' + subj + '/' + subj + ''.join(blocks) + '/analysis/mean_vec/e' + str(phase_elec) + '_e' + str(amp_elec)

#    if not os.path.exists(MI_output_path):
#        os.makedirs(MI_output_path)

    amp_raw_data = {}
    pha_raw_data = {}

    per_chan_bad_epochs = {}
    stimID = {}

    #Trial windows
    bl = 0
    ps = 1 #in seconds

    #Surrogate Runs
    kruns = 200

    #Phase-providing frequency
    #fp = arange(1,15.1,0.1)
    #fp_bandwidth = arange(0.5,5.1,0.1)

    fp = arange(1,21,1)
    fp_bandwidth = arange(1,11,1)

    #fp = arange(2,3)
    #fp_bandwidth = arange(2,3)

    #Amplitude-providing frequency
    fa = array([70,150])

    #Define phase bins
    n_bins = 20
    bin_size = 2*pi/n_bins
    bins = arange(-pi,pi-bin_size,bin_size)

    #Define time_window (roughly entire block, exclude artifacts samples later)
    #t_0 = int(round(allstimtimes[0,0]*srate))
    #t_end = int(round((allstimtimes[-1,1]+3) *srate))
    #t_win = arange(t_0,t_end)

    MI_block = zeros((len(fp),len(fp_bandwidth),2),complex)
    MI_diff = zeros((len(fp),len(fp_bandwidth)))
    if surrogate_analysis == 1:
        MI_diff_surrogate = zeros((len(fp),len(fp_bandwidth),kruns))


    #Determine samples with artifacts
    bad_samp = defaultdict(list)
    allstimtimes = []
    pow = {}

    for blk in blocks:
        #Load ECOG Data

        #data_path = '/home/jcase/data/' + subj + blk + '/' + subj + '_' + blk + '_data.mat'

        data_path = '/Users/johncase/Documents/UCSF Data/' + subj + '/' + subj + blk + '/data/' + subj + '_' + blk + '_data.mat'

        ecog_data = sio.loadmat(data_path)['ecogData']
        amp_raw_data[blk] = ecog_data[amp_elec-1,:]
        pha_raw_data[blk] = ecog_data[phase_elec-1,:]
        del ecog_data



        #Load subject globals
        all_subj_data = sio.loadmat(subglo_path,struct_as_record=False, squeeze_me=True)['subj_globals']
        subj_data = getattr(getattr(all_subj_data,subj),blk)
        srate = subj_data.srate
        per_chan_bad_epochs[blk] = subj_data.per_chan_bad_epochs
        all_onsets = subj_data.allstimtimes[:,0]
        stimID[blk] = subj_data.stimID
        all_onsets = delete(all_onsets,where(stimID[blk]==10),0) #delete clicks
        stimID[blk] = delete(stimID[blk],where(stimID[blk]==10)) #delete clicks


        #Identify bad samples for each block
        if per_chan_bad_epochs[blk][phase_elec-1].size == 2:
            bad_samp[blk] = append(bad_samp[blk],arange(srate*per_chan_bad_epochs[blk][phase_elec-1][0],srate*per_chan_bad_epochs[blk][phase_elec-1][1]))
        else:
            for epoch in per_chan_bad_epochs[blk][phase_elec-1]:
                bad_samp[blk] = append(bad_samp[blk],arange(srate*epoch[0],srate*epoch[1]))

        if not phase_elec == amp_elec:
            if per_chan_bad_epochs[blk][amp_elec-1].size == 2:
                bad_samp[blk] = append(bad_samp[blk],arange(srate*per_chan_bad_epochs[blk][amp_elec-1][0],srate*per_chan_bad_epochs[blk][amp_elec-1][1]))
            else:
                for epoch in per_chan_bad_epochs[blk][amp_elec-1]:
                    bad_samp[blk] = append(bad_samp[blk],arange(srate*epoch[0],srate*epoch[1]))

        #Do high-gamma filtering
        pow[blk],filtwt = eegfilt(amp_raw_data[blk],srate,fa[0],[])
        pow[blk],filtwt = eegfilt(pow[blk][0],srate,[],fa[1])
        pow[blk] = abs(hilbert(pow[blk][0][0:len(pow[blk][0])-1]))


        #make each onset a tuple containing the block name, phase, and amp data
        for onset in all_onsets:

            trl = arange(rint(onset*srate),rint((onset+ps)*srate)).astype(int)
            #if onset does not overlap with an artifact
            if not any(intersect1d(trl,bad_samp[blk])):
                allstimtimes.extend([(blk,trl,pow[blk][trl])])

    #Calculate MI for each phase-providing central-frequencies / bandwidths

    for iFreq,freq in enumerate(fp):
        for iBand,band in enumerate(fp_bandwidth):

            if freq-(band/2) < 0.5:
                MI_diff[iFreq,iBand] = 0
                continue

            print('freq = ' + str(freq) + ', bw = ' + str(band))

            pha = {}

            stim_n = zeros(2)
            for iBlock,blk in enumerate(blocks):

                #Do phase-providing phase filtering
                pha[blk] = zeros([1,len(amp_raw_data)])
                pha[blk] = eegfilt(pha_raw_data[blk],srate,[],freq+(band/2))[0][0]
                pha[blk] = eegfilt(pha[blk],srate,freq-(band/2),[])[0][0]
                pha[blk] = angle(hilbert(pha[blk][0:len(pha[blk])-1]))

                #add phase information
                for iTrial,trial in enumerate(allstimtimes):
                    if trial[0] == blk:
                        allstimtimes[iTrial] = allstimtimes[iTrial][:3] + (pha[blk][trial[1]],)

                #count number of trials in each block (for surrogate analysis)
                stim_n[iBlock] = int(len([trial for trial in allstimtimes if trial[0] == blk]))

                #retrieve power from allstimtimes
                pow_istim = [trial[2] for trial in allstimtimes if trial[0] == blk]
                pow_istim = array([item for sublist in pow_istim for item in sublist])

                #retrieve phase from allstimtimes
                pha_istim = [trial[3] for trial in allstimtimes if trial[0] == blk]
                pha_istim = array([item for sublist in pha_istim for item in sublist])

                #Calculate mean vector of composite signal
                MI_block[iFreq,iBand,iBlock] = mean(pow_istim * exp(1j * pha_istim))

            # difference statistic = Block 2 - Block 1
            MI_diff[iFreq,iBand] = abs(MI_block[iFreq,iBand,1])-abs(MI_block[iFreq,iBand,0])

            if surrogate_analysis == 1:

                phase_trl_onsets_all = [trial[3] for trial in allstimtimes]
                amp_trl_onsets_all = [trial[2] for trial in allstimtimes]

                for iRun in range(kruns):

                    if iRun%10 == 0:
                        print '{}\r'.format('Run ' + str(iRun+10)),

                    MI_block_surrogate = zeros(2)

                    random.shuffle(amp_trl_onsets_all)

                    cnt = 0
                    for iBlock,blk in enumerate(blocks):

                        phase_trls = phase_trl_onsets_all[cnt:cnt+int(stim_n[iBlock])]
                        phase_trls = array([item for sublist in phase_trls for item in sublist])

                        amp_trls = amp_trl_onsets_all[cnt:cnt+int(stim_n[iBlock])]
                        amp_trls = array([item for sublist in amp_trls for item in sublist])

                        cnt = cnt + int(stim_n[iBlock])

                        MI_block_surrogate[iBlock] = abs(mean(amp_trls * exp(1j * phase_trls)))
                    MI_diff_surrogate[iFreq,iBand,iRun] = MI_block_surrogate[1]-MI_block_surrogate[0]



    save(open(MI_output_path+'_block_diff','wb'),MI_diff)
    save(open(MI_output_path+'_block','wb'),MI_block)
    if surrogate_analysis == 1:
        save(open(MI_output_path+'_block_diff_surrogate','wb'),MI_diff_surrogate)
def CFC_low_high_pre_post(subj,blocks,phase_elec,amp_elec,surrogate_analysis):
    import scipy.io as sio
    from numpy import arange,intersect1d,array,append,zeros,pi,angle,logical_and,mean,roll,save,where,empty,delete,in1d,random,rint #for efficiency
    from eegfilt import eegfilt
    from scipy.signal import hilbert
    from math import log
    import pickle
    import os
    from random import randint
    from collections import defaultdict

    #subglo_path = '/home/jcase/data/subj_globals.mat'
    #MI_output_path = '/home/jcase/data/' + subj + ''.join(blocks) + '/MI/e' + str(phase_elec) + '_e' + str(amp_elec)

    subglo_path = '/Users/johncase/Documents/UCSF Data/subj_globals.mat'
    MI_output_path = '/Users/johncase/Documents/UCSF Data/' + subj + '/' + subj + ''.join(blocks) + '/analysis/MI_ap/e' + str(phase_elec) + '_e' + str(amp_elec)

    amp_raw_data = {}
    pha_raw_data = {}

    per_chan_bad_epochs = {}
    stimID = {}

    #Trial windows
    bl = 0
    ps = 1 #in seconds

    #Surrogate Runs
    kruns = 200

    #Phase-providing frequency
    #fp = arange(1,15.1,0.1)
    #fp_bandwidth = arange(0.5,5.1,0.1)

    fp = [(i,i+2) for i in arange(3,50,2)]
#    fp_bandwidth = arange(1,11,1)

    #fp = arange(2,3)
    #fp_bandwidth = arange(2,3)

    #Amplitude-providing frequency
    fa = [(i,i+4) for i in arange(50,160,4)]

    #Define phase bins
    n_bins = 18
    bin_size = 2*pi/n_bins
    bins = arange(-pi,pi-bin_size,bin_size)

    MI_block = zeros((len(fa),len(fp),2))
    MI_diff = zeros((len(fa),len(fp)))
    if surrogate_analysis == 1:
        MI_diff_surrogate = zeros((len(fa),len(fp),kruns))


    #Determine samples with artifacts
    bad_samp = defaultdict(list)
    allstimtimes = []
    pow = {}

    for blk in blocks:
        #Load ECOG Data

        #data_path = '/home/jcase/data/' + subj + blk + '/' + subj + '_' + blk + '_data.mat'

        data_path = '/Users/johncase/Documents/UCSF Data/' + subj + '/' + subj + blk + '/data/' + subj + '_' + blk + '_data.mat'

        ecog_data = sio.loadmat(data_path)['ecogData']
        amp_raw_data[blk] = ecog_data[amp_elec-1,:]
        pha_raw_data[blk] = ecog_data[phase_elec-1,:]
        del ecog_data

        #Load subject globals
        all_subj_data = sio.loadmat(subglo_path,struct_as_record=False, squeeze_me=True)['subj_globals']
        subj_data = getattr(getattr(all_subj_data,subj),blk)
        srate = subj_data.srate
        per_chan_bad_epochs[blk] = subj_data.per_chan_bad_epochs
        all_onsets = subj_data.allstimtimes[:,0]
        stimID[blk] = subj_data.stimID
        all_onsets = delete(all_onsets,where(stimID[blk]==10),0) #delete clicks
        stimID[blk] = delete(stimID[blk],where(stimID[blk]==10)) #delete clicks

        #Identify bad samples for each block
        if per_chan_bad_epochs[blk][phase_elec-1].size == 2:
            bad_samp[blk] = append(bad_samp[blk],arange(srate*per_chan_bad_epochs[blk][phase_elec-1][0],srate*per_chan_bad_epochs[blk][phase_elec-1][1]))
        else:
            for epoch in per_chan_bad_epochs[blk][phase_elec-1]:
                bad_samp[blk] = append(bad_samp[blk],arange(srate*epoch[0],srate*epoch[1]))

        if not phase_elec == amp_elec:
            if per_chan_bad_epochs[blk][amp_elec-1].size == 2:
                bad_samp[blk] = append(bad_samp[blk],arange(srate*per_chan_bad_epochs[blk][amp_elec-1][0],srate*per_chan_bad_epochs[blk][amp_elec-1][1]))
            else:
                for epoch in per_chan_bad_epochs[blk][amp_elec-1]:
                    bad_samp[blk] = append(bad_samp[blk],arange(srate*epoch[0],srate*epoch[1]))

    for iFreq_a,freq_a in enumerate(fa):

        allstimtimes = []

        for iBlock,blk in enumerate(blocks):
            #Do high-gamma filtering
            pow[blk],filtwt = eegfilt(amp_raw_data[blk],srate,freq_a[0],[])
            pow[blk],filtwt = eegfilt(pow[blk][0],srate,[],freq_a[1])
            pow[blk] = abs(hilbert(pow[blk][0][0:len(pow[blk][0])-1]))

            #make each onset a tuple containing the block name, phase, and amp data

            for onset in all_onsets:

                trl = arange(rint(onset*srate),rint((onset+ps)*srate)).astype(int)
                #if onset does not overlap with an artifact
                if not any(intersect1d(trl,bad_samp[blk])):
                    allstimtimes.extend([(blk,trl,pow[blk][trl])])

        for iFreq_p,freq_p in enumerate(fp):


            print('amp = ' + str(int(mean([freq_a[0],freq_a[1]]))) + ', pha  = ' + str(int(mean([freq_p[0],freq_p[1]]))))

            pha = {}

            stim_n = zeros(2)
            for iBlock,blk in enumerate(blocks):

                #Do phase-providing phase filtering
                pha[blk] = zeros([1,len(pha_raw_data)])
                pha[blk] = eegfilt(pha_raw_data[blk],srate,[],freq_p[1])[0][0]
                pha[blk] = eegfilt(pha[blk],srate,freq_p[0],[])[0][0]
                pha[blk] = angle(hilbert(pha[blk][0:len(pha[blk])-1]))

                #add phase information
                for iTrial,trial in enumerate(allstimtimes):
                    if trial[0] == blk:
                        allstimtimes[iTrial] = allstimtimes[iTrial][:3] + (pha[blk][trial[1]],)

                #count number of trials in each block (for surrogate analysis)
                stim_n[iBlock] = int(len([trial for trial in allstimtimes if trial[0] == blk]))

                #retrieve power from allstimtimes
                pow_istim = [trial[2] for trial in allstimtimes if trial[0] == blk]
                pow_istim = array([item for sublist in pow_istim for item in sublist])

                #retrieve phase from allstimtimes
                pha_istim = [trial[3] for trial in allstimtimes if trial[0] == blk]
                pha_istim = array([item for sublist in pha_istim for item in sublist])


                #Calculate mean amplitude within each phase bin to yield a
                #distribution of amplitude(phase)
                bin_dist = zeros([len(bins)])
                for iBin in range(len(bins)):
                    ind = logical_and(pha_istim>=bins[iBin],pha_istim<bins[iBin]+bin_size)
                    bin_dist[iBin] = mean(pow_istim[ind])

                #Normalize distribution to yield pseudo "probability density function" (PDF)
                bin_dist = bin_dist / sum(bin_dist)

                #Calculate Shannon entropy of PDF
                h_p = 0
                for iBin,mybin in enumerate(bin_dist):
                    h_p = h_p - mybin * log(mybin)

                #MI = (Kullback-Leibler distance between h_p and uniform
                #distribution) / (Entropy of uniform distribution) (see
                #http://jn.physiology.org/content/104/2/1195)
                MI_block[iFreq_a,iFreq_p,iBlock] = (log(len(bins)) - h_p) / log(len(bins))

            # difference statistic = Block 2 - Block 1
            MI_diff[iFreq_a,iFreq_p] = MI_block[iFreq_a,iFreq_p,1]-MI_block[iFreq_a,iFreq_p,0]

            if surrogate_analysis == 1:

                phase_trl_onsets_all = [trial[3] for trial in allstimtimes]
                amp_trl_onsets_all = [trial[2] for trial in allstimtimes]

                for iRun in range(kruns):

                    if iRun%10 == 0:
                        print '{}\r'.format('Run ' + str(iRun+10)),

                    MI_block_surrogate = zeros(2)

                    random.shuffle(amp_trl_onsets_all)

                    cnt = 0
                    for iBlock,blk in enumerate(blocks):

                        phase_trl_onsets = phase_trl_onsets_all[cnt:cnt+int(stim_n[iBlock])]
                        amp_trl_onsets = amp_trl_onsets_all[cnt:cnt+int(stim_n[iBlock])]
                        cnt = cnt + int(stim_n[iBlock])

                        bin_dist = zeros([len(bins)])
                        for iBin in range(len(bins)):

                            bin_power_list = array([])

                            for iTrial,trl in enumerate(phase_trl_onsets):

                                #find samples within phase bin (indices relative to stimulus onset (i.e. 0-400))
                                ind = where(logical_and(trl>=bins[iBin],trl<bins[iBin]+bin_size))[0]

                                if any(ind):
                                    bin_power_list = append(bin_power_list,amp_trl_onsets[iTrial][ind])

                            bin_dist[iBin] = mean(bin_power_list)

                        #Normalize distribution to yield pseudo "probability density function" (PDF)
                        bin_dist = bin_dist / sum(bin_dist)

                        #Calculate Shannon entropy of PDF
                        h_p = 0
                        for iBin,mybin in enumerate(bin_dist):
                            h_p = h_p - mybin * log(mybin)

                        MI_block_surrogate[iBlock] = (log(len(bins)) - h_p) / log(len(bins))
                    MI_diff_surrogate[iFreq_a,iFreq_p,iRun] = MI_block_surrogate[1]-MI_block_surrogate[0]

    save(open(MI_output_path+'_block_diff','wb'),MI_diff)
    save(open(MI_output_path+'_block','wb'),MI_block)
    if surrogate_analysis == 1:
        save(open(MI_output_path+'_block_diff_surrogate','wb'),MI_diff_surrogate)
def CFC_rew_pun_permutation(subj,block,phase_elec,amp_elec,surrogate_analysis):
    import scipy.io as sio
    from numpy import arange,array,append,zeros,pi,angle,logical_and,mean,roll,save,where,empty,delete,in1d,random,rint #for efficiency
    from eegfilt import eegfilt
    from scipy.signal import hilbert
    from math import log
    import pickle
    import os.path
    from random import randint

    #data_path = '/home/jcase/data/' + subj + block + '/' + subj + '_' + block + '_data.mat'
    #subglo_path = '/home/jcase/data/subj_globals.mat'
    #MI_output_path = '/home/jcase/data/' + subj + block + '/MI/e' + str(phase_elec) + '_e' + str(amp_elec)

    data_path = '/Users/johncase/Documents/UCSF Data/' + subj + '/' + subj + block + '/data/' + subj + '_' + block + '_data.mat'
    subglo_path = '/Users/johncase/Documents/UCSF Data/subj_globals.mat'
    MI_output_path = '/Users/johncase/Documents/UCSF Data/' + subj + '/' + subj + block + '/analysis/MI/rew_pun/e' + str(phase_elec) + '_e' + str(amp_elec)

    #Load ECOG Data
    ecog_data = sio.loadmat(data_path)['ecogData']
    amp_raw_data = ecog_data[amp_elec-1,:]
    pha_raw_data = ecog_data[phase_elec-1,:]
    del ecog_data

    #Load subject globals
    all_subj_data = sio.loadmat(subglo_path,struct_as_record=False, squeeze_me=True)['subj_globals']
    subj_data = getattr(getattr(all_subj_data,subj),block)
    srate = subj_data.srate
    per_chan_bad_epochs = subj_data.per_chan_bad_epochs
    allstimtimes = subj_data.allstimtimes
    stimID = subj_data.stimID

    #Trial windows
    bl = 0
    ps = 1 #in seconds

    #Surrogate Runs
    kruns = 200

    #Phase-providing frequency
    #fp = arange(1,15.1,0.1)
    #fp_bandwidth = arange(0.5,5.1,0.1)

    #fp = arange(1,21,1)
    #fp_bandwidth = arange(1,11,1)

    fp = arange(2,3)
    fp_bandwidth = arange(2,3)

    #Amplitude-providing frequency
    fa = array([70,150])

    #Define phase bins
    n_bins = 20
    bin_size = 2*pi/n_bins
    bins = arange(-pi,pi-bin_size,bin_size)

    #Define time_window (roughly entire block, exclude artifacts samples later)
    #t_0 = int(round(allstimtimes[0,0]*srate))
    #t_end = int(round((allstimtimes[-1,1]+3) *srate))
    #t_win = arange(t_0,t_end)

    MI_stim = empty((len(fp),len(fp_bandwidth),3))
    MI_diff = empty((len(fp),len(fp_bandwidth)))
    if surrogate_analysis == 1:
        MI_diff_surrogate = empty((len(fp),len(fp_bandwidth),kruns))



    #Determine samples with artifacts
    bad_samp = array([])
    if per_chan_bad_epochs[phase_elec-1].size == 2:
        bad_samp = append(bad_samp,arange(srate*per_chan_bad_epochs[phase_elec-1][0],srate*per_chan_bad_epochs[phase_elec-1][1]))
    else:
        for epoch in per_chan_bad_epochs[phase_elec-1]:
            bad_samp = append(bad_samp,arange(srate*epoch[0],srate*epoch[1]))

    if not phase_elec == amp_elec:
        if per_chan_bad_epochs[amp_elec-1].size == 2:
            bad_samp = append(bad_samp,arange(srate*per_chan_bad_epochs[amp_elec-1][0],srate*per_chan_bad_epochs[amp_elec-1][1]))
        else:
            for epoch in per_chan_bad_epochs[amp_elec-1]:
                bad_samp = append(bad_samp,arange(srate*epoch[0],srate*epoch[1]))



    #Do high-gamma filtering
    pow,filtwt = eegfilt(amp_raw_data,srate,fa[0],[])
    pow,filtwt = eegfilt(pow[0],srate,[],fa[1])
    pow = abs(hilbert(pow[0][0:len(pow[0])-1]))
   # pow = pow[good_samps] #exclude bad` samples

    #Calculate MI for each phase-providing central-frequencies / bandwidths

    for iFreq,freq in enumerate(fp):
        for iBand,band in enumerate(fp_bandwidth):

            if freq-(band/2) < 0.5:
                MI_diff[iFreq,iBand] = 0
                continue

            print('freq = ' + str(freq) + ', bw = ' + str(band))

            #Do phase-providing phase filtering

            pha = zeros([1,len(amp_raw_data)])
            pha = eegfilt(pha_raw_data,srate,[],freq+(band/2))[0][0]
            pha = eegfilt(pha,srate,freq-(band/2),[])[0][0]
            pha = angle(hilbert(pha[0:len(pha)-1]))



            for iStim,iStimID in enumerate(range(1,4)):

                trl_onsets = rint(allstimtimes[where(stimID==iStimID)[0],0]*srate).astype(int)

                trl_samps = array([]).astype(int)
                for trl in trl_onsets:
                    trl_samps = append(trl_samps,arange(int(trl+bl*srate),int(trl+ps*srate)))

                #exclude bad samples
                trl_samps = trl_samps[~in1d(trl_samps,bad_samp)]

                #keep phase/pow info only within iStim trl windows
                pha_istim = pha[trl_samps]
                pow_istim = pow[trl_samps]

                #Calculate mean amplitude within each phase bin to yield a
                #distribution of amplitude(phase)
                bin_dist = zeros([len(bins)])
                for iBin in range(len(bins)):
                    ind = logical_and(pha_istim>=bins[iBin],pha_istim<bins[iBin]+bin_size)
                    bin_dist[iBin] = mean(pow_istim[ind])

                #Normalize distribution to yield pseudo "probability density function" (PDF)
                bin_dist = bin_dist / sum(bin_dist)

                #Calculate Shannon entropy of PDF
                h_p = 0
                for iBin,mybin in enumerate(bin_dist):
                    h_p = h_p - mybin * log(mybin)

                #MI = (Kullback-Leibler distance between h_p and uniform
                #distribution) / (Entropy of uniform distribution) (see
                #http://jn.physiology.org/content/104/2/1195)
                MI_stim[iFreq,iBand,iStim] = (log(len(bins)) - h_p) / log(len(bins))

            # difference statistic = abs(A-B) + abs(A-C) + abs(B-C)
            MI_diff[iFreq,iBand] = abs(MI_stim[iFreq,iBand,0]-MI_stim[iFreq,iBand,1]) + abs(MI_stim[iFreq,iBand,0]-MI_stim[iFreq,iBand,2]) + abs(MI_stim[iFreq,iBand,1]-MI_stim[iFreq,iBand,2])

            if surrogate_analysis == 1:

                stim_n = zeros(3)
                stim_n[0] = len(where(stimID==1)[0])
                stim_n[1] = len(where(stimID==2)[0])
                stim_n[2] = len(where(stimID==3)[0])
                shuffle_ind = arange(sum(stim_n)).astype(int)


                for iRun in range(kruns):

                    if iRun%10 == 0:
                        print '{}\r'.format('Run ' + str(iRun+10)),

                    MI_stim_surrogate = zeros(3)

                    random.shuffle(shuffle_ind)

                    cnt = 0
                    for iStim,iStimID in enumerate(range(1,4)):

                        phase_trl_onsets = rint(allstimtimes[where(stimID==iStimID)[0],0]*srate).astype(int)
                        amp_trl_onsets =  rint(allstimtimes[shuffle_ind[cnt:cnt+len(phase_trl_onsets)],0]*srate).astype(int)
                        cnt = cnt + len(phase_trl_onsets)


                        bin_dist = zeros([len(bins)])
                        for iBin in range(len(bins)):

                            bin_power_list = array([])

                            for iTrial,trl in enumerate(phase_trl_onsets):

                                #find sample of onset to 1s + onset
                                phase_trl = arange(trl,trl+ps*srate).astype(int)

                                #find samples within phase bin (indices relative to stimulus onset (i.e. 0-400))
                                ind = where(logical_and.reduce((pha[phase_trl]>=bins[iBin],pha[phase_trl]<bins[iBin]+bin_size,~in1d(phase_trl,bad_samp))))[0]

                                #find amp samples with the same post-stimulus latency as "inds"
                                amp_samples = amp_trl_onsets[iTrial]+ind
                                amp_samples = amp_samples[~in1d(amp_trl_onsets[iTrial]+ind,bad_samp)]

                                #grow list of power during phase bin (power from random trial but with same post-stim indices)
                                bin_power_list = append(bin_power_list,pow[amp_samples])

                            bin_dist[iBin] = mean(bin_power_list)

                        #Normalize distribution to yield pseudo "probability density function" (PDF)
                        bin_dist = bin_dist / sum(bin_dist)

                        #Calculate Shannon entropy of PDF
                        h_p = 0
                        for iBin,mybin in enumerate(bin_dist):
                            h_p = h_p - mybin * log(mybin)

                        MI_stim_surrogate[iStim] = (log(len(bins)) - h_p) / log(len(bins))
                    MI_diff_surrogate[iFreq,iBand,iRun] = abs(MI_stim_surrogate[0]-MI_stim_surrogate[1]) + abs(MI_stim_surrogate[0]-MI_stim_surrogate[2]) + abs(MI_stim_surrogate[1]-MI_stim_surrogate[2])



    save(open(MI_output_path+'_diff','wb'),MI_diff)
    save(open(MI_output_path+'_stim','wb'),MI_stim)
    if surrogate_analysis == 1:
        save(open(MI_output_path+'_diff_surrogate','wb'),MI_diff_surrogate)
def calculate_CFC(data,srate,onsets,category_array,comparison_array,phase_elec,amp_elec,freqs,bandwidths,CFC_method,surrogate_analysis):

    """
    Input: 1) Data: matrix, <channels,samples>
           2) srate: integer, sampling rate
           3) onsets: numpy array, <length trials> stimulus onsets in samples
           4) category_array: numpy array, <length trials>, where values represent a group membership for a particular trial. Each group is independent of one another
           5) comparison_array: numpy array, <length trials>, where values correspond to which groups will be subtracted together, *within* each category (as specified above). (e.g., an array of [1 1 1 2 2 2] would yield a (2nd half - 1st half) comparison)
           6) phase_elec: integer, phase-providing electrode
           7) amp_elec: integer, amplitude-providing electrode
           8) freqs: numpy array, all central frequencies to be calculated
           9) bandwidths: numpy array, all bandwidths to be calculated
           10) CFC_method: string, 'MI','PAC','dPAC'
           11) surrogate_analysis: 0 or 1, compute surrogate_analysis or not

    Output: 1) CFC_groups, numpy array, <freqs,bandwidths,categories,comparisons>, CFC at all freqs and bandwidths for each category and comparison
            2) CFC_diff, numpy array, <freqs,bandwidths,categories>,   CFC at all freqs and bandwidths after comparisons within each category are subtracted
            3) CFC_diff_surrogate, numpy array, <freqs,bandwidths,categories,kruns>,   same as CFC_diff but after kruns randomized surrogate permutations
    """

    import numpy as np
    from eegfilt import eegfilt
    from scipy.signal import hilbert
    from math import log
    import pickle
    import os
    import random
    from collections import defaultdict

    #Initialize
    categories = np.unique(category_array)               #number of comparisons (e.g., number of stimIDs, values, etc.)
    comparisons = np.unique(comparison_array)            #number of groups within each comparison (e.g., 2nd half vs 1st half)

    #Number of surrogate runs
    kruns = 200

    CFC_groups = np.zeros((len(freqs),len(bandwidths),len(categories),len(comparisons))) # <frequencies,bandwidths,comparisons,categories in each comparison>
    CFC_diff = np.zeros((len(freqs),len(bandwidths),len(categories)))                     # <frequencies,bandwidths,comparisons>
    CFC_surrogate_diff = np.zeros((len(freqs),len(bandwidths),len(categories),kruns))      # <frequencies,bandwidths,comparisons,surrogate runs>

    #Amplitude-providing frequency - High Gamma
    freq_amp = np.array([70,150])

    #Trial window
    post_secs = 1  #How many seconds after onset should be analyzed

    #Keep data for only the amplitude- and phase-providing electrodes
    amp_elec_data = data[amp_elec-1,:]
    phase_elec_data = data[phase_elec-1,:]
    del data

    #Filter and calculate power for amplitude-providing electrode
    pow,filtwt = eegfilt(amp_elec_data,srate,[],freq_amp[1])  #low-pass filter
    pow,filtwt = eegfilt(pow[0],srate,freq_amp[0],[])         #high-pass filter
    pow = abs(hilbert(pow[0][0:len(pow[0])-1]))         #analytic amplitude via hilbert

    #For each trial, a tuple containing the 0) trial index, 1) trial category, 2) trial comparison, 3) trial samples, 4) trial analytic amp
    trial_data = []
    for iTrial,onset in enumerate(onsets):
        trl_category = category_array[iTrial]
        trl_comparison = comparison_array[iTrial]
        trl_samples = np.arange(onset,onset+np.rint(post_secs*srate)).astype(int)
        trial_data.extend([(iTrial,trl_category,trl_comparison,trl_samples,pow[trl_samples])])

    #Calculate MI for each phase-providing central-frequencies / bandwidths
    for iFreq,freq in enumerate(freqs):
        for iBand,band in enumerate(bandwidths):

            #Skip if bandwidth is too big for frequency
            if freq-(band/2) < 0.5:
                CFC_diff[iFreq,iBand,:] = 0
                continue

            print('freq = ' + str(freq) + ', bw = ' + str(band))

            #Filter and calculate instantaneous phase for phase-providing electrode
            pha = np.zeros([1,len(pow)])               # create time series of zeros (effectively deleting previous time series, if it exists)
            pha = eegfilt(phase_elec_data,srate,[],freq+(band/2))[0][0] # low-pass filter from 0 hz to the current freq plus half the bandwidth
            pha = eegfilt(pha,srate,freq-(band/2),[])[0][0]          # high-pass filter from current freq minus the other half the bandwidth to infinity
            pha = np.angle(hilbert(pha[0:len(pha)-1]))             # calculate instantaneous phase via hilbert

            #Update "trial_data" with phase angles for each trial
            for iTrial in range(len(trial_data)):
                trial_data[iTrial] = trial_data[iTrial][:5] + (pha[trial_data[iTrial][3]],)

            #Calculate CFC for each category/comparison combination
            for iCategory,category in enumerate(categories):
                for iComparison,comparison in enumerate(comparisons):

                    #retrieve power from trial_data and unfold power data into one list
                    pow_data = [trial[4] for trial in trial_data if trial[1] == category and trial[2] == comparison] #pick out power data for appropriate trials
                    pow_data = np.array([item for sublist in pow_data for item in sublist])                             #flatten out nested trial structure in the array, returning one, unnested list

                    #retrieve power from trial_data and unfold power data into one list
                    pha_data = [trial[5] for trial in trial_data if trial[1] == category and trial[2] == comparison] #pick out power data for appropriate trials
                    pha_data = np.array([item for sublist in pha_data for item in sublist])                             #flatten out nested trial structure in the array, returning one, unnested list


                    #http://jn.physiology.org/content/104/2/1195
                    if CFC_method == 'MI':

                        #Define phase bins
                        n_bins = 18
                        bin_size = 2*np.pi/n_bins
                        bins = np.arange(-np.pi,np.pi-bin_size,bin_size)

                        #Calculate mean amplitude within each phase bin to yield a
                        #distribution of amplitude(phase)
                        bin_dist = np.zeros([len(bins)])
                        for iBin in range(len(bins)):
                            ind = np.logical_and(pha_data>=bins[iBin],pha_data<bins[iBin]+bin_size)
                            bin_dist[iBin] = np.mean(pow_data[ind])

                        #Normalize distribution to yield pseudo probability density function
                        bin_dist = bin_dist / sum(bin_dist)

                        #Calculate Shannon entropy of PDF
                        h_p = 0
                        for iBin,mybin in enumerate(bin_dist):
                            h_p = h_p - mybin * np.log(mybin)

                        #MI = (Kullback-Leibler distance between h_p and uniform
                        #distribution) / (Entropy of uniform distribution)
                        CFC_groups[iFreq,iBand,iCategory,iComparison] = (np.log(len(bins)) - h_p) / log(len(bins))

                    elif CFC_method == 'PAC':
                        CFC_groups[iFreq,iBand,iCategory,iComparison] = np.abs(np.mean(pow_data*(np.exp(1j*pha_data))))

                    elif CFC_method == 'dPAC':
                        CFC_groups[iFreq,iBand,iCategory,iComparison] = np.abs(np.mean(pow_data*(np.exp(1j*pha_data) - np.mean(np.exp(1j*pha_data)))))

                #If two comparisons (e.g., late vs early), calculate difference.
                #Positive difference = 2 > 1, negative difference = 1 > 2
                if len(comparisons)==2:
                    CFC_diff[iFreq,iBand,iCategory] = CFC_groups[iFreq,iBand,iCategory,1] - CFC_groups[iFreq,iBand,iCategory,0]

                #If three comparisons, calculate the sum of the absolute differences between each pair.
                #High difference signifies selectivity, low difference signifies similiarity
                elif len(comparisons)==3:
                    CFC_diff[iFreq,iBand,iCategory] = np.abs(CFC_groups[iFreq,iBand,iCategory,1] - CFC_groups[iFreq,iBand,iCategory,0]) + np.abs(CFC_groups[iFreq,iBand,iCategory,2] - CFC_groups[iFreq,iBand,iCategory,0]) + np.abs(CFC_groups[iFreq,iBand,iCategory,2] - CFC_groups[iFreq,iBand,iCategory,1])

                if surrogate_analysis == 1:

                    #Calculate surrogate distribution for each category combination
                    for iCategory,category in enumerate(categories):
                        pha_trials = [trial[5] for trial in trial_data if trial[1] == category]
                        amp_trials = [trial[4] for trial in trial_data if trial[1] == category]

                        #it's not necessary to shuffle phases, so calculate them before randomization for efficiency
                        pha_data = [None]*len(comparisons)
                        cnt = 0
                        comparison_array_for_category = [trial[2] for trial in trial_data if trial[1] == category] #same as "comparison_array" but only containing elements equal to current category
                        for iComparison,comparison in enumerate(comparisons):

                            #find n-trials, where n = the number of trials in each comparison category
                            phase_for_comparison = pha_trials[cnt:cnt+sum(comparison_array_for_category==comparison)]

                            #update starting position for the previous line
                            cnt = cnt + sum(comparison_array_for_category==comparison)-1

                            #flatten out nested trial structure in the array, returning one, unnested list
                            pha_data[iComparison] = np.array([item for sublist in phase_for_comparison for item in sublist])

                        #Shuffle the amplitude time course between trials, while maintaining the structure within each time course
                        for iRun in range(kruns):

                            CFC_group_surrogate = np.zeros(len(comparisons))

                            #Shuffle amplitudes
                            random.shuffle(amp_trials)

                            cnt = 0
                            for iComparison,comparison in enumerate(comparisons):

                                amp_for_comparison = amp_trials[cnt:cnt+sum(comparison_array_for_category==comparison)]
                                cnt = cnt + sum(comparison_array_for_category==comparison)-1
                                pow_data = np.array([item for sublist in amp_for_comparison for item in sublist])

                                #http://jn.physiology.org/content/104/2/1195
                                if CFC_method == 'MI':

                                    #Calculate mean amplitude within each phase bin to yield a
                                    #distribution of amplitude(phase)
                                    bin_dist = np.zeros([len(bins)])
                                    for iBin in range(len(bins)):
                                        ind = np.logical_and(pha_data[iComparison]>=bins[iBin],pha_data[iComparison]<bins[iBin]+bin_size)
                                        bin_dist[iBin] = np.mean(pow_data)

                                    #Normalize distribution to yield pseudo probability density function
                                    bin_dist = bin_dist / sum(bin_dist)

                                    #Calculate Shannon entropy of PDF
                                    h_p = 0
                                    for iBin,mybin in enumerate(bin_dist):
                                        h_p = h_p - mybin * np.log(mybin)

                                    #MI = (Kullback-Leibler distance between h_p and uniform
                                    #distribution) / (Entropy of uniform distribution)
                                    CFC_group_surrogate[iComparison] = (np.log(len(bins)) - h_p) / log(len(bins))

                                elif CFC_method == 'PAC':
                                    CFC_group_surrogate[iComparison] = np.abs(np.mean(pow_data*(np.exp(1j*pha_data[iComparison]))))

                                elif CFC_method == 'dPAC':
                                    CFC_group_surrogate[iComparison] = np.abs(np.mean(pow_data*(np.exp(1j*pha_data[iComparison]) - np.mean(np.exp(1j*pha_data[iComparison])))))

                            #If two comparisons (e.g., late vs early), calculate difference.
                            #Positive difference = 2 > 1, negative difference = 1 > 2
                            if len(comparisons)==2:
                                CFC_surrogate_diff[iFreq,iBand,iCategory,iRun] = CFC_group_surrogate[1] - CFC_group_surrogate[0]

                            #If three comparisons, calculate the sum of the absolute differences between each pair.
                            #High difference signifies selectivity, low difference signifies similiarity
                            elif len(comparisons)==3:
                                CFC_surrogate_diff[iFreq,iBand,iCategory,iRun] = np.abs(CFC_group_surrogate[1] - CFC_group_surrogate[0]) + np.abs(CFC_group_surrogate[2] - CFC_group_surrogate[0]) + np.abs(CFC_group_surrogate[2] - CFC_group_surrogate[1])




    return CFC_groups, CFC_diff, CFC_surrogate_diff