示例#1
0
    def _store(self):
        store_dict = {}
        for key in self._data:
            val = self._data[key]
            if isinstance(val, Quantity):

                if self._storage_mode == BrianResult.STRING_MODE:

                    valstr = val.in_best_unit(python_code=True)
                    store_dict[key + BrianResult.IDENTIFIER] = ObjectTable(
                        data={'data': [valstr]})

                elif self._storage_mode == BrianResult.FLOAT_MODE:
                    unitstr = repr(get_unit_fast(val))
                    value = float(val)
                    store_dict[key + BrianResult.IDENTIFIER] = ObjectTable(
                        data={'value': [value], 'unit': [unitstr]})

                else:
                    raise RuntimeError('You shall not pass!')

            else:
                store_dict[key] = val

        return store_dict
示例#2
0
    def check_state_spike_monitor(self, res, monitor):
        self.assertTrue(comp.nested_equal(monitor.delay, res.delay))
        self.assertTrue(comp.nested_equal(monitor.nspikes, res.nspikes))
        self.assertTrue(comp.nested_equal(str(monitor.source), res.source))
        self.assertTrue(comp.nested_equal(monitor._varnames, res.varnames))

        self.assertEqual('second', res.spiketimes_unit)

        if res.v_storage_mode == BrianMonitorResult.TABLE_MODE:
            spike_frame = res.spikes
            spiked_list=sorted(list(set(spike_frame['neuron'].to_dict().values())))
            self.assertEqual(spiked_list, res.neurons_with_spikes)
            for idx,val_tuple in enumerate(monitor.spikes):
                neuron = val_tuple[0]
                time = val_tuple[1]
                vals = val_tuple[2:]

                self.assertEqual(neuron, spike_frame['neuron'][idx])

                self.assertEqual(float(time), spike_frame['spiketimes'][idx])

                for idx_var, varname in enumerate(res.varnames):
                    val = vals[idx_var]
                    self.assertEqual(float(val),spike_frame[varname][idx])



        elif res.v_storage_mode == BrianMonitorResult.ARRAY_MODE:

            self.assertTrue('%0' in res.format_string and 'd' in res.format_string)

            spiked_set=set()
            for item_name in res:

                if item_name.startswith('spiketimes') and not item_name.endswith('unit'):
                    neuron_id = int(item_name.split('_')[-1])
                    spiked_set.add(neuron_id)

                    times = monitor.times(neuron_id)
                    self.assertTrue(comp.nested_equal(times,res[item_name]))

                for varname in res.varnames:
                    if item_name.startswith(varname) and not item_name.endswith('unit'):
                        neuron_id =int(item_name.split('_')[-1])
                        values = monitor.values(varname,neuron_id)


                        # Remove units:
                        self.assertTrue(comp.nested_equal(values,res[item_name]))

            spiked_list = sorted(list(spiked_set))
            self.assertEqual(spiked_list, res.neurons_with_spikes)
        else:
            raise RuntimeError('You shall not pass!')


        # Check Units
        for idx,varname in enumerate(monitor._varnames):
            unit = repr(get_unit_fast(monitor.spikes[0][idx+2]))
            self.assertTrue(unit,res[varname+'_unit'])
示例#3
0
    def _store(self):

        if isinstance(self._data,Quantity):
            store_dict={}

            if self._storage_mode == BrianParameter.STRING_MODE:

                valstr = self._data.in_best_unit(python_code=True)
                store_dict['data'+BrianParameter.IDENTIFIER] = ObjectTable(data={'data':[valstr]})


                if self.f_has_range():
                    valstr_list = []
                    for val in self._explored_range:
                        valstr = val.in_best_unit(python_code=True)
                        valstr_list.append(valstr)


                    store_dict['explored_data'+BrianParameter.IDENTIFIER] = \
                        ObjectTable(data={'data':valstr_list})

            elif self._storage_mode == BrianParameter.FLOAT_MODE:
                unitstr = repr(get_unit_fast(self._data))
                value = float(self._data)
                store_dict['data'+BrianParameter.IDENTIFIER] = ObjectTable(data={'value':[value],
                                                       'unit':[unitstr]})

                if self.f_has_range():
                    value_list = []
                    for val in self._explored_range:
                        value = float(val)
                        value_list.append(value)


                    store_dict['explored_data'+BrianParameter.IDENTIFIER] = \
                        ObjectTable(data={'value':value_list})

            else:
                raise RuntimeError('You shall not pass!')

            self._locked = True

            return store_dict
        else:
            return super(BrianParameter,self)._store()
