forked from wmvanvliet/beamformer_simulation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dics_megset.py
117 lines (94 loc) · 4.38 KB
/
dics_megset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import mne
import numpy as np
import pandas as pd
from mne.beamformer import make_dics, apply_dics_csd
from config import dics_settings, fname, args
from megset.config import fname as megset_fname
from megset.config import freq_range
subject = args.subject
print(f'Running analsis for subject {subject}')
mne.set_log_level(False) # Shhh
###############################################################################
# Load the data
###############################################################################
epochs = mne.read_epochs(megset_fname.epochs_long(subject=subject))
fwd = mne.read_forward_solution(megset_fname.fwd(subject=subject))
dip = mne.read_dipole(megset_fname.ecd(subject=subject))
###############################################################################
# Sensor-level analysis for beamformer
###############################################################################
epochs_grad = epochs.copy().pick_types(meg='grad')
epochs_mag = epochs.copy().pick_types(meg='mag')
epochs_joint = epochs.copy().pick_types(meg=True)
# Make csd matrices
freqs = np.arange(*freq_range[subject])
csd = mne.time_frequency.csd_morlet(epochs, freqs, tmin=-0.8, tmax=1.0, decim=5)
csd_baseline = mne.time_frequency.csd_morlet(epochs, freqs, tmin=-0.8, tmax=0, decim=5)
# ERS activity starts at 0.5 seconds after stimulus onset
csd_ers = mne.time_frequency.csd_morlet(epochs, freqs, tmin=0.2, tmax=1.0, decim=5)
csd = csd.mean()
csd_baseline = csd_baseline.mean()
csd_ers = csd_ers.mean()
###############################################################################
# Compute dics solution and plot stc at dipole location
###############################################################################
dists = []
focs = []
ori_errors = []
for setting in dics_settings:
reg, sensor_type, pick_ori, inversion, weight_norm, normalize_fwd, real_filter, use_noise_cov, reduce_rank = setting
try:
if sensor_type == 'grad':
info = epochs_grad.info
elif sensor_type == 'mag':
info = epochs_mag.info
elif sensor_type == 'joint':
info = epochs_joint.info
else:
raise ValueError('Invalid sensor type: %s', sensor_type)
info_eq, fwd_eq, csd_eq = mne.channels.equalize_channels([info, fwd, csd])
filters = make_dics(info_eq, fwd_eq, csd_eq, reg=reg, pick_ori=pick_ori,
inversion=inversion, weight_norm=weight_norm,
noise_csd=csd_baseline if use_noise_cov else None,
normalize_fwd=normalize_fwd,
real_filter=real_filter, reduce_rank=reduce_rank)
# Compute source power
stc_baseline, _ = apply_dics_csd(csd_baseline, filters)
stc_power, _ = apply_dics_csd(csd_ers, filters)
# Normalize with baseline power.
stc_power /= stc_baseline
stc_power.data = np.log(stc_power.data)
peak_vertex, _ = stc_power.get_peak(vert_as_index=True)
# Compute distance between true and estimated source locations
pos = fwd['source_rr'][peak_vertex]
dist = np.linalg.norm(dip.pos - pos)
# Ratio between estimated peak activity and all estimated activity.
focality_score = stc_power.data[peak_vertex, 0] / stc_power.data.sum()
if pick_ori == 'max-power':
estimated_ori = filters['max_power_oris'][0][peak_vertex]
ori_error = np.rad2deg(np.arccos(estimated_ori @ dip.ori[0]))
if ori_error > 90:
ori_error = 180 - ori_error
else:
ori_error = np.nan
except Exception as e:
print(e)
dist = np.nan
focality_score = np.nan
ori_error = np.nan
print(setting, dist, focality_score, ori_error)
dists.append(dist)
focs.append(focality_score)
ori_errors.append(ori_error)
###############################################################################
# Save everything to a pandas dataframe
###############################################################################
df = pd.DataFrame(dics_settings,
columns=['reg', 'sensor_type', 'pick_ori', 'inversion',
'weight_norm', 'normalize_fwd', 'real_filter',
'use_noise_cov', 'reduce_rank'])
df['dist'] = dists
df['focality'] = focs
df['ori_error'] = ori_errors
df.to_csv(fname.dics_megset_results(subject=subject))
print('OK!')