plt.xlim(-pi, pi) from neural_process import NeuralProcess x_dim = 1 y_dim = 1 r_dim = 50 # Dimension of representation of context points z_dim = 50 # Dimension of sampled latent variable h_dim = 50 # Dimension of hidden layers in encoder and decoder neuralprocess = NeuralProcess(x_dim, y_dim, r_dim, z_dim, h_dim) from torch.utils.data import DataLoader from training import NeuralProcessTrainer batch_size = 2 num_context = 4 num_target = 4 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) optimizer = torch.optim.Adam(neuralprocess.parameters(), lr=3e-4) np_trainer = NeuralProcessTrainer(device, neuralprocess, optimizer, num_context_range=(num_context, num_context), num_extra_target_range=(num_target, num_target), print_freq=200) neuralprocess.training = True np_trainer.train(data_loader, 30)
json.dump(np_trainer.epoch_loss_history, f) # Save model at every epoch torch.save(np_trainer.neural_process.state_dict(), directory + '/model.pt') if epoch % 50 == 0: if epoch == 0: for batch in data_loader: break x, y = batch x_context, y_context, _, _ = context_target_split( x[0:1], y[0:1], 4, 4) x_target = torch.Tensor(np.linspace(-pi, pi, 100)) x_target = x_target.unsqueeze(1).unsqueeze(0) input_data.training = False for i in range(64): # Neural process returns distribution over y_target p_y_pred = input_data(x_context, y_context, x_target) # Extract mean of distribution mu = p_y_pred.loc.detach() plt.plot(x_target.numpy()[0], mu.numpy()[0], alpha=0.05, c='b') input_data.training = True plt.scatter(x_context[0].numpy(), y_context[0].numpy(), c='k') plt.show()