def get_gpr_model(X, y, model=None):
    """
    Fit a gpr model to the data or update the model to new data
    Params ::
    X: (sx1) Tensor: Covariates
    y: (sx1) Tensor: Observations
    model: PyTorch SingleTaskGP model: If model is passed, X and y are used to 
        update it. If None then model is trained on X and y. Default is None
    Return ::
    model: PyTorch SingleTaskGP model: Trained or updated model. 
        Returned in train mode
    mll: PyTorch MarginalLogLikelihood object: Returned in train mode
    """

    if model is None:
        # set up model
        model = SingleTaskGP(X, y)
    else:
        # update model with new observations
        model = model.condition_on_observations(X, y)
    mll = ExactMarginalLogLikelihood(model.likelihood, model).to(X)
    # begin training
    model.train()
    mll.train()
    fit_gpytorch_model(mll)
    return model, mll
Esempio n. 2
0
def get_gpr_model(X, y, model=None):
    """Fit a gpr model to the data or update the model to new data.


    Parameters
    ----------


    X: (sx1) Tensor
        Covariates
    y: (sx1) Tensor
        Observations
    model: PyTorch SingleTaskGP model
        If model is passed, X and y are used to update it. 
        If None then model is trained on X and y. Default is None.

    
    Returns
    -------


    model: PyTorch SingleTaskGP model
        Trained or updated model. Returned in train mode.
    mll: PyTorch MarginalLogLikelihood object
        This is the loss used to train hyperparameters. Returned in train mode.


    """
    if model is None:
        # set up model
        print('X', X.shape)
        print('y', y.shape)
        model = SingleTaskGP(X, y)
    else:
        # update model with new observations
        model = model.condition_on_observations(X, y)
    mll = ExactMarginalLogLikelihood(model.likelihood, model).to(X)
    # begin training
    model.train()
    mll.train()
    fit_gpytorch_model(mll)
    return model, mll
Esempio n. 3
0
    def test_lanczos_fantasy_model(self):
        lanczos_thresh = 10
        n = lanczos_thresh + 1
        n_dims = 2
        with settings.max_cholesky_size(lanczos_thresh):
            x = torch.ones((n, n_dims))
            y = torch.randn(n)
            likelihood = GaussianLikelihood()
            model = ExactGPModel(x, y, likelihood=likelihood)
            mll = ExactMarginalLogLikelihood(likelihood, model)
            mll.train()
            mll.eval()

            # get a posterior to fill in caches
            model(torch.randn((1, n_dims)))

            new_n = 2
            new_x = torch.randn((new_n, n_dims))
            new_y = torch.randn(new_n)
            # just check that this can run without error
            model.get_fantasy_model(new_x, new_y)
