Esempio n. 1
0
def plot_kernel_predictions(model, data_name, legend=True, first=False):
    """Plot the prediction for a kernel."""
    k = wd_results.load(data_name, "data.pickle")["k"]
    t, mean1, var1 = wd_results.load(data_name, "structured", model,
                                     "k_pred.pickle")
    t, mean2, var2 = wd_results.load(data_name, "mean-field", model,
                                     "k_pred.pickle")
    plt.plot(t, k, label="Truth", style="train")
    plt.plot(t, mean1, label="Structured", style="pred")
    plt.fill_between(
        t,
        mean1 - 1.96 * np.sqrt(var1),
        mean1 + 1.96 * np.sqrt(var1),
        style="pred",
    )
    plt.plot(t, mean1 + 1.96 * np.sqrt(var1), style="pred", lw=1)
    plt.plot(t, mean1 - 1.96 * np.sqrt(var1), style="pred", lw=1)
    plt.plot(t, mean2, label="Mean-field", style="pred2")
    plt.fill_between(
        t,
        mean2 - 1.96 * np.sqrt(var2),
        mean2 + 1.96 * np.sqrt(var2),
        style="pred2",
    )
    plt.plot(t, mean2 + 1.96 * np.sqrt(var2), style="pred2", lw=1)
    plt.plot(t, mean2 - 1.96 * np.sqrt(var2), style="pred2", lw=1)
    plt.yticks([0, 0.5, 1])
    plt.xticks([0, 2, 4])
    plt.xlim(0, 4)
    plt.ylim(-0.25, 1.25)
    if not first:
        plt.gca().set_yticklabels([])
    tweak(legend=legend)
def plot_prediction(x, f, pred, x_obs=None, y_obs=None):
    plt.plot(x, f, label="True", style="test")
    if x_obs is not None:
        plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
    mean, lower, upper = pred
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, lower, upper, style="pred")
    tweak()
Esempio n. 3
0
def plot_compare(name1, name2, y_label=True, y_ticks=True, style2=None):
    """Compare prediction for the function for two models."""
    _, mean1, var1 = preds_f[name1]
    _, mean2, var2 = preds_f[name2]

    inds = t_pred >= 150
    mean1 = mean1[inds]
    mean2 = mean2[inds]
    var1 = var1[inds]
    var2 = var2[inds]
    t = t_pred[inds]

    plt.plot(t, mean1, style="pred", label=name1.upper())
    plt.fill_between(
        t,
        mean1 - 1.96 * B.sqrt(var1),
        mean1 + 1.96 * B.sqrt(var1),
        style="pred",
    )
    plt.plot(t, mean1 - 1.96 * B.sqrt(var1), style="pred", lw=0.5)
    plt.plot(t, mean1 + 1.96 * B.sqrt(var1), style="pred", lw=0.5)

    if style2 is None:
        style2 = "pred2"

    plt.plot(t, mean2, style=style2, label=name2.upper())
    plt.fill_between(
        t,
        mean2 - 1.96 * B.sqrt(var2),
        mean2 + 1.96 * B.sqrt(var2),
        style=style2,
    )
    plt.plot(t, mean2 - 1.96 * B.sqrt(var2), style=style2, lw=0.5)
    plt.plot(t, mean2 + 1.96 * B.sqrt(var2), style=style2, lw=0.5)

    inds = t_train >= 150
    plt.scatter(
        t_train[inds],
        normaliser.untransform(y_train[inds]),
        style="train",
        label="Train",
    )
    inds = t_test >= 150
    plt.scatter(t_test[inds], y_test[inds], style="test", label="Test")

    plt.xlim(t[0], t[-1])
    plt.xlabel(f"Day of {args.year}")
    if y_label:
        plt.ylabel("Crude Oil (USD)")
    if not y_ticks:
        plt.gca().set_yticklabels([])
    tweak(legend_loc="upper right")
def plot_prediction(prior, pred):
    f, noise = prior
    mean, lower, upper = pred
    plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
    plt.plot(x, f_true, label="True", style="test")
    plt.plot(x, mean, label="Prediction", style="pred")
    plt.fill_between(x, lower, upper, style="pred")
    plt.ylim(-2, 2)
    plt.text(
        0.02,
        0.02,
        f"var = {f.kernel.factor(0):.2f}, "
        f"scale = {f.kernel.factor(1).stretches[0]:.2f}, "
        f"noise = {noise:.2f}",
        transform=plt.gca().transAxes,
    )
    tweak()
