Esempio n. 1
0
    def train(self, data_loader, epochs, early_stopping=None):
        """Training procedure for DKL on arbitrary function"""
        print('std')
        for epoch in range(epochs):
            epoch_loss = 0.
            for i, data in enumerate(data_loader):
                t = time.time()
                # Zero backprop gradients
                self.optimizer.zero_grad()
                # Get output from model
                x, y = data  # add , num_points

                x_context, y_context, x_target, y_target = context_target_split(x[0:1], y[0:1],
                                                                                self.num_context,
                                                                                self.num_target)
                self.model.set_train_data(inputs=x_context, targets=y_context.view(-1), strict=False)
                #with gpytorch.settings.use_toeplitz(False), gpytorch.settings.fast_pred_var(False):
                output = self.model(x_context)
                if any(torch.isnan(output.stddev.view(-1))):
                    print('nan at epoch ', epoch)
                    continue
                # Calc loss and backprop derivatives
                loss = -self.mll(output, y_context.view(-1))  # .sum()
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()
                avg_loss = epoch_loss / len(data_loader)
            if epoch % self.print_freq == 0 or epoch == epochs - 1:
                print("Epoch: {}, Avg_loss: {}".format(epoch, avg_loss))

            self.epoch_loss_history.append(avg_loss)

            if early_stopping is not None:
                if avg_loss < early_stopping:
                    break
Esempio n. 2
0
    def train_rl(self, data_loader, epochs, early_stopping=None):
        """
        Train MKI model as a part if IMeL
        """
        for epoch in range(epochs):
            epoch_loss = 0.
            for i, data in enumerate(data_loader):
                # Zero backprop gradients
                self.optimizer.zero_grad()
                # Get output from model
                x, y, num_points = data  # add , num_points
                # divide context (N-1) and target (1)
                x_context, y_context, x_target, y_target = context_target_split(x[:,:num_points,:], y[:,:num_points,:],
                                                                  num_points.item()-1, 1)
                all_x_context, all_y_context = merge_context(context_points_list)
                prediction = self.model(x_context, y_context, x_target)
                if torch.isnan(prediction):
                    prediction = self.model(x_context, y_context, x_target)
                loss = self._loss(y_target, prediction)
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()
                avg_loss = epoch_loss / len(data_loader)

            if epoch % self.print_freq == 0 or epoch == epochs-1 :
                print("Epoch: {}, Avg_loss: {}, W_sum: {}".format(epoch, avg_loss, self.model.interpolator.W.sum().item()))

            self.epoch_loss_history.append(avg_loss)
            if early_stopping is not None:
                if avg_loss < early_stopping:
                    break
