コード例 #1
0
def test_wta(p_intra,
             p_inter,
             inputs,
             single_inh_pop=False,
             muscimol_amount=0 * nS,
             injection_site=0):
    wta_params = default_params()
    wta_params.p_b_e = 0.1
    wta_params.p_x_e = 0.1
    wta_params.p_e_e = p_intra
    wta_params.p_e_i = p_inter
    wta_params.p_i_i = p_intra
    wta_params.p_i_e = p_inter

    input_freq = np.zeros(2)
    for i in range(2):
        input_freq[i] = float(inputs[i]) * Hz

    run_wta(wta_params,
            2,
            input_freq,
            1.0 * second,
            record_lfp=False,
            record_neuron_state=True,
            plot_output=True,
            single_inh_pop=single_inh_pop,
            muscimol_amount=muscimol_amount,
            injection_site=injection_site)
コード例 #2
0
def get_prob(x, output_dir):
    num_groups = 2
    trial_duration = 1 * second
    input_sum = 40.0
    num_trials = 5
    num_extra_trials = 10

    wta_params = default_params()
    wta_params.p_b_e = 0.1
    wta_params.p_x_e = 0.05
    wta_params.p_e_e = x[0]
    wta_params.p_e_i = x[1]
    wta_params.p_i_i = x[2]
    wta_params.p_i_e = x[3]

    file_desc='wta.groups.%d.duration.%0.3f.p_b_e.%0.3f.p_x_e.%0.3f.p_e_e.%0.3f.p_e_i.%0.3f.p_i_i.%0.3f.p_i_e.%0.3f' % \
              (num_groups, trial_duration, wta_params.p_b_e, wta_params.p_x_e, wta_params.p_e_e, wta_params.p_e_i,
               wta_params.p_i_i, wta_params.p_i_e)
    file_prefix = os.path.join(output_dir, file_desc)

    num_example_trials = [0, 0]
    for trial in range(num_trials):
        inputs = np.zeros(2)
        inputs[0] = np.random.random() * input_sum
        inputs[1] = input_sum - inputs[0]

        if inputs[0] > inputs[1]:
            num_example_trials[0] += 1
        else:
            num_example_trials[1] += 1

        if trial == num_trials - 1:
            if num_example_trials[0] == 0:
                inputs[
                    0] = input_sum * 0.5 + np.random.random() * input_sum * 0.5
                inputs[1] = input_sum - inputs[0]
                num_example_trials[0] += 1
            elif num_example_trials[1] == 0:
                inputs[
                    1] = input_sum * 0.5 + np.random.random() * input_sum * 0.5
                inputs[0] = input_sum - inputs[1]
                num_example_trials[1] += 1

        output_file = '%s.trial.%d.h5' % (file_prefix, trial)

        run_wta(wta_params,
                num_groups,
                inputs,
                trial_duration,
                output_file=output_file,
                record_lfp=False,
                record_voxel=False,
                record_neuron_state=False,
                record_spikes=False,
                record_firing_rate=True,
                record_inputs=True,
                single_inh_pop=False)

    auc = get_auc(file_prefix, num_trials, num_extra_trials, num_groups)
    return auc
