コード例 #1
0
    def test_window(self):
        """Test if the window parameter is correctly interpreted."""
        cch_win, bin_ids = sc.cch(
            self.binned_st1, self.binned_st2, window=[-30, 30])
        cch_win_mem, bin_ids_mem = sc.cch(
            self.binned_st1, self.binned_st2, window=[-30, 30],
            method='memory')

        self.assertEqual(len(bin_ids), cch_win.shape[0])
        assert_array_equal(bin_ids, np.arange(-30, 31, 1))
        assert_array_equal(
            (bin_ids - 0.5) * self.binned_st1.bin_size, cch_win.times)

        assert_array_equal(bin_ids_mem, np.arange(-30, 31, 1))
        assert_array_equal(
            (bin_ids_mem - 0.5) * self.binned_st1.bin_size, cch_win.times)

        assert_array_equal(cch_win, cch_win_mem)
        cch_unclipped, _ = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='full', binary=False)
        assert_array_equal(cch_win, cch_unclipped[19:80])

        _, bin_ids = sc.cch(
            self.binned_st1, self.binned_st2, window=[20, 30])
        _, bin_ids_mem = sc.cch(
            self.binned_st1, self.binned_st2, window=[20, 30], method='memory')

        assert_array_equal(bin_ids, np.arange(20, 31, 1))
        assert_array_equal(bin_ids_mem, np.arange(20, 31, 1))

        _, bin_ids = sc.cch(
            self.binned_st1, self.binned_st2, window=[-30, -20])

        _, bin_ids_mem = sc.cch(
            self.binned_st1, self.binned_st2, window=[-30, -20],
            method='memory')

        assert_array_equal(bin_ids, np.arange(-30, -19, 1))
        assert_array_equal(bin_ids_mem, np.arange(-30, -19, 1))

        # Check for wrong assignments to the window parameter
        # Test for window longer than the total length of the spike trains
        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-60, 50])
        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-50, 60])
        # Test for no integer or wrong string in input
        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-25.5, 25.5])
        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window='test')
コード例 #2
0
    def test_window(self):
        '''Test if the window parameter is correctly interpreted.'''
        cch_win, bin_ids = sc.cch(
            self.binned_st1, self.binned_st2, window=[-30, 30])
        cch_win_mem, bin_ids_mem = sc.cch(
            self.binned_st1, self.binned_st2, window=[-30, 30])

        assert_array_equal(bin_ids, np.arange(-30, 31, 1))
        assert_array_equal(
            (bin_ids - 0.5) * self.binned_st1.binsize, cch_win.times)

        assert_array_equal(bin_ids_mem, np.arange(-30, 31, 1))
        assert_array_equal(
            (bin_ids_mem - 0.5) * self.binned_st1.binsize, cch_win.times)

        assert_array_equal(cch_win, cch_win_mem)
        cch_unclipped, _ = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='full', binary=False)
        assert_array_equal(cch_win, cch_unclipped[19:80])

        _, bin_ids = sc.cch(
            self.binned_st1, self.binned_st2, window=[20, 30])
        _, bin_ids_mem = sc.cch(
            self.binned_st1, self.binned_st2, window=[20, 30], method='memory')

        assert_array_equal(bin_ids, np.arange(20, 31, 1))
        assert_array_equal(bin_ids_mem, np.arange(20, 31, 1))

        _, bin_ids = sc.cch(
            self.binned_st1, self.binned_st2, window=[-30, -20])

        _, bin_ids_mem = sc.cch(
            self.binned_st1, self.binned_st2, window=[-30, -20],
            method='memory')

        assert_array_equal(bin_ids, np.arange(-30, -19, 1))
        assert_array_equal(bin_ids_mem, np.arange(-30, -19, 1))

        # Check for wrong assignments to the window parameter
        # Test for window longer than the total length of the spike trains
        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-60, 50])
        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-50, 60])
        # Test for no integer or wrong string in input
        self.assertRaises(
            KeyError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-25.5, 25.5])
        self.assertRaises(
            KeyError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window='test')
コード例 #3
0
    def test_window(self):
        '''Test if the window parameter is correctly interpreted.'''
        cch_win, bin_ids = sc.cch(
            self.binned_st1, self.binned_st2, window=[-30, 30])
        cch_win_mem, bin_ids_mem = sc.cch(
            self.binned_st1, self.binned_st2, window=[-30, 30])

        assert_array_equal(bin_ids, np.arange(-30, 31, 1))
        assert_array_equal(
            (bin_ids - 0.5) * self.binned_st1.binsize, cch_win.times)

        assert_array_equal(bin_ids_mem, np.arange(-30, 31, 1))
        assert_array_equal(
            (bin_ids_mem - 0.5) * self.binned_st1.binsize, cch_win.times)

        assert_array_equal(cch_win, cch_win_mem)
        cch_unclipped, _ = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='full', binary=False)
        assert_array_equal(cch_win, cch_unclipped[19:80])

        cch_win, bin_ids = sc.cch(
            self.binned_st1, self.binned_st2, window=[-25*pq.ms, 25*pq.ms])
        cch_win_mem, bin_ids_mem = sc.cch(
            self.binned_st1, self.binned_st2, window=[-25*pq.ms, 25*pq.ms],
            method='memory')

        assert_array_equal(bin_ids, np.arange(-25, 26, 1))
        assert_array_equal(
            (bin_ids - 0.5) * self.binned_st1.binsize, cch_win.times)

        assert_array_equal(bin_ids_mem, np.arange(-25, 26, 1))
        assert_array_equal(
            (bin_ids_mem - 0.5) * self.binned_st1.binsize, cch_win.times)

        assert_array_equal(cch_win, cch_win_mem)

        _, bin_ids = sc.cch(
            self.binned_st1, self.binned_st2, window=[20, 30])
        _, bin_ids_mem = sc.cch(
            self.binned_st1, self.binned_st2, window=[20, 30], method='memory')

        assert_array_equal(bin_ids, np.arange(20, 31, 1))
        assert_array_equal(bin_ids_mem, np.arange(20, 31, 1))

        _, bin_ids = sc.cch(
            self.binned_st1, self.binned_st2, window=[-30, -20])

        _, bin_ids_mem = sc.cch(
            self.binned_st1, self.binned_st2, window=[-30, -20],
            method='memory')

        assert_array_equal(bin_ids, np.arange(-30, -19, 1))
        assert_array_equal(bin_ids_mem, np.arange(-30, -19, 1))

        # Cehck for wrong assignments to the window parameter
        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-60, 50])
        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-60, 50], method='memory')

        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-50, 60])
        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-50, 60], method='memory')

        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-25.5*pq.ms, 25*pq.ms])
        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-25.5*pq.ms, 25*pq.ms], method='memory')

        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-25*pq.ms, 25.5*pq.ms])
        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-25*pq.ms, 25.5*pq.ms], method='memory')

        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-60*pq.ms, 50*pq.ms])
        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-60*pq.ms, 50*pq.ms], method='memory')

        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-50*pq.ms, 60*pq.ms])
        self.assertRaises(
            ValueError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window=[-50*pq.ms, 60*pq.ms], method='memory')
コード例 #4
0
    def test_cross_correlation_histogram(self):
        '''
        Test generic result of a cross-correlation histogram between two binned
        spike trains.
        '''
        # Calculate CCH using Elephant (normal and binary version) with
        # mode equal to 'full' (whole spike trains are correlated)
        cch_clipped, bin_ids_clipped = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='full',
            binary=True)
        cch_unclipped, bin_ids_unclipped = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='full', binary=False)

        cch_clipped_mem, bin_ids_clipped_mem = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='full',
            binary=True, method='memory')
        cch_unclipped_mem, bin_ids_unclipped_mem = \
            sc.cross_correlation_histogram(
                self.binned_st1, self.binned_st2, window='full',
                binary=False, method='memory')
        # Check consistency two methods
        assert_array_equal(
            np.squeeze(cch_clipped.magnitude), np.squeeze(
                cch_clipped_mem.magnitude))
        assert_array_equal(
            np.squeeze(cch_clipped.times), np.squeeze(
                cch_clipped_mem.times))
        assert_array_equal(
            np.squeeze(cch_unclipped.magnitude), np.squeeze(
                cch_unclipped_mem.magnitude))
        assert_array_equal(
            np.squeeze(cch_unclipped.times), np.squeeze(
                cch_unclipped_mem.times))
        assert_array_almost_equal(bin_ids_clipped, bin_ids_clipped_mem)
        assert_array_almost_equal(bin_ids_unclipped, bin_ids_unclipped_mem)

        # Check normal correlation Note: Use numpy correlate to verify result.
        # Note: numpy conventions for input array 1 and input array 2 are
        # swapped compared to Elephant!
        mat1 = self.binned_st1.to_array()[0]
        mat2 = self.binned_st2.to_array()[0]
        target_numpy = np.correlate(mat2, mat1, mode='full')
        assert_array_equal(
            target_numpy, np.squeeze(cch_unclipped.magnitude))

        # Check cross correlation function for several displacements tau
        # Note: Use Elephant corrcoeff to verify result
        tau = [-25.0, 0.0, 13.0]  # in ms
        for t in tau:
            # adjust t_start, t_stop to shift by tau
            t0 = np.min([self.st_1.t_start + t * pq.ms, self.st_2.t_start])
            t1 = np.max([self.st_1.t_stop + t * pq.ms, self.st_2.t_stop])
            st1 = neo.SpikeTrain(self.st_1.magnitude + t, units='ms',
                                 t_start=t0 * pq.ms, t_stop=t1 * pq.ms)
            st2 = neo.SpikeTrain(self.st_2.magnitude, units='ms',
                                 t_start=t0 * pq.ms, t_stop=t1 * pq.ms)
            binned_sts = conv.BinnedSpikeTrain([st1, st2],
                                               binsize=1 * pq.ms,
                                               t_start=t0 * pq.ms,
                                               t_stop=t1 * pq.ms)
            # caluclate corrcoef
            corrcoef = sc.corrcoef(binned_sts)[1, 0]

            # expand t_stop to have two spike trains with same length as st1,
            # st2
            st1 = neo.SpikeTrain(self.st_1.magnitude, units='ms',
                                 t_start=self.st_1.t_start,
                                 t_stop=self.st_1.t_stop + np.abs(t) * pq.ms)
            st2 = neo.SpikeTrain(self.st_2.magnitude, units='ms',
                                 t_start=self.st_2.t_start,
                                 t_stop=self.st_2.t_stop + np.abs(t) * pq.ms)
            binned_st1 = conv.BinnedSpikeTrain(
                st1, t_start=0 * pq.ms, t_stop=(50 + np.abs(t)) * pq.ms,
                binsize=1 * pq.ms)
            binned_st2 = conv.BinnedSpikeTrain(
                st2, t_start=0 * pq.ms, t_stop=(50 + np.abs(t)) * pq.ms,
                binsize=1 * pq.ms)
            # calculate CCHcoef and take value at t=tau
            CCHcoef, _ = sc.cch(binned_st1, binned_st2,
                                cross_corr_coef=True)
            left_edge = - binned_st1.num_bins + 1
            tau_bin = int(t / float(binned_st1.binsize.magnitude))
            assert_array_equal(
                corrcoef, CCHcoef[tau_bin - left_edge].magnitude)

        # Check correlation using binary spike trains
        mat1 = np.array(self.binned_st1.to_bool_array()[0], dtype=int)
        mat2 = np.array(self.binned_st2.to_bool_array()[0], dtype=int)
        target_numpy = np.correlate(mat2, mat1, mode='full')
        assert_array_equal(
            target_numpy, np.squeeze(cch_clipped.magnitude))

        # Check the time axis and bin IDs of the resulting AnalogSignal
        assert_array_almost_equal(
            (bin_ids_clipped - 0.5) * self.binned_st1.binsize,
            cch_unclipped.times)
        assert_array_almost_equal(
            (bin_ids_clipped - 0.5) * self.binned_st1.binsize,
            cch_clipped.times)

        # Calculate CCH using Elephant (normal and binary version) with
        # mode equal to 'valid' (only completely overlapping intervals of the
        # spike trains are correlated)
        cch_clipped, bin_ids_clipped = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='valid',
            binary=True)
        cch_unclipped, bin_ids_unclipped = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='valid',
            binary=False)
        cch_clipped_mem, bin_ids_clipped_mem = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='valid',
            binary=True, method='memory')
        cch_unclipped_mem, bin_ids_unclipped_mem = \
            sc.cross_correlation_histogram(
                self.binned_st1, self.binned_st2, window='valid',
                binary=False, method='memory')

        # Check consistency two methods
        assert_array_equal(
            np.squeeze(cch_clipped.magnitude), np.squeeze(
                cch_clipped_mem.magnitude))
        assert_array_equal(
            np.squeeze(cch_clipped.times), np.squeeze(
                cch_clipped_mem.times))
        assert_array_equal(
            np.squeeze(cch_unclipped.magnitude), np.squeeze(
                cch_unclipped_mem.magnitude))
        assert_array_equal(
            np.squeeze(cch_unclipped.times), np.squeeze(
                cch_unclipped_mem.times))
        assert_array_equal(bin_ids_clipped, bin_ids_clipped_mem)
        assert_array_equal(bin_ids_unclipped, bin_ids_unclipped_mem)

        # Check normal correlation Note: Use numpy correlate to verify result.
        # Note: numpy conventions for input array 1 and input array 2 are
        # swapped compared to Elephant!
        mat1 = self.binned_st1.to_array()[0]
        mat2 = self.binned_st2.to_array()[0]
        target_numpy = np.correlate(mat2, mat1, mode='valid')
        assert_array_equal(
            target_numpy, np.squeeze(cch_unclipped.magnitude))

        # Check correlation using binary spike trains
        mat1 = np.array(self.binned_st1.to_bool_array()[0], dtype=int)
        mat2 = np.array(self.binned_st2.to_bool_array()[0], dtype=int)
        target_numpy = np.correlate(mat2, mat1, mode='valid')
        assert_array_equal(
            target_numpy, np.squeeze(cch_clipped.magnitude))

        # Check the time axis and bin IDs of the resulting AnalogSignal
        assert_array_equal(
            (bin_ids_clipped - 0.5) * self.binned_st1.binsize,
            cch_unclipped.times)
        assert_array_equal(
            (bin_ids_clipped - 0.5) * self.binned_st1.binsize,
            cch_clipped.times)

        # Check for wrong window parameter setting
        self.assertRaises(
            KeyError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window='dsaij')
        self.assertRaises(
            KeyError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window='dsaij', method='memory')
