Example #1
0
    def argmax_posterior_mean(cands: to.Tensor, cands_values: to.Tensor,
                              ddp_space: BoxSpace, num_restarts: int,
                              num_samples: int) -> to.Tensor:
        """
        Compute the GP input with the maximal posterior mean.

        :param cands: candidates a.k.a. x
        :param cands_values: observed values a.k.a. y
        :param ddp_space: space of the domain distribution parameters, indicates the lower and upper bound
        :param num_restarts: number of restarts for the optimization of the acquisition function
        :param num_samples: number of samples for the optimization of the acquisition function
        :return: un-normalized candidate with maximum posterior value a.k.a. x
        """
        if not isinstance(cands, to.Tensor):
            raise pyrado.TypeErr(given=cands, expected_type=to.Tensor)
        if not isinstance(cands_values, to.Tensor):
            raise pyrado.TypeErr(given=cands_values, expected_type=to.Tensor)
        if not isinstance(ddp_space, BoxSpace):
            raise pyrado.TypeErr(given=ddp_space, expected_type=BoxSpace)

        # Normalize the input data and standardize the output data
        uc_projector = UnitCubeProjector(
            to.from_numpy(ddp_space.bound_lo).to(dtype=to.get_default_dtype()),
            to.from_numpy(ddp_space.bound_up).to(dtype=to.get_default_dtype()),
        )
        cands_norm = uc_projector.project_to(cands)
        cands_values_stdized = standardize(cands_values)

        if cands_norm.shape[0] > cands_values.shape[0]:
            print_cbt(
                f"There are {cands.shape[0]} candidates but only {cands_values.shape[0]} evaluations. Ignoring "
                f"the candidates without evaluation for computing the argmax.",
                "y",
            )
            cands_norm = cands_norm[:cands_values.shape[0], :]

        # Create and fit the GP model
        gp = SingleTaskGP(cands_norm, cands_values_stdized)
        gp.likelihood.noise_covar.register_constraint("raw_noise",
                                                      GreaterThan(1e-5))
        mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        fit_gpytorch_model(mll)

        # Find position with maximal posterior mean
        cand_norm, _ = optimize_acqf(
            acq_function=PosteriorMean(gp),
            bounds=to.stack(
                [to.zeros(ddp_space.flat_dim),
                 to.ones(ddp_space.flat_dim)]).to(dtype=to.float32),
            q=1,
            num_restarts=num_restarts,
            raw_samples=num_samples,
        )

        cand_norm = cand_norm.to(dtype=to.get_default_dtype())
        cand = uc_projector.project_back(cand_norm.detach())
        print_cbt(f"Converged to argmax of the posterior mean: {cand.numpy()}",
                  "g",
                  bright=True)
        return cand
Example #2
0
 def current_best(self, past_only: bool = False, **kwargs) -> Tuple[Tensor, Tensor]:
     """
     Get the current best solution and value
     :param past_only: If True, optimization is over previously evaluated points only.
     :param kwargs: ignored
     :return: Current best solution and value
     """
     inner = PosteriorMean(self.model)
     if past_only:
         with torch.no_grad():
             values = inner(self.X.reshape(-1, 1, self.dim_x))
         best = torch.argmax(values)
         current_best_sol = self.X[best]
         current_best_value = -values[best]
     else:
         current_best_sol, current_best_value = optimize_acqf(
             acq_function=inner,
             bounds=self.bounds,
             q=1,
             num_restarts=self.num_restarts,
             raw_samples=self.num_restarts * self.raw_multiplier,
         )
     # negated again to report the correct value
     if self.verbose:
         print(
             "Current best solution, value: ", current_best_sol, -current_best_value
         )
     return current_best_sol, -current_best_value
Example #3
0
 def _update_current_best(self) -> None:
     """
     Updates the current best solution and corresponding value
     """
     pm = PosteriorMean(self.model)
     self.current_best_sol, self.current_best_val = optimize_acqf(
         pm,
         Tensor([[0], [1]]).repeat(1, self.dim),
         q=1,
         num_restarts=self.num_restarts,
         raw_samples=self.raw_samples,
     )
