示例#1
0
    def record_trial(self, trial_idx, inputs, correct_input, wta_net, wta_monitor):
        self.trial_inputs[:,trial_idx]=inputs

        e_rate_0 = wta_monitor.monitors['excitatory_rate_0'].smooth_rate(width= 5 * ms, filter = 'gaussian')
        e_rate_1 = wta_monitor.monitors['excitatory_rate_1'].smooth_rate(width= 5 * ms, filter = 'gaussian')
        i_rate = wta_monitor.monitors['inhibitory_rate'].smooth_rate(width= 5 * ms, filter = 'gaussian')

        if self.record_firing_rates:
            self.pop_rates['excitatory_rate_0'].append(e_rate_0)
            self.pop_rates['excitatory_rate_1'].append(e_rate_1)
            self.pop_rates['inhibitory_rate'].append(i_rate)

        rt, choice = get_response_time(np.array([e_rate_0, e_rate_1]), self.sim_params.stim_start_time,
            self.sim_params.stim_end_time, upper_threshold = self.network_params.resp_threshold,
            dt = self.sim_params.dt)

        correct = choice == correct_input
        if choice>-1:
            print 'response time = %.3f correct = %d' % (rt, int(correct))
        else:
            print 'no response!'
            self.num_no_response+=1
        self.trial_rt[0,trial_idx]=rt
        self.trial_resp[0,trial_idx]=choice
        self.trial_correct[0,trial_idx]=correct
        self.correct_avg[0,trial_idx] = (np.sum(self.trial_correct))/(trial_idx+1)
        for conn in self.record_connections:
            self.trial_weights[conn].append(wta_net.connections[conn].W.todense())
示例#2
0
文件: plot.py 项目: jbonaiuto/pySBI
def plot_network_firing_rates(e_rates, sim_params, network_params, std_e_rates=None, i_rate=None, std_i_rate=None,
                              plt_title=None, labels=None, ax=None):
    rt, choice = get_response_time(e_rates, sim_params.stim_start_time, sim_params.stim_end_time,
                                   upper_threshold = network_params.resp_threshold, dt = sim_params.dt)

    if ax is None:
        figure()
    max_rates=[network_params.resp_threshold]
    if i_rate is not None:
        max_rates.append(np.max(i_rate[500:]))
    for i in range(network_params.num_groups):
        max_rates.append(np.max(e_rates[i,500:]))
    max_rate=np.max(max_rates)

    if i_rate is not None:
        ax=subplot(211)
    elif ax is None:
        ax=subplot(111)
    rect=Rectangle((0,0),(sim_params.stim_end_time-sim_params.stim_start_time)/ms, max_rate+5,
        alpha=0.25, facecolor='yellow', edgecolor='none')
    ax.add_patch(rect)

    for idx in range(network_params.num_groups):
        label='e %d' % idx
        if labels is not None:
            label=labels[idx]
        time_ticks=(np.array(range(e_rates.shape[1]))*sim_params.dt)/ms-sim_params.stim_start_time/ms
        baseline,=ax.plot(time_ticks, e_rates[idx,:], label=label)
        if std_e_rates is not None:
            ax.fill_between(time_ticks, e_rates[idx,:]-std_e_rates[idx,:], e_rates[idx,:]+std_e_rates[idx,:], alpha=0.5,
                facecolor=baseline.get_color())
    ylim(0,max_rate+5)
    ax.plot([0-sim_params.stim_start_time/ms, (sim_params.trial_duration-sim_params.stim_start_time)/ms],
        [network_params.resp_threshold/hertz, network_params.resp_threshold/hertz], 'k--')
    ax.plot([rt,rt],[0, max_rate+5],'k--')
    legend(loc='best')
    ylabel('Firing rate (Hz)')
    if plt_title is not None:
        title(plt_title)

    if i_rate is not None:
        ax=subplot(212)
        rect=Rectangle((0,0),(sim_params.stim_end_time-sim_params.stim_start_time)/ms, max_rate+5,
            alpha=0.25, facecolor='yellow', edgecolor='none')
        ax.add_patch(rect)
        label='i'
        if labels is not None:
            label=labels[network_params.num_groups]
        time_ticks=(np.array(range(len(i_rate)))*sim_params.dt)/ms-sim_params.stim_start_time/ms
        baseline,=ax.plot(time_ticks, i_rate, label=label)
        if std_i_rate is not None:
            ax.fill_between(time_ticks, i_rate-std_i_rate, i_rate+std_i_rate, alpha=0.5, facecolor=baseline.get_color())
        ylim(0,max_rate+5)
        ax.plot([rt,rt],[0, max_rate],'k--')
        ylabel('Firing rate (Hz)')
    xlabel('Time (ms)')
