Exemple #1
0
    def plot_marginals(self, selector=None, axes=None):
        """Visualizes the marginal posterior distributions for each parameter.

        Parameters
        ----------
        selector : iterable of ints or strings, optional
            Indices or keys to use from marginals. Default to all.
        axes: one or an iterable of plt.Axes, optional

        Returns
        -------
        axes: np.ndarray of plt.Axes

        """
        # TODO: allow kwargs
        marginals = self.marginals
        unique_param_vals = self._get_unique_parameter_values()
        ncols = len(marginals.keys()) if len(marginals.keys()) > 5 else 5
        marginals = viz._limit_params(marginals, selector)
        shape = (max(1, len(marginals) // ncols), min(len(marginals), ncols))
        axes, _ = viz._create_axes(axes, shape)
        axes = axes.ravel()

        for idx, key in enumerate(marginals.keys()):
            axes[idx].plot(unique_param_vals[key], marginals[key])
            axes[idx].fill_between(unique_param_vals[key],
                                   marginals[key],
                                   alpha=0.1)
            axes[idx].set_xlabel(key)

        return axes
Exemple #2
0
    def plot_pairs(self, selector=None, axes=None):
        """Visualizes pairwise relationships as a matrix with marginals on the diagonal.

        Parameters
        ----------
        selector: iterable of ints or strings, optional
            Indices or keys to use from marginals and posterior. Default to all.
        axes: one or an iterable of plt.Axes, optional

        Returns
        -------
        axes: np.ndarray of plt.Axes

        """
        # TODO: allow kwargs
        posterior_shape = self._get_number_of_unique_parameter_values()
        posterior = self._posterior.reshape(posterior_shape) / np.sum(
            self._posterior)
        marginals = self.marginals
        unique_param_vals = self._get_unique_parameter_values()
        marginals = viz._limit_params(marginals, selector)
        shape = (len(marginals), len(marginals))
        axes, _ = viz._create_axes(axes, shape)

        for idx_row, key_row in enumerate(marginals):
            for idx_col, key_col in enumerate(marginals):
                if idx_row == idx_col:
                    # plot 1d marginals
                    axes[idx_row, idx_col].plot(unique_param_vals[key_row],
                                                marginals[key_row])
                    axes[idx_row,
                         idx_col].fill_between(unique_param_vals[key_row],
                                               marginals[key_row],
                                               alpha=0.1)
                else:
                    # plot 2d marginals
                    xx, yy = np.meshgrid(unique_param_vals[key_col],
                                         unique_param_vals[key_row],
                                         indexing='ij')
                    axes[idx_row, idx_col].contourf(*[xx, yy],
                                                    self._get_2d_marginal(
                                                        idx_row, idx_col,
                                                        posterior),
                                                    cmap='Blues')
            axes[idx_row, 0].set_ylabel(key_row)
            axes[-1, idx_row].set_xlabel(key_row)

        return axes
    def plot_discrepancy(self, axes=None, **kwargs):
        """Plot acquired parameters vs. resulting discrepancy.

        TODO: refactor
        """
        n_plots = self.target_model.input_dim
        ncols = kwargs.pop('ncols', 5)
        kwargs['sharey'] = kwargs.get('sharey', True)
        shape = (max(1, n_plots // ncols), min(n_plots, ncols))
        axes, kwargs = vis._create_axes(axes, shape, **kwargs)
        axes = axes.ravel()

        for ii in range(n_plots):
            axes[ii].scatter(self.target_model._gp.X[:, ii], self.target_model._gp.Y[:, 0])
            axes[ii].set_xlabel(self.parameter_names[ii])

        axes[0].set_ylabel('Discrepancy')

        return axes