コード例 #5
0
    def test_cross_correlation_histogram(self):
        '''
        Test generic result of a cross-correlation histogram between two binned
        spike trains.
        '''
        # Calculate CCH using Elephant (normal and binary version) with
        # mode equal to 'full' (whole spike trains are correlated)
        cch_clipped, bin_ids_clipped = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='full',
            binary=True)
        cch_unclipped, bin_ids_unclipped = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='full', binary=False)

        cch_clipped_mem, bin_ids_clipped_mem = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='full',
            binary=True, method='memory')
        cch_unclipped_mem, bin_ids_unclipped_mem = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='full',
            binary=False, method='memory')
        # Check consistency two methods
        assert_array_equal(
            np.squeeze(cch_clipped.magnitude), np.squeeze(
                cch_clipped_mem.magnitude))
        assert_array_equal(
            np.squeeze(cch_clipped.times), np.squeeze(
                cch_clipped_mem.times))
        assert_array_equal(
            np.squeeze(cch_unclipped.magnitude), np.squeeze(
                cch_unclipped_mem.magnitude))
        assert_array_equal(
            np.squeeze(cch_unclipped.times), np.squeeze(
                cch_unclipped_mem.times))
        assert_array_almost_equal(bin_ids_clipped, bin_ids_clipped_mem)
        assert_array_almost_equal(bin_ids_unclipped, bin_ids_unclipped_mem)

        # Check normal correlation Note: Use numpy correlate to verify result.
        # Note: numpy conventions for input array 1 and input array 2 are
        # swapped compared to Elephant!
        mat1 = self.binned_st1.to_array()[0]
        mat2 = self.binned_st2.to_array()[0]
        target_numpy = np.correlate(mat2, mat1, mode='full')
        assert_array_equal(
            target_numpy, np.squeeze(cch_unclipped.magnitude))


        # Check cross correlation function for several displacements tau
        # Note: Use Elephant corrcoeff to verify result 
        tau = [-25.0, 0.0, 13.0] # in ms
        for t in tau:
            # adjust t_start, t_stop to shift by tau
            t0 = np.min([self.st_1.t_start+t*pq.ms, self.st_2.t_start])
            t1 = np.max([self.st_1.t_stop+t*pq.ms, self.st_2.t_stop])            
            st1 = neo.SpikeTrain(self.st_1.magnitude+t, units='ms',
                                t_start = t0*pq.ms, t_stop = t1*pq.ms)           
            st2 = neo.SpikeTrain(self.st_2.magnitude, units='ms',
                                t_start = t0*pq.ms, t_stop = t1*pq.ms)              
            binned_sts = conv.BinnedSpikeTrain([st1, st2],
                                               binsize=1*pq.ms,
                                               t_start = t0*pq.ms,
                                               t_stop = t1*pq.ms)
            # caluclate corrcoef
            corrcoef = sc.corrcoef(binned_sts)[1,0]    
            
            # expand t_stop to have two spike trains with same length as st1, st2
            st1 = neo.SpikeTrain(self.st_1.magnitude, units='ms',
                                t_start = self.st_1.t_start, 
                                t_stop = self.st_1.t_stop+np.abs(t)*pq.ms)           
            st2 = neo.SpikeTrain(self.st_2.magnitude, units='ms',
                                t_start = self.st_2.t_start, 
                                t_stop = self.st_2.t_stop+np.abs(t)*pq.ms)
            binned_st1 = conv.BinnedSpikeTrain(
                st1, t_start=0*pq.ms, t_stop=(50+np.abs(t))*pq.ms,
                binsize=1 * pq.ms)
            binned_st2 = conv.BinnedSpikeTrain(
                st2, t_start=0 * pq.ms, t_stop=(50+np.abs(t))*pq.ms,
                binsize=1 * pq.ms)
            # calculate CCHcoef and take value at t=tau
            CCHcoef, _  = sc.cch(binned_st1, binned_st2, 
                               cross_corr_coef=True)
            l = - binned_st1.num_bins + 1
            tau_bin = int(t/float(binned_st1.binsize.magnitude))
            assert_array_equal(corrcoef, CCHcoef[tau_bin-l].magnitude)
            

        # Check correlation using binary spike trains
        mat1 = np.array(self.binned_st1.to_bool_array()[0], dtype=int)
        mat2 = np.array(self.binned_st2.to_bool_array()[0], dtype=int)
        target_numpy = np.correlate(mat2, mat1, mode='full')
        assert_array_equal(
            target_numpy, np.squeeze(cch_clipped.magnitude))

        # Check the time axis and bin IDs of the resulting AnalogSignal
        assert_array_almost_equal(
            (bin_ids_clipped - 0.5) * self.binned_st1.binsize,
            cch_unclipped.times)
        assert_array_almost_equal(
            (bin_ids_clipped - 0.5) * self.binned_st1.binsize,
            cch_clipped.times)

        # Calculate CCH using Elephant (normal and binary version) with
        # mode equal to 'valid' (only completely overlapping intervals of the
        # spike trains are correlated)
        cch_clipped, bin_ids_clipped = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='valid',
            binary=True)
        cch_unclipped, bin_ids_unclipped = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='valid',
            binary=False)
        cch_clipped_mem, bin_ids_clipped_mem = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='valid',
            binary=True, method='memory')
        cch_unclipped_mem, bin_ids_unclipped_mem = sc.cross_correlation_histogram(
            self.binned_st1, self.binned_st2, window='valid',
            binary=False, method='memory')

        # Check consistency two methods
        assert_array_equal(
            np.squeeze(cch_clipped.magnitude), np.squeeze(
                cch_clipped_mem.magnitude))
        assert_array_equal(
            np.squeeze(cch_clipped.times), np.squeeze(
                cch_clipped_mem.times))
        assert_array_equal(
            np.squeeze(cch_unclipped.magnitude), np.squeeze(
                cch_unclipped_mem.magnitude))
        assert_array_equal(
            np.squeeze(cch_unclipped.times), np.squeeze(
                cch_unclipped_mem.times))
        assert_array_equal(bin_ids_clipped, bin_ids_clipped_mem)
        assert_array_equal(bin_ids_unclipped, bin_ids_unclipped_mem)

        # Check normal correlation Note: Use numpy correlate to verify result.
        # Note: numpy conventions for input array 1 and input array 2 are
        # swapped compared to Elephant!
        mat1 = self.binned_st1.to_array()[0]
        mat2 = self.binned_st2.to_array()[0]
        target_numpy = np.correlate(mat2, mat1, mode='valid')
        assert_array_equal(
            target_numpy, np.squeeze(cch_unclipped.magnitude))

        # Check correlation using binary spike trains
        mat1 = np.array(self.binned_st1.to_bool_array()[0], dtype=int)
        mat2 = np.array(self.binned_st2.to_bool_array()[0], dtype=int)
        target_numpy = np.correlate(mat2, mat1, mode='valid')
        assert_array_equal(
            target_numpy, np.squeeze(cch_clipped.magnitude))

        # Check the time axis and bin IDs of the resulting AnalogSignal
        assert_array_equal(
            (bin_ids_clipped - 0.5) * self.binned_st1.binsize,
            cch_unclipped.times)
        assert_array_equal(
            (bin_ids_clipped - 0.5) * self.binned_st1.binsize,
            cch_clipped.times)

        # Check for wrong window parameter setting
        self.assertRaises(
            KeyError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window='dsaij')
        self.assertRaises(
            KeyError, sc.cross_correlation_histogram, self.binned_st1,
            self.binned_st2, window='dsaij', method='memory')