コード例 #3
0
ファイル: test_dist.py プロジェクト: jbonaiuto/pySBI
def get_prob(x, output_dir):
    num_groups=2
    trial_duration=1*second
    input_sum=40.0
    num_trials=5
    num_extra_trials=10

    wta_params=default_params()
    wta_params.p_b_e=0.1
    wta_params.p_x_e=0.05
    wta_params.p_e_e=x[0]
    wta_params.p_e_i=x[1]
    wta_params.p_i_i=x[2]
    wta_params.p_i_e=x[3]

    file_desc='wta.groups.%d.duration.%0.3f.p_b_e.%0.3f.p_x_e.%0.3f.p_e_e.%0.3f.p_e_i.%0.3f.p_i_i.%0.3f.p_i_e.%0.3f' % \
              (num_groups, trial_duration, wta_params.p_b_e, wta_params.p_x_e, wta_params.p_e_e, wta_params.p_e_i,
               wta_params.p_i_i, wta_params.p_i_e)
    file_prefix=os.path.join(output_dir,file_desc)

    num_example_trials=[0,0]
    for trial in range(num_trials):
        inputs=np.zeros(2)
        inputs[0]=np.random.random()*input_sum
        inputs[1]=input_sum-inputs[0]

        if inputs[0]>inputs[1]:
            num_example_trials[0]+=1
        else:
            num_example_trials[1]+=1

        if trial==num_trials-1:
            if num_example_trials[0]==0:
                inputs[0]=input_sum*0.5+np.random.random()*input_sum*0.5
                inputs[1]=input_sum-inputs[0]
                num_example_trials[0]+=1
            elif num_example_trials[1]==0:
                inputs[1]=input_sum*0.5+np.random.random()*input_sum*0.5
                inputs[0]=input_sum-inputs[1]
                num_example_trials[1]+=1

        output_file='%s.trial.%d.h5' % (file_prefix,trial)

        run_wta(wta_params, num_groups, inputs, trial_duration, output_file=output_file, record_lfp=False,
            record_voxel=False, record_neuron_state=False, record_spikes=False, record_firing_rate=True,
            record_inputs=True, single_inh_pop=False)

    auc=get_auc(file_prefix, num_trials, num_extra_trials, num_groups)
    return auc
コード例 #4
0
ファイル: test_wta.py プロジェクト: jbonaiuto/pySBI
def test_wta(p_intra, p_inter, inputs, single_inh_pop=False, muscimol_amount=0*nS, injection_site=0):
    wta_params=default_params()
    wta_params.p_b_e=0.1
    wta_params.p_x_e=0.1
    wta_params.p_e_e=p_intra
    wta_params.p_e_i=p_inter
    wta_params.p_i_i=p_intra
    wta_params.p_i_e=p_inter

    input_freq=np.zeros(2)
    for i in range(2):
        input_freq[i]=float(inputs[i])*Hz

    run_wta(wta_params, 2, input_freq, 1.0*second, record_lfp=False, record_neuron_state=True, plot_output=True,
        single_inh_pop=single_inh_pop, muscimol_amount=muscimol_amount, injection_site=injection_site)
