Esempio n. 1
0
    def plot_real_fake(self,
                       title,
                       real,
                       fake,
                       show=True,
                       filename=None,
                       interactive=False,
                       figsize=(10, 6)):
        """Useful method when using this dataset in conjunction with GAN
        training. Plots the given real and fake input samples in a 2D plane.

        Args:
            (....): See docstring of method
                :meth:`data.dataset.Dataset.plot_samples`.
            real (numpy.ndarray): A 2D numpy array, where each row is an input
                sample. These samples correspond to actual input samples drawn
                from the dataset.
            fake (numpy.ndarray): A 2D numpy array, where each row is an input
                sample. These samples correspond to generated samples.
        """
        if real.shape[1] != 2 or fake.shape[1] != 2:
            raise ValueError(
                'This method is only applicable to 2D input data!')

        plt.figure(figsize=figsize)
        plt.title(title, size=20)
        if interactive:
            plt.ion()

        colors = np.asarray(misc.get_colorbrewer2_colors(family='Dark2'))

        plt.scatter(real[:, 0], real[:, 1], color=colors[0], label='real')
        plt.scatter(fake[:, 0], fake[:, 1], color=colors[1], label='fake')

        plt.legend()
        plt.xlabel('x1')
        plt.ylabel('x2')

        if filename is not None:
            plt.savefig(filename, bbox_inches='tight')

        if show:
            plt.show()
Esempio n. 2
0
    def plot_optimal_classification(self,
                                    title='Classification Map',
                                    input_mesh=None,
                                    mesh_modes=None,
                                    sample_inputs=None,
                                    sample_modes=None,
                                    sample_label=None,
                                    sketch_components=False,
                                    show=True,
                                    filename=None,
                                    figsize=(10, 6)):
        r"""Plot a color-coded grid on how to optimally classify for each input
        value.

        Note:
            Since the training data is drawn randomly, it might be that some
            training samples have a label that doesn't correpond to the optimal
            label.

        Args:
            (....): See arguments of method :meth:`plot_uncertainty_map`.
            mesh_modes (numpy.ndarray, optional): If not provided, then the
                color of each grid position :math:`x` is determined based on
                :math:`\arg\max_k \pi_k \mathcal{N}(x; \mu_k, \Sigma_k)`.
                Otherwise, the labeling provided here will determine the
                coloring.
        """
        if input_mesh is None:
            input_mesh = self.get_input_mesh()
        else:
            assert len(input_mesh) == 3
        _, _, X = input_mesh

        if mesh_modes is None:
            responsibilities = self._compute_responsibilities(X)
            optimal_labels = responsibilities.argmax(axis=1)
        else:
            assert np.all(np.equal(mesh_modes.shape, [X.shape[0], 1]))
            optimal_labels = mesh_modes

        n = self.num_modes
        colors = np.asarray(misc.get_colorbrewer2_colors(family='Dark2'))
        if n > len(colors):
            colors = cm.rainbow(np.linspace(0, 1, n))

        fig, ax = plt.subplots(figsize=figsize)
        plt.title(title, size=20)

        plt.scatter(X[:, 0],
                    X[:, 1],
                    s=1,
                    facecolor=colors[optimal_labels.squeeze().astype(int)])

        if sample_inputs is not None:
            plt.scatter(sample_inputs[:, 0], sample_inputs[:, 1],
                        color='b' if sample_modes is None else None,
                        label=sample_label,
                        edgecolor='k' if sample_modes is not None else None,
                        facecolor=colors[sample_modes.squeeze().astype(int)] \
                        if sample_modes is not None else None)

        if sketch_components:
            self._draw_components('Means')

        if sample_inputs is not None and sample_label is not None or \
                sketch_components:
            plt.legend()

        plt.xlabel('x1')
        plt.ylabel('x2')

        if filename is not None:
            plt.savefig(filename, bbox_inches='tight')

        if show:
            plt.show()