Example #4
0
    def argmax_posterior_mean(cands: to.Tensor, cands_values: to.Tensor,
                              uc_normalizer: UnitCubeProjector,
                              num_restarts: int,
                              num_samples: int) -> to.Tensor:
        """
        Compute the GP input with the maximal posterior mean.

        :param cands: candidates a.k.a. x
        :param cands_values: observed values a.k.a. y
        :param uc_normalizer: unit cube normalizer used during the experiments (can be recovered form the bounds)
        :param num_restarts: number of restarts for the optimization of the acquisition function
        :param num_samples: number of samples for the optimization of the acquisition function
        :return: un-normalized candidate with maximum posterior value a.k.a. x
        """
        # Normalize the input data and standardize the output data
        cands_norm = uc_normalizer.project_to(cands)
        cands_values_stdized = standardize(cands_values)

        # Create and fit the GP model
        gp = SingleTaskGP(cands_norm, cands_values_stdized)
        gp.likelihood.noise_covar.register_constraint('raw_noise',
                                                      GreaterThan(1e-5))
        mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        fit_gpytorch_model(mll)

        # Find position with maximal posterior mean
        cand_norm, acq_value = optimize_acqf(
            acq_function=PosteriorMean(gp),
            bounds=to.stack([
                to.zeros_like(uc_normalizer.bound_lo),
                to.ones_like(uc_normalizer.bound_up)
            ]),
            q=1,
            num_restarts=num_restarts,
            raw_samples=num_samples)

        cand = uc_normalizer.project_back(cand_norm.detach())
        print_cbt(f'Converged to argmax of the posterior mean\n{cand.numpy()}',
                  'g',
                  bright=True)
        return cand
Example #5
0
def main(argv):
    dataset = 1

    try:
        opts, args = getopt.getopt(argv, "hd:", ["dataset="])
    except getopt.GetoptError:
        print('random parallel with input dataset')
        sys.exit(2)
    for opt, arg in opts:
        if opt == '-h':
            print('random parallel with input dataset')
            sys.exit()
        elif opt in ("-d", "--dataset"):
            dataset = int(arg)


# average over multiple trials
    for trial in range(1, N_TRIALS + 1):

        print(f"\nTrial {trial:>2} of {N_TRIALS} ", end="")
        best_observed_ei, best_observed_nei = [], []

        # call helper functions to generate initial training data and initialize model
        train_x_ei, train_obj_ei, best_observed_value_ei, current_best_config = generate_initial_data(
            dataset)

        train_x_nei, train_obj_nei = train_x_ei, train_obj_ei
        best_observed_value_nei = best_observed_value_ei
        mll_nei, model_nei = initialize_model(train_x_nei, train_obj_nei)

        best_observed_nei.append(best_observed_value_nei)

        # run N_BATCH rounds of BayesOpt after the initial random batch
        for iteration in range(1, N_BATCH + 1):

            # fit the models
            fit_gpytorch_model(mll_nei)

            # define the qEI and qNEI acquisition modules using a QMC sampler
            qmc_sampler = SobolQMCNormalSampler(num_samples=MC_SAMPLES)

            # for best_f, we use the best observed noisy values as an approximation

            qNEI = qNoisyExpectedImprovement(
                model=model_nei,
                X_baseline=train_x_nei,
                sampler=qmc_sampler,
            )

            # optimize and get new observation
            new_x_nei, new_obj_nei = optimize_acqf_and_get_observation(
                qNEI, dataset)

            # update training points

            train_x_nei = torch.cat([train_x_nei, new_x_nei])
            train_obj_nei = torch.cat([train_obj_nei, new_obj_nei])

            # update progress

            best_value_nei = train_obj_nei.max().item()
            best_observed_nei.append(best_value_nei)

            # reinitialize the models so they are ready for fitting on next iteration
            # use the current state dict to speed up fitting
            mll_nei, model_nei = initialize_model(
                train_x_nei,
                train_obj_nei,
                model_nei.state_dict(),
            )

        # return the best configuration

        best_tensor_nei, indices_nei = torch.max(train_obj_nei, 0)
        train_best_x_nei = train_x_nei[indices_nei].cpu().numpy()

        from botorch.acquisition import PosteriorMean

        argmax_pmean_nei, max_pmean_nei = optimize_acqf(
            acq_function=PosteriorMean(model_nei),
            bounds=bounds,
            q=1,
            num_restarts=20,
            raw_samples=2048,
        )

        csv_file_name = '/home/junjie/modes/botorch/' + folder_name + '/modes-i/hp-ngp-qnei-dataset-' + str(
            dataset) + '-trail' + str(trial) + '.csv'

        with open(csv_file_name, 'w') as csvFile:
            writer = csv.writer(csvFile)
            writer.writerow([
                str(argmax_pmean_nei.cpu().numpy()),
                str(max_pmean_nei.cpu().numpy())
            ])  # nei prediction
            writer.writerow(
                [str(train_best_x_nei),
                 str(best_tensor_nei.cpu().numpy())])  # nei observation

        csvFile.close()