Esempio n. 3
0
    def train_rl_ctx(self, data_loader, epochs, early_stopping=None):
        """Training module for DKL as part of IMeL. Condition GP on context set """

        for epoch in range(epochs):
            epoch_loss = 0.
            for i, data in enumerate(data_loader):
                t = time.time()
                # Zero backprop gradients
                self.optimizer.zero_grad()
                # Get output from model
                x, y, num_points = data  # add , num_points
                # divide context (N-1) and target (1)
                x_context, y_context, _, _ = context_target_split(x[:,:num_points,:], y[:,:num_points,:],
                                                                  num_points.item()-1, 1)

                self.model.set_train_data(inputs=x_context, targets=y_context.view(-1), strict=False)
                predictions = self.model(x_context)
                loss = -self.mll(predictions, y_context.view(-1))
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()
                avg_loss = epoch_loss / len(data_loader)
            print('epoch %d - Loss: %.3f   lengthscale: %.9f  outpuscale: %.3f   noise: %.3f' % (epoch, loss.item(),
                self.model.covar_module.base_kernel.base_kernel.lengthscale.item(),
                self.model.covar_module.base_kernel.outputscale.item(),
                self.model.likelihood.noise.item()))
            if epoch % self.print_freq == 0 or epoch == epochs-1:
                print("Epoch: {}, Avg_loss: {}".format(epoch, avg_loss))
                #plot_posterior_2d(data_loader, self.model, 'training '+str(epoch), self.args)

            self.epoch_loss_history.append(avg_loss)
            if early_stopping is not None:
                if avg_loss < early_stopping:
                    break
    def train(self, data_loader, epochs, early_stopping=None):
        """
        Trains Neural Process.

        Parameters
        ----------
        dataloader : torch.utils.DataLoader instance

        epochs : int
            Number of epochs to train for.
        """

        # compute the episode-specific context sets
        one_out_list = []
        episode_fixed_list = [ep for _, ep in enumerate(data_loader)]
        for i in range(len(episode_fixed_list)):
            context_list = []
            if len(episode_fixed_list) == 1:
                context_list = [ep for ep in episode_fixed_list]
            else:
                for j, ep in enumerate(episode_fixed_list):
                    if j != i:
                        context_list.append(ep)
                #context_list = [ep for j, ep in enumerate(data_loader) if j != i]
            all_context_points = merge_context(context_list)
            one_out_list.append(all_context_points)

        for epoch in range(epochs):
            epoch_loss = 0.
            for i in range(len(data_loader)):
                self.optimizer.zero_grad()

                all_context_points = one_out_list[i]
                data = episode_fixed_list[i]
                x, y, num_points = data
                num_target = min(num_points.item(), self.num_target)
                x_context, y_context = all_context_points
                _, _, x_target, y_target = context_target_split(x[:, :num_points, :], y[:, :num_points, :],
                                                                0, num_target)
                p_y_pred, q_target, q_context = \
                    self.neural_process(x_context, y_context, x_target, y_target)
                loss = self._loss(p_y_pred, y_target, q_target, q_context)
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()
                self.steps += 1
            avg_loss = epoch_loss / len(data_loader)
            if epoch % self.print_freq == 0 or epoch == epochs-1:
                print("Epoch loo: {}, Avg_loss: {}".format(epoch, avg_loss))
            self.epoch_loss_history.append(avg_loss)

            if early_stopping is not None:
                if avg_loss < early_stopping:
                    break
Esempio n. 5
0
def plot_posterior(data_loader,
                   model,
                   id,
                   args,
                   title='Posterior',
                   num_func=4):
    plt.ylabel('Predicted y distribution')
    colors = ['r', 'b', 'g', 'y', 'b', 'g', 'y']
    for j in range(num_func):
        plt.figure(j)
        plt.xlabel('x')

        x, y = data_loader.dataset.data[j]
        x = x.unsqueeze(0)
        y = y.unsqueeze(0)
        x_context, y_context, x_target, y_target = context_target_split(
            x[0:1], y[0:1], args.num_context, args.num_target)

        #plt.title(title)
        model.set_train_data(x_context.squeeze(0),
                             y_context.squeeze(0).squeeze(-1),
                             strict=False)
        model.training = False
        with torch.no_grad(), gpytorch.settings.use_toeplitz(
                False), gpytorch.settings.fast_pred_var():
            p_y_pred = model(x[0:1].squeeze(0))
        model.training = True
        # Extract mean of distribution
        mu = p_y_pred.loc.detach().cpu().numpy()
        stdv = p_y_pred.stddev.detach().cpu().numpy()

        plt.plot(x[0:1].cpu().numpy()[0].squeeze(-1),
                 mu,
                 alpha=0.9,
                 c=colors[j],
                 label='Mean')
        plt.fill_between(x[0:1].cpu().numpy()[0].squeeze(-1),
                         mu - stdv,
                         mu + stdv,
                         color=colors[j],
                         alpha=0.1,
                         label='stdev')

        plt.plot(x[0].cpu().numpy(),
                 y[0].cpu().numpy(),
                 alpha=0.5,
                 c='k',
                 label='Real function')
        plt.scatter(x_context[0].cpu().numpy(),
                    y_context[0].cpu().numpy(),
                    c=colors[j],
                    label='Context points')
        plt.legend()
        plt.savefig(args.directory_path + '/' + id + str(j))
    plt.close()