Esempio n. 3
0
    def plot_uncertainty_map(self,
                             title='Uncertainty Map',
                             input_mesh=None,
                             uncertainties=None,
                             use_generative_uncertainty=False,
                             use_ent_joint_uncertainty=False,
                             sample_inputs=None,
                             sample_modes=None,
                             sample_label=None,
                             sketch_components=False,
                             norm_eps=None,
                             show=True,
                             filename=None,
                             figsize=(10, 6)):
        r"""Draw an uncertainty heatmap.

        Args:
            title (str): Title of plots.
            input_mesh (tuple, optional): The input mesh of the heatmap (see
                return value of method :meth:`get_input_mesh`). If not
                specified, the default return value of method
                :meth:`get_input_mesh` is used.
            uncertainties (numpy.ndarray, optional): The uncertainties
                corresponding to ``input_mesh``. If not specified, then the
                uncertainties will be computed based the entropy across
                :math:`k=1..K` for
                
                .. math::

                    p(y_k = 1 \mid x) = \frac{ \
                     \pi_k \mathcal{N}(x; \mu_k, \Sigma_k)}{\
                     \sum_{l=1}^K \pi_l \mathcal{N}(x; \mu_l, \Sigma_l)}

                Note:
                    The entropies will be normalized by the maximum uncertainty
                    ``-np.log(1.0 / self.num_modes)``.
            use_generative_uncertainty (bool): If ``True``, the uncertainties
                plotted by default (if ``uncertainties`` is left unspecified)
                are not based on the entropy of the responsibilities
                :math:`p(y_k = 1 \mid x)`, but are the densities of the
                underlying GMM :math:`p(x)`.
            use_ent_joint_uncertainty (bool): If ``True``, the uncertainties
                plotted by default (if ``uncertainties`` is left unspecified)
                are based on the entropy of :math:`p(y, x)` at location
                :math:`x`:

                .. math::

                    & - \sum_k p(x) p(y_k=1 \mid x) \log p(x) p(y_k=1 \mid x)\\\
                    =& -p(x) \sum_k p(y_k=1 \mid x) \log p(y_k=1 \mid x) - \
                        p(x) \log p(x)

                Note, we normalize :math:`p(x)` by its maximum inside the chosen
                grid. Hence, the plot depends on the chosen ``input_mesh``. In
                this way, :math:`p(x) \in [0, 1]` and the second term
                :math:`-p(x) \log p(x) \in [0, \exp(-1)]` (note,
                :math:`-p(x) \log p(x)` would be negative for :math:`p(x) > 1`).

                The first term is simply the entropy of :math:`p(y \mid x)`
                scaled by :math:`p(x)`. Hence, it shows where in the input space
                are the regions where Gaussian bumps are overlapping (regions
                in which data exists but multiple labels :math:`y` are
                possible).

                The second term shows the boundaries of the data manifold. Note,
                :math:`-1 \log 1 = 0` and
                :math:`-\lim_{p(x) \rightarrow 0} p(x) \log p(x) = 0`.

                Note:
                    This option is mutually exclusive with option
                    ``use_generative_uncertainty``.

                Note:
                    Entropies of :math:`p(y \mid x)` won't be normalized in this
                    case.
            sample_inputs (numpy.ndarray, optional): Sample inputs. Can be
                specified if a scatter plot of samples (e.g., train samples)
                should be laid above the heatmap.
            sample_modes (numpy.ndarray, optional): To which mode do the samples
                in ``sample_inputs`` belong to? If provided, then for each
                sample in ``sample_inputs`` a number smaller than
                :attr:`num_modes` is expected. All samples with the same mode
                identifier are colored with the same color.
            sample_label (str, optional): If a label should be shown in the
                legend for inputs ``sample_inputs``.
            sketch_components (bool): Sketch the mean and variance of each
                component.
            norm_eps (float, optional): If uncertainties are computed by this
                method, then (normalized) densities for each x-value in the
                input mesh have to be computed. To avoid division by zero,
                a positive number ``norm_eps`` can be specified.
            (....): See docstring of method
                :meth:`data.dataset.Dataset.plot_samples`.
        """
        assert not use_generative_uncertainty or not use_ent_joint_uncertainty
        if input_mesh is None:
            input_mesh = self.get_input_mesh()
        else:
            assert len(input_mesh) == 3
        X1, X2, X = input_mesh

        if uncertainties is None:
            responsibilities = self._compute_responsibilities(
                X, normalize=not use_generative_uncertainty, eps=norm_eps)
            if use_generative_uncertainty:
                uncertainties = responsibilities.sum(axis=1)
            else:
                # Compute entropy.
                uncertainties = - np.sum(responsibilities * \
                    np.log(np.maximum(responsibilities, 1e-5)), axis=1)

                if use_ent_joint_uncertainty:
                    cond_entropies = np.copy(uncertainties)

                    # FIXME Instead of computing responsibilities again, we
                    # should let `_compute_responsibilities` return both.
                    unnormed_resps = self._compute_responsibilities(
                        X, normalize=False)
                    loc_densities = unnormed_resps.sum(axis=1)
                    # Make sure that p(x) is between 0 and 1.
                    loc_densities /= loc_densities.max()

                    uncertainties = loc_densities * cond_entropies - \
                        loc_densities * np.log(np.maximum(loc_densities, 1e-5))

                    # Look at individual terms instead (by uncommeting).
                    # Areas where data is still likely but uncertainty is high
                    # (e.g., overlapping Gaussian bumps)
                    #uncertainties = loc_densities * cond_entropies
                    # Areas where data is still somewhat likely (not totally
                    # OOD) but also not very common -> boundary of the data
                    # manifold.
                    #uncertainties = -loc_densities * \
                    #    np.log(np.maximum(loc_densities, 1e-5))
                else:
                    # Normalize conditional entropies.
                    max_entropy = -np.log(1.0 / self.num_modes)
                    uncertainties /= max_entropy
        else:
            assert np.all(np.equal(uncertainties.shape, [X.shape[0], 1]))

        if np.any(np.isnan(uncertainties)):
            warn(
                'NaN detected in uncertainties to be drawn. Set to 0 instead!')
        uncertainties[np.isnan(uncertainties)] = 0.

        uncertainties = uncertainties.reshape(X1.shape)

        fig, ax = plt.subplots(figsize=figsize)
        plt.title(title, size=20)

        f = plt.contourf(X1, X2, uncertainties)
        plt.colorbar(f)

        if sample_inputs is not None:
            n = self.num_modes
            colors = np.asarray(misc.get_colorbrewer2_colors(family='Dark2'))
            if n > len(colors):
                colors = cm.rainbow(np.linspace(0, 1, n))

            plt.scatter(sample_inputs[:, 0], sample_inputs[:, 1],
                        color='b' if sample_modes is None else None,
                        label=sample_label,
                        facecolor=colors[sample_modes.squeeze().astype(int)] \
                        if sample_modes is not None else None)

        if sketch_components:
            self._draw_components('Means')

        if sample_inputs is not None and sample_label is not None or \
                sketch_components:
            plt.legend()

        plt.xlabel('x1')
        plt.ylabel('x2')

        if filename is not None:
            plt.savefig(filename, bbox_inches='tight')

        if show:
            plt.show()
