예제 #1
0
 def test_wrong_input_errors(self):
     synchrofact_obj = Synchrotool(
         [neo.SpikeTrain([1] * pq.s, t_stop=2 * pq.s)],
         sampling_rate=1 / pq.s,
         binary=True,
         spread=1)
     self.assertRaises(ValueError, synchrofact_obj.delete_synchrofacts, -1)
예제 #2
0
    def _test_template(self,
                       spiketrains,
                       correct_complexities,
                       sampling_rate,
                       spread,
                       deletion_threshold=2,
                       mode='delete',
                       in_place=False,
                       binary=True):

        synchrofact_obj = Synchrotool(spiketrains,
                                      sampling_rate=sampling_rate,
                                      binary=binary,
                                      spread=spread)

        # test annotation
        synchrofact_obj.annotate_synchrofacts()

        annotations = [
            st.array_annotations['complexity'] for st in spiketrains
        ]

        assert_array_equal(annotations, correct_complexities)
        for a in annotations:
            self.assertEqual(a.dtype, np.dtype(np.uint16).type)

        if mode == 'extract':
            correct_spike_times = [
                spikes[mask] for spikes, mask in zip(
                    spiketrains, correct_complexities >= deletion_threshold)
            ]
        else:
            correct_spike_times = [
                spikes[mask] for spikes, mask in zip(
                    spiketrains, correct_complexities < deletion_threshold)
            ]

        # test deletion
        synchrofact_obj.delete_synchrofacts(threshold=deletion_threshold,
                                            in_place=in_place,
                                            mode=mode)

        cleaned_spike_times = [st.times for st in spiketrains]

        for correct_st, cleaned_st in zip(correct_spike_times,
                                          cleaned_spike_times):
            assert_array_almost_equal(cleaned_st, correct_st)
예제 #3
0
    def test_correct_transfer_of_spiketrain_attributes(self):

        # for delete=True the spiketrains in the block are changed,
        # test if their attributes remain correct

        sampling_rate = 1 / pq.s

        spiketrain = neo.SpikeTrain([1, 1, 5, 0] * pq.s, t_stop=10 * pq.s)

        block = neo.Block()

        group = neo.Group(name='Test Group')
        block.groups.append(group)
        group.spiketrains.append(spiketrain)

        segment = neo.Segment()
        block.segments.append(segment)
        segment.block = block
        segment.spiketrains.append(spiketrain)
        spiketrain.segment = segment

        spiketrain.annotate(cool_spike_train=True)
        spiketrain.array_annotate(
            spike_number=np.arange(len(spiketrain.times.magnitude)))
        spiketrain.waveforms = np.sin(
            np.arange(len(spiketrain.times.magnitude))[:, np.newaxis] +
            np.arange(len(spiketrain.times.magnitude))[np.newaxis, :])

        correct_mask = np.array([False, False, True, True])

        # store the correct attributes
        correct_annotations = spiketrain.annotations.copy()
        correct_waveforms = spiketrain.waveforms[correct_mask].copy()
        correct_array_annotations = {
            key: value[correct_mask]
            for key, value in spiketrain.array_annotations.items()
        }

        # perform a synchrofact search with delete=True
        synchrofact_obj = Synchrotool([spiketrain],
                                      spread=0,
                                      sampling_rate=sampling_rate,
                                      binary=False)
        synchrofact_obj.delete_synchrofacts(mode='delete',
                                            in_place=True,
                                            threshold=2)

        # Ensure that the spiketrain was not duplicated
        self.assertEqual(len(block.filter(objects=neo.SpikeTrain)), 1)

        cleaned_spiketrain = segment.spiketrains[0]

        # Ensure that the spiketrain is also in the group
        self.assertEqual(len(block.groups[0].spiketrains), 1)
        self.assertIs(block.groups[0].spiketrains[0], cleaned_spiketrain)

        cleaned_annotations = cleaned_spiketrain.annotations
        cleaned_waveforms = cleaned_spiketrain.waveforms
        cleaned_array_annotations = cleaned_spiketrain.array_annotations
        cleaned_array_annotations.pop('complexity')

        self.assertDictEqual(correct_annotations, cleaned_annotations)
        assert_array_almost_equal(cleaned_waveforms, correct_waveforms)
        self.assertTrue(
            len(cleaned_array_annotations) == len(correct_array_annotations))
        for key, value in correct_array_annotations.items():
            self.assertTrue(key in cleaned_array_annotations.keys())
            assert_array_almost_equal(value, cleaned_array_annotations[key])