コード例 #6
0
def run_net(tr):

    # prefs.codegen.target = 'numpy'
    # prefs.codegen.target = 'cython'
    if tr.n_threads > 1:
        prefs.devices.cpp_standalone.openmp_threads = tr.n_threads

    set_device('cpp_standalone',
               directory='./builds/%.4d' % (tr.v_idx),
               build_on_run=False)

    print("Started process with id ", str(tr.v_idx))

    T = tr.T1 + tr.T2 + tr.T3 + tr.T4 + tr.T5

    namespace = tr.netw.f_to_dict(short_names=True, fast_access=True)
    namespace['idx'] = tr.v_idx

    defaultclock.dt = tr.netw.sim.dt

    # collect all network components dependent on configuration
    # (e.g. poisson vs. memnoise) and add them to the Brian 2
    # network object later
    netw_objects = []

    if tr.external_mode == 'memnoise':
        neuron_model = tr.condlif_memnoise
    elif tr.external_mode == 'poisson':
        neuron_model = tr.condlif_poisson

    GExc = NeuronGroup(
        N=tr.N_e,
        model=neuron_model,
        threshold=tr.nrnEE_thrshld,
        reset=tr.nrnEE_reset,  #method=tr.neuron_method,
        namespace=namespace)
    GInh = NeuronGroup(
        N=tr.N_i,
        model=neuron_model,
        threshold='V > Vt',
        reset='V=Vr_i',  #method=tr.neuron_method,
        namespace=namespace)

    if tr.external_mode == 'memnoise':
        GExc.mu, GInh.mu = tr.mu_e, tr.mu_i
        GExc.sigma, GInh.sigma = tr.sigma_e, tr.sigma_i

    GExc.Vt, GInh.Vt = tr.Vt_e, tr.Vt_i
    GExc.V , GInh.V  = np.random.uniform(tr.Vr_e/mV, tr.Vt_e/mV,
                                         size=tr.N_e)*mV, \
                       np.random.uniform(tr.Vr_i/mV, tr.Vt_i/mV,
                                         size=tr.N_i)*mV

    netw_objects.extend([GExc, GInh])

    synEE_pre_mod = mod.synEE_pre
    synEE_post_mod = mod.synEE_post

    if tr.external_mode == 'poisson':

        if tr.PInp_mode == 'pool':
            PInp = PoissonGroup(tr.NPInp,
                                rates=tr.PInp_rate,
                                namespace=namespace,
                                name='poissongroup_exc')
            sPN = Synapses(target=GExc,
                           source=PInp,
                           model=tr.poisson_mod,
                           on_pre='gfwd_post += a_EPoi',
                           namespace=namespace,
                           name='synPInpExc')

            sPN_src, sPN_tar = generate_N_connections(N_tar=tr.N_e,
                                                      N_src=tr.NPInp,
                                                      N=tr.NPInp_1n)

        elif tr.PInp_mode == 'indep':
            PInp = PoissonGroup(tr.N_e,
                                rates=tr.PInp_rate,
                                namespace=namespace)
            sPN = Synapses(target=GExc,
                           source=PInp,
                           model=tr.poisson_mod,
                           on_pre='gfwd_post += a_EPoi',
                           namespace=namespace,
                           name='synPInp_inhInh')
            sPN_src, sPN_tar = range(tr.N_e), range(tr.N_e)

        sPN.connect(i=sPN_src, j=sPN_tar)

        if tr.PInp_mode == 'pool':
            PInp_inh = PoissonGroup(tr.NPInp_inh,
                                    rates=tr.PInp_inh_rate,
                                    namespace=namespace,
                                    name='poissongroup_inh')
            sPNInh = Synapses(target=GInh,
                              source=PInp_inh,
                              model=tr.poisson_mod,
                              on_pre='gfwd_post += a_EPoi',
                              namespace=namespace)
            sPNInh_src, sPNInh_tar = generate_N_connections(N_tar=tr.N_i,
                                                            N_src=tr.NPInp_inh,
                                                            N=tr.NPInp_inh_1n)

        elif tr.PInp_mode == 'indep':

            PInp_inh = PoissonGroup(tr.N_i,
                                    rates=tr.PInp_inh_rate,
                                    namespace=namespace)
            sPNInh = Synapses(target=GInh,
                              source=PInp_inh,
                              model=tr.poisson_mod,
                              on_pre='gfwd_post += a_EPoi',
                              namespace=namespace)
            sPNInh_src, sPNInh_tar = range(tr.N_i), range(tr.N_i)

        sPNInh.connect(i=sPNInh_src, j=sPNInh_tar)

        netw_objects.extend([PInp, sPN, PInp_inh, sPNInh])

    if tr.syn_noise:
        synEE_mod = '''%s 
                       %s''' % (tr.synEE_noise, tr.synEE_mod)
    else:
        synEE_mod = '''%s 
                       %s''' % (tr.synEE_static, tr.synEE_mod)

    if tr.stdp_active:
        synEE_pre_mod = '''%s 
                            %s''' % (synEE_pre_mod, mod.synEE_pre_STDP)
        synEE_post_mod = '''%s 
                            %s''' % (synEE_post_mod, mod.synEE_post_STDP)

    if tr.synEE_rec:
        synEE_pre_mod = '''%s 
                            %s''' % (synEE_pre_mod, mod.synEE_pre_rec)
        synEE_post_mod = '''%s 
                            %s''' % (synEE_post_mod, mod.synEE_post_rec)

    # E<-E advanced synapse model, rest simple
    SynEE = Synapses(target=GExc,
                     source=GExc,
                     model=synEE_mod,
                     on_pre=synEE_pre_mod,
                     on_post=synEE_post_mod,
                     namespace=namespace,
                     dt=tr.synEE_mod_dt)
    SynIE = Synapses(target=GInh,
                     source=GExc,
                     on_pre='ge_post += a_ie',
                     namespace=namespace)
    SynEI = Synapses(target=GExc,
                     source=GInh,
                     on_pre='gi_post += a_ei',
                     namespace=namespace)
    SynII = Synapses(target=GInh,
                     source=GInh,
                     on_pre='gi_post += a_ii',
                     namespace=namespace)

    if tr.strct_active:
        sEE_src, sEE_tar = generate_full_connectivity(tr.N_e, same=True)
        SynEE.connect(i=sEE_src, j=sEE_tar)
        SynEE.syn_active = 0

    else:
        srcs_full, tars_full = generate_full_connectivity(tr.N_e, same=True)
        SynEE.connect(i=srcs_full, j=tars_full)
        SynEE.syn_active = 0

    sIE_src, sIE_tar = generate_connections(tr.N_i, tr.N_e, tr.p_ie)
    sEI_src, sEI_tar = generate_connections(tr.N_e, tr.N_i, tr.p_ei)
    sII_src, sII_tar = generate_connections(tr.N_i, tr.N_i, tr.p_ii, same=True)

    SynIE.connect(i=sIE_src, j=sIE_tar)
    SynEI.connect(i=sEI_src, j=sEI_tar)
    SynII.connect(i=sII_src, j=sII_tar)

    tr.f_add_result('sIE_src', sIE_src)
    tr.f_add_result('sIE_tar', sIE_tar)
    tr.f_add_result('sEI_src', sEI_src)
    tr.f_add_result('sEI_tar', sEI_tar)
    tr.f_add_result('sII_src', sII_src)
    tr.f_add_result('sII_tar', sII_tar)

    if tr.syn_noise:
        SynEE.syn_sigma = tr.syn_sigma

    SynEE.insert_P = tr.insert_P
    SynEE.p_inactivate = tr.p_inactivate
    SynEE.stdp_active = 1

    ATM_vals = np.random.normal(loc=tr.ATotalMax,
                                scale=tr.ATotalMax_sd,
                                size=tr.N_e * (tr.N_e - 1))
    assert np.min(ATM_vals) > 0.
    SynEE.ATotalMax = ATM_vals

    # make randomly chosen synapses active at beginning
    rs = np.random.uniform(size=tr.N_e * (tr.N_e - 1))
    initial_active = (rs < tr.p_ee).astype('int')
    initial_a = initial_active * tr.a_ee
    SynEE.syn_active = initial_active
    SynEE.a = initial_a

    # recording of stdp in T4
    SynEE.stdp_rec_start = tr.T1 + tr.T2 + tr.T3
    SynEE.stdp_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.stdp_rec_T

    # synaptic scaling
    if tr.netw.config.scl_active:

        if tr.syn_scl_rec:
            SynEE.scl_rec_start = tr.T1 + tr.T2 + tr.T3
            SynEE.scl_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.scl_rec_T
        else:
            SynEE.scl_rec_start = T + 10 * second
            SynEE.scl_rec_max = T

        SynEE.summed_updaters['Asum_post']._clock = Clock(
            dt=tr.dt_synEE_scaling)
        synscaling = SynEE.run_regularly(tr.synEE_scaling,
                                         dt=tr.dt_synEE_scaling,
                                         when='end',
                                         name='syn_scaling')

    # # intrinsic plasticity
    # if tr.netw.config.it_active:
    #     GExc.h_ip = tr.h_ip
    #     GExc.run_regularly(tr.intrinsic_mod, dt = tr.it_dt, when='end')

    # structural plasticity
    if tr.netw.config.strct_active:
        if tr.strct_mode == 'zero':
            if tr.turnover_rec:
                strct_mod = '''%s 
                                %s''' % (tr.strct_mod, tr.turnover_rec_mod)
            else:
                strct_mod = tr.strct_mod

            strctplst = SynEE.run_regularly(strct_mod,
                                            dt=tr.strct_dt,
                                            when='end',
                                            name='strct_plst_zero')

        elif tr.strct_mode == 'thrs':
            if tr.turnover_rec:
                strct_mod_thrs = '''%s 
                                %s''' % (tr.strct_mod_thrs,
                                         tr.turnover_rec_mod)
            else:
                strct_mod_thrs = tr.strct_mod_thrs

            strctplst = SynEE.run_regularly(strct_mod_thrs,
                                            dt=tr.strct_dt,
                                            when='end',
                                            name='strct_plst_thrs')

    netw_objects.extend([SynEE, SynEI, SynIE, SynII])

    # keep track of the number of active synapses
    sum_target = NeuronGroup(1, 'c : 1 (shared)', dt=tr.csample_dt)

    sum_model = '''NSyn : 1 (constant)
                   c_post = (1.0*syn_active_pre)/NSyn : 1 (summed)'''
    sum_connection = Synapses(target=sum_target,
                              source=SynEE,
                              model=sum_model,
                              dt=tr.csample_dt,
                              name='get_active_synapse_count')
    sum_connection.connect()
    sum_connection.NSyn = tr.N_e * (tr.N_e - 1)

    if tr.adjust_insertP:
        # homeostatically adjust growth rate
        growth_updater = Synapses(sum_target, SynEE)
        growth_updater.run_regularly('insert_P_post *= 0.1/c_pre',
                                     when='after_groups',
                                     dt=tr.csample_dt,
                                     name='update_insP')
        growth_updater.connect(j='0')

        netw_objects.extend([sum_target, sum_connection, growth_updater])

    # -------------- recording ------------------

    GExc_recvars = []
    if tr.memtraces_rec:
        GExc_recvars.append('V')
    if tr.vttraces_rec:
        GExc_recvars.append('Vt')
    if tr.getraces_rec:
        GExc_recvars.append('ge')
    if tr.gitraces_rec:
        GExc_recvars.append('gi')
    if tr.gfwdtraces_rec and tr.external_mode == 'poisson':
        GExc_recvars.append('gfwd')

    GInh_recvars = GExc_recvars

    GExc_stat = StateMonitor(GExc,
                             GExc_recvars,
                             record=[0, 1, 2],
                             dt=tr.GExc_stat_dt)
    GInh_stat = StateMonitor(GInh,
                             GInh_recvars,
                             record=[0, 1, 2],
                             dt=tr.GInh_stat_dt)

    SynEE_recvars = []
    if tr.synee_atraces_rec:
        SynEE_recvars.append('a')
    if tr.synee_activetraces_rec:
        SynEE_recvars.append('syn_active')
    if tr.synee_Apretraces_rec:
        SynEE_recvars.append('Apre')
    if tr.synee_Aposttraces_rec:
        SynEE_recvars.append('Apost')

    SynEE_stat = StateMonitor(SynEE,
                              SynEE_recvars,
                              record=range(tr.n_synee_traces_rec),
                              when='end',
                              dt=tr.synEE_stat_dt)

    if tr.adjust_insertP:

        C_stat = StateMonitor(sum_target,
                              'c',
                              dt=tr.csample_dt,
                              record=[0],
                              when='end')
        insP_stat = StateMonitor(SynEE,
                                 'insert_P',
                                 dt=tr.csample_dt,
                                 record=[0],
                                 when='end')
        netw_objects.extend([C_stat, insP_stat])

    GExc_spks = SpikeMonitor(GExc)
    GInh_spks = SpikeMonitor(GInh)

    GExc_rate = PopulationRateMonitor(GExc)
    GInh_rate = PopulationRateMonitor(GInh)

    if tr.external_mode == 'poisson':
        PInp_spks = SpikeMonitor(PInp)
        PInp_rate = PopulationRateMonitor(PInp)
        netw_objects.extend([PInp_spks, PInp_rate])

    if tr.synee_a_nrecpoints == 0:
        SynEE_a_dt = 10 * tr.sim.T2
    else:
        SynEE_a_dt = tr.sim.T2 / tr.synee_a_nrecpoints
    SynEE_a = StateMonitor(SynEE, ['a', 'syn_active'],
                           record=range(tr.N_e * (tr.N_e - 1)),
                           dt=SynEE_a_dt,
                           when='end',
                           order=100)

    netw_objects.extend([
        GExc_stat, GInh_stat, SynEE_stat, SynEE_a, GExc_spks, GInh_spks,
        GExc_rate, GInh_rate
    ])

    net = Network(*netw_objects)

    def set_active(*argv):
        for net_object in argv:
            net_object.active = True

    def set_inactive(*argv):
        for net_object in argv:
            net_object.active = False

    ### Simulation periods

    # --------- T1 ---------
    # initial recording period,
    # all recorders active

    set_active(GExc_spks, GInh_spks, SynEE_stat, GExc_stat, GInh_stat,
               GExc_rate, GInh_rate)

    if tr.external_mode == 'poisson':
        set_active(PInp_spks, PInp_rate)

    net.run(tr.sim.T1, report='text', report_period=300 * second, profile=True)

    # --------- T2 ---------
    # main simulation period
    # only active recordings are:
    #   1) turnover 2) C_stat 3) SynEE_a

    set_inactive(GExc_spks, GInh_spks, SynEE_stat, GExc_stat, GInh_stat,
                 GExc_rate, GInh_rate)

    if tr.external_mode == 'poisson':
        set_inactive(PInp_spks, PInp_rate)

    net.run(tr.sim.T2, report='text', report_period=300 * second, profile=True)

    # --------- T3 ---------
    # second recording period,
    # all recorders active

    set_active(GExc_spks, GInh_spks, SynEE_stat, GExc_stat, GInh_stat,
               GExc_rate, GInh_rate)

    if tr.external_mode == 'poisson':
        set_active(PInp_spks, PInp_rate)

    net.run(tr.sim.T3, report='text', report_period=300 * second, profile=True)

    # --------- T4 ---------
    # record STDP and scaling weight changes to file
    # through the cpp models

    set_inactive(GExc_spks, GInh_spks, SynEE_stat, GExc_stat, GInh_stat,
                 GExc_rate, GInh_rate)

    if tr.external_mode == 'poisson':
        set_inactive(PInp_spks, PInp_rate)

    net.run(tr.sim.T4, report='text', report_period=300 * second, profile=True)

    # --------- T5 ---------
    # freeze network and record Exc spikes
    # for cross correlations

    synscaling.active = False
    strctplst.active = False
    SynEE.stdp_active = 0

    set_active(GExc_spks)

    net.run(tr.sim.T5, report='text', report_period=300 * second, profile=True)

    SynEE_a.record_single_timestep()

    device.build(directory='builds/%.4d' % (tr.v_idx),
                 clean=True,
                 compile=True,
                 run=True,
                 debug=False)

    # -----------------------------------------

    # save monitors as raws in build directory
    raw_dir = 'builds/%.4d/raw/' % (tr.v_idx)

    if not os.path.exists(raw_dir):
        os.makedirs(raw_dir)

    with open(raw_dir + 'namespace.p', 'wb') as pfile:
        pickle.dump(namespace, pfile)

    with open(raw_dir + 'gexc_stat.p', 'wb') as pfile:
        pickle.dump(GExc_stat.get_states(), pfile)
    with open(raw_dir + 'ginh_stat.p', 'wb') as pfile:
        pickle.dump(GInh_stat.get_states(), pfile)

    with open(raw_dir + 'synee_stat.p', 'wb') as pfile:
        pickle.dump(SynEE_stat.get_states(), pfile)
    with open(raw_dir + 'synee_a.p', 'wb') as pfile:
        SynEE_a_states = SynEE_a.get_states()
        if tr.crs_crrs_rec:
            SynEE_a_states['i'] = list(SynEE.i)
            SynEE_a_states['j'] = list(SynEE.j)
        pickle.dump(SynEE_a_states, pfile)

    if tr.adjust_insertP:
        with open(raw_dir + 'c_stat.p', 'wb') as pfile:
            pickle.dump(C_stat.get_states(), pfile)

        with open(raw_dir + 'insP_stat.p', 'wb') as pfile:
            pickle.dump(insP_stat.get_states(), pfile)

    with open(raw_dir + 'gexc_spks.p', 'wb') as pfile:
        pickle.dump(GExc_spks.get_states(), pfile)
    with open(raw_dir + 'ginh_spks.p', 'wb') as pfile:
        pickle.dump(GInh_spks.get_states(), pfile)

    if tr.external_mode == 'poisson':
        with open(raw_dir + 'pinp_spks.p', 'wb') as pfile:
            pickle.dump(PInp_spks.get_states(), pfile)

    with open(raw_dir + 'gexc_rate.p', 'wb') as pfile:
        pickle.dump(GExc_rate.get_states(), pfile)
        if tr.rates_rec:
            pickle.dump(GExc_rate.smooth_rate(width=25 * ms), pfile)
    with open(raw_dir + 'ginh_rate.p', 'wb') as pfile:
        pickle.dump(GInh_rate.get_states(), pfile)
        if tr.rates_rec:
            pickle.dump(GInh_rate.smooth_rate(width=25 * ms), pfile)

    if tr.external_mode == 'poisson':
        with open(raw_dir + 'pinp_rate.p', 'wb') as pfile:
            pickle.dump(PInp_rate.get_states(), pfile)
            if tr.rates_rec:
                pickle.dump(PInp_rate.smooth_rate(width=25 * ms), pfile)

    # ----------------- add raw data ------------------------
    fpath = 'builds/%.4d/' % (tr.v_idx)

    from pathlib import Path

    Path(fpath + 'turnover').touch()
    turnover_data = np.genfromtxt(fpath + 'turnover', delimiter=',')
    os.remove(fpath + 'turnover')

    with open(raw_dir + 'turnover.p', 'wb') as pfile:
        pickle.dump(turnover_data, pfile)

    Path(fpath + 'spk_register').touch()
    spk_register_data = np.genfromtxt(fpath + 'spk_register', delimiter=',')
    os.remove(fpath + 'spk_register')

    with open(raw_dir + 'spk_register.p', 'wb') as pfile:
        pickle.dump(spk_register_data, pfile)

    Path(fpath + 'scaling_deltas').touch()
    scaling_deltas_data = np.genfromtxt(fpath + 'scaling_deltas',
                                        delimiter=',')
    os.remove(fpath + 'scaling_deltas')

    with open(raw_dir + 'scaling_deltas.p', 'wb') as pfile:
        pickle.dump(scaling_deltas_data, pfile)

    with open(raw_dir + 'profiling_summary.txt', 'w+') as tfile:
        tfile.write(str(profiling_summary(net)))

    # --------------- cross-correlations ---------------------

    if tr.crs_crrs_rec:

        GExc_spks = GExc_spks.get_states()
        synee_a = SynEE_a_states
        wsize = 100 * pq.ms

        for binsize in [1 * pq.ms, 2 * pq.ms, 5 * pq.ms]:

            wlen = int(wsize / binsize)

            ts, idxs = GExc_spks['t'], GExc_spks['i']
            idxs = idxs[ts > tr.T1 + tr.T2 + tr.T3 + tr.T4]
            ts = ts[ts > tr.T1 + tr.T2 + tr.T3 + tr.T4]
            ts = ts - (tr.T1 + tr.T2 + tr.T3 + tr.T4)

            sts = [
                neo.SpikeTrain(ts[idxs == i] / second * pq.s,
                               t_stop=tr.T5 / second * pq.s)
                for i in range(tr.N_e)
            ]

            crs_crrs, syn_a = [], []

            for f, (i, j) in enumerate(zip(synee_a['i'], synee_a['j'])):
                if synee_a['syn_active'][-1][f] == 1:

                    crs_crr, cbin = cch(BinnedSpikeTrain(sts[i],
                                                         binsize=binsize),
                                        BinnedSpikeTrain(sts[j],
                                                         binsize=binsize),
                                        cross_corr_coef=True,
                                        border_correction=True,
                                        window=(-1 * wlen, wlen))

                    crs_crrs.append(list(np.array(crs_crr).T[0]))
                    syn_a.append(synee_a['a'][-1][f])

            fname = 'crs_crrs_wsize%dms_binsize%fms_full' % (wsize / pq.ms,
                                                             binsize / pq.ms)

            df = {
                'cbin': cbin,
                'crs_crrs': np.array(crs_crrs),
                'syn_a': np.array(syn_a),
                'binsize': binsize,
                'wsize': wsize,
                'wlen': wlen
            }

            with open('builds/%.4d/raw/' % (tr.v_idx) + fname + '.p',
                      'wb') as pfile:
                pickle.dump(df, pfile)

    # -----------------  clean up  ---------------------------
    shutil.rmtree('builds/%.4d/results/' % (tr.v_idx))

    # ---------------- plot results --------------------------

    #os.chdir('./analysis/file_based/')

    from analysis.overview_fb import overview_figure
    overview_figure('builds/%.4d' % (tr.v_idx), namespace)

    from analysis.synw_fb import synw_figure
    synw_figure('builds/%.4d' % (tr.v_idx), namespace)

    from analysis.synw_log_fb import synw_log_figure
    synw_log_figure('builds/%.4d' % (tr.v_idx), namespace)