Esempio n. 4
0
    def plot_samples(self,
                     title,
                     inputs,
                     outputs=None,
                     predictions=None,
                     show=True,
                     filename=None,
                     interactive=False,
                     figsize=(10, 6)):
        """Plot samples belonging to this dataset.

        Args:
            (....): See docstring of method
                :meth:`data.dataset.Dataset.plot_samples`.
        """
        if inputs.shape[1] != 2:
            raise ValueError(
                'This method is only applicable to 2D input data!')

        plt.figure(figsize=figsize)
        plt.title(title, size=20)
        if interactive:
            plt.ion()

        if self.classification:
            n = self.num_classes
            colors = np.asarray(misc.get_colorbrewer2_colors(family='Dark2'))
            if n > len(colors):
                #warn('Changing to automatic color scheme as we don\'t have ' +
                #     'as many manual colors as tasks.')
                colors = cm.rainbow(np.linspace(0, 1, n))
        else:
            norm = Normalize(vmin=self._data['out_data'].min(),
                             vmax=self._data['out_data'].max())
            cmap = cm.get_cmap(name='viridis')

            sm = cm.ScalarMappable(norm=norm, cmap=cmap)
            sm.set_array(np.asarray([norm.vmin, norm.vmax]))
            if outputs is not None or predictions is not None:
                plt.colorbar(sm)

        if outputs is not None and predictions is None:
            plt.scatter(inputs[:, 0], inputs[:, 1], #edgecolors='b',
                label='Targets',
                facecolor=colors[outputs.squeeze().astype(int)] \
                    if self.classification else \
                        cmap(norm(outputs.squeeze()))
                )
        elif predictions is not None and outputs is None:
            plt.scatter(inputs[:, 0], inputs[:, 1], #edgecolors='r',
                label='Predictions',
                facecolor=colors[predictions.squeeze().astype(int)] \
                    if self.classification else \
                        cmap(norm(predictions.squeeze()))
                )
        elif predictions is not None and outputs is not None:
            plt.scatter(inputs[:, 0], inputs[:, 1], label='Targets+Predictions',
                edgecolors=colors[outputs.squeeze().astype(int)] \
                    if self.classification else \
                        cmap(norm(outputs.squeeze())),
                facecolor=colors[predictions.squeeze().astype(int)] \
                    if self.classification else \
                        cmap(norm(predictions.squeeze()))
                )
        else:
            assert predictions is None and outputs is None
            plt.scatter(inputs[:, 0], inputs[:, 1], color='k', label='Inputs')

        #plt.legend()
        plt.xlabel('x1')
        plt.ylabel('x2')

        if filename is not None:
            plt.savefig(filename, bbox_inches='tight')

        if show:
            plt.show()
