def plot_noise(figure_id_start, X, noise, title): plt.figure(figure_id_start) n_plots = noise.shape[1] // 2 plt.suptitle(title) for idx in range(n_plots): plt.subplot(1, 3, idx + 1) plt.scatter(slice_column(noise, 2 * idx), slice_column(noise, 2 * idx + 1), s=1, c=X, cmap='magma')
def plot_mse_values(model_means, igp_means, Y_true, figure_id_start=0, initial_labels=None, plot_shape=(1, 3)): labels = initialize_labels(Y_true.shape[1], initial_labels) for out_id, label in enumerate(labels): specify_plot_location(out_id, figure_id_start, plot_shape) single_gpar_means = slice_column(model_means, out_id) single_igp_means = slice_column(igp_means, out_id) true_means = slice_column(Y_true, out_id) gpar_smse = smse(true_means, single_gpar_means) igp_smse = smse(true_means, single_igp_means) plot_bar_plot([gpar_smse, igp_smse], ['GPAR', 'IGP']) plt.title('{} SMSE'.format(label))
def plot_single_output(X, stacked_means, stacked_vars, out_id, label, display_var=False): """Construct plot containing the predictions and observations.""" means = slice_column(stacked_means, out_id) plt.plot(X, means, label=label) if display_var: variances = slice_column(stacked_vars, out_id) ub = means + 2 * np.sqrt(variances) lb = means - 2 * np.sqrt(variances) plt.fill_between( X.flatten(), lb.flatten(), ub.flatten(), alpha=0.2, edgecolor='b')
def _get_trained_gp_model(self, current_X, out_id): y = slice_column(self.Y_obs, out_id) kernel = self._get_kernel(self.X_obs, current_X) m = self._get_model(current_X, y, kernel) m.likelihood.variance = self.init_likelihood_var if self.is_zero_noise: m.likelihood.variance = 0.00001 m.likelihood.variance.trainable = False self._optimize_model(m) return m
def augment_X(self, current_X, out_id): if current_X is None: return self.X_obs y = slice_column(self.Y_obs, out_id) return concat_right_column(current_X, y)
def plot_truth(X_new, Y_true, out_id): single_Y = slice_column(Y_true, out_id) plt.plot(X_new, single_Y, label='Truth')
def plot_observations(X_obs, Y_obs, out_id): single_Y = slice_column(Y_obs, out_id) plt.scatter(X_obs, single_Y, color='b', marker='x', label='Observations')