コード例 #7
0
    def generate_cch_array(self, spiketrains, maxlag=None, **kwargs):
        if not hasattr(self, 'cch_array'):
            return self.cch_array
        else:
            if 'binsize' in self.params:
                binsize = self.params['binsize']
            else:
                t_lims = [(st.t_start, st.t_stop) for st in spiketrains]
                tmin = min(t_lims, key=lambda f: f[0])[0]
                tmax = max(t_lims, key=lambda f: f[1])[1]
                T = tmax - tmin
                binsize = T / float(self.params['num_bins'])
            if maxlag is None:
                maxlag = self.params['maxlag']
            else:
                self.params['maxlag'] = maxlag
            if type(maxlag) == quantity.Quantity:
                maxlag = int(
                    float(maxlag.rescale('ms')) / float(binsize.rescale('ms')))
            try:
                from mpi4py import MPI
                mpi = True
            except:
                mpi = False
            N = len(spiketrains)
            B = 2 * maxlag + 1
            pairs_idx = np.triu_indices(N, 1)
            pairs = [[i, j] for i, j in zip(pairs_idx[0], pairs_idx[1])]
            if mpi:
                comm = MPI.COMM_WORLD
                rank = comm.Get_rank()
                Nnodes = comm.Get_size()
                comm.Barrier()
                if rank == 0:
                    split = np.array_split(pairs, Nnodes)
                else:
                    split = None
                pair_per_node = int(np.ceil(float(len(pairs)) / Nnodes))
                split_pairs = comm.scatter(split, root=0)
            else:
                split_pairs = pairs
                pair_per_node = len(pairs)

            cch_array = np.zeros((pair_per_node, B))
            max_cc = 0
            for count, (i, j) in enumerate(split_pairs):
                binned_sts_i = self.robust_BinnedSpikeTrain(spiketrains[i],
                                                            binsize=binsize)
                binned_sts_j = self.robust_BinnedSpikeTrain(spiketrains[j],
                                                            binsize=binsize)
                cch_array[count] = np.squeeze(
                    cch(binned_sts_i,
                        binned_sts_j,
                        window=[-maxlag, maxlag],
                        cross_corr_coef=True,
                        **kwargs)[0])
                max_cc = max([max_cc, max(cch_array[count])])
            if mpi:
                pop_cch = comm.gather(cch_array, root=0)
                pop_max_cc = comm.gather(max_cc, root=0)
                if rank == 0:
                    cch_array = pop_cch
                    max_cc = pop_max_cc
            self.cch_array = cch_array
            self.max_cc = max_cc
            return self.cch_array
