-
Notifications
You must be signed in to change notification settings - Fork 0
/
AudioTools.py
651 lines (518 loc) · 25.5 KB
/
AudioTools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
"""Module providing objects to detect audio onsets."""
from __future__ import print_function
from __future__ import division
from builtins import zip
from builtins import range
from builtins import object
from past.utils import old_div
import numpy as np
import scipy as sp
import scipy.signal
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
import sys
import os.path
def check_audio_alignment(rs=None, ns5_loader=None, timestamps=None,
plot=True, analog_channels=None, dstart=-1000, dstop=8000,
smoothing_std=100, also_smooth=True, **plot_kwargs):
"""Extracts audio waveforms, to check for sync error.
rs : RecordingSession
ns5_loader, timestamps, analog_channels : extracted from rs if None
plot : if True, dispatches result (and plot_kwargs) to plot_audio_alignment
dstart, dstop : # of samples relative to timestamp
also_smooth : if True, also returns a smoothed version of the data
Uses a Gaussian of standard deviation equal to smoothing_std samples
Returns:
dict { 'n' : samples_from_timestamp, 'raw' : data,
'smoothed' : smoothed_data if also_smoth}
The shape of each is (n_channels, n_timestamps, n_samples)
"""
if timestamps is None:
timestamps = rs.read_timestamps()
if ns5_loader is None:
ns5_loader = rs.get_ns5_loader()
if analog_channels is None:
analog_channels = np.asarray(rs.read_analog_channel_ids()) + 128
# Set up return values
res = {}
res['n'] = np.arange(dstart, dstop, dtype=np.int)
# Extract slices of audio channels
raw_slices = []
for timestamp in timestamps:
raw = ns5_loader.get_chunk_by_channel(start=timestamp + dstart,
stop=timestamp + dstop)
audio_slc = np.array([raw[ch] for ch in analog_channels],
dtype=np.float)
raw_slices.append(audio_slc)
res['raw'] = np.asarray(raw_slices).swapaxes(0, 1)
if also_smooth:
# square and smooth the signal to make it easier to eyeball
gstd = smoothing_std # 3ms, will still get plenty of wiggles
glen = int(2.5 * gstd) # save time by truncation
# Incatation such that b[0] == 1.0
b = scipy.signal.gaussian(glen * 2, gstd, sym=False)[glen:]
b = old_div(b, np.sum(b**2))
# Do the filtering
try:
res['smoothed'] = scipy.signal.filtfilt(b, [1],
res['raw'] ** 2)
except:
print("warning: smoothing error")
also_smooth = False
if plot:
tight_plot_raw(res, analog_channels, **plot_kwargs)
wide_plot_smoothed(res, analog_channels, **plot_kwargs)
return res
def tight_plot_raw(res, analog_channels, start=-50, stop=50,
savefig=False, **plot_kwargs):
# Tight plot of raw
for n, ch in enumerate(analog_channels):
ax = plot_audio_alignment(res['raw'][n], res['n'],
**plot_kwargs)
ax.plot([0, 0], ax.get_ylim(), 'k')
ax.set_title('channel %d' % ch)
ax.set_xlim((start, stop))
plt.grid()
f = ax.get_figure()
if savefig:
if savefig is True:
savefig = '.'
filename = os.path.join(savefig, './tight_plot_raw_%d.png' % ch)
f.savefig(filename)
plt.close(f)
def wide_plot_smoothed(res, analog_channels, savefig=False, **plot_kwargs):
# Wide plot of smoothed
for n, ch in enumerate(analog_channels):
ax = plot_audio_alignment(res['smoothed'][n], res['n'],
**plot_kwargs)
ax.set_title('channel %d' % ch)
plt.grid()
f = ax.get_figure()
if savefig:
if savefig is True:
savefig = '.'
filename = os.path.join(savefig, './wide_plot_smoothed_%d.png' % ch)
f.savefig(filename)
plt.close(f)
def plot_audio_alignment(data, n=None, ax=None, start=None, stop=None,
plot_endpoints=False, downsample_ratio=1, **kwargs):
"""Simple plotting function
data - 2d array of traces to plot
shape n_traces x n_samples
n - indexes of each sample point
if None, will generate based on shape of data
start, stop, plot_endpoints - will plot the beginning and end of the
traces where n == start and n == stop
These are calculated independently for each trace
"""
if n is None:
n = np.arange(data.shape[1], dtype=np.int)
if start is not None or stop is not None:
data = data.copy()
# Deal with downsampling
if downsample_ratio != 1:
data = data[:, ::downsample_ratio]
n = n[::downsample_ratio]
# Set data to nan outside of start
if start is not None:
if hasattr(start, '__len__'):
if len(start) != len(data):
raise ValueError("start must be same length as data or scalar")
for nrow in range(len(data)):
data[nrow, n < start[nrow]] = np.nan
else:
data = data[:, n >= start]
n = n[n >= start]
# Set data to nan outside of stop
if stop is not None:
if hasattr(stop, '__len__'):
if len(stop) != len(data):
raise ValueError("stop must be same length as data or scalar")
for nrow in range(len(data)):
data[nrow, n >= stop[nrow]] = np.nan
else:
data = data[:, n < stop]
n = n[n < stop]
if ax is None:
f = plt.figure()
ax = f.add_subplot(111)
# Now do the plotting
ax.plot(n, data.transpose(), **kwargs)
# Plot the endpoints
if plot_endpoints:
if stop is not None:
for nrow in range(len(data)):
if stop[nrow] > n.max():
idx = len(n) - 1
else:
# Calculate the value of n where the trace stops
try:
idx = np.where(n == stop[nrow])[0].item() - 1
except (IndexError, ValueError):
# This is in case the stop falls between values of n
idx = np.where(n >= stop[nrow])[0][0].item() - 1
ax.plot(n[idx], data[nrow, idx], 'k*')
return ax
class OnsetDetector(object):
"""Given a mono or stereo audio stream, detects sound onsets.
In every case, the audio stream is squared and smoothed to
calculate `smoothed_power`. When smoothed_power crosses a
threshhold, an onset is marked.
Later error checking removes spurious onsets.
"""
def __init__(self, input_data, F_SAMP=30000., manual_threshhold=None,
minimum_threshhold=None, minimum_duration_ms=5,
plot_debugging_figures=False, verbose=False):
"""An object that intelligently detects audio onsets in its input.
Parameters
----------
input_data: 1d or 2d array of values. If 2d, the threshhold crossing
code is run on each channel separately. The waveform need
only cross threshhold on one channel, not both.
manual_threshhold: Specify a threshhold manually. Events will be
detected when the smoothed audio power crosses this threshhold.
Thus, make sure that you specify in units**2.
(Regular units, not dB.)
minimum_duration_ms: Sounds less than this duration are discarded
Call method execute() to actually run the code.
If manual_threshhold is None, a threshhold will be chosen for you.
In the case of stereo data, the threshhold is set using the first
channel (usually left) and then the same threshhold is used for
the other channel. The threshold will be stored in self.threshold
Implementation note: slowest part of execution appears
to be the smoothing filter. Look into implementing more efficiently.
Smoothing an array of length 2**24 (17M) is barely tolerable (10sec).
"""
self.input_data = input_data
self.threshhold = manual_threshhold
self._minimum_threshhold = minimum_threshhold
self.F_SAMP = F_SAMP
self._minimum_duration_samples = \
np.rint(minimum_duration_ms / 1000. * self.F_SAMP)
self.plot_debugging_figures = plot_debugging_figures
self.detected_onsets = None # so far
self.verbose = verbose
# We use a causal filter to be extra sure that we don't err on the side
# of identifying the onset too soon. Note that this guarantees we will
# identify the onset too late! Well, actually, since the threshhold
# is set based on the delayed data, it doesn't matter too much.
# Check debugging figure to reassure yourself that the delay is not
# significant.
flen = int(np.rint(.003*self.F_SAMP)) # 3ms
self.smoother = CausalRectangularSmoother(smoothing_filter_length=flen)
# This is the best onset detector so far, out of the implementations
# I've tried. It will only be used if manual threshhold not
# specified.
self.thresh_setter = ThreshholdAutosetterLeastSensitive
# Check that an impulse won't be so smeared by the smoother that
# it would pass the minimum duration requirement.
if self._minimum_duration_samples < \
self.smoother.smoothing_filter_length:
print("WARNING: with current smoothing settings, even an " \
"impulse could pass the min_duration requirement!")
def execute(self):
"""Executes onset detection.
If onsets or offsets are within smoothing filter length of
beginning or end, a warning is generated.
"""
# Deal with the case of stereo input
if self.input_data.ndim == 2:
sound_bool = self._find_stereo_threshhold_crossings(self.input_data)
elif self.input_data.ndim == 1:
sound_bool = self._find_mono_threshhold_crossings(self.input_data)
# We need to throw out sounds that are too short.
# Finally, will store self.detected_onsets.
if self.verbose:
print("Error checking"); sys.stdout.flush()
self._error_check_onsets(sound_bool)
# Check for sounds too close to beginning
if len(self.detected_onsets >= 1):
if self.detected_onsets[0] < self.smoother.smoothing_filter_length:
print("warning: first onset is within filter width of beginning")
if self.detected_offsets[-1] > \
(len(sound_bool) - self.smoother.smoothing_filter_length):
print("warning: last offset is within filter width of end")
def _find_mono_threshhold_crossings(self, data_vector):
"""Finds threshhold crossings in mono data_vector.
If manual threshhold is None, attempts to autoset.
Returns a boolean array which is True when the signal
exceeds threshhold.
TODO: move this method and _find_stereo to base class
ThreshholdSetter. Doesn't need to know about anything else!
Rearrange stuff a little so that smoothing can be done in
parallel with multiple processes.
"""
# Remove mean
data_vector = data_vector - data_vector.mean()
# Appears to work best when smoothing power first, then putting
# into dB.
# smooth the audio data
if self.verbose:
print("Smoothing the data"); sys.stdout.flush()
smoothed_power_dB = 20*np.log10(\
self.smoother.execute(data_vector**2))
# Comment to save memory, uncomment for debugging
self.smoothed_power_dB = smoothed_power_dB
# Autoset threshhold if it wasn't set already
if self.threshhold is None:
if self.verbose:
print("Autosetting threshhold"); sys.stdout.flush()
# Instantiate and run threshhold setter
self.tset = self.thresh_setter(\
input_data=smoothed_power_dB,
plot_debugging_figures=self.plot_debugging_figures)
th = self.tset.execute()
# Apply minimum threshhold test
if self._minimum_threshhold is None or th > self._minimum_threshhold:
self.threshhold = th
else:
self.threshhold = self._minimum_threshhold
# find when smoothed waveform exceeds threshhold
if self.verbose:
print("Finding threshhold crossings at %0.3f dB" % self.threshhold)
sys.stdout.flush()
if self.plot_debugging_figures:
plt.figure()
plt.plot(smoothed_power_dB)
plt.plot([0, len(smoothed_power_dB)],
[self.threshhold, self.threshhold], 'r')
plt.show()
return smoothed_power_dB > self.threshhold
def _find_stereo_threshhold_crossings(self, data_vector):
"""Finds threshhold crossings in stereo data_vector
We know self.input_data.ndim == 2. Calls the mono detector once on
left channel and once again on right channel. Note that the left
channel threshhold will be used for the right. If this is not
appropriate, considering setting the threshhold manually.
Returns a boolean array of same shape as one row of self.input_data
which is True whenever either of the channels is above threshhold.
"""
if data_vector.shape[0] != 2:
print("WARNING: audio data should have shape (2,N).")
data_vector = data_vector.transpose()
sound_bool_L = self._find_mono_threshhold_crossings(data_vector[0])
sound_bool_R = self._find_mono_threshhold_crossings(data_vector[1])
return sound_bool_L | sound_bool_R
def _error_check_onsets(self, sound_bool):
# Find when the threshhold crossings first happen. `onsets` and
# `offsets` are inclusive bounds on audio power above threshhold.
onsets = np.where(np.diff(np.where(sound_bool, 1, 0)) == 1)[0] + 1
offsets = np.where(np.diff(np.where(sound_bool, 1, 0)) == -1)[0]
# check that we don't start or end in the middle of a sound
try:
if onsets[0] > offsets[0]:
# Extra offset at the beginning
offsets = offsets[1:]
if onsets[-1] > offsets[-1]:
# Extra onset at the end
onsets = onsets[:-1]
except IndexError:
# apparently no onsets or no offsets
print("No sounds found!")
onsets = np.array([])
offsets = np.array([])
if len(onsets) > 0:
# First do some error checking.
assert (len(onsets) == len(offsets)) and (np.all(onsets <= offsets))
# Remove sounds that violate min_duration requirement.
too_short_sounds = ((offsets - onsets) < self._minimum_duration_samples)
if np.any(too_short_sounds):
if self.verbose:
print("Removing %d sounds that violate duration requirement" % \
len(np.where(too_short_sounds)[0]))
onsets = onsets[np.logical_not(too_short_sounds)]
offsets = offsets[np.logical_not(too_short_sounds)]
# Warn when onsets occur very close together. This might occur if the
# sound power briefly drops below threshhold.
if np.any(np.diff(onsets) < self._minimum_duration_samples):
print("WARNING: %d onsets were suspiciously close together." % \
len(find(np.diff(onsets) < self._minimum_duration_samples)))
# Print the total number of sounds identified.
if self.verbose:
print("Identified %d sounds with average duration %0.3fs" % \
(len(onsets), old_div((offsets-onsets).mean(), self.F_SAMP)))
# Store detected onsets
self.detected_onsets = onsets
self.detected_offsets = offsets
def _plot_debugging_figure(self, smoothed_audio_power,
win_duration_ms=5):
# debugging figure
# Plot all sound waveforms overlaid, to verify that they were caught
# correctly.
# This parameter determines how large the plotting window is.
WINDOW_HALF_DURATION = np.rint(win_duration_ms/1000.0*self.F_SAMP) # samples
# Initialize the figure and subplots.
f = plt.figure()
ax = [f.add_subplot(2,2,n+1) for n in range(4)]
ax[0].set_title('Onset of sounds')
ax[1].set_title('Offset of sounds')
ax[2].set_title('Onset of smoothed')
ax[3].set_title('Offset of smoothed')
# Plot threshholds on the smoothed plots
ax[2].plot([-WINDOW_HALF_DURATION, WINDOW_HALF_DURATION],\
self.threshhold * np.ones((2,)), 'k:')
ax[3].plot([-WINDOW_HALF_DURATION, WINDOW_HALF_DURATION],
self.threshhold * np.ones((2,)), 'k:')
# Now plot close-up of onset and offset for each sound
for onset, offset in zip(self.detected_onsets,
self.detected_offsets):
# Deal with case where close to edge of window. Set new
# boundaries that are within the data range. These boundaries
# are used below in the plotting commands.
start_win = max(onset - WINDOW_HALF_DURATION, 0)
stop_win = min(offset + WINDOW_HALF_DURATION,
len(self.input_data))
# Plot a close-up of the onset
ax[0].plot(\
np.arange(start_win-onset, WINDOW_HALF_DURATION),
self.input_data[start_win:onset+WINDOW_HALF_DURATION])
# Plot a close-up of the offset
ax[1].plot(\
np.arange(-WINDOW_HALF_DURATION, stop_win - offset),
self.input_data[offset-WINDOW_HALF_DURATION:stop_win])
# Now do the same but with the smoothed data
ax[2].plot(\
np.arange(start_win - onset, WINDOW_HALF_DURATION),
smoothed_audio_power[start_win:onset+WINDOW_HALF_DURATION])
ax[3].plot(\
np.arange(-WINDOW_HALF_DURATION, stop_win - offset),
smoothed_audio_power[offset-WINDOW_HALF_DURATION:stop_win])
plt.show()
def commit_audio_onsets(self):
"""Writes the newly detected audio onsets to disk.
"""
# Note: even on a 32-bit system, this format allows >19hrs of indices.
np.savetxt('audio_onsets', self.detected_onsets, '%i')
#self.raw_data_loader.audio_onsets = self.audio_onsets
class CausalRectangularSmoother(object):
def __init__(self, smoothing_filter_length=100):
self.smoothing_filter_length = smoothing_filter_length
self.smoothing_filter = None
self.build_filter()
def build_filter(self):
"""Rectangular filter of unity gain"""
fillval = 1. / self.smoothing_filter_length
fillshape = (self.smoothing_filter_length,)
self.smoothing_filter_b = np.ones(fillshape) * fillval # numerator
self.smoothing_filter_a = np.array([1]) # denominator
self.filtering_function = sp.signal.lfilter # causal, 1d data only
def execute(self, input_data):
smoothed = self.filtering_function(b=self.smoothing_filter_b,
a=self.smoothing_filter_a, x=input_data)
return smoothed
class ThreshholdAutosetter(object):
"""Given an unordered stream of data, chooses an event threshhold.
Different algorithms are possible. Each operates on the distribution
of data points to choose an "intelligent" threshhold, ie, one in which
only very large events of non-negligible duration cross.
"""
def __init__(self, input_data, min_p_events=.0001, max_p_events=.99,
max_data_points=50e6, plot_debugging_figures=False):
"""Initialize a threshhold setter.
Parameters
----------
input_data: 1d array of data. The order is irrelevant.
Generally you will want to provide this in dB, eg
20*np.log10(data), so that it "looks linear".
min_p_events: minimum fraction of the data points that will be
above threshhold.
max_p_events: maximum fraction of the data points that will be
above threshhold.
max_data_points: To increase performance, we do not to analyze
all of the data points. If max_data_points<len(input_data),
input_data will be strided to keep only max_data_points.
"""
self.min_p_events = min_p_events
self.max_p_events = max_p_events
self.plot_debugging_figures = plot_debugging_figures
if input_data.size > max_data_points:
stride = np.ceil(old_div(input_data.size, max_data_points))
self.input_data = input_data[::stride]
#print "old size %d, new size %d" % (input_data.size,
# self.input_data.size)
else:
self.input_data = input_data
class ThreshholdAutosetterLeastSensitive(ThreshholdAutosetter):
def execute(self):
"""Automatically calculates a reasonable threshhold to detect onsets.
Theory
------
First calculates the distribution of power in the audio signal
across time. Presumably, for the great majority of the time,
this power will be low. Occasionally, there will be infrequent
bursts of audio power, which are the stimuli to be detected.
We want to set the threshhold to split these two regimes. This
algorithm first limits the search to threshholds that satisfy the
minimum and maximum event probabilities set in `min_p_events`
and `max_p_events`. Within that regime, it finds the largest gap
between observed data samples, and sets the threshhold there.
"""
x = np.sort(self.input_data)
# Apply constraints
search_regime = \
(np.arange(len(x)) > self.min_p_events*len(x)) & \
(np.arange(len(x)) < self.max_p_events*len(x))
# Find largest gap in that regime
diff_x_in_regime = np.diff(x[search_regime])
idx_largest_diff = np.argmax(diff_x_in_regime)
# Set threshhold in middle of that gap
best_thresh = 0.5 * (x[search_regime][idx_largest_diff] + \
x[search_regime][idx_largest_diff + 1])
return best_thresh
class ThreshholdAutosetterMinimalHistogram(ThreshholdAutosetter):
def execute(self, nbins=500):
"""Automatically calculates a reasonable threshhold to detect onsets.
Theory
------
First calculates the distribution of power in the audio signal
across time. Presumably, for the great majority of the time,
this power will be low. Occasionally, there will be infrequent
bursts of audio power, which are the stimuli to be detected.
We want to set the threshhold to split these two regimes. This
algorithm takes a weighted average of all possible power threshholds,
weighting powers that rarely occur in the data more highly.
Only threshholds that satisfy the `min_p_events` and `max_p_events`
criteria are considered. Thus, you are guaranteed that the fraction
of events exceeding the threshhold will be between these two
values.
Parameters
----------
nbins: The distribution is first binned into this many bins to
make analysis easier. Default: 500
"""
# First bin the input data
# This data should `look linear`. The user should convert power
# to dB before passing to this function, for example.
(h, bins) = np.histogram(self.input_data, bins=nbins)
# Convert to bin centers
bins = bins[:-1] + 0.5*np.diff(bins)
# Debugging figure
if self.plot_debugging_figures:
plt.figure(); plt.plot(bins, h);
plt.title('Histogram of input data')
plt.xlabel('input values'); plt.ylabel('frequency')
plt.show()
# Calculate the part of the histogram that satisfies the minimum
# and maximum constraints.
cumhist = old_div(np.float64(np.cumsum(h)), len(self.input_data))
search_regime = \
(cumhist > self.min_p_events) & \
(cumhist < self.max_p_events)
# Apply those constraints to `bins` and `h`
bins = bins[search_regime]
h = np.float64(h[search_regime])
if self.plot_debugging_figures:
plt.plot(bins[0], 0, 'r*')
plt.plot(bins[-1], 0, 'r*')
plt.show()
# Lower points in the histogram make better threshholds, so invert
h = np.max(h) - h
# Now a weighted average gives the index of the best threshhold. The
# sparsest bins are weighted the most.
h = old_div(h, np.sum(h))
weighted_best_bin = np.sum(h * np.arange(len(h)))
best_thresh = bins[int(round(weighted_best_bin))]
if self.plot_debugging_figures:
plt.plot(best_thresh, 0, 'k*')
plt.show()
return best_thresh