/
blech_post_process.py
310 lines (269 loc) · 15 KB
/
blech_post_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import os
import tables
import numpy as np
import easygui
import ast
import pylab as plt
from sklearn.mixture import GaussianMixture
import blech_waveforms_datashader
# Get directory where the hdf5 file sits, and change to that directory
dir_name = easygui.diropenbox()
os.chdir(dir_name)
# Clean up the memory monitor files, pass if clean up has been done already
if not os.path.exists('./memory_monitor_clustering/memory_usage.txt'):
file_list = os.listdir('./memory_monitor_clustering')
f = open('./memory_monitor_clustering/memory_usage.txt', 'w')
for files in file_list:
try:
mem_usage = np.loadtxt('./memory_monitor_clustering/' + files)
print('electrode'+files[:-4], '\t', str(mem_usage)+'MB', file=f)
os.system('rm ' + './memory_monitor_clustering/' + files)
except:
pass
f.close()
# Look for the hdf5 file in the directory
file_list = os.listdir('./')
hdf5_name = ''
for files in file_list:
if files[-2:] == 'h5':
hdf5_name = files
# Open the hdf5 file
hf5 = tables.open_file(hdf5_name, 'r+')
# Delete the raw node, if it exists in the hdf5 file, to cut down on file size
try:
hf5.remove_node('/raw', recursive = 1)
# And if successful, close the currently open hdf5 file and ptrepack the file
hf5.close()
print("Raw recordings removed")
os.system("ptrepack --chunkshape=auto --propindexes --complevel=9 --complib=blosc " + hdf5_name + " " + hdf5_name[:-3] + "_repacked.h5")
# Delete the old (raw and big) hdf5 file
os.system("rm " + hdf5_name)
# And open the new, repacked file
hf5 = tables.open_file(hdf5_name[:-3] + "_repacked.h5", 'r+')
print("File repacked")
except:
print("Raw recordings have already been removed, so moving on ..")
# Make the sorted_units group in the hdf5 file if it doesn't already exist
try:
hf5.create_group('/', 'sorted_units')
except:
pass
# Define a unit_descriptor class to be used to add things (anything!) about the sorted units to a pytables table
class unit_descriptor(tables.IsDescription):
electrode_number = tables.Int32Col()
single_unit = tables.Int32Col()
regular_spiking = tables.Int32Col()
fast_spiking = tables.Int32Col()
# Make a table under /sorted_units describing the sorted units. If unit_descriptor already exists, just open it up in the variable table
try:
table = hf5.create_table('/', 'unit_descriptor', description = unit_descriptor)
except:
table = hf5.root.unit_descriptor
# Run an infinite loop as long as the user wants to pick clusters from the electrodes
while True:
# Get electrode number from user
electrode_num = easygui.multenterbox(msg = 'Which electrode do you want to choose? Hit cancel to exit', fields = ['Electrode #'])
# Break if wrong input/cancel command was given
try:
electrode_num = int(electrode_num[0])
except:
break
# Get the number of clusters in the chosen solution
num_clusters = easygui.multenterbox(msg = 'Which solution do you want to choose for electrode %i?' % electrode_num, fields = ['Number of clusters in the solution'])
num_clusters = int(num_clusters[0])
# Load data from the chosen electrode and solution
spike_waveforms = np.load('./spike_waveforms/electrode%i/spike_waveforms.npy' % electrode_num)
spike_times = np.load('./spike_times/electrode%i/spike_times.npy' % electrode_num)
pca_slices = np.load('./spike_waveforms/electrode%i/pca_waveforms.npy' % electrode_num)
energy = np.load('./spike_waveforms/electrode%i/energy.npy' % electrode_num)
amplitudes = np.load('./spike_waveforms/electrode%i/spike_amplitudes.npy' % electrode_num)
predictions = np.load('./clustering_results/electrode%i/clusters%i/predictions.npy' % (electrode_num, num_clusters))
# Get cluster choices from the chosen solution
clusters = easygui.multchoicebox(msg = 'Which clusters do you want to choose?', choices = tuple([str(i) for i in range(int(np.max(predictions) + 1))]))
# Check if the user wants to merge clusters if more than 1 cluster was chosen. Else ask if the user wants to split/re-cluster the chosen cluster
merge = False
re_cluster = False
if len(clusters) > 1:
merge = easygui.multchoicebox(msg = 'I want to merge these clusters into one unit (True = Yes, False = No)', choices = ('True', 'False'))
merge = ast.literal_eval(merge[0])
else:
re_cluster = easygui.multchoicebox(msg = 'I want to split this cluster (True = Yes, False = No)', choices = ('True', 'False'))
re_cluster = ast.literal_eval(re_cluster[0])
# If the user asked to split/re-cluster, ask them for the clustering parameters and perform clustering
split_predictions = []
chosen_split = 0
if re_cluster:
# Get clustering parameters from user
clustering_params = easygui.multenterbox(msg = 'Fill in the parameters for re-clustering (using a GMM)', fields = ['Number of clusters', 'Maximum number of iterations (1000 is more than enough)', 'Convergence criterion (usually 0.0001)', 'Number of random restarts for GMM (10 is more than enough)'])
n_clusters = int(clustering_params[0])
n_iter = int(clustering_params[1])
thresh = float(clustering_params[2])
n_restarts = int(clustering_params[3])
# Make data array to be put through the GMM - 5 components: 3 PCs, scaled energy, amplitude
this_cluster = np.where(predictions == int(clusters[0]))[0]
n_pc = 3
data = np.zeros((len(this_cluster), n_pc + 2))
data[:,2:] = pca_slices[this_cluster,:n_pc]
data[:,0] = energy[this_cluster]/np.max(energy[this_cluster])
data[:,1] = np.abs(amplitudes[this_cluster])/np.max(np.abs(amplitudes[this_cluster]))
# Cluster the data
g = GaussianMixture(n_components = n_clusters, covariance_type = 'full', tol = thresh, max_iter = n_iter, n_init = n_restarts)
g.fit(data)
# Show the cluster plots if the solution converged
if g.converged_:
split_predictions = g.predict(data)
x = np.arange(len(spike_waveforms[0])/10) + 1
for cluster in range(n_clusters):
split_points = np.where(split_predictions == cluster)[0]
# plt.figure(cluster)
slices_dejittered = spike_waveforms[this_cluster, :] # Waveforms and times from the chosen cluster
times_dejittered = spike_times[this_cluster]
times_dejittered = times_dejittered[split_points] # Waveforms and times from the chosen split of the chosen cluster
ISIs = np.ediff1d(np.sort(times_dejittered))/30.0
violations1 = 100.0*float(np.sum(ISIs < 1.0)/split_points.shape[0])
violations2 = 100.0*float(np.sum(ISIs < 2.0)/split_points.shape[0])
fig, ax = blech_waveforms_datashader.waveforms_datashader(slices_dejittered[split_points, :], x)
# plt.plot(x-15, slices_dejittered[split_points, :].T, linewidth = 0.01, color = 'red')
ax.set_xlabel('Sample (30 samples per ms)')
ax.set_ylabel('Voltage (microvolts)')
ax.set_title("Split Cluster{:d}, 2ms ISI violations={:.1f} percent".format(cluster, violations2) + "\n" + "1ms ISI violations={:.1f}%, Number of waveforms={:d}".format(violations1, split_points.shape[0]))
else:
print("Solution did not converge - try again with higher number of iterations or lower convergence criterion")
continue
plt.show()
# Ask the user for the split clusters they want to choose
chosen_split = easygui.multchoicebox(msg = 'Which split cluster do you want to choose? Hit cancel to exit', choices = tuple([str(i) for i in range(n_clusters)]))
try:
chosen_split = int(chosen_split[0])
except:
continue
# Get list of existing nodes/groups under /sorted_units
node_list = hf5.list_nodes('/sorted_units')
# If node_list is empty, start naming units from 000
unit_name = ''
max_unit = 0
if node_list == []:
unit_name = 'unit%03d' % 0
# Else name the new unit by incrementing the last unit by 1
else:
unit_numbers = []
for node in node_list:
unit_numbers.append(node._v_pathname.split('/')[-1][-3:])
unit_numbers[-1] = int(unit_numbers[-1])
unit_numbers = np.array(unit_numbers)
max_unit = np.max(unit_numbers)
unit_name = 'unit%03d' % int(max_unit + 1)
# Get a new unit_descriptor table row for this new unit
unit_description = table.row
# If the user re-clustered/split clusters, add the chosen clusters in split_clusters
if re_cluster:
hf5.create_group('/sorted_units', unit_name)
unit_waveforms = spike_waveforms[np.where(predictions == int(clusters[0]))[0], :] # Waveforms of originally chosen cluster
unit_waveforms = unit_waveforms[np.where(split_predictions == chosen_split)[0], :] # Subsetting this set of waveforms to include only the chosen split
unit_times = spike_times[np.where(predictions == int(clusters[0]))[0]] # Do the same thing for the spike times
unit_times = unit_times[np.where(split_predictions == chosen_split)[0]]
waveforms = hf5.create_array('/sorted_units/%s' % unit_name, 'waveforms', unit_waveforms)
times = hf5.create_array('/sorted_units/%s' % unit_name, 'times', unit_times)
unit_description['electrode_number'] = electrode_num
single_unit = easygui.multchoicebox(msg = 'I am almost-SURE that this is a beautiful single unit (True = Yes, False = No)', choices = ('True', 'False'))
unit_description['single_unit'] = int(ast.literal_eval(single_unit[0]))
# If the user says that this is a single unit, ask them whether its regular or fast spiking
unit_description['regular_spiking'] = 0
unit_description['fast_spiking'] = 0
if int(ast.literal_eval(single_unit[0])):
unit_type = easygui.multchoicebox(msg = 'What type of unit is this (Regular spiking = Pyramidal cells, Fast spiking = PV+ interneurons)?', choices = ('regular_spiking', 'fast_spiking'))
unit_description[unit_type[0]] = 1
unit_description.append()
table.flush()
hf5.flush()
# If only 1 cluster was chosen (and it wasn't split), add that as a new unit in /sorted_units. Ask if the isolated unit is an almost-SURE single unit
elif len(clusters) == 1:
hf5.create_group('/sorted_units', unit_name)
unit_waveforms = spike_waveforms[np.where(predictions == int(clusters[0]))[0], :]
unit_times = spike_times[np.where(predictions == int(clusters[0]))[0]]
waveforms = hf5.create_array('/sorted_units/%s' % unit_name, 'waveforms', unit_waveforms)
times = hf5.create_array('/sorted_units/%s' % unit_name, 'times', unit_times)
unit_description['electrode_number'] = electrode_num
single_unit = easygui.multchoicebox(msg = 'I am almost-SURE that this is a beautiful single unit (True = Yes, False = No)', choices = ('True', 'False'))
unit_description['single_unit'] = int(ast.literal_eval(single_unit[0]))
# If the user says that this is a single unit, ask them whether its regular or fast spiking
unit_description['regular_spiking'] = 0
unit_description['fast_spiking'] = 0
if int(ast.literal_eval(single_unit[0])):
unit_type = easygui.multchoicebox(msg = 'What type of unit is this (Regular spiking = Pyramidal cells, Fast spiking = PV+ interneurons)?', choices = ('regular_spiking', 'fast_spiking'))
unit_description[unit_type[0]] = 1
unit_description.append()
table.flush()
hf5.flush()
else:
# If the chosen units are going to be merged, merge them
if merge:
unit_waveforms = []
unit_times = []
for cluster in clusters:
if unit_waveforms == []:
unit_waveforms = spike_waveforms[np.where(predictions == int(cluster))[0], :]
unit_times = spike_times[np.where(predictions == int(cluster))[0]]
else:
unit_waveforms = np.concatenate((unit_waveforms, spike_waveforms[np.where(predictions == int(cluster))[0], :]))
unit_times = np.concatenate((unit_times, spike_times[np.where(predictions == int(cluster))[0]]))
# Show the merged cluster to the user, and ask if they still want to merge
x = np.arange(len(unit_waveforms[0])/10) + 1
fig, ax = blech_waveforms_datashader.waveforms_datashader(unit_waveforms, x)
# plt.plot(x - 15, unit_waveforms[:, ::10].T, linewidth = 0.01, color = 'red')
ax.set_xlabel('Sample (30 samples per ms)')
ax.set_ylabel('Voltage (microvolts)')
ax.set_title('Merged cluster, No. of waveforms={:d}'.format(unit_waveforms.shape[0]))
plt.show()
# Warn the user about the frequency of ISI violations in the merged unit
ISIs = np.ediff1d(np.sort(unit_times))/30.0
violations1 = 100.0*float(np.sum(ISIs < 1.0)/len(unit_times))
violations2 = 100.0*float(np.sum(ISIs < 2.0)/len(unit_times))
proceed = easygui.multchoicebox(msg = 'My merged cluster has %.1f percent (<2ms) and %.1f percent (<1ms) ISI violations out of %i total waveforms. I want to still merge these clusters into one unit (True = Yes, False = No)' % (violations2, violations1, len(unit_times)), choices = ('True', 'False'))
proceed = ast.literal_eval(proceed[0])
# Create unit if the user agrees to proceed, else abort and go back to start of the loop
if proceed:
hf5.create_group('/sorted_units', unit_name)
waveforms = hf5.create_array('/sorted_units/%s' % unit_name, 'waveforms', unit_waveforms)
times = hf5.create_array('/sorted_units/%s' % unit_name, 'times', unit_times)
unit_description['electrode_number'] = electrode_num
single_unit = easygui.multchoicebox(msg = 'I am almost-SURE that this is a beautiful single unit (True = Yes, False = No)', choices = ('True', 'False'))
unit_description['single_unit'] = int(ast.literal_eval(single_unit[0]))
# If the user says that this is a single unit, ask them whether its regular or fast spiking
unit_description['regular_spiking'] = 0
unit_description['fast_spiking'] = 0
if int(ast.literal_eval(single_unit[0])):
unit_type = easygui.multchoicebox(msg = 'What type of unit is this (Regular spiking = Pyramidal cells, Fast spiking = PV+ interneurons)?', choices = ('regular_spiking', 'fast_spiking'))
unit_description[unit_type[0]] = 1
unit_description.append()
table.flush()
hf5.flush()
else:
continue
# Otherwise include each cluster as a separate unit
else:
for cluster in clusters:
hf5.create_group('/sorted_units', unit_name)
unit_waveforms = spike_waveforms[np.where(predictions == int(cluster))[0], :]
unit_times = spike_times[np.where(predictions == int(cluster))[0]]
waveforms = hf5.create_array('/sorted_units/%s' % unit_name, 'waveforms', unit_waveforms)
times = hf5.create_array('/sorted_units/%s' % unit_name, 'times', unit_times)
unit_description['electrode_number'] = electrode_num
single_unit = easygui.multchoicebox(msg = 'I am almost-SURE that electrode: %i cluster: %i is a beautiful single unit (True = Yes, False = No)' % (electrode_num, int(cluster)), choices = ('True', 'False'))
unit_description['single_unit'] = int(ast.literal_eval(single_unit[0]))
# If the user says that this is a single unit, ask them whether its regular or fast spiking
unit_description['regular_spiking'] = 0
unit_description['fast_spiking'] = 0
if int(ast.literal_eval(single_unit[0])):
unit_type = easygui.multchoicebox(msg = 'What type of unit is this (Regular spiking = Pyramidal cells, Fast spiking = PV+ interneurons)?', choices = ('regular_spiking', 'fast_spiking'))
unit_description[unit_type[0]] = 1
unit_description.append()
table.flush()
hf5.flush()
# Finally increment max_unit and create a new unit name
max_unit += 1
unit_name = 'unit%03d' % int(max_unit + 1)
# Get a new unit_descriptor table row for this new unit
unit_description = table.row
# Close the hdf5 file
hf5.close()