コード例 #5
0
ファイル: test_wta.py プロジェクト: jbonaiuto/pySBI
def test_contrast(p_intra, p_inter, num_trials, data_path, muscimol_amount=0*nS, injection_site=0, single_inh_pop=False):
    num_groups=2
    trial_duration=1.0*second

    wta_params=default_params()
    wta_params.p_b_e=0.1
    wta_params.p_x_e=0.1
    wta_params.p_e_e=p_intra
    wta_params.p_e_i=p_inter
    wta_params.p_i_i=p_intra
    wta_params.p_i_e=p_inter
    input_sum=40.0

    contrast_range=[0.0, 0.0625, 0.125, 0.25, 0.5, 1.0]
    trial_contrast=np.zeros([len(contrast_range)*num_trials,1])
    trial_max_bold=np.zeros(len(contrast_range)*num_trials)
    trial_max_exc_bold=np.zeros(len(contrast_range)*num_trials)
    for i,contrast in enumerate(contrast_range):
        print('Testing contrast %0.4f' % contrast)
        inputs=np.zeros(2)
        inputs[0]=(input_sum*(contrast+1.0)/2.0)
        inputs[1]=input_sum-inputs[0]

        for j in range(num_trials):
            print('Trial %d' % j)
            trial_contrast[i*num_trials+j]=contrast
            np.random.shuffle(inputs)

            input_freq=np.zeros(num_groups)
            for k in range(num_groups):
                input_freq[k]=float(inputs[k])*Hz

            file='wta.groups.%d.duration.%0.3f.p_b_e.%0.3f.p_x_e.%0.3f.p_e_e.%0.3f.p_e_i.%0.3f.p_i_i.%0.3f.p_i_e.%0.3f.contrast.%0.4f.trial.%d.h5' %\
                 (num_groups, trial_duration, wta_params.p_b_e, wta_params.p_x_e, wta_params.p_e_e, wta_params.p_e_i,
                  wta_params.p_i_i, wta_params.p_i_e, contrast, j)

            out_file=None
            if data_path is not None:
                out_file=os.path.join(data_path,file)
            wta_monitor=run_wta(wta_params, num_groups, input_freq, trial_duration, record_neuron_state=True,
                output_file=out_file, muscimol_amount=muscimol_amount, injection_site=injection_site, single_inh_pop=single_inh_pop)

            trial_max_bold[i*num_trials+j]=np.max(wta_monitor.voxel_monitor['y'].values)
            trial_max_exc_bold[i*num_trials+j]=np.max(wta_monitor.voxel_exc_monitor['y'].values)

    x_min=np.min(contrast_range)
    x_max=np.max(contrast_range)

    fig=plt.figure()
    clf=LinearRegression()
    clf.fit(trial_contrast,trial_max_bold)
    a=clf.coef_[0]
    b=clf.intercept_

    plt.plot(trial_contrast, trial_max_bold, 'x')
    plt.plot([x_min,x_max],[a*x_min+b,a*x_max+b],'--')
    plt.xlabel('Input Contrast')
    plt.ylabel('Max BOLD')
    plt.show()

    fig=plt.figure()
    clf=LinearRegression()
    clf.fit(trial_contrast,trial_max_exc_bold)
    a=clf.coef_[0]
    b=clf.intercept_

    plt.plot(trial_contrast, trial_max_exc_bold, 'o')
    plt.plot([x_min,x_max],[a*x_min+b,a*x_max+b],'--')
    plt.xlabel('Input Contrast')
    plt.ylabel('Max BOLD (exc only)')
    plt.show()