Esempio n. 5
0
def plot_psd(name, y_label=True, style="pred", finish=True):
    """Plot prediction for the PSD."""
    freqs, mean, lower, upper = preds_psd[name]
    freqs -= freqs[0]

    inds = freqs <= 0.2
    freqs = freqs[inds]
    mean = mean[inds]
    lower = lower[inds]
    upper = upper[inds]

    if y_label:
        plt.ylabel("PSD (dB)")

    plt.plot(freqs, mean, style=style, label=name)
    plt.fill_between(freqs, lower, upper, style=style)
    plt.plot(freqs, lower, style=style, lw=0.5)
    plt.plot(freqs, upper, style=style, lw=0.5)
    plt.xlim(0, 0.2)
    plt.ylim(0, 60)
    plt.xlabel("Frequency (day${}^{-1}$)")
    if finish:
        tweak()
    mean, lower, upper = (f | (x_obs, y_obs))(x).marginals()

    n_test = np.random.randint(3, 5 + 1)
    x_test = np.random.rand(n_test) * x.max()
    y_test = (f | (x_obs, y_obs))(x_test).sample().flatten()

    plt.figure(figsize=(2.5, 1.25))
    plt.scatter(x_obs, y_obs, style='train', s=15,
                label=f'${prefix[i]} D{context_label[i]}$')
    plt.ylim(lower.min() - 0.1, upper.max() + 0.1)
    plt.xlim(x.min(), x.max())
    plt.gca().set_xticklabels([])
    plt.gca().set_yticklabels([])
    matplotlib.rc('legend', fontsize=8)
    tweak(legend=True)

    plt.savefig(f'datas_and_predictions/data{i + 1}.pdf')
    pdfcrop(f'datas_and_predictions/data{i + 1}.pdf')

    plt.plot(x, mean, style='pred')
    plt.fill_between(x, lower, upper, style='pred')

    plt.savefig(f'datas_and_predictions/pred{i + 1}.pdf')
    pdfcrop(f'datas_and_predictions/pred{i + 1}.pdf')

    plt.scatter(x_test, y_test, style='test', s=15,
                label=f'${prefix[i]} D{target_label[i]}$')
    tweak(legend=True)
    plt.savefig(f'datas_and_predictions/test{i + 1}.pdf')
    pdfcrop(f'datas_and_predictions/test{i + 1}.pdf')
Esempio n. 7
0
    for q in [1, 5, 10, 20, 30, 40]:
        plt.fill_between(
            x,
            B.quantile(ks, q / 100, axis=1),
            B.quantile(ks, 1 - q / 100, axis=1),
            facecolor="tab:blue",
            alpha=0.2,
        )
    plt.plot(x, B.mean(ks, axis=1), c="black")
    if hasattr(model, "t_u"):
        plt.scatter(model.t_u, model.t_u * 0, s=5, marker="o", c="black")
    plt.title(model.name + " (Kernel)")
    plt.xlabel("Time (s)")
    plt.xlim(-1.5, 1.5)
    plt.ylim(-0.5, 1.25)
    tweak(legend=False)