Example #6
0
from botorch.test_functions import Ackley, Hartmann
from parametric_bandit.arm import ParametricArm
import torch
from torch import Tensor
from botorch.acquisition import qKnowledgeGradient, PosteriorMean
from botorch.optim import optimize_acqf
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d

# test function
ack = Ackley()
# construct the arm
arm1 = ParametricArm(ack)

pm = PosteriorMean(arm1.model)

cand, val = optimize_acqf(pm,
                          Tensor([[0], [1]]).repeat(1, 2),
                          q=1,
                          num_restarts=10,
                          raw_samples=100)

plt.figure()
ax = plt.axes(projection="3d")
k = 40  # number of points in x and y
x = torch.linspace(0, 1, k)
xx = x.view(-1, 1).repeat(1, k)
yy = x.repeat(k, 1)
xy = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2)
means = arm1.model.posterior(xy).mean
ax.scatter3D(
Example #7
0
def bayes_opt(x0, y0):
    """
    Main Bayesian optimization loop. Begins by initializing model, then for each
    iteration, it fits the GP to the data, gets a new point with the acquisition
    function, adds it to the dataset, and exits if it's a successful attack
    """

    best_observed = []
    query_count, success = 0, 0

    # call helper function to initialize model
    train_x, train_obj, mll, model, best_value, mean, std = initialize_model(
        x0, y0, n=args.initial_samples)
    if args.standardize_every_iter:
        train_obj = (train_obj - train_obj.mean()) / train_obj.std()
    best_observed.append(best_value)
    query_count += args.initial_samples

    # run args.iter rounds of BayesOpt after the initial random batch
    for _ in range(args.iter):

        # fit the model
        fit_gpytorch_model(mll)

        # define the qNEI acquisition module using a QMC sampler
        if args.q != 1:
            qmc_sampler = SobolQMCNormalSampler(num_samples=2000, seed=seed)
            qEI = qExpectedImprovement(model=model,
                                       sampler=qmc_sampler,
                                       best_f=best_value)
        else:
            if args.acqf == 'EI':
                qEI = ExpectedImprovement(model=model, best_f=best_value)
            elif args.acqf == 'PM':
                qEI = PosteriorMean(model)
            elif args.acqf == 'POI':
                qEI = ProbabilityOfImprovement(model, best_f=best_value)
            elif args.acqf == 'UCB':
                qEI = UpperConfidenceBound(model, beta=args.beta)

        # optimize and get new observation
        new_x, new_obj = optimize_acqf_and_get_observation(qEI, x0, y0)

        if args.standardize:
            new_obj = (new_obj - mean) / std

        # update training points
        train_x = torch.cat((train_x, new_x))
        train_obj = torch.cat((train_obj, new_obj))
        if args.standardize_every_iter:
            train_obj = (train_obj - train_obj.mean()) / train_obj.std()

        # update progress
        best_value, best_index = train_obj.max(0)
        best_observed.append(best_value.item())
        best_candidate = train_x[best_index]

        # reinitialize the model so it is ready for fitting on next iteration
        torch.cuda.empty_cache()
        model.set_train_data(train_x, train_obj, strict=False)

        # get objective value of best candidate; if we found an adversary, exit
        best_candidate = best_candidate.view(1, -1)
        best_candidate = transform(best_candidate, args.dset, args.arch,
                                   args.cos, args.sin).to(device)
        best_candidate = proj(best_candidate, args.eps, args.inf_norm,
                              args.discrete)
        with torch.no_grad():
            adv_label = torch.argmax(
                cnn_model.predict_scores(best_candidate + x0))
        if adv_label != y0:
            success = 1
            if args.inf_norm:
                print('Adversarial Label', adv_label.item(), 'Norm:',
                      best_candidate.abs().max().item())
            else:
                print('Adversarial Label', adv_label.item(), 'Norm:',
                      best_candidate.norm().item())
            return query_count, success
        query_count += args.q
    # not successful (ran out of query budget)
    return query_count, success
def render_singletask_gp(
    ax: [plt.Axes, Axes3D, Sequence[plt.Axes]],
    data_x: to.Tensor,
    data_y: to.Tensor,
    idcs_sel: list,
    data_x_min: to.Tensor = None,
    data_x_max: to.Tensor = None,
    x_label: str = '',
    y_label: str = '',
    z_label: str = '',
    min_gp_obsnoise: float = None,
    resolution: int = 201,
    num_stds: int = 2,
    alpha: float = 0.3,
    color: chr = None,
    curve_label: str = 'mean',
    heatmap_cmap: colors.Colormap = None,
    show_legend_posterior: bool = True,
    show_legend_std: bool = False,
    show_legend_data: bool = True,
    legend_data_cmap: colors.Colormap = None,
    colorbar_label: str = None,
    title: str = None,
    render3D: bool = True,
) -> plt.Figure:
    """
    Fit the GP posterior to the input data and plot the mean and std as well as the data points.
    There are 3 options: 1D plot (infered by data dimensions), 2D plot

    .. note::
        If you want to have a tight layout, it is best to pass axes of a figure with `tight_layout=True` or
        `constrained_layout=True`.

    :param ax: axis of the figure to plot on, only in case of a 2-dim heat map plot provide 2 axis
    :param data_x: data to plot on the x-axis
    :param data_y: data to process and plot on the y-axis
    :param idcs_sel: selected indices of the input data
    :param data_x_min: explicit minimum value for the evaluation grid, by default this value is extracted from `data_x`
    :param data_x_max: explicit maximum value for the evaluation grid, by default this value is extracted from `data_x`
    :param x_label: label for x-axis
    :param y_label: label for y-axis
    :param z_label: label for z-axis (3D plot only)
    :param min_gp_obsnoise: set a minimal noise value (normalized) for the GP, if `None` the GP has no measurement noise
    :param resolution: number of samples for the input (corresponds to x-axis resolution of the plot)
    :param num_stds: number of standard deviations to plot around the mean
    :param alpha: transparency (alpha-value) for the std area
    :param color: color (e.g. 'k' for black), `None` invokes the default behavior
    :param curve_label: label for the mean curve (1D plot only)
    :param heatmap_cmap: color map forwarded to `render_heatmap()` (2D plot only), `None` to use Pyrado's default
    :param show_legend_posterior: flag if the legend entry for the posterior should be printed (affects mean and std)
    :param show_legend_std: flag if a legend entry for the std area should be printed
    :param show_legend_data: flag if a legend entry for the individual data points should be printed
    :param legend_data_cmap: color map for the sampled points, default is 'binary'
    :param colorbar_label: label for the color bar (2D plot only)
    :param title: title displayed above the figure, set to `None` to suppress the title
    :param render3D: use 3D rendering if possible
    :return: handle to the resulting figure
    """
    if data_x.ndim != 2:
        raise pyrado.ShapeErr(
            msg=
            "The GP's input data needs to be of shape num_samples x dim_input!"
        )
    data_x = data_x[:, idcs_sel]  # forget the rest
    dim_x = data_x.shape[1]  # samples are along axis 0

    if data_y.ndim != 2:
        raise pyrado.ShapeErr(given=data_y,
                              expected_match=to.Size([data_x.shape[0], 1]))

    if legend_data_cmap is None:
        legend_data_cmap = plt.get_cmap('binary')

    # Project to normalized input and standardized output
    if data_x_min is None or data_x_max is None:
        data_x_min, data_x_max = to.min(data_x, dim=0)[0], to.max(data_x,
                                                                  dim=0)[0]
    data_y_mean, data_y_std = to.mean(data_y, dim=0), to.std(data_y, dim=0)
    data_x = (data_x - data_x_min) / (data_x_max - data_x_min)
    data_y = (data_y - data_y_mean) / data_y_std

    # Create and fit the GP model
    gp = SingleTaskGP(data_x, data_y)
    if min_gp_obsnoise is not None:
        gp.likelihood.noise_covar.register_constraint(
            'raw_noise', GreaterThan(min_gp_obsnoise))
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    mll.train()
    fit_gpytorch_model(mll)
    print_cbt('Fitted the SingleTaskGP.', 'g')

    argmax_pmean_norm, argmax_pmean_val_stdzed = optimize_acqf(
        acq_function=PosteriorMean(gp),
        bounds=to.stack([to.zeros(dim_x), to.ones(dim_x)]),
        q=1,
        num_restarts=500,
        raw_samples=1000)
    # Project back
    argmax_posterior = argmax_pmean_norm * (data_x_max -
                                            data_x_min) + data_x_min
    argmax_pmean_val = argmax_pmean_val_stdzed * data_y_std + data_y_mean
    print_cbt(
        f'Converged to argmax of the posterior mean: {argmax_posterior.numpy()}',
        'g')

    mll.eval()
    gp.eval()

    if dim_x == 1:
        # Evaluation grid
        x_grid = np.linspace(min(data_x),
                             max(data_x),
                             resolution,
                             endpoint=True).flatten()
        x_grid = to.from_numpy(x_grid)

        # Mean and standard deviation of the surrogate model
        posterior = gp.posterior(x_grid)
        mean = posterior.mean.detach().flatten()
        std = to.sqrt(posterior.variance.detach()).flatten()

        # Project back from normalized input and standardized output
        x_grid = x_grid * (data_x_max - data_x_min) + data_x_min
        data_x = data_x * (data_x_max - data_x_min) + data_x_min
        data_y = data_y * data_y_std + data_y_mean
        mean = mean * data_y_std + data_y_mean
        std *= data_y_std  # double-checked with posterior.mvn.confidence_region()

        # Plot the curve
        plt.fill_between(x_grid.numpy(),
                         mean.numpy() - num_stds * std.numpy(),
                         mean.numpy() + num_stds * std.numpy(),
                         alpha=alpha,
                         color=color)
        ax.plot(x_grid.numpy(), mean.numpy(), color=color)

        # Plot the queried data points
        scat_plot = ax.scatter(data_x.numpy().flatten(),
                               data_y.numpy().flatten(),
                               marker='o',
                               c=np.arange(data_x.shape[0], dtype=np.int),
                               cmap=legend_data_cmap)

        if show_legend_data:
            scat_legend = ax.legend(
                *scat_plot.legend_elements(fmt='{x:.0f}'),  # integer formatter
                bbox_to_anchor=(0., 1.1, 1., -0.1),
                title='query points',
                ncol=data_x.shape[0],
                loc='upper center',
                mode='expand',
                borderaxespad=0.,
                handletextpad=-0.5)
            ax.add_artist(scat_legend)
            # Increase vertical space between subplots when printing the data labels
            # plt.tight_layout(pad=2.)  # ignore argument
            # plt.subplots_adjust(hspace=0.6)

        # Plot the argmax of the posterior mean
        # ax.scatter(argmax_posterior.item(), argmax_pmean_val, c='darkorange', marker='o', s=60, label='argmax')
        ax.axvline(argmax_posterior.item(),
                   c='darkorange',
                   lw=1.5,
                   label='argmax')

        if show_legend_posterior:
            ax.add_artist(ax.legend(loc='lower right'))

    elif dim_x == 2:
        # Create mesh grid matrices from x and y vectors
        # x0_grid = to.linspace(min(data_x[:, 0]), max(data_x[:, 0]), resolution)
        # x1_grid = to.linspace(min(data_x[:, 1]), max(data_x[:, 1]), resolution)
        x0_grid = to.linspace(0, 1, resolution)
        x1_grid = to.linspace(0, 1, resolution)
        x0_mesh, x1_mesh = to.meshgrid([x0_grid, x1_grid])
        x0_mesh, x1_mesh = x0_mesh.t(), x1_mesh.t(
        )  # transpose not necessary but makes identical mesh as np.meshgrid

        # Mean and standard deviation of the surrogate model
        x_test = to.stack([
            x0_mesh.reshape(resolution**2, 1),
            x1_mesh.reshape(resolution**2, 1)
        ], -1).squeeze(1)
        posterior = gp.posterior(
            x_test)  # identical to  gp.likelihood(gp(x_test))
        mean = posterior.mean.detach().reshape(resolution, resolution)
        std = to.sqrt(posterior.variance.detach()).reshape(
            resolution, resolution)

        # Project back from normalized input and standardized output
        data_x = data_x * (data_x_max - data_x_min) + data_x_min
        data_y = data_y * data_y_std + data_y_mean
        mean_raw = mean * data_y_std + data_y_mean
        std_raw = std * data_y_std

        if render3D:
            # Project back from normalized input and standardized output (custom for 3D)
            x0_mesh = x0_mesh * (data_x_max[0] - data_x_min[0]) + data_x_min[0]
            x1_mesh = x1_mesh * (data_x_max[1] - data_x_min[1]) + data_x_min[1]
            lower = mean_raw - num_stds * std_raw
            upper = mean_raw + num_stds * std_raw

            # Plot a 2D surface in 3D
            ax.plot_surface(x0_mesh.numpy(), x1_mesh.numpy(), mean_raw.numpy())
            ax.plot_surface(x0_mesh.numpy(),
                            x1_mesh.numpy(),
                            lower.numpy(),
                            color='r',
                            alpha=alpha)
            ax.plot_surface(x0_mesh.numpy(),
                            x1_mesh.numpy(),
                            upper.numpy(),
                            color='r',
                            alpha=alpha)
            ax.set_xlabel(x_label)
            ax.set_ylabel(y_label)
            ax.set_zlabel(z_label)

            # Plot the queried data points
            scat_plot = ax.scatter(data_x[:, 0].numpy(),
                                   data_x[:, 1].numpy(),
                                   data_y.numpy(),
                                   marker='o',
                                   c=np.arange(data_x.shape[0], dtype=np.int),
                                   cmap=legend_data_cmap)

            if show_legend_data:
                scat_legend = ax.legend(
                    *scat_plot.legend_elements(
                        fmt='{x:.0f}'),  # integer formatter
                    bbox_to_anchor=(0.05, 1.1, 0.95, -0.1),
                    loc='upper center',
                    ncol=data_x.shape[0],
                    mode='expand',
                    borderaxespad=0.,
                    handletextpad=-0.5)
                ax.add_artist(scat_legend)

            # Plot the argmax of the posterior mean
            x, y = argmax_posterior[0, 0], argmax_posterior[0, 1]
            ax.scatter(x,
                       y,
                       argmax_pmean_val,
                       c='darkorange',
                       marker='*',
                       s=60)
            # ax.plot((x, x), (y, y), (data_y.min(), data_y.max()), c='k', ls='--', lw=1.5)

        else:
            if not len(ax) == 4:
                raise pyrado.ShapeErr(
                    msg='Provide 4 axes! 2 heat maps and 2 color bars.')

            # Project back normalized input and standardized output (custom for 2D)
            x0_grid_raw = x0_grid * (data_x_max[0] -
                                     data_x_min[0]) + data_x_min[0]
            x1_grid_raw = x1_grid * (data_x_max[1] -
                                     data_x_min[1]) + data_x_min[1]

            # Plot a 2D image
            df_mean = pd.DataFrame(mean_raw.numpy(),
                                   columns=x0_grid_raw.numpy(),
                                   index=x1_grid_raw.numpy())
            render_heatmap(df_mean,
                           ax_hm=ax[0],
                           ax_cb=ax[1],
                           x_label=x_label,
                           y_label=y_label,
                           annotate=False,
                           fig_canvas_title='Returns',
                           tick_label_prec=2,
                           add_sep_colorbar=True,
                           cmap=heatmap_cmap,
                           colorbar_label=colorbar_label,
                           num_major_ticks_hm=3,
                           num_major_ticks_cb=2,
                           colorbar_orientation='horizontal')

            df_std = pd.DataFrame(std_raw.numpy(),
                                  columns=x0_grid_raw.numpy(),
                                  index=x1_grid_raw.numpy())
            render_heatmap(
                df_std,
                ax_hm=ax[2],
                ax_cb=ax[3],
                x_label=x_label,
                y_label=y_label,
                annotate=False,
                fig_canvas_title='Standard Deviations',
                tick_label_prec=2,
                add_sep_colorbar=True,
                cmap=heatmap_cmap,
                colorbar_label=colorbar_label,
                num_major_ticks_hm=3,
                num_major_ticks_cb=2,
                colorbar_orientation='horizontal',
                norm=colors.Normalize())  # explicitly instantiate a new norm

            # Plot the queried data points
            for i in [0, 2]:
                scat_plot = ax[i].scatter(data_x[:, 0].numpy(),
                                          data_x[:, 1].numpy(),
                                          marker='o',
                                          s=15,
                                          c=np.arange(data_x.shape[0],
                                                      dtype=np.int),
                                          cmap=legend_data_cmap)

                if show_legend_data:
                    scat_legend = ax[i].legend(
                        *scat_plot.legend_elements(
                            fmt='{x:.0f}'),  # integer formatter
                        bbox_to_anchor=(0., 1.1, 1., 0.05),
                        loc='upper center',
                        ncol=data_x.shape[0],
                        mode='expand',
                        borderaxespad=0.,
                        handletextpad=-0.5)
                    ax[i].add_artist(scat_legend)

            # Plot the argmax of the posterior mean
            ax[0].scatter(argmax_posterior[0, 0],
                          argmax_posterior[0, 1],
                          c='darkorange',
                          marker='*',
                          s=60)  # steelblue
            ax[2].scatter(argmax_posterior[0, 0],
                          argmax_posterior[0, 1],
                          c='darkorange',
                          marker='*',
                          s=60)  # steelblue
            # ax[0].axvline(argmax_posterior[0, 0], c='w', ls='--', lw=1.5)
            # ax[0].axhline(argmax_posterior[0, 1], c='w', ls='--', lw=1.5)
            # ax[2].axvline(argmax_posterior[0, 0], c='w', ls='--', lw=1.5)
            # ax[2].axhline(argmax_posterior[0, 1], c='w', ls='--', lw=1.5)

    else:
        raise pyrado.ValueErr(msg='Can only plot 1-dim or 2-dim data!')

    return plt.gcf()