Esempio n. 5
0
    def plot_datasets(data_handlers, inputs=None, predictions=None, labels=None,
                      fun_xranges=None, show=True, filename=None,
                      figsize=(10, 6), publication_style=False):
        """Plot several datasets of this class in one plot.

        Args:
            data_handlers: A list of ToyRegression objects.
            inputs (optional): A list of numpy arrays representing inputs for
                each dataset.
            predictions (optional): A list of numpy arrays containing the
                predicted output values for the given input values.
            labels (optional): A label for each dataset.
            fun_xranges (optional): List of x ranges in which the true
                underlying function per dataset should be sketched.
            show: Whether the plot should be shown.
            filename (optional): If provided, the figure will be stored under
                this filename.
            figsize: A tuple, determining the size of the figure in inches.
            publication_style: whether the plots should be in publication style
        """
        n = len(data_handlers)
        assert ((inputs is None and predictions is None) or \
                (inputs is not None and predictions is not None))
        assert ((inputs is None or len(inputs) == n) and \
                (predictions is None or len(predictions) == n) and \
                (labels is None or len(labels) == n))
        assert (fun_xranges is None or len(fun_xranges) == n)

        # Set-up matplotlib to adhere to our graphical conventions.
        # misc.configure_matplotlib_params(fig_size=1.2*np.array([1.6, 1]),
        #                                 font_size=8)

        # Get a colorscheme from colorbrewer2.org.
        colors = misc.get_colorbrewer2_colors(family='Dark2')
        if n > len(colors):
            warn('Changing to automatic color scheme as we don\'t have ' +
                 'as many manual colors as tasks.')
            colors = cm.rainbow(np.linspace(0, 1, n))

        if publication_style:
            ts, lw, ms = 60, 15, 140  # text fontsize, line width, marker size
            figsize = (12, 6)
        else:
            ts, lw, ms = 12, 2, 15

        fig, axes = plt.subplots(figsize=figsize)
        plt.title('1D regression', size=ts, pad=ts)

        phandlers = []
        plabels = []

        for i, data in enumerate(data_handlers):
            if labels is not None:
                lbl = labels[i]
            else:
                lbl = 'Function %d' % i

            fun_xrange = None
            if fun_xranges is not None:
                fun_xrange = fun_xranges[i]
            sample_x, sample_y = data._get_function_vals(x_range=fun_xrange)
            p, = plt.plot(sample_x, sample_y, color=colors[i],
                          linestyle='dashed', linewidth=lw / 3)

            phandlers.append(p)
            plabels.append(lbl)
            if inputs is not None:
                p = plt.scatter(inputs[i], predictions[i], color=colors[i],
                                s=ms)
                phandlers.append(p)
                plabels.append('Predictions')

        if publication_style:
            axes.grid(False)
            axes.set_facecolor('w')
            axes.axhline(y=axes.get_ylim()[0], color='k', lw=lw)
            axes.axvline(x=axes.get_xlim()[0], color='k', lw=lw)
            if len(data_handlers) == 3:
                plt.yticks([-1, 0, 1], fontsize=ts)
                plt.xticks([-2.5, 0, 2.5], fontsize=ts)
            else:
                for tick in axes.yaxis.get_major_ticks():
                    tick.label.set_fontsize(ts)
                for tick in axes.xaxis.get_major_ticks():
                    tick.label.set_fontsize(ts)
            axes.tick_params(axis='both', length=lw, direction='out', width=lw / 2.)
        else:
            plt.legend(phandlers, plabels)

        plt.xlabel('$x$', fontsize=ts)
        plt.ylabel('$y$', fontsize=ts)
        plt.tight_layout()

        if filename is not None:
            # plt.savefig(filename + '.pdf', bbox_inches='tight')
            plt.savefig(filename, bbox_inches='tight')

        if show:
            plt.show()