for i, (model, (freqs, psds)) in enumerate(zip(models, model_psds)):
    plt.subplot(1, 6, 4 + i)

    def apply_to_psd(f):
        raw = 10**(psds / 10)
        return 10 * B.log(f(raw)) / B.log(10)

    for q in [1, 5, 10, 20, 30, 40]:
        plt.fill_between(
            freqs,
            apply_to_psd(lambda x: B.quantile(x, q / 100, axis=1)),
            apply_to_psd(lambda x: B.quantile(x, 1 - q / 100, axis=1)),
            facecolor="tab:blue",
            alpha=0.2,
# Constuct a prior:
prior = Measure()
w = lambda x: B.exp(-(x**2) / 0.5)  # Window
b = [(w * GP(EQ(), measure=prior)).shift(xi)
     for xi in x_obs]  # Weighted basis funs
f = sum(b)  # Latent function
e = GP(Delta(), measure=prior)  # Noise
y = f + 0.2 * e  # Observation model

# Sample a true, underlying function and observations.
f_true, y_obs = prior.sample(f(x), y(x_obs))

# Condition on the observations to make predictions.
post = prior | (y(x_obs), y_obs)

# Plot result.
for i, bi in enumerate(b):
    mean, lower, upper = post(bi(x)).marginals()
    kw_args = {"label": "Basis functions"} if i == 0 else {}
    plt.plot(x, mean, style="pred2", **kw_args)
plt.plot(x, f_true, label="True", style="test")
plt.scatter(x_obs, y_obs, label="Observations", style="train", s=20)
mean, lower, upper = post(f(x)).marginals()
plt.plot(x, mean, label="Prediction", style="pred")
plt.fill_between(x, lower, upper, style="pred")
tweak()

plt.savefig("readme_example11_nonparametric_basis.png")
plt.show()
Esempio n. 9
0
)
plt.fill_between(freqs, lower, upper, style="pred", zorder=1)
plt.plot(freqs, lower, style="pred", lw=1, zorder=1)
plt.plot(freqs, upper, style="pred", lw=1, zorder=1)
per_x, per = periodogram(y)
plt.plot(per_x,
         10 * np.log10(per),
         label="Periodogram",
         lw=1,
         style="train",
         zorder=0)
plt.xlim(0, 0.5)
plt.ylim(-30, 20)
plt.xlabel("Frequency $f$ (day${}^{-1}$)")
plt.ylabel("Spectral density (dB)")
tweak(legend_loc="upper right")

# For the decomposition, first subtract the spectrum of the excitation.
mean -= spec_x
lower -= spec_x
upper -= spec_x

# Since the spectra of the excitation and filter multiply, we can exchange a
# multiplicative constant. Use this to shift the components of the decomposition up
# and down to make the result visually look nice.
mean -= 7.5
lower -= 7.5
upper -= 7.5
spec_x += 7.5

# Plot the decomposition of the PDF.
Esempio n. 10
0
t_s, elbo_s = tracker_s.get_xy(start=0)

# Double check that there isn't a huge delay between the first two times e.g. due to
# JIT compilation.
assert t_mf[1] < 1
assert t_cmf[1] < 1
assert t_ca[1] < 1
assert t_s[1] < 1

plt.figure(figsize=(5, 4))
plt.axhline(y=gp_logpdf, ls="--", c="black", lw=1, label="GP")
plt.plot(
    # These times should line up exactly, but they might not due to natural variability,
    # in runtimes. Force them to line up exactly by scaling the times.
    t_s / max(t_s) * max(t_ca),
    elbo_s,
    label="Structured",
)
plt.plot(t_ca, elbo_ca, label="CA")
plt.plot(t_cmf, np.maximum.accumulate(elbo_cmf), label="Collapsed MF")
plt.plot(t_mf, np.maximum.accumulate(elbo_mf), label="MF")
plt.xlabel("Time (s)")
plt.ylabel("ELBO")
plt.ylim(-900, -550)
# Round to the nearest five seconds.
plt.xlim(0, 5 * (max(max(t_mf), max(t_cmf), max(t_s)) // 5 + 1))
tweak(legend_loc="lower right")
plt.savefig(wd.file("elbos.pdf"))
pdfcrop(wd.file("elbos.pdf"))
plt.show()
Esempio n. 11
0
File: smk.py Progetto: wesselb/gpcm
plt.fill_between(
    t,
    err_95_lower,
    err_95_upper,
    style="pred2",
)
plt.plot(t, err_95_upper, style="pred2", lw=1)
plt.plot(t, err_95_lower, style="pred2", lw=1)

plt.xlabel("Time (s)")
plt.ylabel("Covariance")
plt.title("Kernel")
plt.xlim(0, 4)
plt.ylim(-0.75, 1.25)
plt.yticks([-0.5, 0, 0.5, 1])
tweak(legend=False)

# Plot prediction for PSD.

plt.subplot(1, 2, 2)
t_k = B.linspace(-8, 8, 1001)
freqs, psd = estimate_psd(t, kernel(t, 0).flatten(), db=True)
inds = freqs <= 1
freqs = freqs[inds]
psd = psd[inds]
plt.plot(freqs, psd, label="Truth", style="train")

t, mean, var, err_95_lower, err_95_upper = psd_pred_struc
inds = t <= 1
t = t[inds]
mean = mean[inds]