예제 #1
0
def test_update_nr_samples(history: History):
    history.store_initial_data(None, {}, {}, {}, ["m0"], "", "", "")
    pops = history.get_all_populations()
    assert 0 == pops[pops['t'] == History.PRE_TIME]['samples'].values
    history.update_nr_samples(History.PRE_TIME, 43)
    pops = history.get_all_populations()
    assert 43 == pops[pops['t'] == History.PRE_TIME]['samples'].values
예제 #2
0
def test_update_after_calibration(history: History):
    history.store_initial_data(None, {}, {}, {}, ["m0"], "", "", "")
    pops = history.get_all_populations()
    assert 0 == pops[pops['t'] == History.PRE_TIME]['samples'].values
    time = datetime.datetime.now()
    history.update_after_calibration(43, end_time=time)
    pops = history.get_all_populations()
    assert 43 == pops[pops['t'] == History.PRE_TIME]['samples'].values
    assert pops.population_end_time[0] == time
예제 #3
0
def test_population_retrieval(history: History):
    history.append_population(
        1, .23, Population(rand_pop_list(0)), 234, ["m1"])
    history.append_population(
        2, .123, Population(rand_pop_list(0)), 345, ["m1"])
    history.append_population(
        2, .1235, Population(rand_pop_list(5)), 20345, ["m1"] * 6)
    history.append_population(
        3, .12330, Population(rand_pop_list(30)), 30345, ["m1"] * 31)
    df = history.get_all_populations()

    assert df[df.t == 1].epsilon.iloc[0] == .23
    assert df[df.t == 2].epsilon.iloc[0] == .123
    assert df[df.t == 2].epsilon.iloc[1] == .1235
    assert df[df.t == 3].epsilon.iloc[0] == .12330

    assert df[df.t == 1].samples.iloc[0] == 234
    assert df[df.t == 2].samples.iloc[0] == 345
    assert df[df.t == 2].samples.iloc[1] == 20345
    assert df[df.t == 3].samples.iloc[0] == 30345

    assert history.alive_models(1) == [0]
    assert history.alive_models(2) == [0, 5]
    assert history.alive_models(3) == [30]
    print("ID", history.id)
예제 #4
0
파일: plots.py 프로젝트: Sandalmoth/ratrack
def abc_info(paramfile, obsfile, dbfile, run_id, save):
    """
    Plots for examining ABC fitting process
    """

    db_path = 'sqlite:///' + dbfile
    abc_history = History(db_path)
    abc_history.id = run_id

    observed = simtools.parse_observations(obsfile)
    simtools.parse_params(paramfile, observed)

    ### PLOTS SHOWING MODEL PROBABILITIES ###
    num_models = abc_history.nr_of_models_alive(0)
    max_points_in_models = max([abc_history.get_distribution(m=x, t=0)[0].shape[1] for x in range(num_models)])

    axs = abc_history.get_model_probabilities().plot.bar()
    axs.set_ylabel("Probability")
    axs.set_xlabel("Generation")
    resolutions = list(range(simtools.PARAMS['abc_params']['resolution_limits'][0],
                             simtools.PARAMS['abc_params']['resolution_limits'][1] + 1))
    axs.legend(resolutions,
               title="Reconstruction resolution")

    if save is not None:
        # first time, construct the multipage pdf
        pdf_out = PdfPages(save)
        pdf_out.savefig()
    else:
        plt.show()

    ### ABC SIMULATION DIAGNOSTICS ###
    fig, ax = plt.subplots(nrows=3, sharex=True)

    t_axis = list(range(abc_history.max_t + 1))

    populations = abc_history.get_all_populations()
    populations = populations[populations.t >= 0]

    ax[0].plot(t_axis, populations['particles'])
    ax[1].plot(t_axis, populations['epsilon'])
    ax[2].plot(t_axis, populations['samples'])

    ax[0].set_title('ABC parameters per generation')
    ax[0].set_ylabel('Particles')
    ax[1].set_ylabel('Epsilon')
    ax[2].set_ylabel('Samples')
    ax[-1].set_xlabel('Generation (t)')
    ax[0].xaxis.set_major_locator(MaxNLocator(integer=True))

    fig.set_size_inches(8, 5)

    if save is not None:
        pdf_out.savefig()
    else:
        plt.show()


    ### PARAMETERS OVER TIME ###
    fig, axs = plt.subplots(nrows=max_points_in_models, sharex=True, sharey=True)

    t_axis = np.arange(abc_history.max_t + 1)
    # print(t_axis)
    # parameters = ['birthrate.s0.d', 'birthrate.s0.r0']
    all_parameters = [list(abc_history.get_distribution(m=m, t=0)[0].columns)
                  for m in range(num_models)]
    # abc_data, __ = abc_history.get_distribution(m=m, t=generation)
    parameters = []
    for x in all_parameters:
        for y in x:
            parameters.append(y)
    parameters = list(set(parameters))
    parameters = sorted(parameters, key=lambda x: x[-1])
    # print(parameters)

    for m in range(num_models):

        qs1 = {param: [np.nan for __ in t_axis] for param in parameters}
        medians = {param: [np.nan for __ in t_axis] for param in parameters}
        qs3 = {param: [np.nan for __ in t_axis] for param in parameters}

        for i, generation in enumerate(t_axis):
            abc_data, __ = abc_history.get_distribution(m=m, t=generation)
            data = {x: np.array(abc_data[x]) for x in parameters if x in abc_data}
            for k, v in data.items():
                t_q1, t_m, t_q3 = np.percentile(
                    v, [25, 50, 75]
                )
                qs1[k][i] = t_q1
                medians[k][i] = t_m
                qs3[k][i] = t_q3


        for i, param in enumerate(parameters):
            # if len(medians[param]) == 0:
            if not medians[param]:
                continue
            # print(t_axis, medians[param])
            axs[i].plot(t_axis, medians[param], color=COLORS[m])
            axs[i].fill_between(t_axis, qs1[param], qs3[param], color=COLORS[m], alpha=0.2)

            axs[i].set_ylabel(param[10:])

        axs[-1].set_xlabel('Generation (t)')

    if save is not None:
        pdf_out.savefig()
    else:
        plt.show()

    if save is not None:
        pdf_out.close()