def parse_kernel( kernel_type, x_dim, rkhs_dim, length_scale, sigma_var, learnable_length_scale=False, ): if kernel_type == "rbf": kernel = SeparableKernel( x_dim, rkhs_dim, RBFKernelReparametrised( x_dim, log_length_scale=nn.Parameter( torch.tensor(length_scale).log()) if learnable_length_scale else torch.tensor(length_scale).log(), sigma_var=sigma_var, ), ) elif kernel_type == "divfree": if (rkhs_dim != x_dim) & rkhs_dim != 2: raise ValueError( f"RKHS and X dim for {kernel_type} must be 2. Given {rkhs_dim=}, {x_dim=}." ) kernel = RBFDivergenceFreeKernelReparametrised( x_dim, log_length_scale=nn.Parameter(torch.tensor(length_scale).log()) if learnable_length_scale else torch.tensor(length_scale).log(), sigma_var=sigma_var, ) elif kernel_type == "curlfree": if (rkhs_dim != x_dim) & rkhs_dim != 2: raise ValueError( f"RKHS and X dim for {kernel_type} must be 2. Given {rkhs_dim=}, {x_dim=}." ) kernel = RBFCurlFreeKernelReparametrised( x_dim, log_length_scale=nn.Parameter(torch.tensor(length_scale).log()) if learnable_length_scale else torch.tensor(length_scale).log(), sigma_var=sigma_var, ) else: raise ValueError( f"{kernel_type} is not a recognised kernel type to use.") return kernel
def get_kernel(self): if self.kernel_type == "rbf": kernel = SeparableKernel( 2, 2, RBFKernel(2, length_scale=self.length_scale, sigma_var=self.sigma_var), ) elif self.kernel_type == "divfree": kernel = RBFDivergenceFreeKernel(2, length_scale=self.length_scale, sigma_var=self.sigma_var) elif self.kernel_type == "curlfree": kernel = RBFCurlFreeKernel(2, length_scale=self.length_scale, sigma_var=self.sigma_var) else: raise ValueError( f"{self.kernel_type} is not a recognised kernel type to use.") return kernel
import torch from steer_cnp.kernel import ( RBFKernel, SeparableKernel, DotProductKernel, RBFDivergenceFreeKernel, RBFCurlFreeKernel, kernel_smooth, ) from steer_cnp.gp import sample_gp_prior, conditional_gp_posterior from steer_cnp.utils import sample_gp_grid_2d, plot_inference import matplotlib.pyplot as plt # %% rbf = SeparableKernel(3, 3, RBFKernel(3, 3.0)) # div = RBFDivergenceFreeKernel(3, 3.0) # curl = RBFCurlFreeKernel(3, 3.0) # %% x = torch.arange(-4, 4, step=0.5) x1, x2, x3 = torch.meshgrid(x, x, x) x1 = x1.flatten() x2 = x2.flatten() x3 = x3.flatten() X_grid = torch.stack([x1, x2, x3], dim=-1) # %% Y_rbf = sample_gp_prior(X_grid, rbf) # plt.quiver(
import torch from steer_cnp.kernel import ( RBFKernel, SeparableKernel, DotProductKernel, RBFDivergenceFreeKernel, RBFCurlFreeKernel, kernel_smooth, ) from steer_cnp.gp import sample_gp_prior, conditional_gp_posterior from steer_cnp.utils import sample_gp_grid_2d, plot_inference import matplotlib.pyplot as plt # %% rbf = SeparableKernel(2, 2, RBFKernel(2, 3.0)) div = RBFDivergenceFreeKernel(2, 3.0) curl = RBFCurlFreeKernel(2, 3.0) # %% x = torch.arange(-4, 4, step=0.5) x1, x2 = torch.meshgrid(x, x) x1 = x1.flatten() x2 = x2.flatten() X_grid = torch.stack([x1, x2], dim=-1) # %% Y_rbf = sample_gp_prior(X_grid, rbf) plt.quiver( X_grid[:, 0],
from steer_cnp.equiv_deepsets import EquivDeepSet from steer_cnp.utils import ( get_e2_decoder, grid_2d, plot_vector_field, plot_inference, plot_embedding, ) import e2cnn.nn as gnn import matplotlib.pyplot as plt # %% embedding_kernel = SeparableKernel(2, 3, RBFKernel(2, 1.0)) grid_ranges = [-4.0, 4.0] n_axes = 20 normalise = True cnn = get_e2_decoder(4, False, "regular_small", [1], [1], activation="normrelu") output_kernel = SeparableKernel(2, 6, RBFKernel(2, 1.0)) dim = 2 deepset = EquivDeepSet( grid_ranges=grid_ranges, n_axes=n_axes, embedding_kernel=embedding_kernel, normalise_embedding=normalise,
x, y, n_context = m[np.random.randint(1000)] # %% # plt.scatter(m0[0][:, 0], m0[0][:, 1], c=m0[1]) img = points_to_partial(28, x[:n_context].numpy(), y[:n_context].numpy()) plt.imshow(img) # %% kernel = SeparableKernel( 2, 2, RBFKernelReparametrised( 2, log_length_scale=torch.tensor(3.0).log(), sigma_var=1.0, ), ) embedder = DiscretisedRKHSEmbedding([0, 27], 28, dim=2, kernel=kernel, normalise=True) # %% grid, Y_target = embedder(x[:n_context].unsqueeze(0), y[:n_context].unsqueeze(0)) grid = grid.squeeze(0).numpy().astype(int) Y_target = Y_target.squeeze(0).numpy() img = points_to_img(28, grid, Y_target[:, 0]) plt.imshow(img) plt.colorbar() # %%
RBFCurlFreeKernel, kernel_smooth, ) # %% rbf = RBFKernel(3, 1.0) # %% X = torch.randn((10, 5, 3)) Y = torch.randn((10, 4, 3)) rbf(X, Y).shape # %% sep = SeparableKernel(3, 2, rbf) # %% sep(X, Y, flatten=False).shape # %% dp = DotProductKernel(3) # %% dp(X, Y).shape # %% # %%