コード例 #8
0
def run_net(tr):

    # prefs.codegen.target = 'numpy'
    # prefs.codegen.target = 'cython'
    if tr.n_threads > 1:
        prefs.devices.cpp_standalone.openmp_threads = tr.n_threads

    set_device('cpp_standalone',
               directory='./builds/%.4d' % (tr.v_idx),
               build_on_run=False)

    # set brian 2 and numpy random seeds
    seed(tr.random_seed)
    np.random.seed(tr.random_seed + 11)

    print("Started process with id ", str(tr.v_idx))

    T = tr.T1 + tr.T2 + tr.T3 + tr.T4 + tr.T5

    namespace = tr.netw.f_to_dict(short_names=True, fast_access=True)
    namespace['idx'] = tr.v_idx

    defaultclock.dt = tr.netw.sim.dt

    # collect all network components dependent on configuration
    # (e.g. poisson vs. memnoise) and add them to the Brian 2
    # network object later
    netw_objects = []

    if tr.external_mode == 'memnoise':
        neuron_model = tr.condlif_memnoise
    elif tr.external_mode == 'poisson':
        raise NotImplementedError
        #neuron_model = tr.condlif_poisson

    if tr.syn_cond_mode == 'exp':
        neuron_model += tr.syn_cond_EE_exp
        print("Using EE exp mode")
    elif tr.syn_cond_mode == 'alpha':
        neuron_model += tr.syn_cond_EE_alpha
        print("Using EE alpha mode")
    elif tr.syn_cond_mode == 'biexp':
        neuron_model += tr.syn_cond_EE_biexp
        namespace['invpeakEE'] = (tr.tau_e / tr.tau_e_rise) ** \
            (tr.tau_e_rise / (tr.tau_e - tr.tau_e_rise))
        print("Using EE biexp mode")

    if tr.syn_cond_mode_EI == 'exp':
        neuron_model += tr.syn_cond_EI_exp
        print("Using EI exp mode")
    elif tr.syn_cond_mode_EI == 'alpha':
        neuron_model += tr.syn_cond_EI_alpha
        print("Using EI alpha mode")
    elif tr.syn_cond_mode_EI == 'biexp':
        neuron_model += tr.syn_cond_EI_biexp
        namespace['invpeakEI'] = (tr.tau_i / tr.tau_i_rise) ** \
            (tr.tau_i_rise / (tr.tau_i - tr.tau_i_rise))
        print("Using EI biexp mode")

    GExc = NeuronGroup(
        N=tr.N_e,
        model=neuron_model,
        threshold=tr.nrnEE_thrshld,
        reset=tr.nrnEE_reset,  #method=tr.neuron_method,
        name='GExc',
        namespace=namespace)
    GInh = NeuronGroup(
        N=tr.N_i,
        model=neuron_model,
        threshold='V > Vt',
        reset='V=Vr_i',  #method=tr.neuron_method,
        name='GInh',
        namespace=namespace)

    if tr.external_mode == 'memnoise':
        # GExc.mu, GInh.mu = [0.*mV] + (tr.N_e-1)*[tr.mu_e], tr.mu_i
        # GExc.sigma, GInh.sigma = [0.*mV] + (tr.N_e-1)*[tr.sigma_e], tr.sigma_i
        GExc.mu, GInh.mu = tr.mu_e, tr.mu_i
        GExc.sigma, GInh.sigma = tr.sigma_e, tr.sigma_i

    GExc.Vt, GInh.Vt = tr.Vt_e, tr.Vt_i
    GExc.V , GInh.V  = np.random.uniform(tr.Vr_e/mV, tr.Vt_e/mV,
                                         size=tr.N_e)*mV, \
                       np.random.uniform(tr.Vr_i/mV, tr.Vt_i/mV,
                                         size=tr.N_i)*mV

    netw_objects.extend([GExc, GInh])

    if tr.external_mode == 'poisson':

        if tr.PInp_mode == 'pool':
            PInp = PoissonGroup(tr.NPInp,
                                rates=tr.PInp_rate,
                                namespace=namespace,
                                name='poissongroup_exc')
            sPN = Synapses(target=GExc,
                           source=PInp,
                           model=tr.poisson_mod,
                           on_pre='gfwd_post += a_EPoi',
                           namespace=namespace,
                           name='synPInpExc')

            sPN_src, sPN_tar = generate_N_connections(N_tar=tr.N_e,
                                                      N_src=tr.NPInp,
                                                      N=tr.NPInp_1n)

        elif tr.PInp_mode == 'indep':
            PInp = PoissonGroup(tr.N_e,
                                rates=tr.PInp_rate,
                                namespace=namespace)
            sPN = Synapses(target=GExc,
                           source=PInp,
                           model=tr.poisson_mod,
                           on_pre='gfwd_post += a_EPoi',
                           namespace=namespace,
                           name='synPInp_inhInh')
            sPN_src, sPN_tar = range(tr.N_e), range(tr.N_e)

        sPN.connect(i=sPN_src, j=sPN_tar)

        if tr.PInp_mode == 'pool':

            PInp_inh = PoissonGroup(tr.NPInp_inh,
                                    rates=tr.PInp_inh_rate,
                                    namespace=namespace,
                                    name='poissongroup_inh')

            sPNInh = Synapses(target=GInh,
                              source=PInp_inh,
                              model=tr.poisson_mod,
                              on_pre='gfwd_post += a_EPoi',
                              namespace=namespace)

            sPNInh_src, sPNInh_tar = generate_N_connections(N_tar=tr.N_i,
                                                            N_src=tr.NPInp_inh,
                                                            N=tr.NPInp_inh_1n)

        elif tr.PInp_mode == 'indep':

            PInp_inh = PoissonGroup(tr.N_i,
                                    rates=tr.PInp_inh_rate,
                                    namespace=namespace)

            sPNInh = Synapses(target=GInh,
                              source=PInp_inh,
                              model=tr.poisson_mod,
                              on_pre='gfwd_post += a_EPoi',
                              namespace=namespace)

            sPNInh_src, sPNInh_tar = range(tr.N_i), range(tr.N_i)

        sPNInh.connect(i=sPNInh_src, j=sPNInh_tar)

        netw_objects.extend([PInp, sPN, PInp_inh, sPNInh])

    if tr.syn_noise:

        if tr.syn_noise_type == 'additive':
            synEE_mod = '''%s 
                           %s''' % (tr.synEE_noise_add, tr.synEE_mod)

            synEI_mod = '''%s 
                           %s''' % (tr.synEE_noise_add, tr.synEE_mod)

        elif tr.syn_noise_type == 'multiplicative':
            synEE_mod = '''%s 
                           %s''' % (tr.synEE_noise_mult, tr.synEE_mod)

            synEI_mod = '''%s 
                           %s''' % (tr.synEE_noise_mult, tr.synEE_mod)

    else:
        synEE_mod = '''%s 
                       %s''' % (tr.synEE_static, tr.synEE_mod)

        synEI_mod = '''%s 
                       %s''' % (tr.synEE_static, tr.synEE_mod)

    if tr.scl_active:
        synEE_mod = '''%s
                       %s''' % (synEE_mod, tr.synEE_scl_mod)
        synEI_mod = '''%s
                       %s''' % (synEI_mod, tr.synEI_scl_mod)

    if tr.syn_cond_mode == 'exp':
        synEE_pre_mod = mod.synEE_pre_exp
    elif tr.syn_cond_mode == 'alpha':
        synEE_pre_mod = mod.synEE_pre_alpha
    elif tr.syn_cond_mode == 'biexp':
        synEE_pre_mod = mod.synEE_pre_biexp

    synEE_post_mod = mod.syn_post

    if tr.stdp_active:
        synEE_pre_mod = '''%s 
                            %s''' % (synEE_pre_mod, mod.syn_pre_STDP)
        synEE_post_mod = '''%s 
                            %s''' % (synEE_post_mod, mod.syn_post_STDP)

    if tr.synEE_rec:
        synEE_pre_mod = '''%s 
                            %s''' % (synEE_pre_mod, mod.synEE_pre_rec)
        synEE_post_mod = '''%s 
                            %s''' % (synEE_post_mod, mod.synEE_post_rec)

    # E<-E advanced synapse model
    SynEE = Synapses(target=GExc,
                     source=GExc,
                     model=synEE_mod,
                     on_pre=synEE_pre_mod,
                     on_post=synEE_post_mod,
                     namespace=namespace,
                     dt=tr.synEE_mod_dt)

    if tr.istdp_active and tr.istdp_type == 'dbexp':

        if tr.syn_cond_mode_EI == 'exp':
            EI_pre_mod = mod.synEI_pre_exp
        elif tr.syn_cond_mode_EI == 'alpha':
            EI_pre_mod = mod.synEI_pre_alpha
        elif tr.syn_cond_mode_EI == 'biexp':
            EI_pre_mod = mod.synEI_pre_biexp

        synEI_pre_mod = '''%s 
                            %s''' % (EI_pre_mod, mod.syn_pre_STDP)
        synEI_post_mod = '''%s 
                            %s''' % (mod.syn_post, mod.syn_post_STDP)

    elif tr.istdp_active and tr.istdp_type == 'sym':

        if tr.syn_cond_mode_EI == 'exp':
            EI_pre_mod = mod.synEI_pre_sym_exp
        elif tr.syn_cond_mode_EI == 'alpha':
            EI_pre_mod = mod.synEI_pre_sym_alpha
        elif tr.syn_cond_mode_EI == 'biexp':
            EI_pre_mod = mod.synEI_pre_sym_biexp

        synEI_pre_mod = '''%s 
                            %s''' % (EI_pre_mod, mod.syn_pre_STDP)
        synEI_post_mod = '''%s 
                            %s''' % (mod.synEI_post_sym, mod.syn_post_STDP)

    if tr.istdp_active and tr.synEI_rec:

        synEI_pre_mod = '''%s 
                            %s''' % (synEI_pre_mod, mod.synEI_pre_rec)
        synEI_post_mod = '''%s 
                            %s''' % (synEI_post_mod, mod.synEI_post_rec)

    if tr.istdp_active:
        SynEI = Synapses(target=GExc,
                         source=GInh,
                         model=synEI_mod,
                         on_pre=synEI_pre_mod,
                         on_post=synEI_post_mod,
                         namespace=namespace,
                         dt=tr.synEE_mod_dt)

    else:
        model = '''a : 1
                   syn_active : 1'''
        SynEI = Synapses(target=GExc,
                         source=GInh,
                         model=model,
                         on_pre='gi_post += a',
                         namespace=namespace)

    #other simple
    SynIE = Synapses(target=GInh,
                     source=GExc,
                     on_pre='ge_post += a_ie',
                     namespace=namespace)

    SynII = Synapses(target=GInh,
                     source=GInh,
                     on_pre='gi_post += a_ii',
                     namespace=namespace)

    sEE_src, sEE_tar = generate_full_connectivity(tr.N_e, same=True)
    SynEE.connect(i=sEE_src, j=sEE_tar)
    SynEE.syn_active = 0
    SynEE.taupre, SynEE.taupost = tr.taupre, tr.taupost

    if tr.istdp_active and tr.istrct_active:
        print('istrct active')
        sEI_src, sEI_tar = generate_full_connectivity(Nsrc=tr.N_i,
                                                      Ntar=tr.N_e,
                                                      same=False)
        SynEI.connect(i=sEI_src, j=sEI_tar)
        SynEI.syn_active = 0

    else:
        print('istrct not active')
        if tr.weight_mode == 'init':
            sEI_src, sEI_tar = generate_connections(tr.N_e, tr.N_i, tr.p_ei)
            # print('Index Zero will not get inhibition')
            # sEI_src, sEI_tar = np.array(sEI_src), np.array(sEI_tar)
            # sEI_src, sEI_tar = sEI_src[sEI_tar > 0],sEI_tar[sEI_tar > 0]

        elif tr.weight_mode == 'load':

            fpath = os.path.join(tr.basepath, tr.weight_path)

            with open(fpath + 'synei_a.p', 'rb') as pfile:
                synei_a_init = pickle.load(pfile)

            sEI_src, sEI_tar = synei_a_init['i'], synei_a_init['j']

        SynEI.connect(i=sEI_src, j=sEI_tar)

    if tr.istdp_active:
        SynEI.taupre, SynEI.taupost = tr.taupre_EI, tr.taupost_EI

    sIE_src, sIE_tar = generate_connections(tr.N_i, tr.N_e, tr.p_ie)
    sII_src, sII_tar = generate_connections(tr.N_i, tr.N_i, tr.p_ii, same=True)

    SynIE.connect(i=sIE_src, j=sIE_tar)
    SynII.connect(i=sII_src, j=sII_tar)

    tr.f_add_result('sEE_src', sEE_src)
    tr.f_add_result('sEE_tar', sEE_tar)
    tr.f_add_result('sIE_src', sIE_src)
    tr.f_add_result('sIE_tar', sIE_tar)
    tr.f_add_result('sEI_src', sEI_src)
    tr.f_add_result('sEI_tar', sEI_tar)
    tr.f_add_result('sII_src', sII_src)
    tr.f_add_result('sII_tar', sII_tar)

    if tr.syn_noise:
        SynEE.syn_sigma = tr.syn_sigma
        SynEE.run_regularly('a = clip(a,0,amax)',
                            when='after_groups',
                            name='SynEE_noise_clipper')

    if tr.syn_noise and tr.istdp_active:
        SynEI.syn_sigma = tr.syn_sigma
        SynEI.run_regularly('a = clip(a,0,amax)',
                            when='after_groups',
                            name='SynEI_noise_clipper')

    SynEE.insert_P = tr.insert_P
    SynEE.p_inactivate = tr.p_inactivate
    SynEE.stdp_active = 1
    print('Setting maximum EE weight threshold to ', tr.amax)
    SynEE.amax = tr.amax

    if tr.istdp_active:
        SynEI.insert_P = tr.insert_P_ei
        SynEI.p_inactivate = tr.p_inactivate_ei
        SynEI.stdp_active = 1
        SynEI.amax = tr.amax

    SynEE.syn_active, SynEE.a = init_synapses('EE', tr)
    SynEI.syn_active, SynEI.a = init_synapses('EI', tr)

    # recording of stdp in T4
    SynEE.stdp_rec_start = tr.T1 + tr.T2 + tr.T3
    SynEE.stdp_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.stdp_rec_T

    if tr.istdp_active:
        SynEI.stdp_rec_start = tr.T1 + tr.T2 + tr.T3
        SynEI.stdp_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.stdp_rec_T

    # synaptic scaling
    if tr.netw.config.scl_active:

        if tr.syn_scl_rec:
            SynEE.scl_rec_start = tr.T1 + tr.T2 + tr.T3
            SynEE.scl_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.scl_rec_T
        else:
            SynEE.scl_rec_start = T + 10 * second
            SynEE.scl_rec_max = T

        if tr.sig_ATotalMax == 0.:
            GExc.ANormTar = tr.ATotalMax
        else:
            GExc.ANormTar = np.random.normal(loc=tr.ATotalMax,
                                             scale=tr.sig_ATotalMax,
                                             size=tr.N_e)

        SynEE.summed_updaters['AsumEE_post']._clock = Clock(
            dt=tr.dt_synEE_scaling)
        synee_scaling = SynEE.run_regularly(tr.synEE_scaling,
                                            dt=tr.dt_synEE_scaling,
                                            when='end',
                                            name='synEE_scaling')

    if tr.istdp_active and tr.netw.config.iscl_active:

        if tr.syn_iscl_rec:
            SynEI.scl_rec_start = tr.T1 + tr.T2 + tr.T3
            SynEI.scl_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.scl_rec_T
        else:
            SynEI.scl_rec_start = T + 10 * second
            SynEI.scl_rec_max = T

        if tr.sig_iATotalMax == 0.:
            GExc.iANormTar = tr.iATotalMax
        else:
            GExc.iANormTar = np.random.normal(loc=tr.iATotalMax,
                                              scale=tr.sig_iATotalMax,
                                              size=tr.N_e)

        SynEI.summed_updaters['AsumEI_post']._clock = Clock(
            dt=tr.dt_synEE_scaling)

        synei_scaling = SynEI.run_regularly(tr.synEI_scaling,
                                            dt=tr.dt_synEE_scaling,
                                            when='end',
                                            name='synEI_scaling')

    # # intrinsic plasticity
    # if tr.netw.config.it_active:
    #     GExc.h_ip = tr.h_ip
    #     GExc.run_regularly(tr.intrinsic_mod, dt = tr.it_dt, when='end')

    # structural plasticity
    if tr.netw.config.strct_active:
        if tr.strct_mode == 'zero':
            if tr.turnover_rec:
                strct_mod = '''%s 
                                %s''' % (tr.strct_mod, tr.turnover_rec_mod)
            else:
                strct_mod = tr.strct_mod

            strctplst = SynEE.run_regularly(strct_mod,
                                            dt=tr.strct_dt,
                                            when='end',
                                            name='strct_plst_zero')

        elif tr.strct_mode == 'thrs':
            if tr.turnover_rec:
                strct_mod_thrs = '''%s 
                                %s''' % (tr.strct_mod_thrs,
                                         tr.turnover_rec_mod)
            else:
                strct_mod_thrs = tr.strct_mod_thrs

            strctplst = SynEE.run_regularly(strct_mod_thrs,
                                            dt=tr.strct_dt,
                                            when='end',
                                            name='strct_plst_thrs')

    if tr.istdp_active and tr.netw.config.istrct_active:
        if tr.strct_mode == 'zero':
            if tr.turnover_rec:
                strct_mod_EI = '''%s 
                                   %s''' % (tr.strct_mod,
                                            tr.turnoverEI_rec_mod)
            else:
                strct_mod_EI = tr.strct_mod

            strctplst_EI = SynEI.run_regularly(strct_mod_EI,
                                               dt=tr.strct_dt,
                                               when='end',
                                               name='strct_plst_EI')

        elif tr.strct_mode == 'thrs':
            raise NotImplementedError

    netw_objects.extend([SynEE, SynEI, SynIE, SynII])

    # keep track of the number of active synapses
    sum_target = NeuronGroup(1, 'c : 1 (shared)', dt=tr.csample_dt)

    sum_model = '''NSyn : 1 (constant)
                   c_post = (1.0*syn_active_pre)/NSyn : 1 (summed)'''
    sum_connection = Synapses(target=sum_target,
                              source=SynEE,
                              model=sum_model,
                              dt=tr.csample_dt,
                              name='get_active_synapse_count')
    sum_connection.connect()
    sum_connection.NSyn = tr.N_e * (tr.N_e - 1)

    if tr.adjust_insertP:
        # homeostatically adjust growth rate
        growth_updater = Synapses(sum_target, SynEE)
        growth_updater.run_regularly('insert_P_post *= 0.1/c_pre',
                                     when='after_groups',
                                     dt=tr.csample_dt,
                                     name='update_insP')
        growth_updater.connect(j='0')

        netw_objects.extend([sum_target, sum_connection, growth_updater])

    if tr.istdp_active and tr.istrct_active:

        # keep track of the number of active synapses
        sum_target_EI = NeuronGroup(1, 'c : 1 (shared)', dt=tr.csample_dt)

        sum_model_EI = '''NSyn : 1 (constant)
                          c_post = (1.0*syn_active_pre)/NSyn : 1 (summed)'''
        sum_connection_EI = Synapses(target=sum_target_EI,
                                     source=SynEI,
                                     model=sum_model_EI,
                                     dt=tr.csample_dt,
                                     name='get_active_synapse_count_EI')
        sum_connection_EI.connect()
        sum_connection_EI.NSyn = tr.N_e * tr.N_i

        if tr.adjust_EI_insertP:
            # homeostatically adjust growth rate
            growth_updater_EI = Synapses(sum_target_EI, SynEI)
            growth_updater_EI.run_regularly('insert_P_post *= 0.1/c_pre',
                                            when='after_groups',
                                            dt=tr.csample_dt,
                                            name='update_insP_EI')
            growth_updater_EI.connect(j='0')

            netw_objects.extend(
                [sum_target_EI, sum_connection_EI, growth_updater_EI])

    # -------------- recording ------------------

    GExc_recvars = []
    if tr.memtraces_rec:
        GExc_recvars.append('V')
    if tr.vttraces_rec:
        GExc_recvars.append('Vt')
    if tr.getraces_rec:
        GExc_recvars.append('ge')
    if tr.gitraces_rec:
        GExc_recvars.append('gi')
    if tr.gfwdtraces_rec and tr.external_mode == 'poisson':
        GExc_recvars.append('gfwd')

    GInh_recvars = GExc_recvars

    GExc_stat = StateMonitor(GExc,
                             GExc_recvars,
                             record=list(range(tr.nrec_GExc_stat)),
                             dt=tr.GExc_stat_dt)
    GInh_stat = StateMonitor(GInh,
                             GInh_recvars,
                             record=list(range(tr.nrec_GInh_stat)),
                             dt=tr.GInh_stat_dt)

    # SynEE stat
    SynEE_recvars = []
    if tr.synee_atraces_rec:
        SynEE_recvars.append('a')
    if tr.synee_activetraces_rec:
        SynEE_recvars.append('syn_active')
    if tr.synee_Apretraces_rec:
        SynEE_recvars.append('Apre')
    if tr.synee_Aposttraces_rec:
        SynEE_recvars.append('Apost')

    SynEE_stat = StateMonitor(SynEE,
                              SynEE_recvars,
                              record=range(tr.n_synee_traces_rec),
                              when='end',
                              dt=tr.synEE_stat_dt)

    if tr.istdp_active:
        # SynEI stat
        SynEI_recvars = []
        if tr.synei_atraces_rec:
            SynEI_recvars.append('a')
        if tr.synei_activetraces_rec:
            SynEI_recvars.append('syn_active')
        if tr.synei_Apretraces_rec:
            SynEI_recvars.append('Apre')
        if tr.synei_Aposttraces_rec:
            SynEI_recvars.append('Apost')

        SynEI_stat = StateMonitor(SynEI,
                                  SynEI_recvars,
                                  record=range(tr.n_synei_traces_rec),
                                  when='end',
                                  dt=tr.synEI_stat_dt)
        netw_objects.append(SynEI_stat)

    if tr.adjust_insertP:

        C_stat = StateMonitor(sum_target,
                              'c',
                              dt=tr.csample_dt,
                              record=[0],
                              when='end')
        insP_stat = StateMonitor(SynEE,
                                 'insert_P',
                                 dt=tr.csample_dt,
                                 record=[0],
                                 when='end')
        netw_objects.extend([C_stat, insP_stat])

    if tr.istdp_active and tr.adjust_EI_insertP:

        C_EI_stat = StateMonitor(sum_target_EI,
                                 'c',
                                 dt=tr.csample_dt,
                                 record=[0],
                                 when='end')
        insP_EI_stat = StateMonitor(SynEI,
                                    'insert_P',
                                    dt=tr.csample_dt,
                                    record=[0],
                                    when='end')
        netw_objects.extend([C_EI_stat, insP_EI_stat])

    GExc_spks = SpikeMonitor(GExc)
    GInh_spks = SpikeMonitor(GInh)

    GExc_rate = PopulationRateMonitor(GExc)
    GInh_rate = PopulationRateMonitor(GInh)

    if tr.external_mode == 'poisson':
        PInp_spks = SpikeMonitor(PInp)
        PInp_rate = PopulationRateMonitor(PInp)
        netw_objects.extend([PInp_spks, PInp_rate])

    if tr.synee_a_nrecpoints == 0 or tr.sim.T2 == 0 * second:
        SynEE_a_dt = 2 * (tr.T1 + tr.T2 + tr.T3 + tr.T4 + tr.T5)
    else:
        SynEE_a_dt = tr.sim.T2 / tr.synee_a_nrecpoints

        # make sure that choice of SynEE_a_dt does lead
        # to execessively many recordings - this can
        # happen if t1 >> t2.
        estm_nrecs = int(T / SynEE_a_dt)
        if estm_nrecs > 3 * tr.synee_a_nrecpoints:
            print('''Estimated number of EE weight recordings (%d)
            exceeds desired number (%d), increasing 
            SynEE_a_dt''' % (estm_nrecs, tr.synee_a_nrecpoints))

            SynEE_a_dt = T / tr.synee_a_nrecpoints

    SynEE_a = StateMonitor(SynEE, ['a', 'syn_active'],
                           record=range(tr.N_e * (tr.N_e - 1)),
                           dt=SynEE_a_dt,
                           when='end',
                           order=100)

    if tr.istrct_active:
        record_range = range(tr.N_e * tr.N_i)
    else:
        record_range = range(len(sEI_src))

    if tr.synei_a_nrecpoints > 0 and tr.sim.T2 > 0 * second:
        SynEI_a_dt = tr.sim.T2 / tr.synei_a_nrecpoints

        estm_nrecs = int(T / SynEI_a_dt)
        if estm_nrecs > 3 * tr.synei_a_nrecpoints:
            print('''Estimated number of EI weight recordings
            (%d) exceeds desired number (%d), increasing 
            SynEI_a_dt''' % (estm_nrecs, tr.synei_a_nrecpoints))

            SynEI_a_dt = T / tr.synei_a_nrecpoints

        SynEI_a = StateMonitor(SynEI, ['a', 'syn_active'],
                               record=record_range,
                               dt=SynEI_a_dt,
                               when='end',
                               order=100)

        netw_objects.append(SynEI_a)

    netw_objects.extend([
        GExc_stat, GInh_stat, SynEE_stat, SynEE_a, GExc_spks, GInh_spks,
        GExc_rate, GInh_rate
    ])

    if (tr.synEEdynrec
            and (2 * tr.syndynrec_npts * tr.syndynrec_dt < tr.sim.T2)):
        SynEE_dynrec = StateMonitor(SynEE, ['a'],
                                    record=range(tr.N_e * (tr.N_e - 1)),
                                    dt=tr.syndynrec_dt,
                                    name='SynEE_dynrec',
                                    when='end',
                                    order=100)
        SynEE_dynrec.active = False
        netw_objects.extend([SynEE_dynrec])

    if (tr.synEIdynrec
            and (2 * tr.syndynrec_npts * tr.syndynrec_dt < tr.sim.T2)):
        SynEI_dynrec = StateMonitor(SynEI, ['a'],
                                    record=record_range,
                                    dt=tr.syndynrec_dt,
                                    name='SynEI_dynrec',
                                    when='end',
                                    order=100)
        SynEI_dynrec.active = False
        netw_objects.extend([SynEI_dynrec])

    net = Network(*netw_objects)

    def set_active(*argv):
        for net_object in argv:
            net_object.active = True

    def set_inactive(*argv):
        for net_object in argv:
            net_object.active = False

    ### Simulation periods

    # --------- T1 ---------
    # initial recording period,
    # all recorders active

    T1T3_recorders = [
        GExc_spks, GInh_spks, SynEE_stat, GExc_stat, GInh_stat, GExc_rate,
        GInh_rate
    ]

    if tr.istdp_active:
        T1T3_recorders.append(SynEI_stat)

    set_active(*T1T3_recorders)

    if tr.external_mode == 'poisson':
        set_active(PInp_spks, PInp_rate)

    net.run(tr.sim.T1, report='text', report_period=300 * second, profile=True)

    # --------- T2 ---------
    # main simulation period
    # only active recordings are:
    #   1) turnover 2) C_stat 3) SynEE_a

    set_inactive(*T1T3_recorders)

    if tr.T2_spks_rec:
        set_active(GExc_spks, GInh_spks)

    if tr.external_mode == 'poisson':
        set_inactive(PInp_spks, PInp_rate)

    run_T2_syndynrec(net, tr, netw_objects)

    # --------- T3 ---------
    # second recording period,
    # all recorders active

    set_active(*T1T3_recorders)

    if tr.external_mode == 'poisson':
        set_active(PInp_spks, PInp_rate)

    run_T3_split(net, tr)

    # --------- T4 ---------
    # record STDP and scaling weight changes to file
    # through the cpp models

    set_inactive(*T1T3_recorders)

    if tr.external_mode == 'poisson':
        set_inactive(PInp_spks, PInp_rate)

    run_T4(net, tr)

    # --------- T5 ---------
    # freeze network and record Exc spikes
    # for cross correlations

    if tr.scl_active:
        synee_scaling.active = False
    if tr.istdp_active and tr.netw.config.iscl_active:
        synei_scaling.active = False
    if tr.strct_active:
        strctplst.active = False
    if tr.istdp_active and tr.istrct_active:
        strctplst_EI.active = False
    SynEE.stdp_active = 0
    if tr.istdp_active:
        SynEI.stdp_active = 0

    set_active(GExc_rate, GInh_rate)
    set_active(GExc_spks, GInh_spks)

    run_T5(net, tr)

    SynEE_a.record_single_timestep()
    if tr.synei_a_nrecpoints > 0 and tr.sim.T2 > 0. * second:
        SynEI_a.record_single_timestep()

    device.build(directory='builds/%.4d' % (tr.v_idx),
                 clean=True,
                 compile=True,
                 run=True,
                 debug=False)

    # -----------------------------------------

    # save monitors as raws in build directory
    raw_dir = 'builds/%.4d/raw/' % (tr.v_idx)

    if not os.path.exists(raw_dir):
        os.makedirs(raw_dir)

    with open(raw_dir + 'namespace.p', 'wb') as pfile:
        pickle.dump(namespace, pfile)

    with open(raw_dir + 'gexc_stat.p', 'wb') as pfile:
        pickle.dump(GExc_stat.get_states(), pfile)
    with open(raw_dir + 'ginh_stat.p', 'wb') as pfile:
        pickle.dump(GInh_stat.get_states(), pfile)

    with open(raw_dir + 'synee_stat.p', 'wb') as pfile:
        pickle.dump(SynEE_stat.get_states(), pfile)

    if tr.istdp_active:
        with open(raw_dir + 'synei_stat.p', 'wb') as pfile:
            pickle.dump(SynEI_stat.get_states(), pfile)

    if ((tr.synEEdynrec or tr.synEIdynrec)
            and (2 * tr.syndynrec_npts * tr.syndynrec_dt < tr.sim.T2)):

        if tr.synEEdynrec:
            with open(raw_dir + 'syneedynrec.p', 'wb') as pfile:
                pickle.dump(SynEE_dynrec.get_states(), pfile)
        if tr.synEIdynrec:
            with open(raw_dir + 'syneidynrec.p', 'wb') as pfile:
                pickle.dump(SynEI_dynrec.get_states(), pfile)

    with open(raw_dir + 'synee_a.p', 'wb') as pfile:
        SynEE_a_states = SynEE_a.get_states()
        if tr.crs_crrs_rec:
            SynEE_a_states['i'] = list(SynEE.i)
            SynEE_a_states['j'] = list(SynEE.j)
        pickle.dump(SynEE_a_states, pfile)

    if tr.synei_a_nrecpoints > 0 and tr.sim.T2 > 0. * second:
        with open(raw_dir + 'synei_a.p', 'wb') as pfile:
            SynEI_a_states = SynEI_a.get_states()
            if tr.crs_crrs_rec:
                SynEI_a_states['i'] = list(SynEI.i)
                SynEI_a_states['j'] = list(SynEI.j)
            pickle.dump(SynEI_a_states, pfile)

    if tr.adjust_insertP:
        with open(raw_dir + 'c_stat.p', 'wb') as pfile:
            pickle.dump(C_stat.get_states(), pfile)

        with open(raw_dir + 'insP_stat.p', 'wb') as pfile:
            pickle.dump(insP_stat.get_states(), pfile)

    if tr.istdp_active and tr.adjust_EI_insertP:
        with open(raw_dir + 'c_EI_stat.p', 'wb') as pfile:
            pickle.dump(C_EI_stat.get_states(), pfile)

        with open(raw_dir + 'insP_EI_stat.p', 'wb') as pfile:
            pickle.dump(insP_EI_stat.get_states(), pfile)

    with open(raw_dir + 'gexc_spks.p', 'wb') as pfile:
        pickle.dump(GExc_spks.get_states(), pfile)
    with open(raw_dir + 'ginh_spks.p', 'wb') as pfile:
        pickle.dump(GInh_spks.get_states(), pfile)

    if tr.external_mode == 'poisson':
        with open(raw_dir + 'pinp_spks.p', 'wb') as pfile:
            pickle.dump(PInp_spks.get_states(), pfile)

    with open(raw_dir + 'gexc_rate.p', 'wb') as pfile:
        pickle.dump(GExc_rate.get_states(), pfile)
        if tr.rates_rec:
            pickle.dump(GExc_rate.smooth_rate(width=25 * ms), pfile)
    with open(raw_dir + 'ginh_rate.p', 'wb') as pfile:
        pickle.dump(GInh_rate.get_states(), pfile)
        if tr.rates_rec:
            pickle.dump(GInh_rate.smooth_rate(width=25 * ms), pfile)

    if tr.external_mode == 'poisson':
        with open(raw_dir + 'pinp_rate.p', 'wb') as pfile:
            pickle.dump(PInp_rate.get_states(), pfile)
            if tr.rates_rec:
                pickle.dump(PInp_rate.smooth_rate(width=25 * ms), pfile)

    # ----------------- add raw data ------------------------
    fpath = 'builds/%.4d/' % (tr.v_idx)

    from pathlib import Path

    Path(fpath + 'turnover').touch()
    turnover_data = np.genfromtxt(fpath + 'turnover', delimiter=',')
    os.remove(fpath + 'turnover')

    with open(raw_dir + 'turnover.p', 'wb') as pfile:
        pickle.dump(turnover_data, pfile)

    Path(fpath + 'turnover_EI').touch()
    turnover_EI_data = np.genfromtxt(fpath + 'turnover_EI', delimiter=',')
    os.remove(fpath + 'turnover_EI')

    with open(raw_dir + 'turnover_EI.p', 'wb') as pfile:
        pickle.dump(turnover_EI_data, pfile)

    Path(fpath + 'spk_register').touch()
    spk_register_data = np.genfromtxt(fpath + 'spk_register', delimiter=',')
    os.remove(fpath + 'spk_register')

    with open(raw_dir + 'spk_register.p', 'wb') as pfile:
        pickle.dump(spk_register_data, pfile)

    Path(fpath + 'spk_register_EI').touch()
    spk_register_EI_data = np.genfromtxt(fpath + 'spk_register_EI',
                                         delimiter=',')
    os.remove(fpath + 'spk_register_EI')

    with open(raw_dir + 'spk_register_EI.p', 'wb') as pfile:
        pickle.dump(spk_register_EI_data, pfile)

    Path(fpath + 'scaling_deltas').touch()
    scaling_deltas_data = np.genfromtxt(fpath + 'scaling_deltas',
                                        delimiter=',')
    os.remove(fpath + 'scaling_deltas')

    with open(raw_dir + 'scaling_deltas.p', 'wb') as pfile:
        pickle.dump(scaling_deltas_data, pfile)

    Path(fpath + 'scaling_deltas_EI').touch()
    scaling_deltas_data = np.genfromtxt(fpath + 'scaling_deltas_EI',
                                        delimiter=',')
    os.remove(fpath + 'scaling_deltas_EI')

    with open(raw_dir + 'scaling_deltas_EI.p', 'wb') as pfile:
        pickle.dump(scaling_deltas_data, pfile)

    with open(raw_dir + 'profiling_summary.txt', 'w+') as tfile:
        tfile.write(str(profiling_summary(net)))

    # --------------- cross-correlations ---------------------

    if tr.crs_crrs_rec:

        GExc_spks = GExc_spks.get_states()
        synee_a = SynEE_a_states
        wsize = 100 * pq.ms

        for binsize in [1 * pq.ms, 2 * pq.ms, 5 * pq.ms]:

            wlen = int(wsize / binsize)

            ts, idxs = GExc_spks['t'], GExc_spks['i']
            idxs = idxs[ts > tr.T1 + tr.T2 + tr.T3 + tr.T4]
            ts = ts[ts > tr.T1 + tr.T2 + tr.T3 + tr.T4]
            ts = ts - (tr.T1 + tr.T2 + tr.T3 + tr.T4)

            sts = [
                neo.SpikeTrain(ts[idxs == i] / second * pq.s,
                               t_stop=tr.T5 / second * pq.s)
                for i in range(tr.N_e)
            ]

            crs_crrs, syn_a = [], []

            for f, (i, j) in enumerate(zip(synee_a['i'], synee_a['j'])):
                if synee_a['syn_active'][-1][f] == 1:

                    crs_crr, cbin = cch(BinnedSpikeTrain(sts[i],
                                                         binsize=binsize),
                                        BinnedSpikeTrain(sts[j],
                                                         binsize=binsize),
                                        cross_corr_coef=True,
                                        border_correction=True,
                                        window=(-1 * wlen, wlen))

                    crs_crrs.append(list(np.array(crs_crr).T[0]))
                    syn_a.append(synee_a['a'][-1][f])

            fname = 'crs_crrs_wsize%dms_binsize%fms_full' % (wsize / pq.ms,
                                                             binsize / pq.ms)

            df = {
                'cbin': cbin,
                'crs_crrs': np.array(crs_crrs),
                'syn_a': np.array(syn_a),
                'binsize': binsize,
                'wsize': wsize,
                'wlen': wlen
            }

            with open('builds/%.4d/raw/' % (tr.v_idx) + fname + '.p',
                      'wb') as pfile:
                pickle.dump(df, pfile)

        GInh_spks = GInh_spks.get_states()
        synei_a = SynEI_a_states
        wsize = 100 * pq.ms

        for binsize in [1 * pq.ms, 2 * pq.ms, 5 * pq.ms]:

            wlen = int(wsize / binsize)

            ts_E, idxs_E = GExc_spks['t'], GExc_spks['i']
            idxs_E = idxs_E[ts_E > tr.T1 + tr.T2 + tr.T3 + tr.T4]
            ts_E = ts_E[ts_E > tr.T1 + tr.T2 + tr.T3 + tr.T4]
            ts_E = ts_E - (tr.T1 + tr.T2 + tr.T3 + tr.T4)

            ts_I, idxs_I = GInh_spks['t'], GInh_spks['i']
            idxs_I = idxs_I[ts_I > tr.T1 + tr.T2 + tr.T3 + tr.T4]
            ts_I = ts_I[ts_I > tr.T1 + tr.T2 + tr.T3 + tr.T4]
            ts_I = ts_I - (tr.T1 + tr.T2 + tr.T3 + tr.T4)

            sts_E = [
                neo.SpikeTrain(ts_E[idxs_E == i] / second * pq.s,
                               t_stop=tr.T5 / second * pq.s)
                for i in range(tr.N_e)
            ]

            sts_I = [
                neo.SpikeTrain(ts_I[idxs_I == i] / second * pq.s,
                               t_stop=tr.T5 / second * pq.s)
                for i in range(tr.N_i)
            ]

            crs_crrs, syn_a = [], []

            for f, (i, j) in enumerate(zip(synei_a['i'], synei_a['j'])):
                if synei_a['syn_active'][-1][f] == 1:

                    crs_crr, cbin = cch(BinnedSpikeTrain(sts_I[i],
                                                         binsize=binsize),
                                        BinnedSpikeTrain(sts_E[j],
                                                         binsize=binsize),
                                        cross_corr_coef=True,
                                        border_correction=True,
                                        window=(-1 * wlen, wlen))

                    crs_crrs.append(list(np.array(crs_crr).T[0]))
                    syn_a.append(synei_a['a'][-1][f])

            fname = 'EI_crrs_wsize%dms_binsize%fms_full' % (wsize / pq.ms,
                                                            binsize / pq.ms)

            df = {
                'cbin': cbin,
                'crs_crrs': np.array(crs_crrs),
                'syn_a': np.array(syn_a),
                'binsize': binsize,
                'wsize': wsize,
                'wlen': wlen
            }

            with open('builds/%.4d/raw/' % (tr.v_idx) + fname + '.p',
                      'wb') as pfile:
                pickle.dump(df, pfile)

    # -----------------  clean up  ---------------------------
    shutil.rmtree('builds/%.4d/results/' % (tr.v_idx))
    shutil.rmtree('builds/%.4d/static_arrays/' % (tr.v_idx))
    shutil.rmtree('builds/%.4d/brianlib/' % (tr.v_idx))
    shutil.rmtree('builds/%.4d/code_objects/' % (tr.v_idx))

    # ---------------- plot results --------------------------

    #os.chdir('./analysis/file_based/')

    if tr.istdp_active:
        from src.analysis.overview_winh import overview_figure
        overview_figure('builds/%.4d' % (tr.v_idx), namespace)
    else:
        from src.analysis.overview import overview_figure
        overview_figure('builds/%.4d' % (tr.v_idx), namespace)

    from src.analysis.synw_fb import synw_figure
    synw_figure('builds/%.4d' % (tr.v_idx), namespace)
    if tr.istdp_active:
        synw_figure('builds/%.4d' % (tr.v_idx), namespace, connections='EI')

    from src.analysis.synw_log_fb import synw_log_figure
    synw_log_figure('builds/%.4d' % (tr.v_idx), namespace)
    if tr.istdp_active:
        synw_log_figure('builds/%.4d' % (tr.v_idx),
                        namespace,
                        connections='EI')
