예제 #1
0
def test_hdi_plot_bad_dimensions():
    N = 10
    start = 0
    end = 12
    x_fits = linspace(start, end, N)
    curves = array([default_rng(1324).normal(size=N + 1) for _ in range(N + 1)])

    with pytest.raises(ValueError):
        hdi_plot(x_fits, curves)
예제 #2
0
def test_hdi_plot():
    N = 10
    start = 0
    end = 12
    x_fits = linspace(start, end, N)
    curves = array([default_rng(1324).normal(size=N) for _ in range(N)])
    intervals = [0.5, 0.65, 0.95]

    ax = hdi_plot(x_fits, curves, intervals)

    # Not much to check here, so check the viewing portion is sensible
    # and we've plotted the same number of PolyCollections as
    # requested intervals -- this could fail if the implementation
    # changes!
    number_of_plotted_intervals = len(
        [child for child in ax.get_children() if isinstance(child, PolyCollection)]
    )

    assert len(intervals) == number_of_plotted_intervals

    left, right, bottom, top = ax.axis()
    assert left <= start
    assert right >= end
    assert bottom <= curves.min()
    assert top >= curves.max()
예제 #3
0
# generate an axis on which to evaluate the model
x_fits = linspace(0, 12, 500)
# get the sample
sample = chain.get_sample()
# pass each through the forward model
curves = array([PeakModel.forward_model(x_fits, theta) for theta in sample])

# We could plot the predictions for each sample all on a single graph, but this is
# often cluttered and difficult to interpret.

# A better option is to use the hdi_plot function from the plotting module to plot
# highest-density intervals for each point where the model is evaluated:
from inference.plotting import hdi_plot
fig = plt.figure(figsize=(8, 5))
ax = fig.add_subplot(111)
hdi_plot(x_fits, curves, intervals=[0.68, 0.95], axis=ax)

# plot the MAP estimate (the sample with the single highest posterior probability)
MAP_prediction = PeakModel.forward_model(x_fits, chain.mode())
ax.plot(x_fits,
        MAP_prediction,
        ls='dashed',
        lw=3,
        c='C0',
        label='MAP estimate')
# build the rest of the plot
ax.errorbar(x_data,
            y_data,
            yerr=y_error,
            linestyle='none',
            c='red',
예제 #4
0
def test_hdi_plot_bad_intervals():
    intervals = [0.5, 0.65, 1.2, 0.95]

    with pytest.raises(ValueError):
        hdi_plot(zeros(5), zeros(5), intervals)
예제 #5
0
# plot the PDF
pdf.plot_summary()

# You may also want to assess the level of uncertainty in the model predictions.
# This can be done easily by passing each sample through the forward-model
# and observing the distribution of model expressions that result.

# generate an axis on which to evaluate the model
M = 500
x_fits = linspace(400, 450, M)
# get the sample
sample = chain.get_sample()
# pass each through the forward model
curves = array([model.forward_model(x_fits, theta) for theta in sample])

plt.figure(figsize=(8, 5))

# We can use the hdi_plot function from the plotting module to plot
# highest-density intervals for each point where the model is evaluated:
from inference.plotting import hdi_plot
hdi_plot(x_fits, curves)

# build the rest of the plot
plt.plot(x_data, y_data, 'D', c='red', label='data')
plt.xlabel('wavelength (nm)')
plt.ylabel('intensity')
plt.xlim([410, 440])
plt.legend()
plt.grid()
plt.tight_layout()
plt.show()
예제 #6
0
chain.burn = 5000
chain.thin = 2

# chain.plot_diagnostics()
# chain.trace_plot()
# chain.matrix_plot()


x_fits = linspace(0,10,100)
sample = chain.get_sample()
# pass each through the forward model
curves = array([HdiPosterior.forward(x_fits, theta) for theta in sample])

# We can use the hdi_plot function from the plotting module to plot
# highest-density intervals for each point where the model is evaluated:
from inference.plotting import hdi_plot

fig = plt.figure(figsize = (5,4))
ax = fig.add_subplot(111)

hdi_plot(x_fits, curves, axis=ax)
ax.errorbar(x, y, yerr = s, c = 'red', markeredgecolor = 'black', marker = 'D', ls = 'none', markersize=5, label = 'data')
# ax.set_ylim([20.,None])
ax.set_xlim([0,10])
# ax.set_xticks([])
# ax.set_yticks([])
# ax1.set_title('Gibbs sampling')
plt.tight_layout()
plt.legend()
plt.savefig('gallery_hdi.png')
plt.show()