Esempio n. 6
0
def plot_predictive_distributions(config,
                                  writer,
                                  data_handlers,
                                  inputs,
                                  preds_mean,
                                  preds_std,
                                  save_fig=True,
                                  publication_style=False):
    """Plot the predictive distribution of several tasks into one plot.

    Args:
        config: Command-line arguments.
        writer: Tensorboard summary writer.
        data_handlers: A set of data loaders.
        inputs: A list of arrays containing the x values per task.
        preds_mean: The mean predictions corresponding to each task in `inputs`.
        preds_std: The std of all predictions in `preds_mean`.
        save_fig: Whether the figure should be saved in the output folder.
        publication_style: whether plots should be made in publication style.
    """
    num_tasks = len(data_handlers)
    assert (len(inputs) == num_tasks)
    assert (len(preds_mean) == num_tasks)
    assert (len(preds_std) == num_tasks)

    colors = utils.get_colorbrewer2_colors(family='Dark2')
    if num_tasks > len(colors):
        warn('Changing to automatic color scheme as we don\'t have ' +
             'as many manual colors as tasks.')
        colors = cm.rainbow(np.linspace(0, 1, num_tasks))

    fig, axes = plt.subplots(figsize=(12, 6))

    if publication_style:
        ts, lw, ms = 60, 15, 10  # text fontsize, line width, marker size
    else:
        ts, lw, ms = 12, 5, 8

    # The default matplotlib setting is usually too high for most plots.
    plt.locator_params(axis='y', nbins=2)
    plt.locator_params(axis='x', nbins=6)

    for i, data in enumerate(data_handlers):
        assert (isinstance(data, ToyRegression))

        # In what range to plot the real function values?
        train_range = data.train_x_range
        range_offset = (train_range[1] - train_range[0]) * 0.05
        sample_x, sample_y = data._get_function_vals( \
            x_range=[train_range[0]-range_offset, train_range[1]+range_offset])
        #sample_x, sample_y = data._get_function_vals()

        #plt.plot(sample_x, sample_y, color='k', label='f(x)',
        #         linestyle='dashed', linewidth=.5)
        plt.plot(sample_x,
                 sample_y,
                 color='k',
                 linestyle='dashed',
                 linewidth=lw / 7.)

        train_x = data.get_train_inputs().squeeze()
        train_y = data.get_train_outputs().squeeze()

        if i == 0:
            plt.plot(train_x,
                     train_y,
                     'o',
                     color='k',
                     label='Training Data',
                     markersize=ms)
        else:
            plt.plot(train_x, train_y, 'o', color='k', markersize=ms)

        plt.plot(inputs[i],
                 preds_mean[i],
                 color=colors[i],
                 label='Task %d' % (i + 1),
                 lw=lw / 3.)

        plt.fill_between(inputs[i],
                         preds_mean[i] + preds_std[i],
                         preds_mean[i] - preds_std[i],
                         color=colors[i],
                         alpha=0.3)
        plt.fill_between(inputs[i],
                         preds_mean[i] + 2. * preds_std[i],
                         preds_mean[i] - 2. * preds_std[i],
                         color=colors[i],
                         alpha=0.2)
        plt.fill_between(inputs[i],
                         preds_mean[i] + 3. * preds_std[i],
                         preds_mean[i] - 3. * preds_std[i],
                         color=colors[i],
                         alpha=0.1)

    if publication_style:
        axes.grid(False)
        axes.set_facecolor('w')
        axes.set_ylim([-2.5, 3])
        axes.axhline(y=axes.get_ylim()[0], color='k', lw=lw)
        axes.axvline(x=axes.get_xlim()[0], color='k', lw=lw)
        if len(data_handlers) == 3:
            plt.yticks([-2, 0, 2], fontsize=ts)
            plt.xticks([-3, 0, 3], fontsize=ts)
        else:
            for tick in axes.yaxis.get_major_ticks():
                tick.label.set_fontsize(ts)
            for tick in axes.xaxis.get_major_ticks():
                tick.label.set_fontsize(ts)
        axes.tick_params(axis='both',
                         length=lw,
                         direction='out',
                         width=lw / 2.)
        if config.train_from_scratch:
            plt.title('training from scratch', fontsize=ts, pad=ts)
        elif config.beta == 0:
            plt.title('fine-tuning', fontsize=ts, pad=ts)
        else:
            plt.title('CL with prob. hnet reg.', fontsize=ts, pad=ts)
    else:
        plt.legend()
        plt.title('Predictive distributions', fontsize=ts, pad=ts)

    plt.xlabel('$x$', fontsize=ts)
    plt.ylabel('$y$', fontsize=ts)

    if save_fig:
        plt.savefig(os.path.join(config.out_dir, 'pred_dists_%d' % num_tasks),
                    bbox_inches='tight')

    writer.add_figure('test/pred_dists',
                      plt.gcf(),
                      num_tasks,
                      close=not config.show_plots)
    if config.show_plots:
        utils.repair_canvas_and_show_fig(plt.gcf())