def render_singletask_gp(
    ax: [plt.Axes, Axes3D, Sequence[plt.Axes]],
    data_x: to.Tensor,
    data_y: to.Tensor,
    idcs_sel: list,
    data_x_min: to.Tensor = None,
    data_x_max: to.Tensor = None,
    x_label: str = '',
    y_label: str = '',
    z_label: str = '',
    min_gp_obsnoise: float = None,
    resolution: int = 201,
    num_stds: int = 2,
    alpha: float = 0.3,
    color: chr = None,
    curve_label: str = 'mean',
    heatmap_cmap: colors.Colormap = None,
    show_legend_posterior: bool = True,
    show_legend_std: bool = False,
    show_legend_data: bool = True,
    legend_data_cmap: colors.Colormap = None,
    colorbar_label: str = None,
    title: str = None,
    render3D: bool = True,
) -> plt.Figure:
    """
    Fit the GP posterior to the input data and plot the mean and std as well as the data points.
    There are 3 options: 1D plot (infered by data dimensions), 2D plot

    .. note::
        If you want to have a tight layout, it is best to pass axes of a figure with `tight_layout=True` or
        `constrained_layout=True`.

    :param ax: axis of the figure to plot on, only in case of a 2-dim heat map plot provide 2 axis
    :param data_x: data to plot on the x-axis
    :param data_y: data to process and plot on the y-axis
    :param idcs_sel: selected indices of the input data
    :param data_x_min: explicit minimum value for the evaluation grid, by default this value is extracted from `data_x`
    :param data_x_max: explicit maximum value for the evaluation grid, by default this value is extracted from `data_x`
    :param x_label: label for x-axis
    :param y_label: label for y-axis
    :param z_label: label for z-axis (3D plot only)
    :param min_gp_obsnoise: set a minimal noise value (normalized) for the GP, if `None` the GP has no measurement noise
    :param resolution: number of samples for the input (corresponds to x-axis resolution of the plot)
    :param num_stds: number of standard deviations to plot around the mean
    :param alpha: transparency (alpha-value) for the std area
    :param color: color (e.g. 'k' for black), `None` invokes the default behavior
    :param curve_label: label for the mean curve (1D plot only)
    :param heatmap_cmap: color map forwarded to `render_heatmap()` (2D plot only), `None` to use Pyrado's default
    :param show_legend_posterior: flag if the legend entry for the posterior should be printed (affects mean and std)
    :param show_legend_std: flag if a legend entry for the std area should be printed
    :param show_legend_data: flag if a legend entry for the individual data points should be printed
    :param legend_data_cmap: color map for the sampled points, default is 'binary'
    :param colorbar_label: label for the color bar (2D plot only)
    :param title: title displayed above the figure, set to `None` to suppress the title
    :param render3D: use 3D rendering if possible
    :return: handle to the resulting figure
    """
    if data_x.ndim != 2:
        raise pyrado.ShapeErr(
            msg=
            "The GP's input data needs to be of shape num_samples x dim_input!"
        )
    data_x = data_x[:, idcs_sel]  # forget the rest
    dim_x = data_x.shape[1]  # samples are along axis 0

    if data_y.ndim != 2:
        raise pyrado.ShapeErr(given=data_y,
                              expected_match=to.Size([data_x.shape[0], 1]))

    if legend_data_cmap is None:
        legend_data_cmap = plt.get_cmap('binary')

    # Project to normalized input and standardized output
    if data_x_min is None or data_x_max is None:
        data_x_min, data_x_max = to.min(data_x, dim=0)[0], to.max(data_x,
                                                                  dim=0)[0]
    data_y_mean, data_y_std = to.mean(data_y, dim=0), to.std(data_y, dim=0)
    data_x = (data_x - data_x_min) / (data_x_max - data_x_min)
    data_y = (data_y - data_y_mean) / data_y_std

    # Create and fit the GP model
    gp = SingleTaskGP(data_x, data_y)
    if min_gp_obsnoise is not None:
        gp.likelihood.noise_covar.register_constraint(
            'raw_noise', GreaterThan(min_gp_obsnoise))
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    mll.train()
    fit_gpytorch_model(mll)
    print_cbt('Fitted the SingleTaskGP.', 'g')

    argmax_pmean_norm, argmax_pmean_val_stdzed = optimize_acqf(
        acq_function=PosteriorMean(gp),
        bounds=to.stack([to.zeros(dim_x), to.ones(dim_x)]),
        q=1,
        num_restarts=500,
        raw_samples=1000)
    # Project back
    argmax_posterior = argmax_pmean_norm * (data_x_max -
                                            data_x_min) + data_x_min
    argmax_pmean_val = argmax_pmean_val_stdzed * data_y_std + data_y_mean
    print_cbt(
        f'Converged to argmax of the posterior mean: {argmax_posterior.numpy()}',
        'g')

    mll.eval()
    gp.eval()

    if dim_x == 1:
        # Evaluation grid
        x_grid = np.linspace(min(data_x),
                             max(data_x),
                             resolution,
                             endpoint=True).flatten()
        x_grid = to.from_numpy(x_grid)

        # Mean and standard deviation of the surrogate model
        posterior = gp.posterior(x_grid)
        mean = posterior.mean.detach().flatten()
        std = to.sqrt(posterior.variance.detach()).flatten()

        # Project back from normalized input and standardized output
        x_grid = x_grid * (data_x_max - data_x_min) + data_x_min
        data_x = data_x * (data_x_max - data_x_min) + data_x_min
        data_y = data_y * data_y_std + data_y_mean
        mean = mean * data_y_std + data_y_mean
        std *= data_y_std  # double-checked with posterior.mvn.confidence_region()

        # Plot the curve
        plt.fill_between(x_grid.numpy(),
                         mean.numpy() - num_stds * std.numpy(),
                         mean.numpy() + num_stds * std.numpy(),
                         alpha=alpha,
                         color=color)
        ax.plot(x_grid.numpy(), mean.numpy(), color=color)

        # Plot the queried data points
        scat_plot = ax.scatter(data_x.numpy().flatten(),
                               data_y.numpy().flatten(),
                               marker='o',
                               c=np.arange(data_x.shape[0], dtype=np.int),
                               cmap=legend_data_cmap)

        if show_legend_data:
            scat_legend = ax.legend(
                *scat_plot.legend_elements(fmt='{x:.0f}'),  # integer formatter
                bbox_to_anchor=(0., 1.1, 1., -0.1),
                title='query points',
                ncol=data_x.shape[0],
                loc='upper center',
                mode='expand',
                borderaxespad=0.,
                handletextpad=-0.5)
            ax.add_artist(scat_legend)
            # Increase vertical space between subplots when printing the data labels
            # plt.tight_layout(pad=2.)  # ignore argument
            # plt.subplots_adjust(hspace=0.6)

        # Plot the argmax of the posterior mean
        # ax.scatter(argmax_posterior.item(), argmax_pmean_val, c='darkorange', marker='o', s=60, label='argmax')
        ax.axvline(argmax_posterior.item(),
                   c='darkorange',
                   lw=1.5,
                   label='argmax')

        if show_legend_posterior:
            ax.add_artist(ax.legend(loc='lower right'))

    elif dim_x == 2:
        # Create mesh grid matrices from x and y vectors
        # x0_grid = to.linspace(min(data_x[:, 0]), max(data_x[:, 0]), resolution)
        # x1_grid = to.linspace(min(data_x[:, 1]), max(data_x[:, 1]), resolution)
        x0_grid = to.linspace(0, 1, resolution)
        x1_grid = to.linspace(0, 1, resolution)
        x0_mesh, x1_mesh = to.meshgrid([x0_grid, x1_grid])
        x0_mesh, x1_mesh = x0_mesh.t(), x1_mesh.t(
        )  # transpose not necessary but makes identical mesh as np.meshgrid

        # Mean and standard deviation of the surrogate model
        x_test = to.stack([
            x0_mesh.reshape(resolution**2, 1),
            x1_mesh.reshape(resolution**2, 1)
        ], -1).squeeze(1)
        posterior = gp.posterior(
            x_test)  # identical to  gp.likelihood(gp(x_test))
        mean = posterior.mean.detach().reshape(resolution, resolution)
        std = to.sqrt(posterior.variance.detach()).reshape(
            resolution, resolution)

        # Project back from normalized input and standardized output
        data_x = data_x * (data_x_max - data_x_min) + data_x_min
        data_y = data_y * data_y_std + data_y_mean
        mean_raw = mean * data_y_std + data_y_mean
        std_raw = std * data_y_std

        if render3D:
            # Project back from normalized input and standardized output (custom for 3D)
            x0_mesh = x0_mesh * (data_x_max[0] - data_x_min[0]) + data_x_min[0]
            x1_mesh = x1_mesh * (data_x_max[1] - data_x_min[1]) + data_x_min[1]
            lower = mean_raw - num_stds * std_raw
            upper = mean_raw + num_stds * std_raw

            # Plot a 2D surface in 3D
            ax.plot_surface(x0_mesh.numpy(), x1_mesh.numpy(), mean_raw.numpy())
            ax.plot_surface(x0_mesh.numpy(),
                            x1_mesh.numpy(),
                            lower.numpy(),
                            color='r',
                            alpha=alpha)
            ax.plot_surface(x0_mesh.numpy(),
                            x1_mesh.numpy(),
                            upper.numpy(),
                            color='r',
                            alpha=alpha)
            ax.set_xlabel(x_label)
            ax.set_ylabel(y_label)
            ax.set_zlabel(z_label)

            # Plot the queried data points
            scat_plot = ax.scatter(data_x[:, 0].numpy(),
                                   data_x[:, 1].numpy(),
                                   data_y.numpy(),
                                   marker='o',
                                   c=np.arange(data_x.shape[0], dtype=np.int),
                                   cmap=legend_data_cmap)

            if show_legend_data:
                scat_legend = ax.legend(
                    *scat_plot.legend_elements(
                        fmt='{x:.0f}'),  # integer formatter
                    bbox_to_anchor=(0.05, 1.1, 0.95, -0.1),
                    loc='upper center',
                    ncol=data_x.shape[0],
                    mode='expand',
                    borderaxespad=0.,
                    handletextpad=-0.5)
                ax.add_artist(scat_legend)

            # Plot the argmax of the posterior mean
            x, y = argmax_posterior[0, 0], argmax_posterior[0, 1]
            ax.scatter(x,
                       y,
                       argmax_pmean_val,
                       c='darkorange',
                       marker='*',
                       s=60)
            # ax.plot((x, x), (y, y), (data_y.min(), data_y.max()), c='k', ls='--', lw=1.5)

        else:
            if not len(ax) == 4:
                raise pyrado.ShapeErr(
                    msg='Provide 4 axes! 2 heat maps and 2 color bars.')

            # Project back normalized input and standardized output (custom for 2D)
            x0_grid_raw = x0_grid * (data_x_max[0] -
                                     data_x_min[0]) + data_x_min[0]
            x1_grid_raw = x1_grid * (data_x_max[1] -
                                     data_x_min[1]) + data_x_min[1]

            # Plot a 2D image
            df_mean = pd.DataFrame(mean_raw.numpy(),
                                   columns=x0_grid_raw.numpy(),
                                   index=x1_grid_raw.numpy())
            render_heatmap(df_mean,
                           ax_hm=ax[0],
                           ax_cb=ax[1],
                           x_label=x_label,
                           y_label=y_label,
                           annotate=False,
                           fig_canvas_title='Returns',
                           tick_label_prec=2,
                           add_sep_colorbar=True,
                           cmap=heatmap_cmap,
                           colorbar_label=colorbar_label,
                           num_major_ticks_hm=3,
                           num_major_ticks_cb=2,
                           colorbar_orientation='horizontal')

            df_std = pd.DataFrame(std_raw.numpy(),
                                  columns=x0_grid_raw.numpy(),
                                  index=x1_grid_raw.numpy())
            render_heatmap(
                df_std,
                ax_hm=ax[2],
                ax_cb=ax[3],
                x_label=x_label,
                y_label=y_label,
                annotate=False,
                fig_canvas_title='Standard Deviations',
                tick_label_prec=2,
                add_sep_colorbar=True,
                cmap=heatmap_cmap,
                colorbar_label=colorbar_label,
                num_major_ticks_hm=3,
                num_major_ticks_cb=2,
                colorbar_orientation='horizontal',
                norm=colors.Normalize())  # explicitly instantiate a new norm

            # Plot the queried data points
            for i in [0, 2]:
                scat_plot = ax[i].scatter(data_x[:, 0].numpy(),
                                          data_x[:, 1].numpy(),
                                          marker='o',
                                          s=15,
                                          c=np.arange(data_x.shape[0],
                                                      dtype=np.int),
                                          cmap=legend_data_cmap)

                if show_legend_data:
                    scat_legend = ax[i].legend(
                        *scat_plot.legend_elements(
                            fmt='{x:.0f}'),  # integer formatter
                        bbox_to_anchor=(0., 1.1, 1., 0.05),
                        loc='upper center',
                        ncol=data_x.shape[0],
                        mode='expand',
                        borderaxespad=0.,
                        handletextpad=-0.5)
                    ax[i].add_artist(scat_legend)

            # Plot the argmax of the posterior mean
            ax[0].scatter(argmax_posterior[0, 0],
                          argmax_posterior[0, 1],
                          c='darkorange',
                          marker='*',
                          s=60)  # steelblue
            ax[2].scatter(argmax_posterior[0, 0],
                          argmax_posterior[0, 1],
                          c='darkorange',
                          marker='*',
                          s=60)  # steelblue
            # ax[0].axvline(argmax_posterior[0, 0], c='w', ls='--', lw=1.5)
            # ax[0].axhline(argmax_posterior[0, 1], c='w', ls='--', lw=1.5)
            # ax[2].axvline(argmax_posterior[0, 0], c='w', ls='--', lw=1.5)
            # ax[2].axhline(argmax_posterior[0, 1], c='w', ls='--', lw=1.5)

    else:
        raise pyrado.ValueErr(msg='Can only plot 1-dim or 2-dim data!')

    return plt.gcf()