示例#4
0
    def _store(self):

        if isinstance(self._data, Quantity):
            store_dict = {}

            if self._storage_mode == BrianParameter.STRING_MODE:

                valstr = self._data.in_best_unit(python_code=True)
                store_dict['data' + BrianParameter.IDENTIFIER] = ObjectTable(
                    data={'data': [valstr]})

                if self.f_has_range():
                    valstr_list = []
                    for val in self._explored_range:
                        valstr = val.in_best_unit(python_code=True)
                        valstr_list.append(valstr)

                    store_dict['explored_data' + BrianParameter.IDENTIFIER] = \
                        ObjectTable(data={'data': valstr_list})

            elif self._storage_mode == BrianParameter.FLOAT_MODE:
                unitstr = repr(get_unit_fast(self._data))
                value = float(self._data)
                store_dict['data' + BrianParameter.IDENTIFIER] = ObjectTable(
                    data={'value': [value], 'unit': [unitstr]})

                if self.f_has_range():
                    value_list = []
                    for val in self._explored_range:
                        value = float(val)
                        value_list.append(value)

                    store_dict['explored_data' + BrianParameter.IDENTIFIER] = \
                        ObjectTable(data={'value': value_list})

            else:
                raise RuntimeError('You shall not pass!')

            self._locked = True

            return store_dict
        else:
            return super(BrianParameter, self)._store()
示例#5
0
    def _extract_state_spike_monitor(self,monitor):

        self.f_set(source = str(monitor.source))

        varnames = monitor._varnames
        if not isinstance(varnames, tuple) :
            varnames = (varnames,)

        for idx,varname in enumerate(varnames):
            unit = repr(get_unit_fast(monitor.spikes[0][idx+2]))
            self.f_set(**{varname+'_unit':unit})


        self.f_set(varnames = varnames)

        #self.f_set(record = monitor.record)
        self.f_set(delay=monitor.delay)
        self.f_set(nspikes = monitor.nspikes)
        self.f_set(spiketimes_unit = 'second')


        if self._storage_mode==BrianMonitorResult.TABLE_MODE:
            spike_dict={}

            if len(monitor.spikes)>0:
                zip_lists = zip(*monitor.spikes)
                time_list = zip_lists[1]

                nounit_list = [np.float64(time) for time in time_list]

                spike_dict['spiketimes'] = nounit_list
                spike_dict['neuron'] = list(zip_lists[0])

                spiked_neurons = sorted(list(set(spike_dict['neuron'])))
                if spiked_neurons:
                    self.f_set(neurons_with_spikes=spiked_neurons)

                    count = 2
                    for varname in varnames:

                        var_list = list(zip_lists[count])

                        nounit_list = [np.float64(var) for var in var_list]
                        spike_dict[varname] = nounit_list
                        count += 1

                    self.f_set(spikes=pd.DataFrame(data=spike_dict))

        elif self._storage_mode==BrianMonitorResult.ARRAY_MODE:

                format_string = self._get_format_string(monitor)
                self.f_set(format_string=format_string)

                spiked_neurons = set()

                for neuron in range(len(monitor.source)):
                    spikes = monitor.times(neuron)
                    if len(spikes)>0:

                        spiked_neurons.add(neuron)

                        key = 'spiketimes_' + format_string % neuron
                        self.f_set(**{key:spikes})

                spiked_neurons = sorted(list(spiked_neurons))
                if spiked_neurons:
                    self.f_set(neurons_with_spikes=spiked_neurons)

                for varname in varnames:
                     for neuron in range(len(monitor.source)):
                         values = monitor.values(varname,neuron)
                         if len(values)>0:
                             key = varname+'_' + format_string % neuron
                             self.f_set(**{key:values})
        else:
                raise RuntimeError('You shall not pass!')