コード例 #6
0
ファイル: test_wta.py プロジェクト: jbonaiuto/pySBI
def test_contrast_lesion(p_intra, p_inter, trial_numbers, data_path, muscimol_amount=0*nS, injection_site=0,
                         single_inh_pop=False, plot_summary=True):
    num_groups=2
    trial_duration=1.0*second

    wta_params=default_params()
    wta_params.p_b_e=0.1
    wta_params.p_x_e=0.1
    wta_params.p_e_e=p_intra
    wta_params.p_e_i=p_inter
    wta_params.p_i_i=p_intra
    wta_params.p_i_e=p_inter
    input_sum=40.0

    contrast_range=[0.0, 0.0625, 0.125, 0.25, 0.5, 1.0]
    num_trials=len(trial_numbers)
    trial_contrast=np.zeros([len(contrast_range)*num_trials,1])
    trial_max_bold=np.zeros(len(contrast_range)*num_trials)
    trial_max_exc_bold=np.zeros(len(contrast_range)*num_trials)
    for i,contrast in enumerate(contrast_range):
        print('Testing contrast %0.4f' % contrast)
        inputs=np.zeros(2)
        inputs[0]=(input_sum*(contrast+1.0)/2.0)
        inputs[1]=input_sum-inputs[0]

        for j,trial_idx in enumerate(trial_numbers):
            print('Trial %d' % trial_idx)
            trial_contrast[i*num_trials+j]=contrast
            np.random.shuffle(inputs)

            input_freq=np.zeros(num_groups)
            for k in range(num_groups):
                input_freq[k]=float(inputs[k])*Hz

            file='wta.groups.%d.duration.%0.3f.p_b_e.%0.3f.p_x_e.%0.3f.p_e_e.%0.3f.p_e_i.%0.3f.p_i_i.%0.3f.p_i_e.%0.3f.contrast.%0.4f.trial.%d.h5' %\
                 (num_groups, trial_duration, wta_params.p_b_e, wta_params.p_x_e, wta_params.p_e_e, wta_params.p_e_i,
                  wta_params.p_i_i, wta_params.p_i_e, contrast, trial_idx)

            out_file=None
            if not data_path is None:
                out_file=os.path.join(data_path,file)
            wta_monitor=run_wta(wta_params, num_groups, input_freq, trial_duration, output_file=out_file,
                single_inh_pop=single_inh_pop, record_spikes=False, record_lfp=False, save_summary_only=True)

            trial_max_bold[i*num_trials+j]=np.max(wta_monitor.voxel_monitor['y'].values)
            trial_max_exc_bold[i*num_trials+j]=np.max(wta_monitor.voxel_exc_monitor['y'].values)

    lesioned_trial_max_bold=np.zeros(len(contrast_range)*num_trials)
    lesioned_trial_max_exc_bold=np.zeros(len(contrast_range)*num_trials)
    for i,contrast in enumerate(contrast_range):
        print('Testing contrast %0.4f' % contrast)
        inputs=np.zeros(2)
        inputs[0]=(input_sum*(contrast+1.0)/2.0)
        inputs[1]=input_sum-inputs[0]

        for j,trial_idx in enumerate(trial_numbers):
            print('Trial %d' % j)
            trial_contrast[i*num_trials+j]=contrast
            np.random.shuffle(inputs)

            input_freq=np.zeros(num_groups)
            for k in range(num_groups):
                input_freq[k]=float(inputs[k])*Hz

            file='lesioned.wta.groups.%d.duration.%0.3f.p_b_e.%0.3f.p_x_e.%0.3f.p_e_e.%0.3f.p_e_i.%0.3f.p_i_i.%0.3f.p_i_e.%0.3f.contrast.%0.4f.trial.%d.h5' %\
                 (num_groups, trial_duration, wta_params.p_b_e, wta_params.p_x_e, wta_params.p_e_e, wta_params.p_e_i,
                  wta_params.p_i_i, wta_params.p_i_e, contrast, trial_idx)

            out_file=None
            if not data_path is None:
                out_file=os.path.join(data_path,file)
            wta_monitor=run_wta(wta_params, num_groups, input_freq, trial_duration, output_file=out_file,
                muscimol_amount=muscimol_amount, injection_site=injection_site, single_inh_pop=single_inh_pop,
                record_spikes=False, record_lfp=False, save_summary_only=True)

            lesioned_trial_max_bold[i*num_trials+j]=np.max(wta_monitor.voxel_monitor['y'].values)
            lesioned_trial_max_exc_bold[i*num_trials+j]=np.max(wta_monitor.voxel_exc_monitor['y'].values)

    if plot_summary:
        x_min=np.min(contrast_range)
        x_max=np.max(contrast_range)

        fig=plt.figure()
        control_clf=LinearRegression()
        control_clf.fit(trial_contrast,trial_max_bold)
        control_a=control_clf.coef_[0]
        control_b=control_clf.intercept_

        lesion_clf=LinearRegression()
        lesion_clf.fit(trial_contrast,lesioned_trial_max_bold)
        lesion_a=lesion_clf.coef_[0]
        lesion_b=lesion_clf.intercept_

        plt.plot(trial_contrast, trial_max_bold, 'xb')
        plt.plot(trial_contrast, lesioned_trial_max_bold, 'xr')
        plt.plot([x_min,x_max],[control_a*x_min+control_b,control_a*x_max+control_b],'--b',label='Control')
        plt.plot([x_min,x_max],[lesion_a*x_min+lesion_b,lesion_a*x_max+lesion_b],'--r',label='Lesioned')
        plt.xlabel('Input Contrast')
        plt.ylabel('Max BOLD')
        plt.legend()
        plt.show()

        fig=plt.figure()
        control_exc_clf=LinearRegression()
        control_exc_clf.fit(trial_contrast,trial_max_exc_bold)
        control_exc_a=control_exc_clf.coef_[0]
        control_exc_b=control_exc_clf.intercept_

        lesion_exc_clf=LinearRegression()
        lesion_exc_clf.fit(trial_contrast,lesioned_trial_max_exc_bold)
        lesion_exc_a=lesion_exc_clf.coef_[0]
        lesion_exc_b=lesion_exc_clf.intercept_

        plt.plot(trial_contrast, trial_max_exc_bold, 'ob')
        plt.plot(trial_contrast, lesioned_trial_max_exc_bold, 'or')
        plt.plot([x_min,x_max],[control_exc_a*x_min+control_exc_b,control_exc_a*x_max+control_exc_b],'--b',label='Control')
        plt.plot([x_min,x_max],[lesion_exc_a*x_min+lesion_exc_b,lesion_exc_a*x_max+lesion_exc_b],'--r',label='Lesioned')
        plt.xlabel('Input Contrast')
        plt.ylabel('Max BOLD (exc only)')
        plt.legend()
        plt.show()