Esempio n. 7
0
def plot_mse(config,
             writer,
             num_tasks,
             current_mse,
             during_mse=None,
             baselines=None,
             save_fig=True,
             summary_label='test/mse'):
    """Produce a scatter plot that shows the current and immediate mse values
    (and maybe a set of baselines) of each task. This visualization helps to
    understand the impact of forgetting.

    Args:
        config: Command-line arguments.
        writer: Tensorboard summary writer.
        num_tasks: Number of tasks.
        current_mse: Array of MSE values currently achieved.
        during_mse (optional): Array of MSE values achieved right after
            training on the corresponding task.
        baselines (optional): A dictionary of label names mapping onto arrays.
            Can be used to plot additional baselines.
        save_fig: Whether the figure should be saved in the output folder.
        summary_label: Label used for the figure when writing it to tensorboard.
    """
    x_vals = np.arange(1, num_tasks + 1)

    num_plots = 1
    if during_mse is not None:
        num_plots += 1
    if baselines is not None:
        num_plots += len(baselines.keys())

    colors = utils.get_colorbrewer2_colors(family='Dark2')
    if num_plots > len(colors):
        warn('Changing to automatic color scheme as we don\'t have ' +
             'as many manual colors as tasks.')
        colors = cm.rainbow(np.linspace(0, 1, num_plots))

    fig, axes = plt.subplots(figsize=(10, 6))
    plt.title('Current MSE on each task')

    plt.scatter(x_vals, current_mse, color=colors[0], label='Current Val MSE')

    if during_mse is not None:
        plt.scatter(x_vals,
                    during_mse,
                    label='During MSE',
                    color=colors[1],
                    marker='*')

    if baselines is not None:
        for i, (label, vals) in enumerate(baselines.items()):
            plt.scatter(x_vals,
                        vals,
                        label=label,
                        color=colors[2 + i],
                        marker='x')

    plt.ylabel('MSE')
    plt.xlabel('Task')
    plt.xticks(x_vals)
    plt.legend()

    if save_fig:
        plt.savefig(os.path.join(config.out_dir, 'mse_%d' % num_tasks),
                    bbox_inches='tight')

    writer.add_figure(summary_label,
                      plt.gcf(),
                      num_tasks,
                      close=not config.show_plots)
    if config.show_plots:
        utils.repair_canvas_and_show_fig(plt.gcf())
