예제 #1
0
def create_roc_report(file_prefix,
                      num_groups,
                      contrast_range,
                      num_trials,
                      reports_dir,
                      regenerate_plot=True):
    num_extra_trials = 10
    roc_report = Struct()
    roc_report.auc = get_auc(file_prefix, contrast_range, num_trials,
                             num_extra_trials, num_groups)
    roc_report.auc_single_option = []
    roc_url = 'img/roc.png'
    fname = os.path.join(reports_dir, roc_url)
    roc_report.roc_url = roc_url
    if regenerate_plot or not os.path.exists(fname):
        fig = plt.figure()
        for i in range(num_groups):
            roc = get_roc_single_option(file_prefix, contrast_range,
                                        num_trials, num_extra_trials, i)
            plt.plot(roc[:, 0], roc[:, 1], 'x-', label='option %d' % i)
            roc_report.auc_single_option.append(
                get_auc_single_option(file_prefix, contrast_range, num_trials,
                                      num_extra_trials, i))
        plt.plot([0, 1], [0, 1], '--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        save_to_png(fig, fname)
        save_to_eps(fig, os.path.join(reports_dir, 'img/roc.eps'))
        plt.close()
    return roc_report
예제 #2
0
파일: wta.py 프로젝트: jbonaiuto/pySBI
def create_roc_report(file_prefix, num_groups, contrast_range, num_trials, reports_dir, regenerate_plot=True):
    num_extra_trials=10
    roc_report=Struct()
    roc_report.auc=get_auc(file_prefix, contrast_range, num_trials, num_extra_trials, num_groups)
    roc_report.auc_single_option=[]
    roc_url = 'img/roc.png'
    fname=os.path.join(reports_dir, roc_url)
    roc_report.roc_url=roc_url
    if regenerate_plot or not os.path.exists(fname):
        fig=plt.figure()
        for i in range(num_groups):
            roc=get_roc_single_option(file_prefix, contrast_range, num_trials, num_extra_trials, i)
            plt.plot(roc[:,0],roc[:,1],'x-',label='option %d' % i)
            roc_report.auc_single_option.append(get_auc_single_option(file_prefix, contrast_range, num_trials, num_extra_trials, i))
        plt.plot([0,1],[0,1],'--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        save_to_png(fig, fname)
        save_to_eps(fig, os.path.join(reports_dir, 'img/roc.eps'))
        plt.close()
    return roc_report
예제 #3
0
def create_trial_report(data,
                        reports_dir,
                        contrast,
                        trial_idx,
                        regenerate_plots=True):
    trial = Struct()
    trial.input_freq = data.input_freq
    trial.input_contrast = abs(data.input_freq[0] - data.input_freq[1]) / sum(
        data.input_freq)
    trial.correct = 0.0
    option_idx = -1
    if data.input_freq[0] > data.input_freq[1]:
        option_idx = 0
    elif data.input_freq[1] > data.input_freq[0]:
        option_idx = 1
    if option_idx > -1:
        if np.max(data.e_firing_rates[option_idx, 6500:7500]) > np.max(
                data.e_firing_rates[1 - option_idx, 6500:7500]):
            trial.correct = 1.0
    trial.rt = data.rt

    max_input_idx = np.where(
        trial.input_freq == np.max(trial.input_freq))[0][0]
    trial.max_input = trial.input_freq[max_input_idx]
    trial.max_rate = np.max(data.e_firing_rates[max_input_idx])

    trial.e_raster_url = None
    trial.i_raster_url = None
    if data.e_spike_neurons is not None and data.i_spike_neurons is not None:
        furl = 'img/e_raster.contrast.%0.4f.trial.%d.png' % (contrast,
                                                             trial_idx)
        fname = os.path.join(reports_dir, furl)
        trial.e_raster_url = furl
        if regenerate_plots or not os.path.exists(fname):
            e_group_sizes = [
                int(4 * data.network_group_size / 5)
                for i in range(data.num_groups)
            ]
            fig = plot_raster(data.e_spike_neurons, data.e_spike_times,
                              e_group_sizes)
            save_to_png(fig, fname)
            save_to_eps(
                fig,
                os.path.join(
                    reports_dir, 'img/e_raster.contrast.%0.4f.trial.%d.eps' %
                    (contrast, trial_idx)))
            plt.close()

        furl = 'img/i_raster.contrast.%0.4f.trial.%d.png' % (contrast,
                                                             trial_idx)
        fname = os.path.join(reports_dir, furl)
        trial.i_raster_url = furl
        if regenerate_plots or not os.path.exists(fname):
            i_group_sizes = [
                int(data.network_group_size / 5)
                for i in range(data.num_groups)
            ]
            fig = plot_raster(data.i_spike_neurons, data.i_spike_times,
                              i_group_sizes)
            save_to_png(fig, fname)
            save_to_eps(
                fig,
                os.path.join(
                    reports_dir, 'img/i_raster.contrast.%0.4f.trial.%d.eps' %
                    (contrast, trial_idx)))
            plt.close()

    trial.firing_rate_url = None
    if data.e_firing_rates is not None and data.i_firing_rates is not None:
        furl = 'img/firing_rate.contrast.%0.4f.trial.%d.png' % (contrast,
                                                                trial_idx)
        fname = os.path.join(reports_dir, furl)
        trial.firing_rate_url = furl
        if regenerate_plots or not os.path.exists(fname):
            fig = plt.figure()
            ax = plt.subplot(211)
            for i, pop_rate in enumerate(data.e_firing_rates):
                ax.plot(np.array(range(len(pop_rate))) * .1,
                        pop_rate / Hz,
                        label='group %d' % i)
            plt.xlabel('Time (ms)')
            plt.ylabel('Firing Rate (Hz)')
            ax = plt.subplot(212)
            for i, pop_rate in enumerate(data.i_firing_rates):
                ax.plot(np.array(range(len(pop_rate))) * .1,
                        pop_rate / Hz,
                        label='group %d' % i)
            plt.xlabel('Time (ms)')
            plt.ylabel('Firing Rate (Hz)')
            save_to_png(fig, fname)
            save_to_eps(
                fig,
                os.path.join(
                    reports_dir,
                    'img/firing_rate.contrast.%0.4f.trial.%d.eps' %
                    (contrast, trial_idx)))
            plt.close()

    trial.neural_state_url = None
    if data.neural_state_rec is not None:
        furl = 'img/neural_state.contrast.%0.4f.trial.%d.png' % (contrast,
                                                                 trial_idx)
        fname = os.path.join(reports_dir, furl)
        trial.neural_state_url = furl
        if regenerate_plots or not os.path.exists(fname):
            fig = plt.figure()
            for i in range(data.num_groups):
                times = np.array(
                    range(len(data.neural_state_rec['g_ampa_r'][i * 2]))) * .1
                ax = plt.subplot(data.num_groups * 100 + 20 + (i * 2 + 1))
                ax.plot(times,
                        data.neural_state_rec['g_ampa_r'][i * 2] / nA,
                        label='AMPA-recurrent')
                ax.plot(times,
                        data.neural_state_rec['g_ampa_x'][i * 2] / nA,
                        label='AMPA-task')
                ax.plot(times,
                        data.neural_state_rec['g_ampa_b'][i * 2] / nA,
                        label='AMPA-backgrnd')
                ax.plot(times,
                        data.neural_state_rec['g_nmda'][i * 2] / nA,
                        label='NMDA')
                ax.plot(times,
                        data.neural_state_rec['g_gaba_a'][i * 2] / nA,
                        label='GABA_A')
                plt.xlabel('Time (ms)')
                plt.ylabel('Conductance (nA)')
                ax = plt.subplot(data.num_groups * 100 + 20 + (i * 2 + 2))
                ax.plot(times,
                        data.neural_state_rec['g_ampa_r'][i * 2 + 1] / nA,
                        label='AMPA-recurrent')
                ax.plot(times,
                        data.neural_state_rec['g_ampa_x'][i * 2 + 1] / nA,
                        label='AMPA-task')
                ax.plot(times,
                        data.neural_state_rec['g_ampa_b'][i * 2 + 1] / nA,
                        label='AMPA-backgrnd')
                ax.plot(times,
                        data.neural_state_rec['g_nmda'][i * 2 + 1] / nA,
                        label='NMDA')
                ax.plot(times,
                        data.neural_state_rec['g_gaba_a'][i * 2 + 1] / nA,
                        label='GABA_A')
                plt.xlabel('Time (ms)')
                plt.ylabel('Conductance (nA)')
            save_to_png(fig, fname)
            save_to_eps(
                fig,
                os.path.join(
                    reports_dir,
                    'img/neural_state.contrast.%0.4f.trial.%d.eps' %
                    (contrast, trial_idx)))
            plt.close()

    trial.lfp_url = None
    if data.lfp_rec is not None:
        furl = 'img/lfp.contrast.%0.4f.trial.%d.png' % (contrast, trial_idx)
        fname = os.path.join(reports_dir, furl)
        trial.lfp_url = furl
        if regenerate_plots or not os.path.exists(fname):
            fig = plt.figure()
            ax = plt.subplot(111)
            lfp = get_lfp_signal(data)
            ax.plot(np.array(range(len(lfp))), lfp / mA)
            plt.xlabel('Time (ms)')
            plt.ylabel('LFP (mA)')
            save_to_png(fig, fname)
            save_to_eps(
                fig,
                os.path.join(
                    reports_dir, 'img/lfp.contrast.%0.4f.trial.%d.eps' %
                    (contrast, trial_idx)))
            plt.close()

    trial.voxel_url = None
    trial.max_bold = 0
    if data.voxel_rec is not None:
        trial.max_bold = -1000
        for val in data.voxel_rec['y'][0]:
            if not math.isnan(val) and val > trial.max_bold:
                trial.max_bold = val
        furl = 'img/voxel.contrast.%0.4f.trial.%d.png' % (contrast, trial_idx)
        fname = os.path.join(reports_dir, furl)
        trial.voxel_url = furl
        if regenerate_plots or not os.path.exists(fname):
            end_idx = int(data.trial_duration / ms / .1)
            fig = plt.figure()
            ax = plt.subplot(211)
            ax.plot(
                np.array(range(end_idx)) * .1,
                data.voxel_rec['G_total'][0][:end_idx] / nA)
            plt.xlabel('Time (ms)')
            plt.ylabel('Total Synaptic Activity (nA)')
            ax = plt.subplot(212)
            ax.plot(
                np.array(range(len(data.voxel_rec['y'][0]))) * .1 * ms,
                data.voxel_rec['y'][0])
            plt.xlabel('Time (s)')
            plt.ylabel('BOLD')
            save_to_png(fig, fname)
            save_to_eps(
                fig,
                os.path.join(
                    reports_dir, 'img/voxel.contrast.%0.4f.trial.%d.eps' %
                    (contrast, trial_idx)))
            plt.close()
    return trial
예제 #4
0
def create_bold_report(reports_dir,
                       trial_contrast,
                       trial_max_bold,
                       trial_max_rate,
                       trial_rt,
                       regenerate_plot=True):

    report_info = Struct()

    clf = LinearRegression()
    clf.fit(trial_contrast, trial_max_bold)
    a = clf.coef_[0]
    b = clf.intercept_
    report_info.bold_contrast_slope = a
    report_info.bold_contrast_intercept = b
    report_info.bold_contrast_r_sqr = clf.score(trial_contrast, trial_max_bold)

    furl = 'img/contrast_bold.png'
    fname = os.path.join(reports_dir, furl)
    report_info.contrast_bold_url = furl
    if regenerate_plot or not os.path.exists(fname):
        fig = plt.figure()
        plt.plot(trial_contrast, trial_max_bold, 'x')
        x_min = np.min(trial_contrast)
        x_max = np.max(trial_contrast)
        plt.plot([x_min, x_max], [a * x_min + b, a * x_max + b], '--')
        plt.xlabel('Input Contrast')
        plt.ylabel('Max BOLD')
        save_to_png(fig, fname)
        save_to_eps(fig, os.path.join(reports_dir, 'img/contrast_bold.eps'))
        plt.close()

    clf = LinearRegression()
    clf.fit(trial_max_rate, trial_max_bold)
    a = clf.coef_[0]
    b = clf.intercept_
    report_info.bold_firing_rate_slope = a
    report_info.bold_firing_rate_intercept = b
    report_info.bold_firing_rate_r_sqr = clf.score(trial_max_rate,
                                                   trial_max_bold)

    furl = 'img/firing_rate_bold.png'
    fname = os.path.join(reports_dir, furl)
    report_info.firing_rate_bold_url = furl
    if regenerate_plot or not os.path.exists(fname):
        fig = plt.figure()
        plt.plot(trial_max_rate, trial_max_bold, 'x')
        x_min = np.min(trial_max_rate)
        x_max = np.max(trial_max_rate)
        plt.plot([x_min, x_max], [a * x_min + b, a * x_max + b], '--')
        plt.xlabel('Max Firing Rate')
        plt.ylabel('Max BOLD')
        save_to_png(fig, fname)
        save_to_eps(fig, os.path.join(reports_dir, 'img/firing_rate_bold.eps'))
        plt.close()

    clf = LinearRegression()
    clf.fit(trial_rt, trial_max_bold)
    a = clf.coef_[0]
    b = clf.intercept_
    report_info.bold_rt_slope = a
    report_info.bold_rt_intercept = b
    report_info.bold_rt_r_sqr = clf.score(trial_rt, trial_max_bold)

    furl = 'img/response_time_bold.png'
    fname = os.path.join(reports_dir, furl)
    report_info.response_time_bold_url = furl
    if regenerate_plot or not os.path.exists(fname):
        fig = plt.figure()
        plt.plot(trial_rt, trial_max_bold, 'x')
        x_min = np.min(trial_rt)
        x_max = np.max(trial_rt)
        plt.plot([x_min, x_max], [a * x_min + b, a * x_max + b], '--')
        plt.xlabel('Response Time')
        plt.ylabel('Max BOLD')
        save_to_png(fig, fname)
        save_to_eps(fig, os.path.join(reports_dir,
                                      'img/response_time_bold.eps'))
        plt.close()

    return report_info
예제 #5
0
def create_wta_network_report(file_prefix,
                              contrast_range,
                              num_trials,
                              reports_dir,
                              edesc,
                              regenerate_network_plots=True,
                              regenerate_trial_plots=True):

    make_report_dirs(reports_dir)

    report_info = Struct()
    report_info.edesc = edesc
    report_info.trials = []

    (data_dir, data_file_prefix) = os.path.split(file_prefix)

    total_trials = num_trials * len(contrast_range)
    trial_contrast = np.zeros([total_trials, 1])
    trial_max_bold = np.zeros(total_trials)
    trial_max_input = np.zeros([total_trials, 1])
    trial_max_rate = np.zeros([total_trials, 1])
    trial_rt = np.zeros([total_trials, 1])

    report_info.contrast_accuracy = np.zeros([len(contrast_range), 1])

    max_bold = []
    for j, contrast in enumerate(contrast_range):
        contrast_max_bold = []
        report_info.contrast_accuracy[j] = 0.0
        for i in range(num_trials):
            file_name = '%s.contrast.%0.4f.trial.%d.h5' % (file_prefix,
                                                           contrast, i)
            print('opening %s' % file_name)
            data = FileInfo(file_name)

            if not i:
                report_info.wta_params = data.wta_params
                report_info.voxel_params = data.voxel_params
                report_info.num_groups = data.num_groups
                report_info.trial_duration = data.trial_duration
                report_info.background_rate = data.background_rate
                report_info.stim_start_time = data.stim_start_time
                report_info.stim_end_time = data.stim_end_time
                report_info.network_group_size = data.network_group_size
                report_info.background_input_size = data.background_input_size
                report_info.task_input_size = data.task_input_size

            trial_idx = j * num_trials + i
            trial = create_trial_report(
                data,
                reports_dir,
                contrast,
                i,
                regenerate_plots=regenerate_trial_plots)
            trial_contrast[trial_idx] = trial.input_contrast
            if not math.isnan(trial.max_bold):
                trial_max_bold[trial_idx] = trial.max_bold
            else:
                if j > 1 and max_bold[j - 1] < 1.0 and max_bold[j - 2] < 1.0:
                    trial_max_bold[trial_idx] = max_bold[j - 1] + (
                        max_bold[j - 1] - max_bold[j - 2])
                elif j > 0 and max_bold[j - 1] < 1.0:
                    trial_max_bold[trial_idx] = max_bold[j - 1] * 2.0
                else:
                    trial_max_bold[trial_idx] = 1.0

            report_info.contrast_accuracy[j] += trial.correct

            trial_max_input[trial_idx] = trial.max_input
            trial_max_rate[trial_idx] = trial.max_rate
            trial_rt[trial_idx] = trial.rt
            report_info.trials.append(trial)

            contrast_max_bold.append(trial_max_bold[trial_idx])
        report_info.contrast_accuracy[j] /= float(num_trials)
        mean_contrast_bold = np.mean(np.array(contrast_max_bold))
        max_bold.append(mean_contrast_bold)

    clf = LinearRegression()
    clf.fit(trial_max_input, trial_max_rate)
    a = clf.coef_[0]
    b = clf.intercept_
    report_info.io_slope = a
    report_info.io_intercept = b
    report_info.io_r_sqr = clf.score(trial_max_input, trial_max_rate)

    furl = 'img/input_output_rate.png'
    fname = os.path.join(reports_dir, furl)
    report_info.input_output_rate_url = furl
    if regenerate_network_plots or not os.path.exists(fname):
        fig = plt.figure()
        plt.plot(trial_max_input, trial_max_rate, 'x')
        x_min = np.min(trial_max_input)
        x_max = np.max(trial_max_input)
        plt.plot([x_min, x_max], [x_min, x_max], '--')
        plt.plot([x_min, x_max], [a * x_min + b, a * x_max + b], '--')
        plt.xlabel('Max Input Rate')
        plt.ylabel('Max Population Rate')
        save_to_png(fig, fname)
        save_to_eps(fig, os.path.join(reports_dir,
                                      'img/input_output_rate.eps'))
        plt.close()

    report_info.bold = create_bold_report(
        reports_dir,
        trial_contrast,
        trial_max_bold,
        trial_max_rate,
        trial_rt,
        regenerate_plot=regenerate_network_plots)

    report_info.roc = create_roc_report(
        file_prefix,
        report_info.num_groups,
        contrast_range,
        num_trials,
        reports_dir,
        regenerate_plot=regenerate_network_plots)

    #create report
    template_file = 'wta_network_instance.html'
    env = Environment(loader=FileSystemLoader(TEMPLATE_DIR))
    template = env.get_template(template_file)

    output_file = 'wta_network.%s.html' % data_file_prefix
    fname = os.path.join(reports_dir, output_file)
    stream = template.stream(rinfo=report_info)
    stream.dump(fname)

    return report_info
예제 #6
0
def create_all_reports(data_dir,
                       num_groups,
                       trial_duration,
                       p_b_e_range,
                       p_x_e_range,
                       p_e_e_range,
                       p_e_i_range,
                       p_i_i_range,
                       p_i_e_range,
                       contrast_range,
                       num_trials,
                       e_desc,
                       base_report_dir,
                       regenerate_network_plots=True,
                       regenerate_trial_plots=True,
                       smooth_missing_params=False,
                       summary_filename='wta_network_summary.h5'):

    make_report_dirs(base_report_dir)

    summary_data = SummaryData(num_groups=num_groups,
                               num_trials=num_trials,
                               trial_duration=trial_duration,
                               p_b_e_range=p_b_e_range,
                               p_x_e_range=p_x_e_range,
                               p_e_e_range=p_e_e_range,
                               p_e_i_range=p_e_i_range,
                               p_i_i_range=p_i_i_range,
                               p_i_e_range=p_i_e_range)

    bc_slope_dict = {}
    bc_intercept_dict = {}
    bc_r_sqr_dict = {}
    auc_dict = {}
    bfr_slope_dict = {}
    bfr_intercept_dict = {}
    bfr_r_sqr_dict = {}

    param_combos = get_tested_param_combos(data_dir, num_groups,
                                           trial_duration, contrast_range,
                                           num_trials, e_desc)

    report_info = Struct()
    report_info.edesc = e_desc
    report_info.roc_auc = {}
    report_info.bc_slope = {}
    report_info.bc_intercept = {}
    report_info.bc_r_sqr = {}
    report_info.bfr_slope = {}
    report_info.bfr_intercept = {}
    report_info.bfr_r_sqr = {}

    for (p_b_e, p_x_e, p_e_e, p_e_i, p_i_i, p_i_e) in param_combos:
        if p_b_e in p_b_e_range and p_x_e in p_x_e_range and p_e_e in p_e_e_range and p_e_i in p_e_i_range and p_i_i in p_i_i_range and p_i_e in p_i_e_range:
            i = p_b_e_range.index(round(p_b_e, 3))
            j = p_x_e_range.index(round(p_x_e, 3))
            k = p_e_e_range.index(round(p_e_e, 3))
            l = p_e_i_range.index(round(p_e_i, 3))
            m = p_i_i_range.index(round(p_i_i, 3))
            n = p_i_e_range.index(round(p_i_e, 3))

            file_desc='wta.groups.%d.duration.%0.3f.p_b_e.%0.3f.p_x_e.%0.3f.p_e_e.%0.3f.p_e_i.%0.3f.p_i_i.%0.3f.p_i_e.%0.3f.%s' %\
                      (num_groups, trial_duration, p_b_e, p_x_e, p_e_e, p_e_i, p_i_i, p_i_e, e_desc)
            file_prefix = os.path.join(data_dir, file_desc)
            reports_dir = os.path.join(base_report_dir, file_desc)
            if all_trials_exist(file_prefix, contrast_range, num_trials):
                print('Creating report for %s' % file_desc)
                wta_report = create_wta_network_report(
                    file_prefix,
                    contrast_range,
                    num_trials,
                    reports_dir,
                    e_desc,
                    regenerate_network_plots=regenerate_network_plots,
                    regenerate_trial_plots=regenerate_trial_plots)

                if not (i, j, k, l, m, n) in bc_slope_dict:
                    bc_slope_dict[(i, j, k, l, m, n)] = []
                bc_slope_dict[(i, j, k, l, m,
                               n)].append(wta_report.bold.bold_contrast_slope)
                if not (i, j, k, l, m, n) in bc_intercept_dict:
                    bc_intercept_dict[(i, j, k, l, m, n)] = []
                bc_intercept_dict[(i, j, k, l, m, n)].append(
                    wta_report.bold.bold_contrast_intercept)
                if not (i, j, k, l, m, n) in bc_r_sqr_dict:
                    bc_r_sqr_dict[(i, j, k, l, m, n)] = []
                bc_r_sqr_dict[(i, j, k, l, m,
                               n)].append(wta_report.bold.bold_contrast_r_sqr)

                if not (i, j, k, l, m, n) in auc_dict:
                    auc_dict[(i, j, k, l, m, n)] = []
                auc_dict[(i, j, k, l, m, n)].append(wta_report.roc.auc)

                if not (i, j, k, l, m, n) in bfr_slope_dict:
                    bfr_slope_dict[(i, j, k, l, m, n)] = []
                bfr_slope_dict[(i, j, k, l, m, n)].append(
                    wta_report.bold.bold_firing_rate_slope)
                if not (i, j, k, l, m, n) in bfr_intercept_dict:
                    bfr_intercept_dict[(i, j, k, l, m, n)] = []
                bfr_intercept_dict[(i, j, k, l, m, n)].append(
                    wta_report.bold.bold_firing_rate_intercept)
                if not (i, j, k, l, m, n) in bfr_r_sqr_dict:
                    bfr_r_sqr_dict[(i, j, k, l, m, n)] = []
                bfr_r_sqr_dict[(i, j, k, l, m, n)].append(
                    wta_report.bold.bold_firing_rate_r_sqr)

                report_info.roc_auc[(p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                     p_i_e)] = wta_report.roc.auc
                report_info.bc_slope[(
                    p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                    p_i_e)] = wta_report.bold.bold_contrast_slope
                report_info.bc_intercept[(
                    p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                    p_i_e)] = wta_report.bold.bold_contrast_intercept
                report_info.bc_r_sqr[(
                    p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                    p_i_e)] = wta_report.bold.bold_contrast_r_sqr
                report_info.bfr_slope[(
                    p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                    p_i_e)] = wta_report.bold.bold_firing_rate_slope
                report_info.bfr_intercept[(
                    p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                    p_i_e)] = wta_report.bold.bold_firing_rate_intercept
                report_info.bfr_r_sqr[(
                    p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                    p_i_e)] = wta_report.bold.bold_firing_rate_r_sqr
            else:
                report_info.roc_auc[(p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                     p_i_e)] = 0
                report_info.bc_slope[(p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                      p_i_e)] = 0
                report_info.bc_intercept[(p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                          p_i_e)] = 0
                report_info.bc_r_sqr[(p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                      p_i_e)] = 0
                report_info.bfr_slope[(p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                       p_i_e)] = 0
                report_info.bfr_intercept[(p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                           p_i_e)] = 0
                report_info.brf_r_sqr[(p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                       p_i_e)] = 0

    report_info.num_groups = num_groups
    report_info.trial_duration = trial_duration
    report_info.num_trials = num_trials
    report_info.p_b_e_range = p_b_e_range
    report_info.p_x_e_range = p_x_e_range
    report_info.p_e_e_range = p_e_e_range
    report_info.p_e_i_range = p_e_i_range
    report_info.p_i_i_range = p_i_i_range
    report_info.p_i_e_range = p_i_e_range

    summary_data.fill(auc_dict,
                      bc_slope_dict,
                      bc_intercept_dict,
                      bc_r_sqr_dict,
                      bfr_slope_dict,
                      bfr_intercept_dict,
                      bfr_r_sqr_dict,
                      smooth_missing_params=smooth_missing_params)

    summary_data.write_to_file(os.path.join(base_report_dir, summary_filename))

    bc_bayes_analysis = run_bayesian_analysis(
        summary_data.auc, summary_data.bc_slope, summary_data.bc_intercept,
        summary_data.bc_r_sqr, num_trials, p_b_e_range, p_e_e_range,
        p_e_i_range, p_i_e_range, p_i_i_range, p_x_e_range)

    bfr_bayes_analysis = run_bayesian_analysis(
        summary_data.auc, summary_data.bfr_slope, summary_data.bfr_intercept,
        summary_data.bfr_r_sqr, num_trials, p_b_e_range, p_e_e_range,
        p_e_i_range, p_i_e_range, p_i_i_range, p_x_e_range)

    bc_base_dir = os.path.join(base_report_dir, 'bold-contrast')
    make_report_dirs(bc_base_dir)
    render_summary_report(bc_base_dir, bc_bayes_analysis, p_b_e_range,
                          p_e_e_range, p_e_i_range, p_i_e_range, p_i_i_range,
                          p_x_e_range, report_info)

    bfr_base_dir = os.path.join(base_report_dir, 'bold-firing_rate')
    make_report_dirs(bfr_base_dir)
    render_summary_report(bfr_base_dir, bfr_bayes_analysis, p_b_e_range,
                          p_e_e_range, p_e_i_range, p_i_e_range, p_i_i_range,
                          p_x_e_range, report_info)
예제 #7
0
파일: summary.py 프로젝트: jbonaiuto/pySBI
def create_summary_report(summary_file_name, base_report_dir, e_desc):
    make_report_dirs(base_report_dir)

    summary_data=SummaryData()
    summary_data.read_from_file(summary_file_name)

    report_info=Struct()
    report_info.edesc=e_desc
    report_info.roc_auc={}
    report_info.io_slope={}
    report_info.io_intercept={}
    report_info.io_r_sqr={}
    report_info.bc_slope={}
    report_info.bc_intercept={}
    report_info.bc_r_sqr={}
    report_info.bfr_slope={}
    report_info.bfr_intercept={}
    report_info.bfr_r_sqr={}
    for i,p_b_e in enumerate(summary_data.p_b_e_range):
        for j,p_x_e in enumerate(summary_data.p_x_e_range):
            for k,p_e_e in enumerate(summary_data.p_e_e_range):
                for l,p_e_i in enumerate(summary_data.p_e_i_range):
                    for m,p_i_i in enumerate(summary_data.p_i_i_range):
                        for n,p_i_e in enumerate(summary_data.p_i_e_range):
                            if summary_data.auc[i,j,k,l,m,n]>0:
                                report_info.roc_auc[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=summary_data.auc[i,j,k,l,m,n]
                                report_info.bc_slope[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=summary_data.bc_slope[i,j,k,l,m,n]
                                report_info.bc_intercept[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=summary_data.bc_intercept[i,j,k,l,m,n]
                                report_info.bc_r_sqr[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=summary_data.bc_r_sqr[i,j,k,l,m,n]
                                report_info.bfr_slope[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=summary_data.bfr_slope[i,j,k,l,m,n]
                                report_info.bfr_intercept[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=summary_data.bfr_intercept[i,j,k,l,m,n]
                                report_info.bfr_r_sqr[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=summary_data.bfr_r_sqr[i,j,k,l,m,n]
                            else:
                                report_info.roc_auc[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0
                                report_info.bc_slope[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0
                                report_info.bc_intercept[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0
                                report_info.bc_r_sqr[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0
                                report_info.bfr_slope[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0
                                report_info.bfr_intercept[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0
                                report_info.bfr_r_sqr[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0

    report_info.num_groups=summary_data.num_groups
    report_info.trial_duration=summary_data.trial_duration
    report_info.num_trials=summary_data.num_trials
    report_info.p_b_e_range=summary_data.p_b_e_range
    report_info.p_x_e_range=summary_data.p_x_e_range
    report_info.p_e_e_range=summary_data.p_e_e_range[:-1]
    report_info.p_e_i_range=summary_data.p_e_i_range
    report_info.p_i_i_range=summary_data.p_i_i_range
    report_info.p_i_e_range=summary_data.p_i_e_range

    bc_bayes_analysis=run_bayesian_analysis(summary_data.auc, summary_data.bc_slope, summary_data.bc_intercept,
        summary_data.bc_r_sqr, summary_data.num_trials, summary_data.p_b_e_range, summary_data.p_e_e_range,
        summary_data.p_e_i_range, summary_data.p_i_e_range, summary_data.p_i_i_range, summary_data.p_x_e_range)

    bfr_bayes_analysis=run_bayesian_analysis(summary_data.auc, summary_data.bfr_slope, summary_data.bfr_intercept,
        summary_data.bfr_r_sqr, summary_data.num_trials, summary_data.p_b_e_range, summary_data.p_e_e_range,
        summary_data.p_e_i_range, summary_data.p_i_e_range, summary_data.p_i_i_range, summary_data.p_x_e_range)

    bc_base_dir=os.path.join(base_report_dir, 'bold-contrast')
    make_report_dirs(bc_base_dir)
    render_summary_report(bc_base_dir, bc_bayes_analysis, summary_data.p_b_e_range, summary_data.p_e_e_range,
        summary_data.p_e_i_range, summary_data.p_i_e_range, summary_data.p_i_i_range, summary_data.p_x_e_range,
        report_info)

    bfr_base_dir=os.path.join(base_report_dir, 'bold-firing_rate')
    make_report_dirs(bfr_base_dir)
    render_summary_report(bfr_base_dir, bfr_bayes_analysis, summary_data.p_b_e_range, summary_data.p_e_e_range,
        summary_data.p_e_i_range, summary_data.p_i_e_range, summary_data.p_i_i_range, summary_data.p_x_e_range,
        report_info)
예제 #8
0
파일: wta.py 프로젝트: jbonaiuto/pySBI
def create_bold_report(reports_dir, trial_contrast, trial_max_bold, trial_max_rate, trial_rt, regenerate_plot=True):

    report_info=Struct()

    clf=LinearRegression()
    clf.fit(trial_contrast,trial_max_bold)
    a=clf.coef_[0]
    b=clf.intercept_
    report_info.bold_contrast_slope=a
    report_info.bold_contrast_intercept=b
    report_info.bold_contrast_r_sqr=clf.score(trial_contrast,trial_max_bold)

    furl='img/contrast_bold.png'
    fname=os.path.join(reports_dir, furl)
    report_info.contrast_bold_url=furl
    if regenerate_plot or not os.path.exists(fname):
        fig=plt.figure()
        plt.plot(trial_contrast, trial_max_bold, 'x')
        x_min=np.min(trial_contrast)
        x_max=np.max(trial_contrast)
        plt.plot([x_min,x_max],[a*x_min+b,a*x_max+b],'--')
        plt.xlabel('Input Contrast')
        plt.ylabel('Max BOLD')
        save_to_png(fig, fname)
        save_to_eps(fig, os.path.join(reports_dir, 'img/contrast_bold.eps'))
        plt.close()

    clf=LinearRegression()
    clf.fit(trial_max_rate,trial_max_bold)
    a=clf.coef_[0]
    b=clf.intercept_
    report_info.bold_firing_rate_slope=a
    report_info.bold_firing_rate_intercept=b
    report_info.bold_firing_rate_r_sqr=clf.score(trial_max_rate,trial_max_bold)

    furl='img/firing_rate_bold.png'
    fname=os.path.join(reports_dir, furl)
    report_info.firing_rate_bold_url=furl
    if regenerate_plot or not os.path.exists(fname):
        fig=plt.figure()
        plt.plot(trial_max_rate, trial_max_bold, 'x')
        x_min=np.min(trial_max_rate)
        x_max=np.max(trial_max_rate)
        plt.plot([x_min,x_max],[a*x_min+b,a*x_max+b],'--')
        plt.xlabel('Max Firing Rate')
        plt.ylabel('Max BOLD')
        save_to_png(fig, fname)
        save_to_eps(fig, os.path.join(reports_dir, 'img/firing_rate_bold.eps'))
        plt.close()

    clf=LinearRegression()
    clf.fit(trial_rt,trial_max_bold)
    a=clf.coef_[0]
    b=clf.intercept_
    report_info.bold_rt_slope=a
    report_info.bold_rt_intercept=b
    report_info.bold_rt_r_sqr=clf.score(trial_rt,trial_max_bold)

    furl='img/response_time_bold.png'
    fname=os.path.join(reports_dir, furl)
    report_info.response_time_bold_url=furl
    if regenerate_plot or not os.path.exists(fname):
        fig=plt.figure()
        plt.plot(trial_rt, trial_max_bold, 'x')
        x_min=np.min(trial_rt)
        x_max=np.max(trial_rt)
        plt.plot([x_min,x_max],[a*x_min+b,a*x_max+b],'--')
        plt.xlabel('Response Time')
        plt.ylabel('Max BOLD')
        save_to_png(fig, fname)
        save_to_eps(fig, os.path.join(reports_dir, 'img/response_time_bold.eps'))
        plt.close()

    return report_info
예제 #9
0
파일: wta.py 프로젝트: jbonaiuto/pySBI
def create_trial_report(data, reports_dir, contrast, trial_idx, regenerate_plots=True):
    trial = Struct()
    trial.input_freq=data.input_freq
    trial.input_contrast=abs(data.input_freq[0]-data.input_freq[1])/sum(data.input_freq)
    trial.correct=0.0
    option_idx=-1
    if data.input_freq[0]>data.input_freq[1]:
        option_idx=0
    elif data.input_freq[1]>data.input_freq[0]:
        option_idx=1
    if option_idx>-1:
        if np.max(data.e_firing_rates[option_idx, 6500:7500])>np.max(data.e_firing_rates[1 - option_idx, 6500:7500]):
            trial.correct=1.0
    trial.rt=data.rt

    max_input_idx=np.where(trial.input_freq==np.max(trial.input_freq))[0][0]
    trial.max_input=trial.input_freq[max_input_idx]
    trial.max_rate=np.max(data.e_firing_rates[max_input_idx])

    trial.e_raster_url = None
    trial.i_raster_url = None
    if data.e_spike_neurons is not None and data.i_spike_neurons is not None:
        furl='img/e_raster.contrast.%0.4f.trial.%d.png' % (contrast, trial_idx)
        fname=os.path.join(reports_dir, furl)
        trial.e_raster_url = furl
        if regenerate_plots or not os.path.exists(fname):
            e_group_sizes=[int(4*data.network_group_size/5) for i in range(data.num_groups)]
            fig=plot_raster(data.e_spike_neurons, data.e_spike_times, e_group_sizes)
            save_to_png(fig, fname)
            save_to_eps(fig, os.path.join(reports_dir, 'img/e_raster.contrast.%0.4f.trial.%d.eps' % (contrast, trial_idx)))
            plt.close()

        furl='img/i_raster.contrast.%0.4f.trial.%d.png' % (contrast, trial_idx)
        fname=os.path.join(reports_dir, furl)
        trial.i_raster_url = furl
        if regenerate_plots or not os.path.exists(fname):
            i_group_sizes=[int(data.network_group_size/5) for i in range(data.num_groups)]
            fig=plot_raster(data.i_spike_neurons, data.i_spike_times, i_group_sizes)
            save_to_png(fig, fname)
            save_to_eps(fig, os.path.join(reports_dir, 'img/i_raster.contrast.%0.4f.trial.%d.eps' % (contrast, trial_idx)))
            plt.close()

    trial.firing_rate_url = None
    if data.e_firing_rates is not None and data.i_firing_rates is not None:
        furl = 'img/firing_rate.contrast.%0.4f.trial.%d.png' % (contrast, trial_idx)
        fname = os.path.join(reports_dir, furl)
        trial.firing_rate_url = furl
        if regenerate_plots or not os.path.exists(fname):
            fig = plt.figure()
            ax = plt.subplot(211)
            for i, pop_rate in enumerate(data.e_firing_rates):
                ax.plot(np.array(range(len(pop_rate))) *.1, pop_rate / Hz, label='group %d' % i)
            plt.xlabel('Time (ms)')
            plt.ylabel('Firing Rate (Hz)')
            ax = plt.subplot(212)
            for i, pop_rate in enumerate(data.i_firing_rates):
                ax.plot(np.array(range(len(pop_rate))) *.1, pop_rate / Hz, label='group %d' % i)
            plt.xlabel('Time (ms)')
            plt.ylabel('Firing Rate (Hz)')
            save_to_png(fig, fname)
            save_to_eps(fig, os.path.join(reports_dir, 'img/firing_rate.contrast.%0.4f.trial.%d.eps' % (contrast, trial_idx)))
            plt.close()

    trial.neural_state_url=None
    if data.neural_state_rec is not None:
        furl = 'img/neural_state.contrast.%0.4f.trial.%d.png' % (contrast, trial_idx)
        fname = os.path.join(reports_dir, furl)
        trial.neural_state_url = furl
        if regenerate_plots or not os.path.exists(fname):
            fig = plt.figure()
            for i in range(data.num_groups):
                times=np.array(range(len(data.neural_state_rec['g_ampa_r'][i*2])))*.1
                ax = plt.subplot(data.num_groups * 100 + 20 + (i * 2 + 1))
                ax.plot(times, data.neural_state_rec['g_ampa_r'][i * 2] / nA, label='AMPA-recurrent')
                ax.plot(times, data.neural_state_rec['g_ampa_x'][i * 2] / nA, label='AMPA-task')
                ax.plot(times, data.neural_state_rec['g_ampa_b'][i * 2] / nA, label='AMPA-backgrnd')
                ax.plot(times, data.neural_state_rec['g_nmda'][i * 2] / nA, label='NMDA')
                ax.plot(times, data.neural_state_rec['g_gaba_a'][i * 2] / nA, label='GABA_A')
                plt.xlabel('Time (ms)')
                plt.ylabel('Conductance (nA)')
                ax = plt.subplot(data.num_groups * 100 + 20 + (i * 2 + 2))
                ax.plot(times, data.neural_state_rec['g_ampa_r'][i * 2 + 1] / nA, label='AMPA-recurrent')
                ax.plot(times, data.neural_state_rec['g_ampa_x'][i * 2 + 1] / nA, label='AMPA-task')
                ax.plot(times, data.neural_state_rec['g_ampa_b'][i * 2 + 1] / nA, label='AMPA-backgrnd')
                ax.plot(times, data.neural_state_rec['g_nmda'][i * 2 + 1] / nA, label='NMDA')
                ax.plot(times, data.neural_state_rec['g_gaba_a'][i * 2 + 1] / nA, label='GABA_A')
                plt.xlabel('Time (ms)')
                plt.ylabel('Conductance (nA)')
            save_to_png(fig, fname)
            save_to_eps(fig, os.path.join(reports_dir, 'img/neural_state.contrast.%0.4f.trial.%d.eps' % (contrast, trial_idx)))
            plt.close()

    trial.lfp_url = None
    if data.lfp_rec is not None:
        furl = 'img/lfp.contrast.%0.4f.trial.%d.png' % (contrast, trial_idx)
        fname = os.path.join(reports_dir, furl)
        trial.lfp_url = furl
        if regenerate_plots or not os.path.exists(fname):
            fig = plt.figure()
            ax = plt.subplot(111)
            lfp=get_lfp_signal(data)
            ax.plot(np.array(range(len(lfp))), lfp / mA)
            plt.xlabel('Time (ms)')
            plt.ylabel('LFP (mA)')
            save_to_png(fig, fname)
            save_to_eps(fig, os.path.join(reports_dir, 'img/lfp.contrast.%0.4f.trial.%d.eps' % (contrast, trial_idx)))
            plt.close()

    trial.voxel_url = None
    trial.max_bold=0
    if data.voxel_rec is not None:
        trial.max_bold=-1000
        for val in data.voxel_rec['y'][0]:
            if not math.isnan(val) and val>trial.max_bold:
                trial.max_bold=val
        furl = 'img/voxel.contrast.%0.4f.trial.%d.png' % (contrast, trial_idx)
        fname = os.path.join(reports_dir, furl)
        trial.voxel_url = furl
        if regenerate_plots or not os.path.exists(fname):
            end_idx=int(data.trial_duration/ms/.1)
            fig = plt.figure()
            ax = plt.subplot(211)
            ax.plot(np.array(range(end_idx))*.1, data.voxel_rec['G_total'][0][:end_idx] / nA)
            plt.xlabel('Time (ms)')
            plt.ylabel('Total Synaptic Activity (nA)')
            ax = plt.subplot(212)
            ax.plot(np.array(range(len(data.voxel_rec['y'][0])))*.1*ms, data.voxel_rec['y'][0])
            plt.xlabel('Time (s)')
            plt.ylabel('BOLD')
            save_to_png(fig, fname)
            save_to_eps(fig, os.path.join(reports_dir, 'img/voxel.contrast.%0.4f.trial.%d.eps' % (contrast, trial_idx)))
            plt.close()
    return trial
예제 #10
0
파일: wta.py 프로젝트: jbonaiuto/pySBI
def create_all_reports(data_dir, num_groups, trial_duration, p_b_e_range, p_x_e_range, p_e_e_range, p_e_i_range,
                       p_i_i_range, p_i_e_range, contrast_range, num_trials, e_desc, base_report_dir, regenerate_network_plots=True,
                       regenerate_trial_plots=True, smooth_missing_params=False,
                       summary_filename='wta_network_summary.h5'):

    make_report_dirs(base_report_dir)

    summary_data=SummaryData(num_groups=num_groups, num_trials=num_trials, trial_duration=trial_duration,
        p_b_e_range=p_b_e_range, p_x_e_range=p_x_e_range, p_e_e_range=p_e_e_range, p_e_i_range=p_e_i_range,
        p_i_i_range=p_i_i_range, p_i_e_range=p_i_e_range)

    bc_slope_dict={}
    bc_intercept_dict={}
    bc_r_sqr_dict={}
    auc_dict={}
    bfr_slope_dict={}
    bfr_intercept_dict={}
    bfr_r_sqr_dict={}

    param_combos=get_tested_param_combos(data_dir, num_groups, trial_duration, contrast_range, num_trials, e_desc)

    report_info=Struct()
    report_info.edesc=e_desc
    report_info.roc_auc={}
    report_info.bc_slope={}
    report_info.bc_intercept={}
    report_info.bc_r_sqr={}
    report_info.bfr_slope={}
    report_info.bfr_intercept={}
    report_info.bfr_r_sqr={}


    for (p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e) in param_combos:
        if p_b_e in p_b_e_range and p_x_e in p_x_e_range and p_e_e in p_e_e_range and p_e_i in p_e_i_range and p_i_i in p_i_i_range and p_i_e in p_i_e_range:
            i=p_b_e_range.index(round(p_b_e,3))
            j=p_x_e_range.index(round(p_x_e,3))
            k=p_e_e_range.index(round(p_e_e,3))
            l=p_e_i_range.index(round(p_e_i,3))
            m=p_i_i_range.index(round(p_i_i,3))
            n=p_i_e_range.index(round(p_i_e,3))

            file_desc='wta.groups.%d.duration.%0.3f.p_b_e.%0.3f.p_x_e.%0.3f.p_e_e.%0.3f.p_e_i.%0.3f.p_i_i.%0.3f.p_i_e.%0.3f.%s' %\
                      (num_groups, trial_duration, p_b_e, p_x_e, p_e_e, p_e_i, p_i_i, p_i_e, e_desc)
            file_prefix=os.path.join(data_dir,file_desc)
            reports_dir=os.path.join(base_report_dir,file_desc)
            if all_trials_exist(file_prefix, contrast_range, num_trials):
                print('Creating report for %s' % file_desc)
                wta_report=create_wta_network_report(file_prefix, contrast_range, num_trials, reports_dir,
                    e_desc, regenerate_network_plots=regenerate_network_plots, regenerate_trial_plots=regenerate_trial_plots)

                if not (i,j,k,l,m,n) in bc_slope_dict:
                    bc_slope_dict[(i,j,k,l,m,n)]=[]
                bc_slope_dict[(i,j,k,l,m,n)].append(wta_report.bold.bold_contrast_slope)
                if not (i,j,k,l,m,n) in bc_intercept_dict:
                    bc_intercept_dict[(i,j,k,l,m,n)]=[]
                bc_intercept_dict[(i,j,k,l,m,n)].append(wta_report.bold.bold_contrast_intercept)
                if not (i,j,k,l,m,n) in bc_r_sqr_dict:
                    bc_r_sqr_dict[(i,j,k,l,m,n)]=[]
                bc_r_sqr_dict[(i,j,k,l,m,n)].append(wta_report.bold.bold_contrast_r_sqr)

                if not (i,j,k,l,m,n) in auc_dict:
                    auc_dict[(i,j,k,l,m,n)]=[]
                auc_dict[(i,j,k,l,m,n)].append(wta_report.roc.auc)

                if not (i,j,k,l,m,n) in bfr_slope_dict:
                    bfr_slope_dict[(i,j,k,l,m,n)]=[]
                bfr_slope_dict[(i,j,k,l,m,n)].append(wta_report.bold.bold_firing_rate_slope)
                if not (i,j,k,l,m,n) in bfr_intercept_dict:
                    bfr_intercept_dict[(i,j,k,l,m,n)]=[]
                bfr_intercept_dict[(i,j,k,l,m,n)].append(wta_report.bold.bold_firing_rate_intercept)
                if not (i,j,k,l,m,n) in bfr_r_sqr_dict:
                    bfr_r_sqr_dict[(i,j,k,l,m,n)]=[]
                bfr_r_sqr_dict[(i,j,k,l,m,n)].append(wta_report.bold.bold_firing_rate_r_sqr)

                report_info.roc_auc[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=wta_report.roc.auc
                report_info.bc_slope[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=wta_report.bold.bold_contrast_slope
                report_info.bc_intercept[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=wta_report.bold.bold_contrast_intercept
                report_info.bc_r_sqr[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=wta_report.bold.bold_contrast_r_sqr
                report_info.bfr_slope[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=wta_report.bold.bold_firing_rate_slope
                report_info.bfr_intercept[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=wta_report.bold.bold_firing_rate_intercept
                report_info.bfr_r_sqr[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=wta_report.bold.bold_firing_rate_r_sqr
            else:
                report_info.roc_auc[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0
                report_info.bc_slope[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0
                report_info.bc_intercept[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0
                report_info.bc_r_sqr[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0
                report_info.bfr_slope[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0
                report_info.bfr_intercept[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0
                report_info.brf_r_sqr[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=0

    report_info.num_groups=num_groups
    report_info.trial_duration=trial_duration
    report_info.num_trials=num_trials
    report_info.p_b_e_range=p_b_e_range
    report_info.p_x_e_range=p_x_e_range
    report_info.p_e_e_range=p_e_e_range
    report_info.p_e_i_range=p_e_i_range
    report_info.p_i_i_range=p_i_i_range
    report_info.p_i_e_range=p_i_e_range

    summary_data.fill(auc_dict, bc_slope_dict, bc_intercept_dict, bc_r_sqr_dict, bfr_slope_dict, bfr_intercept_dict,
        bfr_r_sqr_dict, smooth_missing_params=smooth_missing_params)

    summary_data.write_to_file(os.path.join(base_report_dir,summary_filename))

    bc_bayes_analysis=run_bayesian_analysis(summary_data.auc, summary_data.bc_slope, summary_data.bc_intercept,
        summary_data.bc_r_sqr, num_trials, p_b_e_range, p_e_e_range, p_e_i_range, p_i_e_range, p_i_i_range, p_x_e_range)

    bfr_bayes_analysis=run_bayesian_analysis(summary_data.auc, summary_data.bfr_slope, summary_data.bfr_intercept,
            summary_data.bfr_r_sqr, num_trials, p_b_e_range, p_e_e_range, p_e_i_range, p_i_e_range, p_i_i_range, p_x_e_range)

    bc_base_dir=os.path.join(base_report_dir, 'bold-contrast')
    make_report_dirs(bc_base_dir)
    render_summary_report(bc_base_dir, bc_bayes_analysis, p_b_e_range, p_e_e_range, p_e_i_range, p_i_e_range,
        p_i_i_range, p_x_e_range, report_info)

    bfr_base_dir=os.path.join(base_report_dir, 'bold-firing_rate')
    make_report_dirs(bfr_base_dir)
    render_summary_report(bfr_base_dir, bfr_bayes_analysis, p_b_e_range, p_e_e_range, p_e_i_range, p_i_e_range,
        p_i_i_range, p_x_e_range, report_info)
예제 #11
0
파일: wta.py 프로젝트: jbonaiuto/pySBI
def create_wta_network_report(file_prefix, contrast_range, num_trials, reports_dir, edesc, regenerate_network_plots=True,
                              regenerate_trial_plots=True):

    make_report_dirs(reports_dir)

    report_info=Struct()
    report_info.edesc=edesc
    report_info.trials=[]

    (data_dir, data_file_prefix) = os.path.split(file_prefix)

    total_trials=num_trials*len(contrast_range)
    trial_contrast=np.zeros([total_trials,1])
    trial_max_bold=np.zeros(total_trials)
    trial_max_input=np.zeros([total_trials,1])
    trial_max_rate=np.zeros([total_trials,1])
    trial_rt=np.zeros([total_trials,1])

    report_info.contrast_accuracy=np.zeros([len(contrast_range),1])

    max_bold=[]
    for j,contrast in enumerate(contrast_range):
        contrast_max_bold=[]
        report_info.contrast_accuracy[j]=0.0
        for i in range(num_trials):
            file_name='%s.contrast.%0.4f.trial.%d.h5' % (file_prefix, contrast, i)
            print('opening %s' % file_name)
            data=FileInfo(file_name)

            if not i:
                report_info.wta_params=data.wta_params
                report_info.voxel_params=data.voxel_params
                report_info.num_groups=data.num_groups
                report_info.trial_duration=data.trial_duration
                report_info.background_rate=data.background_rate
                report_info.stim_start_time=data.stim_start_time
                report_info.stim_end_time=data.stim_end_time
                report_info.network_group_size=data.network_group_size
                report_info.background_input_size=data.background_input_size
                report_info.task_input_size=data.task_input_size

            trial_idx=j*num_trials+i
            trial = create_trial_report(data, reports_dir, contrast, i, regenerate_plots=regenerate_trial_plots)
            trial_contrast[trial_idx]=trial.input_contrast
            if not math.isnan(trial.max_bold):
                trial_max_bold[trial_idx]=trial.max_bold
            else:
                if j>1 and max_bold[j-1]<1.0 and max_bold[j-2]<1.0:
                    trial_max_bold[trial_idx]=max_bold[j-1]+(max_bold[j-1]-max_bold[j-2])
                elif j>0 and max_bold[j-1]<1.0:
                    trial_max_bold[trial_idx]=max_bold[j-1]*2.0
                else:
                    trial_max_bold[trial_idx]=1.0

            report_info.contrast_accuracy[j]+=trial.correct

            trial_max_input[trial_idx]=trial.max_input
            trial_max_rate[trial_idx]=trial.max_rate
            trial_rt[trial_idx]=trial.rt
            report_info.trials.append(trial)

            contrast_max_bold.append(trial_max_bold[trial_idx])
        report_info.contrast_accuracy[j]/=float(num_trials)
        mean_contrast_bold=np.mean(np.array(contrast_max_bold))
        max_bold.append(mean_contrast_bold)

    clf=LinearRegression()
    clf.fit(trial_max_input,trial_max_rate)
    a=clf.coef_[0]
    b=clf.intercept_
    report_info.io_slope=a
    report_info.io_intercept=b
    report_info.io_r_sqr=clf.score(trial_max_input,trial_max_rate)

    furl='img/input_output_rate.png'
    fname=os.path.join(reports_dir, furl)
    report_info.input_output_rate_url=furl
    if regenerate_network_plots or not os.path.exists(fname):
        fig=plt.figure()
        plt.plot(trial_max_input, trial_max_rate, 'x')
        x_min=np.min(trial_max_input)
        x_max=np.max(trial_max_input)
        plt.plot([x_min,x_max],[x_min,x_max],'--')
        plt.plot([x_min,x_max],[a*x_min+b,a*x_max+b],'--')
        plt.xlabel('Max Input Rate')
        plt.ylabel('Max Population Rate')
        save_to_png(fig, fname)
        save_to_eps(fig, os.path.join(reports_dir, 'img/input_output_rate.eps'))
        plt.close()

    report_info.bold=create_bold_report(reports_dir, trial_contrast, trial_max_bold, trial_max_rate, trial_rt,
        regenerate_plot=regenerate_network_plots)

    report_info.roc=create_roc_report(file_prefix, report_info.num_groups, contrast_range, num_trials, reports_dir,
        regenerate_plot=regenerate_network_plots)

    #create report
    template_file='wta_network_instance.html'
    env = Environment(loader=FileSystemLoader(TEMPLATE_DIR))
    template=env.get_template(template_file)

    output_file='wta_network.%s.html' % data_file_prefix
    fname=os.path.join(reports_dir,output_file)
    stream=template.stream(rinfo=report_info)
    stream.dump(fname)

    return report_info
예제 #12
0
def create_summary_report(summary_file_name, base_report_dir, e_desc):
    make_report_dirs(base_report_dir)

    summary_data = SummaryData()
    summary_data.read_from_file(summary_file_name)

    report_info = Struct()
    report_info.edesc = e_desc
    report_info.roc_auc = {}
    report_info.io_slope = {}
    report_info.io_intercept = {}
    report_info.io_r_sqr = {}
    report_info.bc_slope = {}
    report_info.bc_intercept = {}
    report_info.bc_r_sqr = {}
    report_info.bfr_slope = {}
    report_info.bfr_intercept = {}
    report_info.bfr_r_sqr = {}
    for i, p_b_e in enumerate(summary_data.p_b_e_range):
        for j, p_x_e in enumerate(summary_data.p_x_e_range):
            for k, p_e_e in enumerate(summary_data.p_e_e_range):
                for l, p_e_i in enumerate(summary_data.p_e_i_range):
                    for m, p_i_i in enumerate(summary_data.p_i_i_range):
                        for n, p_i_e in enumerate(summary_data.p_i_e_range):
                            if summary_data.auc[i, j, k, l, m, n] > 0:
                                report_info.roc_auc[(
                                    p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                    p_i_e)] = summary_data.auc[i, j, k, l, m,
                                                               n]
                                report_info.bc_slope[(
                                    p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                    p_i_e)] = summary_data.bc_slope[i, j, k, l,
                                                                    m, n]
                                report_info.bc_intercept[(
                                    p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                    p_i_e)] = summary_data.bc_intercept[i, j,
                                                                        k, l,
                                                                        m, n]
                                report_info.bc_r_sqr[(
                                    p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                    p_i_e)] = summary_data.bc_r_sqr[i, j, k, l,
                                                                    m, n]
                                report_info.bfr_slope[(
                                    p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                    p_i_e)] = summary_data.bfr_slope[i, j, k,
                                                                     l, m, n]
                                report_info.bfr_intercept[(
                                    p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                    p_i_e)] = summary_data.bfr_intercept[i, j,
                                                                         k, l,
                                                                         m, n]
                                report_info.bfr_r_sqr[(
                                    p_b_e, p_x_e, p_e_e, p_e_i, p_i_i,
                                    p_i_e)] = summary_data.bfr_r_sqr[i, j, k,
                                                                     l, m, n]
                            else:
                                report_info.roc_auc[(p_b_e, p_x_e, p_e_e,
                                                     p_e_i, p_i_i, p_i_e)] = 0
                                report_info.bc_slope[(p_b_e, p_x_e, p_e_e,
                                                      p_e_i, p_i_i, p_i_e)] = 0
                                report_info.bc_intercept[(p_b_e, p_x_e, p_e_e,
                                                          p_e_i, p_i_i,
                                                          p_i_e)] = 0
                                report_info.bc_r_sqr[(p_b_e, p_x_e, p_e_e,
                                                      p_e_i, p_i_i, p_i_e)] = 0
                                report_info.bfr_slope[(p_b_e, p_x_e, p_e_e,
                                                       p_e_i, p_i_i,
                                                       p_i_e)] = 0
                                report_info.bfr_intercept[(p_b_e, p_x_e, p_e_e,
                                                           p_e_i, p_i_i,
                                                           p_i_e)] = 0
                                report_info.bfr_r_sqr[(p_b_e, p_x_e, p_e_e,
                                                       p_e_i, p_i_i,
                                                       p_i_e)] = 0

    report_info.num_groups = summary_data.num_groups
    report_info.trial_duration = summary_data.trial_duration
    report_info.num_trials = summary_data.num_trials
    report_info.p_b_e_range = summary_data.p_b_e_range
    report_info.p_x_e_range = summary_data.p_x_e_range
    report_info.p_e_e_range = summary_data.p_e_e_range[:-1]
    report_info.p_e_i_range = summary_data.p_e_i_range
    report_info.p_i_i_range = summary_data.p_i_i_range
    report_info.p_i_e_range = summary_data.p_i_e_range

    bc_bayes_analysis = run_bayesian_analysis(
        summary_data.auc, summary_data.bc_slope, summary_data.bc_intercept,
        summary_data.bc_r_sqr, summary_data.num_trials,
        summary_data.p_b_e_range, summary_data.p_e_e_range,
        summary_data.p_e_i_range, summary_data.p_i_e_range,
        summary_data.p_i_i_range, summary_data.p_x_e_range)

    bfr_bayes_analysis = run_bayesian_analysis(
        summary_data.auc, summary_data.bfr_slope, summary_data.bfr_intercept,
        summary_data.bfr_r_sqr, summary_data.num_trials,
        summary_data.p_b_e_range, summary_data.p_e_e_range,
        summary_data.p_e_i_range, summary_data.p_i_e_range,
        summary_data.p_i_i_range, summary_data.p_x_e_range)

    bc_base_dir = os.path.join(base_report_dir, 'bold-contrast')
    make_report_dirs(bc_base_dir)
    render_summary_report(bc_base_dir, bc_bayes_analysis,
                          summary_data.p_b_e_range, summary_data.p_e_e_range,
                          summary_data.p_e_i_range, summary_data.p_i_e_range,
                          summary_data.p_i_i_range, summary_data.p_x_e_range,
                          report_info)

    bfr_base_dir = os.path.join(base_report_dir, 'bold-firing_rate')
    make_report_dirs(bfr_base_dir)
    render_summary_report(bfr_base_dir, bfr_bayes_analysis,
                          summary_data.p_b_e_range, summary_data.p_e_e_range,
                          summary_data.p_e_i_range, summary_data.p_i_e_range,
                          summary_data.p_i_i_range, summary_data.p_x_e_range,
                          report_info)
예제 #13
0
파일: bayesian.py 프로젝트: jbonaiuto/pySBI
def create_bayesian_report(title, num_groups, trial_duration, roc_auc, bc_slope, bc_intercept, bc_r_sqr, evidence,
                           posterior, marginals, p_b_e_range, p_x_e_range, p_e_e_range, p_e_i_range, p_i_i_range,
                           p_i_e_range, file_prefix, reports_dir, edesc, marginal_ylim):
    report_info=Struct()
    report_info.title=title
    report_info.edesc=edesc
    report_info.evidence=evidence
    report_info.roc_auc=roc_auc
    report_info.bc_slope=bc_slope
    report_info.bc_intercept=bc_intercept
    report_info.bc_r_sqr=bc_r_sqr
    report_info.num_groups=num_groups
    report_info.trial_duration=trial_duration
    report_info.p_b_e_range=p_b_e_range
    report_info.p_x_e_range=p_x_e_range
    report_info.p_e_e_range=p_e_e_range
    report_info.p_e_i_range=p_e_i_range
    report_info.p_i_i_range=p_i_i_range
    report_info.p_i_e_range=p_i_e_range

    report_info.posterior={}
    for i,p_b_e in enumerate(p_b_e_range):
        for j,p_x_e in enumerate(p_x_e_range):
            for k,p_e_e in enumerate(p_e_e_range):
                for l,p_e_i in enumerate(p_e_i_range):
                    for m,p_i_i in enumerate(p_i_i_range):
                        for n,p_i_e in enumerate(p_i_e_range):
                            if posterior[i,j,k,l,m,n]>0:
                                report_info.posterior[(p_b_e,p_x_e,p_e_e,p_e_i,p_i_i,p_i_e)]=posterior[i,j,k,l,m,n]
    report_info.marginal_prior_p_b_e_url,\
    report_info.marginal_likelihood_p_b_e_url,\
    report_info.marginal_posterior_p_b_e_url=render_marginal_report('p_b_e', p_b_e_range,
        marginals.prior_p_b_e, marginals.likelihood_p_b_e, marginals.posterior_p_b_e, file_prefix, reports_dir, marginal_ylim)

    report_info.marginal_prior_p_x_e_url,\
    report_info.marginal_likelihood_p_x_e_url,\
    report_info.marginal_posterior_p_x_e_url=render_marginal_report('p_x_e', p_x_e_range,
        marginals.prior_p_x_e, marginals.likelihood_p_x_e, marginals.posterior_p_x_e, file_prefix, reports_dir, marginal_ylim)

    report_info.marginal_prior_p_e_e_url,\
    report_info.marginal_likelihood_p_e_e_url,\
    report_info.marginal_posterior_p_e_e_url=render_marginal_report('p_e_e', p_e_e_range,
        marginals.prior_p_e_e, marginals.likelihood_p_e_e, marginals.posterior_p_e_e, file_prefix, reports_dir, marginal_ylim)

    report_info.marginal_prior_p_e_i_url,\
    report_info.marginal_likelihood_p_e_i_url,\
    report_info.marginal_posterior_p_e_i_url=render_marginal_report('p_e_i', p_e_i_range,
        marginals.prior_p_e_i, marginals.likelihood_p_e_i, marginals.posterior_p_e_i, file_prefix, reports_dir, marginal_ylim)

    report_info.marginal_prior_p_i_i_url,\
    report_info.marginal_likelihood_p_i_i_url,\
    report_info.marginal_posterior_p_i_i_url=render_marginal_report('p_i_i', p_i_i_range,
        marginals.prior_p_i_i, marginals.likelihood_p_i_i, marginals.posterior_p_i_i, file_prefix, reports_dir, marginal_ylim)

    report_info.marginal_prior_p_i_e_url,\
    report_info.marginal_likelihood_p_i_e_url,\
    report_info.marginal_posterior_p_i_e_url=render_marginal_report('p_i_e', p_i_e_range,
        marginals.prior_p_i_e, marginals.likelihood_p_i_e, marginals.posterior_p_i_e, file_prefix, reports_dir, marginal_ylim)


    report_info.joint_marginal_p_b_e_p_x_e_url = render_joint_marginal_report('p_b_e', 'p_x_e', p_b_e_range, p_x_e_range,
        marginals.posterior_p_b_e_p_x_e, file_prefix, reports_dir)

    report_info.joint_marginal_p_e_e_p_e_i_url = render_joint_marginal_report('p_e_e', 'p_e_i', p_e_e_range, p_e_i_range,
        marginals.posterior_p_e_e_p_e_i, file_prefix, reports_dir)

    report_info.joint_marginal_p_e_e_p_i_i_url = render_joint_marginal_report('p_e_e', 'p_i_i', p_e_e_range, p_i_i_range,
        marginals.posterior_p_e_e_p_i_i, file_prefix, reports_dir)

    report_info.joint_marginal_p_e_e_p_i_e_url = render_joint_marginal_report('p_e_e', 'p_i_e', p_e_e_range, p_i_e_range,
        marginals.posterior_p_e_e_p_i_e, file_prefix, reports_dir)

    report_info.joint_marginal_p_e_i_p_i_i_url = render_joint_marginal_report('p_e_i', 'p_i_i', p_e_i_range, p_i_i_range,
        marginals.posterior_p_e_i_p_i_i, file_prefix, reports_dir)

    report_info.joint_marginal_p_e_i_p_i_e_url = render_joint_marginal_report('p_e_i', 'p_i_e', p_e_i_range, p_i_e_range,
        marginals.posterior_p_e_i_p_i_e, file_prefix, reports_dir)

    report_info.joint_marginal_p_i_i_p_i_e_url = render_joint_marginal_report('p_i_i', 'p_i_e', p_i_i_range, p_i_e_range,
        marginals.posterior_p_i_i_p_i_e, file_prefix, reports_dir)

    return report_info
예제 #14
0
def create_bayesian_report(title, num_groups, trial_duration, roc_auc,
                           bc_slope, bc_intercept, bc_r_sqr, evidence,
                           posterior, marginals, p_b_e_range, p_x_e_range,
                           p_e_e_range, p_e_i_range, p_i_i_range, p_i_e_range,
                           file_prefix, reports_dir, edesc, marginal_ylim):
    report_info = Struct()
    report_info.title = title
    report_info.edesc = edesc
    report_info.evidence = evidence
    report_info.roc_auc = roc_auc
    report_info.bc_slope = bc_slope
    report_info.bc_intercept = bc_intercept
    report_info.bc_r_sqr = bc_r_sqr
    report_info.num_groups = num_groups
    report_info.trial_duration = trial_duration
    report_info.p_b_e_range = p_b_e_range
    report_info.p_x_e_range = p_x_e_range
    report_info.p_e_e_range = p_e_e_range
    report_info.p_e_i_range = p_e_i_range
    report_info.p_i_i_range = p_i_i_range
    report_info.p_i_e_range = p_i_e_range

    report_info.posterior = {}
    for i, p_b_e in enumerate(p_b_e_range):
        for j, p_x_e in enumerate(p_x_e_range):
            for k, p_e_e in enumerate(p_e_e_range):
                for l, p_e_i in enumerate(p_e_i_range):
                    for m, p_i_i in enumerate(p_i_i_range):
                        for n, p_i_e in enumerate(p_i_e_range):
                            if posterior[i, j, k, l, m, n] > 0:
                                report_info.posterior[(p_b_e, p_x_e, p_e_e,
                                                       p_e_i, p_i_i,
                                                       p_i_e)] = posterior[i,
                                                                           j,
                                                                           k,
                                                                           l,
                                                                           m,
                                                                           n]
    report_info.marginal_prior_p_b_e_url,\
    report_info.marginal_likelihood_p_b_e_url,\
    report_info.marginal_posterior_p_b_e_url=render_marginal_report('p_b_e', p_b_e_range,
        marginals.prior_p_b_e, marginals.likelihood_p_b_e, marginals.posterior_p_b_e, file_prefix, reports_dir, marginal_ylim)

    report_info.marginal_prior_p_x_e_url,\
    report_info.marginal_likelihood_p_x_e_url,\
    report_info.marginal_posterior_p_x_e_url=render_marginal_report('p_x_e', p_x_e_range,
        marginals.prior_p_x_e, marginals.likelihood_p_x_e, marginals.posterior_p_x_e, file_prefix, reports_dir, marginal_ylim)

    report_info.marginal_prior_p_e_e_url,\
    report_info.marginal_likelihood_p_e_e_url,\
    report_info.marginal_posterior_p_e_e_url=render_marginal_report('p_e_e', p_e_e_range,
        marginals.prior_p_e_e, marginals.likelihood_p_e_e, marginals.posterior_p_e_e, file_prefix, reports_dir, marginal_ylim)

    report_info.marginal_prior_p_e_i_url,\
    report_info.marginal_likelihood_p_e_i_url,\
    report_info.marginal_posterior_p_e_i_url=render_marginal_report('p_e_i', p_e_i_range,
        marginals.prior_p_e_i, marginals.likelihood_p_e_i, marginals.posterior_p_e_i, file_prefix, reports_dir, marginal_ylim)

    report_info.marginal_prior_p_i_i_url,\
    report_info.marginal_likelihood_p_i_i_url,\
    report_info.marginal_posterior_p_i_i_url=render_marginal_report('p_i_i', p_i_i_range,
        marginals.prior_p_i_i, marginals.likelihood_p_i_i, marginals.posterior_p_i_i, file_prefix, reports_dir, marginal_ylim)

    report_info.marginal_prior_p_i_e_url,\
    report_info.marginal_likelihood_p_i_e_url,\
    report_info.marginal_posterior_p_i_e_url=render_marginal_report('p_i_e', p_i_e_range,
        marginals.prior_p_i_e, marginals.likelihood_p_i_e, marginals.posterior_p_i_e, file_prefix, reports_dir, marginal_ylim)

    report_info.joint_marginal_p_b_e_p_x_e_url = render_joint_marginal_report(
        'p_b_e', 'p_x_e', p_b_e_range, p_x_e_range,
        marginals.posterior_p_b_e_p_x_e, file_prefix, reports_dir)

    report_info.joint_marginal_p_e_e_p_e_i_url = render_joint_marginal_report(
        'p_e_e', 'p_e_i', p_e_e_range, p_e_i_range,
        marginals.posterior_p_e_e_p_e_i, file_prefix, reports_dir)

    report_info.joint_marginal_p_e_e_p_i_i_url = render_joint_marginal_report(
        'p_e_e', 'p_i_i', p_e_e_range, p_i_i_range,
        marginals.posterior_p_e_e_p_i_i, file_prefix, reports_dir)

    report_info.joint_marginal_p_e_e_p_i_e_url = render_joint_marginal_report(
        'p_e_e', 'p_i_e', p_e_e_range, p_i_e_range,
        marginals.posterior_p_e_e_p_i_e, file_prefix, reports_dir)

    report_info.joint_marginal_p_e_i_p_i_i_url = render_joint_marginal_report(
        'p_e_i', 'p_i_i', p_e_i_range, p_i_i_range,
        marginals.posterior_p_e_i_p_i_i, file_prefix, reports_dir)

    report_info.joint_marginal_p_e_i_p_i_e_url = render_joint_marginal_report(
        'p_e_i', 'p_i_e', p_e_i_range, p_i_e_range,
        marginals.posterior_p_e_i_p_i_e, file_prefix, reports_dir)

    report_info.joint_marginal_p_i_i_p_i_e_url = render_joint_marginal_report(
        'p_i_i', 'p_i_e', p_i_i_range, p_i_e_range,
        marginals.posterior_p_i_i_p_i_e, file_prefix, reports_dir)

    return report_info