コード例 #7
0
def run_rl_simulation(mat_file, alpha=0.4, beta=5.0, background_freq=None, p_dcs=0*pA, i_dcs=0*pA, dcs_start_time=0*ms,
                      output_file=None):
    mat = scipy.io.loadmat(mat_file)
    prob_idx=-1
    mags_idx=-1
    for idx,(dtype,o) in enumerate(mat['store']['dat'][0][0].dtype.descr):
        if dtype=='probswalk':
            prob_idx=idx
        elif dtype=='mags':
            mags_idx=idx
    prob_walk=mat['store']['dat'][0][0][0][0][prob_idx]
    mags=mat['store']['dat'][0][0][0][0][mags_idx]
    prob_walk=prob_walk.astype(np.float32, copy=False)
    mags=mags.astype(np.float32, copy=False)
    mags /= 100.0

    wta_params=default_params()
    wta_params.input_var=0*Hz

    sim_params=simulation_params()
    sim_params.p_dcs=p_dcs
    sim_params.i_dcs=i_dcs
    sim_params.dcs_start_time=dcs_start_time

    exp_rew=np.array([0.5, 0.5])
    if background_freq is None:
        background_freq=(beta-161.08)/-.17
    wta_params.background_freq=background_freq


    trials=prob_walk.shape[1]
    sim_params.ntrials=trials

    vals=np.zeros(prob_walk.shape)
    choice=np.zeros(trials)
    rew=np.zeros(trials)
    rts=np.zeros(trials)
    inputs=np.zeros(prob_walk.shape)

    if output_file is not None:
        f = h5py.File(output_file, 'w')

        f.attrs['alpha']=alpha
        f.attrs['beta']=beta
        f.attrs['mat_file']=mat_file

        f_sim_params=f.create_group('sim_params')
        for attr, value in sim_params.iteritems():
            f_sim_params.attrs[attr] = value

        f_network_params=f.create_group('network_params')
        for attr, value in wta_params.iteritems():
            f_network_params.attrs[attr] = value

        f_pyr_params=f.create_group('pyr_params')
        for attr, value in pyr_params.iteritems():
            f_pyr_params.attrs[attr] = value

        f_inh_params=f.create_group('inh_params')
        for attr, value in inh_params.iteritems():
            f_inh_params.attrs[attr] = value

    for trial in range(sim_params.ntrials):
        print('Trial %d' % trial)
        vals[:,trial]=exp_rew
        ev=vals[:,trial]*mags[:,trial]
        inputs[0,trial]=ev[0]
        inputs[1,trial]=ev[1]
        inputs[:,trial]=40.0+40.0*inputs[:,trial]

        trial_monitor=run_wta(wta_params, inputs[:,trial], sim_params, record_lfp=False, record_voxel=False,
            record_neuron_state=False, record_spikes=True, record_firing_rate=True, record_inputs=False,
            plot_output=False)

        e_rates = []
        for i in range(wta_params.num_groups):
            e_rates.append(trial_monitor.monitors['excitatory_rate_%d' % i].smooth_rate(width=5 * ms, filter='gaussian'))
        i_rates = [trial_monitor.monitors['inhibitory_rate'].smooth_rate(width=5 * ms, filter='gaussian')]

        if output_file is not None:
            trial_group=f.create_group('trial %d' % trial)
            trial_group['e_rates'] = np.array(e_rates)

            trial_group['i_rates'] = np.array(i_rates)

        rt,decision_idx=get_response_time(e_rates, sim_params.stim_start_time, sim_params.stim_end_time,
            upper_threshold=wta_params.resp_threshold, lower_threshold=None, dt=sim_params.dt)

        reward=0.0
        if decision_idx>=0 and np.random.random()<=prob_walk[decision_idx,trial]:
            reward=1.0

        exp_rew[decision_idx]=(1.0-alpha)*exp_rew[decision_idx]+alpha*reward
        choice[trial]=decision_idx
        rts[trial]=rt
        rew[trial]=reward

    param_ests,prop_correct=fit_behavior(prob_walk, mags, rew, choice)

    if output_file is not None:
        f.attrs['est_alpha']=param_ests[0]
        f.attrs['est_beta']=param_ests[1]
        f.attrs['prop_correct']=prop_correct

        f['prob_walk']=prob_walk
        f['mags']=mags
        f['rew']=rew
        f['choice']=choice
        f['vals']=vals
        f['inputs']=inputs
        f['rts']=rts
        f.close()