示例#3
0
def run_rl_simulation(mat_file, alpha=0.4, beta=5.0, background_freq=None, p_dcs=0*pA, i_dcs=0*pA, dcs_start_time=0*ms,
                      output_file=None):
    mat = scipy.io.loadmat(mat_file)
    prob_idx=-1
    mags_idx=-1
    for idx,(dtype,o) in enumerate(mat['store']['dat'][0][0].dtype.descr):
        if dtype=='probswalk':
            prob_idx=idx
        elif dtype=='mags':
            mags_idx=idx
    prob_walk=mat['store']['dat'][0][0][0][0][prob_idx]
    mags=mat['store']['dat'][0][0][0][0][mags_idx]
    prob_walk=prob_walk.astype(np.float32, copy=False)
    mags=mags.astype(np.float32, copy=False)
    mags /= 100.0

    wta_params=default_params()
    wta_params.input_var=0*Hz

    sim_params=simulation_params()
    sim_params.p_dcs=p_dcs
    sim_params.i_dcs=i_dcs
    sim_params.dcs_start_time=dcs_start_time

    exp_rew=np.array([0.5, 0.5])
    if background_freq is None:
        background_freq=(beta-161.08)/-.17
    wta_params.background_freq=background_freq


    trials=prob_walk.shape[1]
    sim_params.ntrials=trials

    vals=np.zeros(prob_walk.shape)
    choice=np.zeros(trials)
    rew=np.zeros(trials)
    rts=np.zeros(trials)
    inputs=np.zeros(prob_walk.shape)

    if output_file is not None:
        f = h5py.File(output_file, 'w')

        f.attrs['alpha']=alpha
        f.attrs['beta']=beta
        f.attrs['mat_file']=mat_file

        f_sim_params=f.create_group('sim_params')
        for attr, value in sim_params.iteritems():
            f_sim_params.attrs[attr] = value

        f_network_params=f.create_group('network_params')
        for attr, value in wta_params.iteritems():
            f_network_params.attrs[attr] = value

        f_pyr_params=f.create_group('pyr_params')
        for attr, value in pyr_params.iteritems():
            f_pyr_params.attrs[attr] = value

        f_inh_params=f.create_group('inh_params')
        for attr, value in inh_params.iteritems():
            f_inh_params.attrs[attr] = value

    for trial in range(sim_params.ntrials):
        print('Trial %d' % trial)
        vals[:,trial]=exp_rew
        ev=vals[:,trial]*mags[:,trial]
        inputs[0,trial]=ev[0]
        inputs[1,trial]=ev[1]
        inputs[:,trial]=40.0+40.0*inputs[:,trial]

        trial_monitor=run_wta(wta_params, inputs[:,trial], sim_params, record_lfp=False, record_voxel=False,
            record_neuron_state=False, record_spikes=True, record_firing_rate=True, record_inputs=False,
            plot_output=False)

        e_rates = []
        for i in range(wta_params.num_groups):
            e_rates.append(trial_monitor.monitors['excitatory_rate_%d' % i].smooth_rate(width=5 * ms, filter='gaussian'))
        i_rates = [trial_monitor.monitors['inhibitory_rate'].smooth_rate(width=5 * ms, filter='gaussian')]

        if output_file is not None:
            trial_group=f.create_group('trial %d' % trial)
            trial_group['e_rates'] = np.array(e_rates)

            trial_group['i_rates'] = np.array(i_rates)

        rt,decision_idx=get_response_time(e_rates, sim_params.stim_start_time, sim_params.stim_end_time,
            upper_threshold=wta_params.resp_threshold, lower_threshold=None, dt=sim_params.dt)

        reward=0.0
        if decision_idx>=0 and np.random.random()<=prob_walk[decision_idx,trial]:
            reward=1.0

        exp_rew[decision_idx]=(1.0-alpha)*exp_rew[decision_idx]+alpha*reward
        choice[trial]=decision_idx
        rts[trial]=rt
        rew[trial]=reward

    param_ests,prop_correct=fit_behavior(prob_walk, mags, rew, choice)

    if output_file is not None:
        f.attrs['est_alpha']=param_ests[0]
        f.attrs['est_beta']=param_ests[1]
        f.attrs['prop_correct']=prop_correct

        f['prob_walk']=prob_walk
        f['mags']=mags
        f['rew']=rew
        f['choice']=choice
        f['vals']=vals
        f['inputs']=inputs
        f['rts']=rts
        f.close()