コード例 #9
0
for dta, sts in zip(['spinnaker', 'nest'], [sts_spinnaker, sts_nest]):
    for calc_i in range(task_starts_idx[job_parameter],
                        task_stop_idx[job_parameter]):
        # save neuron i,j index
        ni = all_combos_unit_i[calc_i]
        nj = all_combos_unit_j[calc_i]

        cc[dta]['unit_i'][calc_i] = ni
        cc[dta]['unit_j'][calc_i] = nj

        print("Cross-correlating %i and %i" % (ni, nj))

        # original CCH
        cco = stc.cch(conv.BinnedSpikeTrain(sts[ni], lag_res),
                      conv.BinnedSpikeTrain(sts[nj], lag_res),
                      window=[-max_lag_bins, max_lag_bins])[0]
        cc[dta]['original'][calc_i] = cco.magnitude
        cc[dta]['times_ms'][calc_i] = cco.times.rescale(pq.ms).magnitude

        # extract measure
        ccom = cch_measure(cco)
        cc[dta]['original_measure'][calc_i] = ccom

        surr_i = elephant.spike_train_surrogates.dither_spikes(sts[ni],
                                                               dither=50. *
                                                               pq.ms,
                                                               n=num_surrs)
        surr_j = elephant.spike_train_surrogates.dither_spikes(sts[nj],
                                                               dither=50. *
                                                               pq.ms,
コード例 #10
0
    def test_window(self):
        '''Test if the window parameter is correctly interpreted.'''
        cch_win, bin_ids = sc.cch(self.binned_st1,
                                  self.binned_st2,
                                  window=[-30, 30])
        cch_win_mem, bin_ids_mem = sc.cch(self.binned_st1,
                                          self.binned_st2,
                                          window=[-30, 30])

        assert_array_equal(bin_ids, np.arange(-30, 31, 1))
        assert_array_equal((bin_ids - 0.5) * self.binned_st1.binsize,
                           cch_win.times)

        assert_array_equal(bin_ids_mem, np.arange(-30, 31, 1))
        assert_array_equal((bin_ids_mem - 0.5) * self.binned_st1.binsize,
                           cch_win.times)

        assert_array_equal(cch_win, cch_win_mem)
        cch_unclipped, _ = sc.cross_correlation_histogram(self.binned_st1,
                                                          self.binned_st2,
                                                          window='full',
                                                          binary=False)
        assert_array_equal(cch_win, cch_unclipped[19:80])

        cch_win, bin_ids = sc.cch(self.binned_st1,
                                  self.binned_st2,
                                  window=[-25 * pq.ms, 25 * pq.ms])
        cch_win_mem, bin_ids_mem = sc.cch(self.binned_st1,
                                          self.binned_st2,
                                          window=[-25 * pq.ms, 25 * pq.ms],
                                          method='memory')

        assert_array_equal(bin_ids, np.arange(-25, 26, 1))
        assert_array_equal((bin_ids - 0.5) * self.binned_st1.binsize,
                           cch_win.times)

        assert_array_equal(bin_ids_mem, np.arange(-25, 26, 1))
        assert_array_equal((bin_ids_mem - 0.5) * self.binned_st1.binsize,
                           cch_win.times)

        assert_array_equal(cch_win, cch_win_mem)

        _, bin_ids = sc.cch(self.binned_st1, self.binned_st2, window=[20, 30])
        _, bin_ids_mem = sc.cch(self.binned_st1,
                                self.binned_st2,
                                window=[20, 30],
                                method='memory')

        assert_array_equal(bin_ids, np.arange(20, 31, 1))
        assert_array_equal(bin_ids_mem, np.arange(20, 31, 1))

        _, bin_ids = sc.cch(self.binned_st1,
                            self.binned_st2,
                            window=[-30, -20])

        _, bin_ids_mem = sc.cch(self.binned_st1,
                                self.binned_st2,
                                window=[-30, -20],
                                method='memory')

        assert_array_equal(bin_ids, np.arange(-30, -19, 1))
        assert_array_equal(bin_ids_mem, np.arange(-30, -19, 1))

        # Cehck for wrong assignments to the window parameter
        self.assertRaises(ValueError,
                          sc.cross_correlation_histogram,
                          self.binned_st1,
                          self.binned_st2,
                          window=[-60, 50])
        self.assertRaises(ValueError,
                          sc.cross_correlation_histogram,
                          self.binned_st1,
                          self.binned_st2,
                          window=[-60, 50],
                          method='memory')

        self.assertRaises(ValueError,
                          sc.cross_correlation_histogram,
                          self.binned_st1,
                          self.binned_st2,
                          window=[-50, 60])
        self.assertRaises(ValueError,
                          sc.cross_correlation_histogram,
                          self.binned_st1,
                          self.binned_st2,
                          window=[-50, 60],
                          method='memory')

        self.assertRaises(ValueError,
                          sc.cross_correlation_histogram,
                          self.binned_st1,
                          self.binned_st2,
                          window=[-25.5 * pq.ms, 25 * pq.ms])
        self.assertRaises(ValueError,
                          sc.cross_correlation_histogram,
                          self.binned_st1,
                          self.binned_st2,
                          window=[-25.5 * pq.ms, 25 * pq.ms],
                          method='memory')

        self.assertRaises(ValueError,
                          sc.cross_correlation_histogram,
                          self.binned_st1,
                          self.binned_st2,
                          window=[-25 * pq.ms, 25.5 * pq.ms])
        self.assertRaises(ValueError,
                          sc.cross_correlation_histogram,
                          self.binned_st1,
                          self.binned_st2,
                          window=[-25 * pq.ms, 25.5 * pq.ms],
                          method='memory')

        self.assertRaises(ValueError,
                          sc.cross_correlation_histogram,
                          self.binned_st1,
                          self.binned_st2,
                          window=[-60 * pq.ms, 50 * pq.ms])
        self.assertRaises(ValueError,
                          sc.cross_correlation_histogram,
                          self.binned_st1,
                          self.binned_st2,
                          window=[-60 * pq.ms, 50 * pq.ms],
                          method='memory')

        self.assertRaises(ValueError,
                          sc.cross_correlation_histogram,
                          self.binned_st1,
                          self.binned_st2,
                          window=[-50 * pq.ms, 60 * pq.ms])
        self.assertRaises(ValueError,
                          sc.cross_correlation_histogram,
                          self.binned_st1,
                          self.binned_st2,
                          window=[-50 * pq.ms, 60 * pq.ms],
                          method='memory')