コード例 #8
0
def test_contrast(p_intra,
                  p_inter,
                  num_trials,
                  data_path,
                  muscimol_amount=0 * nS,
                  injection_site=0,
                  single_inh_pop=False):
    num_groups = 2
    trial_duration = 1.0 * second

    wta_params = default_params()
    wta_params.p_b_e = 0.1
    wta_params.p_x_e = 0.1
    wta_params.p_e_e = p_intra
    wta_params.p_e_i = p_inter
    wta_params.p_i_i = p_intra
    wta_params.p_i_e = p_inter
    input_sum = 40.0

    contrast_range = [0.0, 0.0625, 0.125, 0.25, 0.5, 1.0]
    trial_contrast = np.zeros([len(contrast_range) * num_trials, 1])
    trial_max_bold = np.zeros(len(contrast_range) * num_trials)
    trial_max_exc_bold = np.zeros(len(contrast_range) * num_trials)
    for i, contrast in enumerate(contrast_range):
        print('Testing contrast %0.4f' % contrast)
        inputs = np.zeros(2)
        inputs[0] = (input_sum * (contrast + 1.0) / 2.0)
        inputs[1] = input_sum - inputs[0]

        for j in range(num_trials):
            print('Trial %d' % j)
            trial_contrast[i * num_trials + j] = contrast
            np.random.shuffle(inputs)

            input_freq = np.zeros(num_groups)
            for k in range(num_groups):
                input_freq[k] = float(inputs[k]) * Hz

            file='wta.groups.%d.duration.%0.3f.p_b_e.%0.3f.p_x_e.%0.3f.p_e_e.%0.3f.p_e_i.%0.3f.p_i_i.%0.3f.p_i_e.%0.3f.contrast.%0.4f.trial.%d.h5' %\
                 (num_groups, trial_duration, wta_params.p_b_e, wta_params.p_x_e, wta_params.p_e_e, wta_params.p_e_i,
                  wta_params.p_i_i, wta_params.p_i_e, contrast, j)

            out_file = None
            if data_path is not None:
                out_file = os.path.join(data_path, file)
            wta_monitor = run_wta(wta_params,
                                  num_groups,
                                  input_freq,
                                  trial_duration,
                                  record_neuron_state=True,
                                  output_file=out_file,
                                  muscimol_amount=muscimol_amount,
                                  injection_site=injection_site,
                                  single_inh_pop=single_inh_pop)

            trial_max_bold[i * num_trials + j] = np.max(
                wta_monitor.voxel_monitor['y'].values)
            trial_max_exc_bold[i * num_trials + j] = np.max(
                wta_monitor.voxel_exc_monitor['y'].values)

    x_min = np.min(contrast_range)
    x_max = np.max(contrast_range)

    fig = plt.figure()
    clf = LinearRegression()
    clf.fit(trial_contrast, trial_max_bold)
    a = clf.coef_[0]
    b = clf.intercept_

    plt.plot(trial_contrast, trial_max_bold, 'x')
    plt.plot([x_min, x_max], [a * x_min + b, a * x_max + b], '--')
    plt.xlabel('Input Contrast')
    plt.ylabel('Max BOLD')
    plt.show()

    fig = plt.figure()
    clf = LinearRegression()
    clf.fit(trial_contrast, trial_max_exc_bold)
    a = clf.coef_[0]
    b = clf.intercept_

    plt.plot(trial_contrast, trial_max_exc_bold, 'o')
    plt.plot([x_min, x_max], [a * x_min + b, a * x_max + b], '--')
    plt.xlabel('Input Contrast')
    plt.ylabel('Max BOLD (exc only)')
    plt.show()
