plt.xlim(-pi, pi)

    from neural_process import NeuralProcess

x_dim = 1
y_dim = 1
r_dim = 50  # Dimension of representation of context points
z_dim = 50  # Dimension of sampled latent variable
h_dim = 50  # Dimension of hidden layers in encoder and decoder

neuralprocess = NeuralProcess(x_dim, y_dim, r_dim, z_dim, h_dim)

from torch.utils.data import DataLoader
from training import NeuralProcessTrainer

batch_size = 2
num_context = 4
num_target = 4

data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(neuralprocess.parameters(), lr=3e-4)
np_trainer = NeuralProcessTrainer(device,
                                  neuralprocess,
                                  optimizer,
                                  num_context_range=(num_context, num_context),
                                  num_extra_target_range=(num_target,
                                                          num_target),
                                  print_freq=200)

neuralprocess.training = True
np_trainer.train(data_loader, 30)
        json.dump(np_trainer.epoch_loss_history, f)

    # Save model at every epoch
    torch.save(np_trainer.neural_process.state_dict(), directory + '/model.pt')

    if epoch % 50 == 0:

        if epoch == 0:
            for batch in data_loader:
                break
            x, y = batch
            x_context, y_context, _, _ = context_target_split(
                x[0:1], y[0:1], 4, 4)

            x_target = torch.Tensor(np.linspace(-pi, pi, 100))
            x_target = x_target.unsqueeze(1).unsqueeze(0)

        input_data.training = False

        for i in range(64):
            # Neural process returns distribution over y_target
            p_y_pred = input_data(x_context, y_context, x_target)
            # Extract mean of distribution
            mu = p_y_pred.loc.detach()
            plt.plot(x_target.numpy()[0], mu.numpy()[0], alpha=0.05, c='b')

        input_data.training = True

        plt.scatter(x_context[0].numpy(), y_context[0].numpy(), c='k')
        plt.show()