Exemplo n.º 1
0
def parse_kernel(
    kernel_type,
    x_dim,
    rkhs_dim,
    length_scale,
    sigma_var,
    learnable_length_scale=False,
):
    if kernel_type == "rbf":
        kernel = SeparableKernel(
            x_dim,
            rkhs_dim,
            RBFKernelReparametrised(
                x_dim,
                log_length_scale=nn.Parameter(
                    torch.tensor(length_scale).log()) if learnable_length_scale
                else torch.tensor(length_scale).log(),
                sigma_var=sigma_var,
            ),
        )
    elif kernel_type == "divfree":
        if (rkhs_dim != x_dim) & rkhs_dim != 2:
            raise ValueError(
                f"RKHS and X dim for {kernel_type} must be 2. Given {rkhs_dim=}, {x_dim=}."
            )
        kernel = RBFDivergenceFreeKernelReparametrised(
            x_dim,
            log_length_scale=nn.Parameter(torch.tensor(length_scale).log())
            if learnable_length_scale else torch.tensor(length_scale).log(),
            sigma_var=sigma_var,
        )
    elif kernel_type == "curlfree":
        if (rkhs_dim != x_dim) & rkhs_dim != 2:
            raise ValueError(
                f"RKHS and X dim for {kernel_type} must be 2. Given {rkhs_dim=}, {x_dim=}."
            )
        kernel = RBFCurlFreeKernelReparametrised(
            x_dim,
            log_length_scale=nn.Parameter(torch.tensor(length_scale).log())
            if learnable_length_scale else torch.tensor(length_scale).log(),
            sigma_var=sigma_var,
        )
    else:
        raise ValueError(
            f"{kernel_type} is not a recognised kernel type to use.")

    return kernel
Exemplo n.º 2
0
    def get_kernel(self):
        if self.kernel_type == "rbf":
            kernel = SeparableKernel(
                2,
                2,
                RBFKernel(2,
                          length_scale=self.length_scale,
                          sigma_var=self.sigma_var),
            )
        elif self.kernel_type == "divfree":
            kernel = RBFDivergenceFreeKernel(2,
                                             length_scale=self.length_scale,
                                             sigma_var=self.sigma_var)
        elif self.kernel_type == "curlfree":
            kernel = RBFCurlFreeKernel(2,
                                       length_scale=self.length_scale,
                                       sigma_var=self.sigma_var)
        else:
            raise ValueError(
                f"{self.kernel_type} is not a recognised kernel type to use.")

        return kernel
Exemplo n.º 3
0
import torch

from steer_cnp.kernel import (
    RBFKernel,
    SeparableKernel,
    DotProductKernel,
    RBFDivergenceFreeKernel,
    RBFCurlFreeKernel,
    kernel_smooth,
)
from steer_cnp.gp import sample_gp_prior, conditional_gp_posterior
from steer_cnp.utils import sample_gp_grid_2d, plot_inference
import matplotlib.pyplot as plt

# %%
rbf = SeparableKernel(3, 3, RBFKernel(3, 3.0))
# div = RBFDivergenceFreeKernel(3, 3.0)
# curl = RBFCurlFreeKernel(3, 3.0)

# %%
x = torch.arange(-4, 4, step=0.5)
x1, x2, x3 = torch.meshgrid(x, x, x)
x1 = x1.flatten()
x2 = x2.flatten()
x3 = x3.flatten()

X_grid = torch.stack([x1, x2, x3], dim=-1)

# %%
Y_rbf = sample_gp_prior(X_grid, rbf)
# plt.quiver(
Exemplo n.º 4
0
import torch

from steer_cnp.kernel import (
    RBFKernel,
    SeparableKernel,
    DotProductKernel,
    RBFDivergenceFreeKernel,
    RBFCurlFreeKernel,
    kernel_smooth,
)
from steer_cnp.gp import sample_gp_prior, conditional_gp_posterior
from steer_cnp.utils import sample_gp_grid_2d, plot_inference
import matplotlib.pyplot as plt

# %%
rbf = SeparableKernel(2, 2, RBFKernel(2, 3.0))
div = RBFDivergenceFreeKernel(2, 3.0)
curl = RBFCurlFreeKernel(2, 3.0)

# %%
x = torch.arange(-4, 4, step=0.5)
x1, x2 = torch.meshgrid(x, x)
x1 = x1.flatten()
x2 = x2.flatten()

X_grid = torch.stack([x1, x2], dim=-1)

# %%
Y_rbf = sample_gp_prior(X_grid, rbf)
plt.quiver(
    X_grid[:, 0],
Exemplo n.º 5
0
from steer_cnp.equiv_deepsets import EquivDeepSet
from steer_cnp.utils import (
    get_e2_decoder,
    grid_2d,
    plot_vector_field,
    plot_inference,
    plot_embedding,
)

import e2cnn.nn as gnn

import matplotlib.pyplot as plt

# %%

embedding_kernel = SeparableKernel(2, 3, RBFKernel(2, 1.0))
grid_ranges = [-4.0, 4.0]
n_axes = 20
normalise = True
cnn = get_e2_decoder(4,
                     False,
                     "regular_small", [1], [1],
                     activation="normrelu")
output_kernel = SeparableKernel(2, 6, RBFKernel(2, 1.0))
dim = 2

deepset = EquivDeepSet(
    grid_ranges=grid_ranges,
    n_axes=n_axes,
    embedding_kernel=embedding_kernel,
    normalise_embedding=normalise,
Exemplo n.º 6
0
x, y, n_context = m[np.random.randint(1000)]

# %%

# plt.scatter(m0[0][:, 0], m0[0][:, 1], c=m0[1])

img = points_to_partial(28, x[:n_context].numpy(), y[:n_context].numpy())

plt.imshow(img)
# %%

kernel = SeparableKernel(
    2,
    2,
    RBFKernelReparametrised(
        2,
        log_length_scale=torch.tensor(3.0).log(),
        sigma_var=1.0,
    ),
)
embedder = DiscretisedRKHSEmbedding([0, 27], 28, dim=2, kernel=kernel, normalise=True)

# %%

grid, Y_target = embedder(x[:n_context].unsqueeze(0), y[:n_context].unsqueeze(0))
grid = grid.squeeze(0).numpy().astype(int)
Y_target = Y_target.squeeze(0).numpy()
img = points_to_img(28, grid, Y_target[:, 0])
plt.imshow(img)
plt.colorbar()
# %%
Exemplo n.º 7
0
    RBFCurlFreeKernel,
    kernel_smooth,
)

# %%

rbf = RBFKernel(3, 1.0)

# %%
X = torch.randn((10, 5, 3))
Y = torch.randn((10, 4, 3))

rbf(X, Y).shape
# %%

sep = SeparableKernel(3, 2, rbf)

# %%

sep(X, Y, flatten=False).shape

# %%

dp = DotProductKernel(3)

# %%

dp(X, Y).shape
# %%

# %%