コード例 #9
0
def test_contrast_lesion(p_intra,
                         p_inter,
                         trial_numbers,
                         data_path,
                         muscimol_amount=0 * nS,
                         injection_site=0,
                         single_inh_pop=False,
                         plot_summary=True):
    num_groups = 2
    trial_duration = 1.0 * second

    wta_params = default_params()
    wta_params.p_b_e = 0.1
    wta_params.p_x_e = 0.1
    wta_params.p_e_e = p_intra
    wta_params.p_e_i = p_inter
    wta_params.p_i_i = p_intra
    wta_params.p_i_e = p_inter
    input_sum = 40.0

    contrast_range = [0.0, 0.0625, 0.125, 0.25, 0.5, 1.0]
    num_trials = len(trial_numbers)
    trial_contrast = np.zeros([len(contrast_range) * num_trials, 1])
    trial_max_bold = np.zeros(len(contrast_range) * num_trials)
    trial_max_exc_bold = np.zeros(len(contrast_range) * num_trials)
    for i, contrast in enumerate(contrast_range):
        print('Testing contrast %0.4f' % contrast)
        inputs = np.zeros(2)
        inputs[0] = (input_sum * (contrast + 1.0) / 2.0)
        inputs[1] = input_sum - inputs[0]

        for j, trial_idx in enumerate(trial_numbers):
            print('Trial %d' % trial_idx)
            trial_contrast[i * num_trials + j] = contrast
            np.random.shuffle(inputs)

            input_freq = np.zeros(num_groups)
            for k in range(num_groups):
                input_freq[k] = float(inputs[k]) * Hz

            file='wta.groups.%d.duration.%0.3f.p_b_e.%0.3f.p_x_e.%0.3f.p_e_e.%0.3f.p_e_i.%0.3f.p_i_i.%0.3f.p_i_e.%0.3f.contrast.%0.4f.trial.%d.h5' %\
                 (num_groups, trial_duration, wta_params.p_b_e, wta_params.p_x_e, wta_params.p_e_e, wta_params.p_e_i,
                  wta_params.p_i_i, wta_params.p_i_e, contrast, trial_idx)

            out_file = None
            if not data_path is None:
                out_file = os.path.join(data_path, file)
            wta_monitor = run_wta(wta_params,
                                  num_groups,
                                  input_freq,
                                  trial_duration,
                                  output_file=out_file,
                                  single_inh_pop=single_inh_pop,
                                  record_spikes=False,
                                  record_lfp=False,
                                  save_summary_only=True)

            trial_max_bold[i * num_trials + j] = np.max(
                wta_monitor.voxel_monitor['y'].values)
            trial_max_exc_bold[i * num_trials + j] = np.max(
                wta_monitor.voxel_exc_monitor['y'].values)

    lesioned_trial_max_bold = np.zeros(len(contrast_range) * num_trials)
    lesioned_trial_max_exc_bold = np.zeros(len(contrast_range) * num_trials)
    for i, contrast in enumerate(contrast_range):
        print('Testing contrast %0.4f' % contrast)
        inputs = np.zeros(2)
        inputs[0] = (input_sum * (contrast + 1.0) / 2.0)
        inputs[1] = input_sum - inputs[0]

        for j, trial_idx in enumerate(trial_numbers):
            print('Trial %d' % j)
            trial_contrast[i * num_trials + j] = contrast
            np.random.shuffle(inputs)

            input_freq = np.zeros(num_groups)
            for k in range(num_groups):
                input_freq[k] = float(inputs[k]) * Hz

            file='lesioned.wta.groups.%d.duration.%0.3f.p_b_e.%0.3f.p_x_e.%0.3f.p_e_e.%0.3f.p_e_i.%0.3f.p_i_i.%0.3f.p_i_e.%0.3f.contrast.%0.4f.trial.%d.h5' %\
                 (num_groups, trial_duration, wta_params.p_b_e, wta_params.p_x_e, wta_params.p_e_e, wta_params.p_e_i,
                  wta_params.p_i_i, wta_params.p_i_e, contrast, trial_idx)

            out_file = None
            if not data_path is None:
                out_file = os.path.join(data_path, file)
            wta_monitor = run_wta(wta_params,
                                  num_groups,
                                  input_freq,
                                  trial_duration,
                                  output_file=out_file,
                                  muscimol_amount=muscimol_amount,
                                  injection_site=injection_site,
                                  single_inh_pop=single_inh_pop,
                                  record_spikes=False,
                                  record_lfp=False,
                                  save_summary_only=True)

            lesioned_trial_max_bold[i * num_trials + j] = np.max(
                wta_monitor.voxel_monitor['y'].values)
            lesioned_trial_max_exc_bold[i * num_trials + j] = np.max(
                wta_monitor.voxel_exc_monitor['y'].values)

    if plot_summary:
        x_min = np.min(contrast_range)
        x_max = np.max(contrast_range)

        fig = plt.figure()
        control_clf = LinearRegression()
        control_clf.fit(trial_contrast, trial_max_bold)
        control_a = control_clf.coef_[0]
        control_b = control_clf.intercept_

        lesion_clf = LinearRegression()
        lesion_clf.fit(trial_contrast, lesioned_trial_max_bold)
        lesion_a = lesion_clf.coef_[0]
        lesion_b = lesion_clf.intercept_

        plt.plot(trial_contrast, trial_max_bold, 'xb')
        plt.plot(trial_contrast, lesioned_trial_max_bold, 'xr')
        plt.plot(
            [x_min, x_max],
            [control_a * x_min + control_b, control_a * x_max + control_b],
            '--b',
            label='Control')
        plt.plot([x_min, x_max],
                 [lesion_a * x_min + lesion_b, lesion_a * x_max + lesion_b],
                 '--r',
                 label='Lesioned')
        plt.xlabel('Input Contrast')
        plt.ylabel('Max BOLD')
        plt.legend()
        plt.show()

        fig = plt.figure()
        control_exc_clf = LinearRegression()
        control_exc_clf.fit(trial_contrast, trial_max_exc_bold)
        control_exc_a = control_exc_clf.coef_[0]
        control_exc_b = control_exc_clf.intercept_

        lesion_exc_clf = LinearRegression()
        lesion_exc_clf.fit(trial_contrast, lesioned_trial_max_exc_bold)
        lesion_exc_a = lesion_exc_clf.coef_[0]
        lesion_exc_b = lesion_exc_clf.intercept_

        plt.plot(trial_contrast, trial_max_exc_bold, 'ob')
        plt.plot(trial_contrast, lesioned_trial_max_exc_bold, 'or')
        plt.plot([x_min, x_max], [
            control_exc_a * x_min + control_exc_b,
            control_exc_a * x_max + control_exc_b
        ],
                 '--b',
                 label='Control')
        plt.plot([x_min, x_max], [
            lesion_exc_a * x_min + lesion_exc_b,
            lesion_exc_a * x_max + lesion_exc_b
        ],
                 '--r',
                 label='Lesioned')
        plt.xlabel('Input Contrast')
        plt.ylabel('Max BOLD (exc only)')
        plt.legend()
        plt.show()