class OnlineExactRegression(torch.nn.Module):
    def __init__(self, stem, init_x, init_y, lr, **kwargs):
        super().__init__()
        self.stem = stem.to(init_x.device)
        if init_y.t().shape[0] != 1:
            _batch_shape = init_y.t().shape[:-1]
        else:
            _batch_shape = torch.Size()
        features = self.stem(init_x)
        self.gp = SingleTaskGP(features,
                               init_y,
                               covar_module=ScaleKernel(
                                   RBFKernel(batch_shape=_batch_shape,
                                             ard_num_dims=stem.output_dim),
                                   batch_shape=_batch_shape))
        self.mll = ExactMarginalLogLikelihood(self.gp.likelihood, self.gp)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        self._raw_inputs = [init_x]
        self._target_batch_shape = _batch_shape
        self.target_dim = init_y.size(-1)

    def update(self, inputs, targets, update_stem=True, update_gp=True):
        inputs = inputs.view(-1, self.stem.input_dim)
        targets = targets.view(-1, self.target_dim)

        # add observation
        self.train()
        self._raw_inputs = [torch.cat([*self._raw_inputs, inputs])]
        self.gp.train_targets = torch.cat(
            [self.gp.train_targets,
             self._reshape_targets(targets)], dim=-1)

        if update_stem:
            self._refresh_features(*self._raw_inputs, strict=False)
        else:
            with torch.no_grad():
                self._refresh_features(*self._raw_inputs, strict=False)

        self.mll = ExactMarginalLogLikelihood(self.gp.likelihood, self.gp)
        # update stem and GP
        if update_gp:
            self.optimizer.zero_grad()
            with gpytorch.settings.skip_logdet_forward(True):
                train_dist = self.gp(*self.gp.train_inputs)
                loss = -self.mll(train_dist, self.gp.train_targets).sum()
            loss.backward()
            self.optimizer.step()
            self.gp.zero_grad()

        # update GP training data again
        if update_stem:
            with torch.no_grad():
                self._refresh_features(*self._raw_inputs)

        self.eval()
        stem_loss = gp_loss = loss.item() if update_gp else 0.
        return stem_loss, gp_loss

    def fit(self, inputs, targets, num_epochs, test_dataset=None):
        records = []
        self.gp.train_targets = self._reshape_targets(targets)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, num_epochs, 1e-4)
        for epoch in range(num_epochs):
            self.train()
            self.mll.train()
            self.optimizer.zero_grad()
            self._refresh_features(inputs)
            train_dist = self.gp(*self.gp.train_inputs)
            with gpytorch.settings.skip_logdet_forward(False):
                loss = -self.mll(train_dist, self.gp.train_targets).sum()
            loss.backward()
            self.optimizer.step()
            lr_scheduler.step()
            self.gp.zero_grad()

            rmse = nll = float('NaN')
            if test_dataset is not None:
                test_x, test_y = test_dataset[:]
                rmse, nll = self.evaluate(test_x, test_y)
            records.append({
                'train_loss': loss.item(),
                'test_rmse': rmse,
                'test_nll': nll,
                'noise': self.gp.likelihood.noise.mean().item(),
                'epoch': epoch + 1
            })

        with torch.no_grad():
            self._refresh_features(inputs)

        self.eval()
        return records

    def forward(self, inputs):
        inputs = inputs.view(-1, self.stem.input_dim)
        features = self.stem(inputs)
        return self.gp(features)

    def predict(self, inputs):
        self.eval()
        pred_dist = self(inputs)
        pred_dist = self.gp.likelihood(pred_dist)
        return pred_dist.mean, pred_dist.variance

    def evaluate(self, inputs, targets):
        inputs = inputs.view(-1, self.stem.input_dim)
        targets = targets.view(-1, self.target_dim)
        with torch.no_grad():
            return regression.evaluate(self, inputs, targets)

    def set_train_data(self, inputs, targets, strict):
        inputs = inputs.expand(*self._target_batch_shape, -1, -1)
        if self.target_dim == 1:
            targets = targets.squeeze(0)
        self.gp.set_train_data(inputs, targets, strict)

    def _reshape_targets(self, targets):
        targets = targets.view(-1, self.target_dim)
        if targets.size(-1) == 1:
            targets = targets.squeeze(-1)
        else:
            targets = targets.t()
        return targets

    def _refresh_features(self, inputs, strict=True):
        features = self.stem(inputs)
        self.set_train_data(features, self.gp.train_targets, strict)
        return features

    def set_lr(self, gp_lr, stem_lr=None, bn_mom=None):
        stem_lr = gp_lr if stem_lr is None else stem_lr
        self.optimizer = torch.optim.Adam([
            dict(params=self.gp.parameters(), lr=gp_lr),
            dict(params=self.stem.parameters(), lr=stem_lr)
        ])
        if bn_mom is not None:
            for m in self.stem.modules():
                if isinstance(m, torch.nn.BatchNorm1d):
                    m.momentum = bn_mom

    @property
    def noise(self):
        return self.gp.likelihood.noise