Exemplo n.º 1
0
import torch
import wbml.plot

from convcnp.architectures import SimpleConv
from convcnp.set_conv import ConvCNP

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def to_numpy(x):
    """Convert a PyTorch tensor to NumPy."""
    return x.squeeze().detach().cpu().numpy()


convcnp = ConvCNP(learn_length_scale=True,
                  points_per_unit=64,
                  architecture=SimpleConv())
convcnp.to(device)
load_dict = torch.load('./saved_models/convcnp-matern/model_best.pth.tar')
convcnp.load_state_dict(load_dict['state_dict'])
convcnp.eval()

# Construct GP.
kernel = stheno.Matern52().stretch(0.25)
gp = stheno.GP(kernel)

# Sample function from GP and random permutation of data.
num_points = 200
rand_indices = torch.randperm(num_points)
x_all = torch.linspace(-2., 2., num_points)
y_all = gp(x_all).sample()
Exemplo n.º 2
0
        kernel = stheno.EQ().stretch(1.) + \
                 stheno.EQ().stretch(.25) + \
                 0.001 * stheno.Delta()
    elif args.data == 'weakly-periodic':
        kernel = stheno.EQ().stretch(0.5) * stheno.EQ().periodic(period=0.25)
    else:
        raise ValueError(f'Unknown data "{args.data}".')

    gen = convcnp.data.GPGenerator(kernel=kernel)
    gen_val = convcnp.data.GPGenerator(kernel=kernel, num_tasks=60)
    gen_test = convcnp.data.GPGenerator(kernel=kernel, num_tasks=2048)

# Load model.
if args.model == 'convcnp':
    model = ConvCNP(learn_length_scale=True,
                    points_per_unit=64,
                    architecture=SimpleConv())
elif args.model == 'convcnpxl':
    model = ConvCNP(learn_length_scale=True,
                    points_per_unit=64,
                    architecture=UNet())
elif args.model == 'cnp':
    model = CNP(latent_dim=128)
elif args.model == 'anp':
    model = ANP(latent_dim=128)
else:
    raise ValueError(f'Unknown model {args.model}.')

model.to(device)

# Perform training.