示例#6
0
    def _extract_state_spike_monitor(self, monitor):

        self.f_set(source=str(monitor.source))

        varnames = monitor._varnames
        if not isinstance(varnames, tuple):
            varnames = (varnames,)

        for idx, varname in enumerate(varnames):
            unit = repr(get_unit_fast(monitor.spikes[0][idx + 2]))
            self.f_set(**{varname + '_unit': unit})


        self.f_set(varnames=varnames)

        #self.f_set(record = monitor.record)
        self.f_set(delay=monitor.delay)
        self.f_set(nspikes=monitor.nspikes)
        self.f_set(spiketimes_unit='second')

        if self._storage_mode == BrianMonitorResult.TABLE_MODE:
            spike_dict = {}

            if len(monitor.spikes) > 0:
                zip_lists = list(zip(*monitor.spikes))
                time_list = zip_lists[1]

                nounit_list = [np.float64(time) for time in time_list]

                spike_dict['spiketimes'] = nounit_list
                spike_dict['neuron'] = list(zip_lists[0])

                spiked_neurons = sorted(list(set(spike_dict['neuron'])))
                if spiked_neurons:
                    self.f_set(neurons_with_spikes=spiked_neurons)

                    count = 2
                    for varname in varnames:

                        var_list = list(zip_lists[count])

                        nounit_list = [np.float64(var) for var in var_list]
                        spike_dict[varname] = nounit_list
                        count += 1

                    self.f_set(spikes=pd.DataFrame(data=spike_dict))

        elif self._storage_mode == BrianMonitorResult.ARRAY_MODE:

            format_string = self._get_format_string(monitor)
            self.f_set(format_string=format_string)

            spiked_neurons = set()

            for neuron in range(len(monitor.source)):
                spikes = monitor.times(neuron)
                if len(spikes) > 0:

                    spiked_neurons.add(neuron)

                    key = 'spiketimes_' + format_string % neuron
                    self.f_set(**{key: spikes})

            spiked_neurons = sorted(list(spiked_neurons))
            if spiked_neurons:
                self.f_set(neurons_with_spikes=spiked_neurons)

            for varname in varnames:
                for neuron in range(len(monitor.source)):
                    values = monitor.values(varname, neuron)
                    if len(values) > 0:
                        key = varname + '_' + format_string % neuron
                        self.f_set(**{key: values})
        else:
            raise RuntimeError('You shall not pass!')
示例#7
0
    def check_state_spike_monitor(self, res, monitor):
        self.assertTrue(comp.nested_equal(monitor.delay, res.delay))
        self.assertTrue(comp.nested_equal(monitor.nspikes, res.nspikes))
        self.assertTrue(comp.nested_equal(str(monitor.source), res.source))
        self.assertTrue(comp.nested_equal(monitor._varnames, res.varnames))

        self.assertEqual('second', res.spiketimes_unit)

        if res.v_storage_mode == BrianMonitorResult.TABLE_MODE:
            spike_frame = res.spikes
            spiked_list = sorted(
                list(set(spike_frame['neuron'].to_dict().values())))
            self.assertEqual(spiked_list, res.neurons_with_spikes)
            for idx, val_tuple in enumerate(monitor.spikes):
                neuron = val_tuple[0]
                time = val_tuple[1]
                vals = val_tuple[2:]

                self.assertEqual(neuron, spike_frame['neuron'][idx])

                self.assertEqual(float(time), spike_frame['spiketimes'][idx])

                for idx_var, varname in enumerate(res.varnames):
                    val = vals[idx_var]
                    self.assertEqual(float(val), spike_frame[varname][idx])

        elif res.v_storage_mode == BrianMonitorResult.ARRAY_MODE:

            self.assertTrue('%0' in res.format_string
                            and 'd' in res.format_string)

            spiked_set = set()
            for item_name in res:

                if item_name.startswith(
                        'spiketimes') and not item_name.endswith('unit'):
                    neuron_id = int(item_name.split('_')[-1])
                    spiked_set.add(neuron_id)

                    times = monitor.times(neuron_id)
                    self.assertTrue(comp.nested_equal(times, res[item_name]))

                for varname in res.varnames:
                    if item_name.startswith(
                            varname) and not item_name.endswith('unit'):
                        neuron_id = int(item_name.split('_')[-1])
                        values = monitor.values(varname, neuron_id)

                        # Remove units:
                        self.assertTrue(
                            comp.nested_equal(values, res[item_name]))

            spiked_list = sorted(list(spiked_set))
            self.assertEqual(spiked_list, res.neurons_with_spikes)
        else:
            raise RuntimeError('You shall not pass!')

        # Check Units
        for idx, varname in enumerate(monitor._varnames):
            unit = repr(get_unit_fast(monitor.spikes[0][idx + 2]))
            self.assertTrue(unit, res[varname + '_unit'])