Esempio n. 6
0
def plot_posterior_2d(data_loader, model, id, args):
    for batch in data_loader:
        break

    # Use batch to create random set of context points
    x, y = batch  # , real_len

    x_context, y_context, _, _ = context_target_split(x[0:1], y[0:1],
                                                      args.num_context // 2,
                                                      args.num_target)
    x, X1, X2, x1, x2 = create_plot_grid(args.extent,
                                         args,
                                         size=args.grid_size)

    fig = plt.figure(figsize=(20, 8))  # figsize=plt.figaspect(1.5)
    #fig.suptitle(id, fontsize=20)
    #fig.tight_layout()

    ax_real = fig.add_subplot(131, projection='3d')
    ax_real.plot_surface(X1,
                         X2,
                         y.reshape(X1.shape).cpu().numpy(),
                         cmap='viridis')
    ax_real.set_title('Real function')

    ax_context = fig.add_subplot(132, projection='3d')
    ax_context.scatter(x_context[0, :, 0].detach().cpu().numpy(),
                       x_context[0, :, 1].detach().cpu().numpy(),
                       y_context[0, :, 0].detach().cpu().numpy(),
                       c=y_context[0, :, 0].detach().cpu().numpy(),
                       cmap='viridis',
                       vmin=-1.,
                       vmax=1.,
                       s=8)

    ax_context.set_title('Context points')
    with torch.no_grad():
        mu = model(x_context, y_context, x[0:1])

    ax_mean = fig.add_subplot(133, projection='3d')
    # Extract mean of distribution
    ax_mean.plot_surface(X1,
                         X2,
                         mu.cpu().view(X1.shape).numpy(),
                         cmap='viridis')
    ax_mean.set_title('Posterior estimate')

    for ax in [ax_mean, ax_context, ax_real]:
        ax.set_xlabel('x1')
        ax.set_ylabel('x2')
        ax.set_zlabel('y')
    plt.savefig(args.directory_path + '/posteriior' + id, dpi=350)
    #plt.show()
    plt.close(fig)
    return
Esempio n. 7
0
 def train_rl_loo(self, data_loader, epochs, early_stopping=None):
     """
     train MKI as part of IMeL by leave-one-out procedure
     """
     one_out_list = []
     episode_fixed_list = [ep for _, ep in enumerate(data_loader)]
     for i in range(len(episode_fixed_list)):
         context_list = []
         if len(episode_fixed_list) == 1:
             context_list = [ep for ep in episode_fixed_list]
         else:
             for j, ep in enumerate(episode_fixed_list):
                 if j != i:
                     context_list.append(ep)
             # context_list = [ep for j, ep in enumerate(data_loader) if j != i]
         #all_context_points = merge_context(context_list)
         one_out_list.append(context_list)
     loader_len = len(data_loader)
     for epoch in range(epochs):
         epoch_loss = 0.
         for i, data in enumerate(data_loader):
             tt0 = time.time()
             # Zero backprop gradients
             self.optimizer.zero_grad()
             # Get output from model
             all_context_points = one_out_list[i]
             data = episode_fixed_list[i]
             x, y, num_points = data
             x_context, y_context = get_random_context(all_context_points, self.num_context)
             num_target = min(self.num_target, num_points.item())
             _, _, x_target, y_target = context_target_split(x[:, :num_points, :], y[:, :num_points, :], 0, num_target)
             tp0 = time.time()
             prediction = self.model(x_context, y_context, x_target)
             tp1 = time.time()
             loss = self._loss(y_target, prediction)
             loss.backward()
             tb0 = time.time()
             self.optimizer.step()
             tb1 = time.time()
             if epoch % self.print_freq == 0 or epoch == epochs-1:
                 epoch_loss += loss.item()
             avg_loss = epoch_loss / loader_len
             tt1 = time.time()
         if epoch % self.print_freq == 0 or epoch == epochs-1 :
             print("Epoch: {}, Avg_loss: {}".format(epoch, avg_loss))
             #print('tot-forw: ', tp0-tt1, '  forw-back:', tb0-tp1, '  back-nw:', tb1-tb0)
         self.epoch_loss_history.append(avg_loss)
         if early_stopping is not None:
             if avg_loss < early_stopping:
                 break
