def clear_refrac(self, threshed, ms=2): """Remove spikes from the refractory period of all channels. Parameters ---------- threshed : array_like Array of ones and zeros. ms : float, optional, default 2 The length of the refractory period in milliseconds. Raises ------ AssertionError * If `ms` is not an instance of the ADT ``numbers.Integral``. * If `ms` is less than 0 or is not ``None``. Returns ------- r : SpikeDataFrame The thresholded and refractory-period-cleared array of booleans indicating the sample point at which a spike was above threshold. """ assert isinstance(ms, (numbers.Integral, types.NoneType)), \ '"ms" must be an integer or None' assert ms >= 0 or ms is None, \ 'refractory period must be a nonnegative integer or None' if ms: # copy so we don't write over the values clr = threshed.values.copy() # get the number of samples in ms milliseconds ms_fs = samples_per_ms(self.fs, ms) # TODO: make sure samples by channels is shape of clr # WARNING: you must pass a np.uint8 type array (view or otherwise) clear_refrac(clr.view(np.uint8), ms_fs) r = SpikeDataFrame(clr, self.meta, index=threshed.index, columns=threshed.columns) else: r = threshed return r
def test_one_thresh(self): thr = self.x > rand() cleared = thr.copy() clear_refrac(thr.view(np.uint8), self.window) self.assertRaises(ValueError, clear_refrac, thr, self.window) self.assertFalse(np.array_equal(thr, cleared))