示例#4
0
def plot_network_firing_rates(e_rates,
                              sim_params,
                              network_params,
                              std_e_rates=None,
                              i_rate=None,
                              std_i_rate=None,
                              plt_title=None,
                              labels=None,
                              ax=None):
    rt, choice = get_response_time(
        e_rates,
        sim_params.stim_start_time,
        sim_params.stim_end_time,
        upper_threshold=network_params.resp_threshold,
        dt=sim_params.dt)

    if ax is None:
        figure()
    max_rates = [network_params.resp_threshold]
    if i_rate is not None:
        max_rates.append(np.max(i_rate[500:]))
    for i in range(network_params.num_groups):
        max_rates.append(np.max(e_rates[i, 500:]))
    max_rate = np.max(max_rates)

    if i_rate is not None:
        ax = subplot(211)
    elif ax is None:
        ax = subplot(111)
    rect = Rectangle(
        (0, 0), (sim_params.stim_end_time - sim_params.stim_start_time) / ms,
        max_rate + 5,
        alpha=0.25,
        facecolor='yellow',
        edgecolor='none')
    ax.add_patch(rect)

    for idx in range(network_params.num_groups):
        label = 'e %d' % idx
        if labels is not None:
            label = labels[idx]
        time_ticks = (np.array(range(e_rates.shape[1])) *
                      sim_params.dt) / ms - sim_params.stim_start_time / ms
        baseline, = ax.plot(time_ticks, e_rates[idx, :], label=label)
        if std_e_rates is not None:
            ax.fill_between(time_ticks,
                            e_rates[idx, :] - std_e_rates[idx, :],
                            e_rates[idx, :] + std_e_rates[idx, :],
                            alpha=0.5,
                            facecolor=baseline.get_color())
    ylim(0, max_rate + 5)
    ax.plot([
        0 - sim_params.stim_start_time / ms,
        (sim_params.trial_duration - sim_params.stim_start_time) / ms
    ], [
        network_params.resp_threshold / hertz,
        network_params.resp_threshold / hertz
    ], 'k--')
    ax.plot([rt, rt], [0, max_rate + 5], 'k--')
    legend(loc='best')
    ylabel('Firing rate (Hz)')
    if plt_title is not None:
        title(plt_title)

    if i_rate is not None:
        ax = subplot(212)
        rect = Rectangle(
            (0, 0),
            (sim_params.stim_end_time - sim_params.stim_start_time) / ms,
            max_rate + 5,
            alpha=0.25,
            facecolor='yellow',
            edgecolor='none')
        ax.add_patch(rect)
        label = 'i'
        if labels is not None:
            label = labels[network_params.num_groups]
        time_ticks = (np.array(range(len(i_rate))) *
                      sim_params.dt) / ms - sim_params.stim_start_time / ms
        baseline, = ax.plot(time_ticks, i_rate, label=label)
        if std_i_rate is not None:
            ax.fill_between(time_ticks,
                            i_rate - std_i_rate,
                            i_rate + std_i_rate,
                            alpha=0.5,
                            facecolor=baseline.get_color())
        ylim(0, max_rate + 5)
        ax.plot([rt, rt], [0, max_rate], 'k--')
        ylabel('Firing rate (Hz)')
    xlabel('Time (ms)')