Esempio n. 8
0
    def train_rl(self, data_loader, epochs, early_stopping=None):
        """Training module for DKL as part of IMeL """

        for epoch in range(epochs):
            epoch_loss = 0.
            for i, data in enumerate(data_loader):
                t = time.time()
                # Zero backprop gradients
                self.optimizer.zero_grad()
                # Get output from model
                x, y, num_points = data  # add , num_points
                # divide context (N-1) and target (1)
                x_context, y_context, x_target, y_target = context_target_split(x[:,:num_points,:], y[:,:num_points,:],
                                                                  num_points.item()-1, 1)

                self.model.set_train_data(inputs=x_context, targets=y_context.view(-1), strict=False)
                self.model.eval()
                self.model.likelihood.eval()
                with gpytorch.settings.use_toeplitz(False):
                    predictions = self.model(x_target)
                self.model.train()
                self.model.likelihood.train()

                loss = -self.mll(predictions, y_target.view(-1))
                if torch.isnan(loss):
                    print(loss)
                    self.model.eval()
                    self.model.likelihood.eval()
                    s = self.model(x_target)
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()
                avg_loss = epoch_loss / len(data_loader)
            print('epoch %d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (epoch, loss.item(),
                  self.model.covar_module.base_kernel.lengthscale.item(), self.model.likelihood.noise.item()))
            if epoch % self.print_freq == 0 or epoch == epochs-1:
                print("Epoch: {}, Avg_loss: {}".format(epoch, avg_loss))

            self.epoch_loss_history.append(avg_loss)
            if early_stopping is not None:
                if avg_loss < early_stopping:
                    break
anp = False

args.directory_path += id
os.mkdir(args.directory_path)

# Create dataset
dataset = MultiGPData(mean, kernel, num_samples=args.num_tot_samples, amplitude_range=x_range, num_points=200)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

test_dataset = MultiGPData(mean, kernel, num_samples=args.batch_size, amplitude_range=x_range, num_points=200)
test_data_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

for data_init in data_loader:
    break
x_init, y_init = data_init
x_init, y_init, _, _ = context_target_split(x_init[0:1], y_init[0:1], args.num_context, args.num_target)
print('dataset created', x_init.size())

# create model
likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device)
model_dkl = GPRegressionModel(x_init, y_init.squeeze(0).squeeze(-1), likelihood,
                          args.h_dim_dkl, args.z_dim_dkl, name_id='DKL').to(device)
if anp:
    model_np = AttentiveNeuralProcess(args.x_dim, args.y_dim, args.r_dim_np, args.z_dim_np, args.h_dim_np, args.a_dim_np,
                                      use_self_att=True, fixed_sigma=None).to(device)
else:
    model_np = NeuralProcess(args.x_dim, args.y_dim, args.r_dim_np, args.z_dim_np, args.h_dim_np, fixed_sigma=None).to(device)