Esempio n. 8
0
def plot_predictive_distribution(data,
                                 inputs,
                                 predictions,
                                 show_raw_pred=False,
                                 figsize=(10, 6),
                                 show=True):
    """Plot the predictive distribution of a single regression task.

    Args:
        data: The dataset handler (class `ToyRegression`).
        inputs: A 2D numpy array, denoting the inputs used to generate the
            `predictions`.
        predictions: A 2D numpy array with dimensions (batch_size x sample_size)
            where the sample size refers to the number of weight samples that
            have been used to produce an ensemble of predictions.
        show_raw_pred: Whether a second subplot should be shown, in which the
            standard deviations of the predictions are ignored and only the mean
            is shown.
        figsize: A tuple, determining the size of the figure in inches.
        show: Whether the plot should be shown.
    """
    assert (isinstance(data, ToyRegression))
    colors = utils.get_colorbrewer2_colors(family='Dark2')

    num_plots = 2 if show_raw_pred else 1
    fig, axes = plt.subplots(nrows=1, ncols=num_plots, figsize=figsize)

    train_x = data.get_train_inputs().squeeze()
    train_y = data.get_train_outputs().squeeze()

    #test_x = data.get_test_inputs().squeeze()
    #test_y = data.get_test_outputs().squeeze()

    sample_x, sample_y = data._get_function_vals()

    for i, ax in enumerate(axes):
        # The default matplotlib setting is usually too high for most plots.
        ax.locator_params(axis='y', nbins=2)
        ax.locator_params(axis='x', nbins=6)

        ax.plot(sample_x,
                sample_y,
                color='k',
                label='f(x)',
                linestyle='dashed',
                linewidth=.5)
        ax.plot(train_x, train_y, 'o', color='k', label='Train')
        #plt.plot(test_x, test_y, 'o', color=colors[1], label='Test')

        inputs = inputs.squeeze()
        mean_pred = predictions.mean(axis=1)
        std_pred = predictions.std(axis=1)

        c = colors[2]
        ax.plot(inputs, mean_pred, color=c, label='Pred')
        if i == 0:
            ax.fill_between(inputs,
                            mean_pred + std_pred,
                            mean_pred - std_pred,
                            color=c,
                            alpha=0.3)
            ax.fill_between(inputs,
                            mean_pred + 2. * std_pred,
                            mean_pred - 2. * std_pred,
                            color=c,
                            alpha=0.2)
            ax.fill_between(inputs,
                            mean_pred + 3. * std_pred,
                            mean_pred - 3. * std_pred,
                            color=c,
                            alpha=0.1)

        ax.legend()
        ax.set_xlabel('$x$')
        ax.set_ylabel('$y$')

        if i == 1:
            ax.set_title('Mean Predictions')
        else:
            ax.set_title('Predictive Distribution')

    if show:
        plt.show()