Ejemplo n.º 1
0
 def __init__(self, kernel=stheno.EQ(), **kw_args):
     self.gp = stheno.GP(kernel, graph=stheno.Graph())
     DataGenerator.__init__(self, **kw_args)
Ejemplo n.º 2
0
    Returns:
        number: `x` rounded to the nearest above multiple of `multiple`.
    """
    if x % multiple == 0:
        return x
    else:
        return x + multiple - x % multiple


if __name__ == "__main__":

    # example data
    from data.GP.GP_data_generator import MultiClassGPGenerator
    import stheno
    train_data = MultiClassGPGenerator(
        [stheno.EQ(), stheno.EQ().periodic(1)],
        0.5,
        kernel_names=["EQ", "Periodic"],
        batch_size=64,
        num_tasks=10)
    task, label = train_data.generate_task()
    x_context = task["x_context"]
    y_context = task["y_context"]
    x_target = task["x"]
    y_target = task["y"]

    learn_length_scale = True
    points_per_unit = 10
    type_CNN = "UNet"
    num_input_channels = 1
    num_output_channels = 2
Ejemplo n.º 3
0
import stheno
import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":
    kernels = [stheno.EQ().stretch(1) * 0.1, stheno.EQ().stretch(0.1) * 1]
    gps = [stheno.GP(kernel) for kernel in kernels]
    labels = [""]
    num_samples = 5
    step = 0.01
    x = np.arange(-2,2 + step, step)

    dir_plot = "../figures/write_up/gps/gp.svg"
    for i in range(len(gps)):
        for j in range(num_samples):
            y = gps[i](x).sample()
            plt.plot(x,y,alpha=0.5,color="b")
            plt.ylim([-4,4])
            plt.xlabel("x",fontsize=15)
            plt.ylabel("y",fontsize=15)
        dir_plot_local = dir_plot[:-4] + str(i) + dir_plot[-4:]
        plt.savefig(dir_plot_local)
        plt.close()