optimizer_dkl = torch.optim.Adam([
    {'params': model_dkl.feature_extractor.parameters()},
    {'params': model_dkl.covar_module.parameters()},
                      amplitude_range=x_range)
data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
#all_dataset = [dataset.data[0][0], dataset.data[0][1]]
#for func in dataset.data[1:]:
#    all_train_dataset = [torch.cat([all_dataset[0], func[0]], dim=0), torch.cat([all_dataset[1], func[1]], dim=0)]
#train_x, train_y = all_train_dataset
test_dataset = MultiGPData(mean,
                           kernel,
                           num_samples=10,
                           amplitude_range=x_range)
test_data_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

for data_init in data_loader:
    break
x_init, y_init = data_init
x_init, y_init, _, _ = context_target_split(x_init[0:1], y_init[0:1],
                                            args.num_context, args.num_target)
#x_init, y_init = dataset.data[0]
print('dataset created')
# create model
likelihood = gpytorch.likelihoods.GaussianLikelihood().to(
    device)  # noise_constraint=gpytorch.constraints.GreaterThan(1e-3)
model = GPRegressionModel(x_init,
                          y_init.squeeze(0).squeeze(-1),
                          likelihood,
                          args.h_dim,
                          args.z_dim,
                          name_id=id,
                          scaling=args.scaling,
                          grid_size=100).to(device)

optimizer = torch.optim.Adam([{
Esempio n. 11
0
                          amplitude_range=x_range)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Extract a batch from data_loader
test_x_context_l = []
test_y_context_l = []
test_x_l = []
test_y_l = []
for e, batch in enumerate(data_loader):
    # Use batch to create random set of context points
    test_x, test_y = batch
    test_x_l.append(test_x)
    test_y_l.append(test_y)
    num_context_t = 21
    num_target_t = 21
    test_x_context, test_y_context, _, _ = context_target_split(
        test_x[0:1], test_y[0:1], num_context_t, num_target_t)
    test_x_context_l.append(test_x_context)
    test_y_context_l.append(test_y_context)

# Visualize data samples
fig_data, ax_data = plt.subplots(1, 1)

#ax_data.set_title('Multiple samples from the GP with ' + kernel[0] + ' kernel')
for i in range(64):
    x, y = dataset[i * (num_tot_samples // 64)]
    ax_data.plot(x.cpu().numpy(), y.cpu().numpy(), c='k', alpha=0.5)
    ax_data.set_xlabel('x')
    ax_data.set_ylabel('y')
    ax_data.set_xlim(x_range[0], x_range[1])
fig_data.savefig(plots_path + '-'.join(kernel) + '_data', dpi=250)
plt.close(fig_data)
plt.savefig(plots_path + '-'.join(kernel) + '_prior_' + id)
plt.close()
# Extract a batch from data_loader
plt.figure(4)
colors = ['r', 'b', 'g', 'y']
for j in range(2):
    for batch in data_loader:
        break

    # Use batch to create random set of context points
    x, y = batch
    num_context_t = randint(*num_context)
    num_target_t = randint(*num_target)
    x_context, y_context, _, _ = context_target_split(x[0:1], y[0:1],
                                                      num_context_t,
                                                      num_target_t)

    neuralprocess.training = False

    plt.title(mdl + ' Posterior')
    for i in range(4):
        # Neural process returns distribution over y_target
        p_y_pred = neuralprocess(x_context, y_context, x_target)
        # Extract mean of distribution
        plt.xlabel('x')
        plt.ylabel('means of y distribution')
        mu = p_y_pred.loc.detach()
        plt.plot(x_target.cpu().numpy()[0],
                 mu.cpu().numpy()[0],
                 alpha=0.3,
Esempio n. 13
0
def plot_posterior_2d(data_loader, model, id, args):
    for n, batch in enumerate(data_loader):

        # Use batch to create random set of context points
        x, y = batch  # , real_len

        x_context, y_context, _, _ = context_target_split(
            x[0:1], y[0:1], args.num_context, args.num_target)
        x, X1, X2, x1, x2 = create_plot_grid(args.extent,
                                             args,
                                             size=args.grid_size)

        fig = plt.figure(figsize=(20, 8))  # figsize=plt.figaspect(1.5)
        fig.suptitle(id, fontsize=20)
        #fig.tight_layout()

        ax_real = fig.add_subplot(131, projection='3d')
        ax_real.plot_surface(X1,
                             X2,
                             y.reshape(X1.shape).cpu().numpy(),
                             cmap='viridis')
        ax_real.set_title('Real function')

        ax_context = fig.add_subplot(132, projection='3d')
        ax_context.scatter(x_context[0, :, 0].detach().cpu().numpy(),
                           x_context[0, :, 1].detach().cpu().numpy(),
                           y_context[0, :, 0].detach().cpu().numpy(),
                           c=y_context[0, :, 0].detach().cpu().numpy(),
                           cmap='viridis',
                           vmin=-1.,
                           vmax=1.,
                           s=8)

        ax_context.set_title('Context points')
        model.set_train_data(x_context.squeeze(0),
                             y_context.squeeze(0).squeeze(-1),
                             strict=False)
        model.training = False
        with torch.no_grad():
            p_y_pred = model(x[0:1])
        mu = p_y_pred.mean.reshape(X1.shape).cpu()
        mu[torch.isnan(mu)] = 0.
        mu = mu.numpy()
        sigma = p_y_pred.stddev.reshape(X1.shape).cpu()
        sigma[torch.isnan(sigma)] = 0.
        sigma = sigma.detach().numpy()
        std_h = mu + sigma
        std_l = mu - sigma
        model.training = True
        max_mu = std_h.max()
        min_mu = std_l.min()
        ax_mean = fig.add_subplot(133, projection='3d')
        i = 0
        for y_slice in x2:
            ax_mean.add_collection3d(plt.fill_between(x1,
                                                      std_l[i, :],
                                                      std_h[i, :],
                                                      color='lightseagreen',
                                                      alpha=0.04,
                                                      label='stdev'),
                                     zs=y_slice,
                                     zdir='y')
            i += 1
        # Extract mean of distribution
        ax_mean.plot_surface(X1, X2, mu, cmap='viridis', label='mean')
        for ax in [ax_mean, ax_context, ax_real]:
            ax.set_zlim(min_mu, max_mu)
            ax.set_xlabel('x1')
            ax.set_ylabel('x2')
            ax.set_zlabel('y')
        ax_mean.set_title('Posterior estimate')
        ax_mean.set_xlim(args.extent[0], args.extent[1])
        ax_mean.set_ylim(args.extent[2], args.extent[3])
        plt.savefig(args.directory_path + '/posteriior' + id + str(n), dpi=250)
        #plt.show()
        plt.close(fig)
    return
def img_plot():
    grid = torch.zeros(grid_size, len(grid_bounds))
    for i in range(len(grid_bounds)):
        grid_diff = float(grid_bounds[i][1] - grid_bounds[i][0]) / (grid_size -
                                                                    2)
        grid[:, i] = torch.linspace(grid_bounds[i][0] - grid_diff,
                                    grid_bounds[i][1] + grid_diff, grid_size)

    x = gpytorch.utils.grid.create_data_from_grid(grid)
    x = x.unsqueeze(0)

    if not use_attention:
        mu_list = []
        for i in range(4):
            z_sample = torch.randn((1, z_dim))
            z_sample = z_sample.unsqueeze(1).repeat(1, x.size()[1], 1)
            mu, _ = neuralprocess.xz_to_y(x, z_sample)
            mu_list.append(mu)

        f2, axarr2 = plt.subplots(2, 2)

        axarr2[0, 0].imshow(mu_list[0].view(-1,
                                            grid_size).detach().cpu().numpy(),
                            extent=extent)
        axarr2[0, 1].imshow(mu_list[1].view(-1,
                                            grid_size).detach().cpu().numpy(),
                            extent=extent)
        axarr2[1, 0].imshow(mu_list[2].view(-1,
                                            grid_size).detach().cpu().numpy(),
                            extent=extent)
        axarr2[1, 1].imshow(mu_list[3].view(-1,
                                            grid_size).detach().cpu().numpy(),
                            extent=extent)
        f2.suptitle('Samples from trained prior')
        for ax in axarr2.flat:
            ax.set(xlabel='x1', ylabel='x2')
            ax.label_outer()
        plt.savefig(plots_path + kernel + 'prior' + id)
        plt.show()
        plt.close(f2)

    # Extract a batch from data_loader
    for n, batch in enumerate(data_loader):

        # Use batch to create random set of context points
        x, y = batch
        num_context_t = randint(*num_context)
        num_target_t = randint(*num_target)
        x_context, y_context, _, _ = context_target_split(
            x[0:1], y[0:1], num_context_t, num_target_t)

        neuralprocess.training = False

        f3, axarr3 = plt.subplots(2, 2)
        axarr3[0, 1].scatter(x_context[0].detach().cpu().numpy()[:, 0],
                             x_context[0].detach().cpu().numpy()[:, 1],
                             cmap='viridis',
                             c=y_context[0].detach().cpu().numpy()[:, 0],
                             s=1)
        axarr3[0, 0].imshow(y[0].view(-1,
                                      grid_size).detach().cpu().numpy()[::-1],
                            extent=extent)
        axarr3[0, 0].set_title('Real function')
        axarr3[0, 1].set_xlim(extent[0], extent[1])
        axarr3[0, 1].set_ylim(extent[2], extent[3])
        axarr3[0, 1].set_aspect('equal')
        axarr3[0, 1].set_title('Context points')
        mu_list = []
        for i in range(2):
            # Neural process returns distribution over y_target
            p_y_pred = neuralprocess(x_context, y_context, x[0].unsqueeze(0))

            # Extract mean of distribution
            mu_list.append(p_y_pred.loc.detach())

        axarr3[1, 0].imshow(mu_list[0].view(
            -1, grid_size).detach().cpu().numpy()[::-1],
                            extent=extent)
        axarr3[1, 0].set_title('Posterior estimate_1')

        axarr3[1, 1].imshow(mu_list[1].view(
            -1, grid_size).detach().cpu().numpy()[::-1],
                            extent=extent)
        axarr3[1, 1].set_title('Posterior estimate_2')
        plt.savefig(plots_path + kernel + ' posteriior' + id)
    #plt.show()
    plt.close(f3)