optimizer_dkl = torch.optim.Adam([
    {'params': model_dkl.feature_extractor.parameters()},
    {'params': model_dkl.covar_module.parameters()},
    {'params': model_dkl.mean_module.parameters()},
    {'params': model_dkl.likelihood.parameters()}], lr=0.01)
trainer_dkl = DKMTrainer(device, model_dkl, optimizer_dkl, args, print_freq=args.print_freq)

optimizer_np = torch.optim.Adam(model_np.parameters(), lr=learning_rate)
np_trainer = NeuralProcessTrainer(device, model_np, optimizer_np,
                                  num_context_range=(args.num_context,args.num_context),
                                  num_extra_target_range=(args.num_target,args.num_target),
                                  print_freq=args.print_freq)
# train
print('start dkl training')
t_np_t0 = time.time()
model_np.training = True
np_trainer.train(data_loader, args.epochs, early_stopping=args.early_stopping)
t_np_t1 = time.time()

t_dkl_t0 = time.time()
trainer_dkl.train_dkl(data_loader, args.epochs, early_stopping=args.early_stopping)
t_dkl_t1 = time.time()



# Visualize data samples
plt.figure(1)
plt.title('Samples from gp with kernels: ' + ' '.join(kernel))
for i in range(args.num_tot_samples):
    x, y = dataset[i]
    plt.plot(x.cpu().numpy(), y.cpu().numpy(), c='b', alpha=0.5)
Example #2
0
            a_dim,
            use_self_att=use_self_att).to(device)
        first = False
    else:
        neuralprocess = NeuralProcess(x_dim, y_dim, r_dim, z_dim,
                                      h_dim).to(device)

    t0 = time.time()
    optimizer = torch.optim.Adam(neuralprocess.parameters(), lr=learning_rate)
    np_trainer = NeuralProcessTrainer(device,
                                      neuralprocess,
                                      optimizer,
                                      num_context_range=num_context,
                                      num_extra_target_range=num_target,
                                      print_freq=50000)
    neuralprocess.training = True
    np_trainer.train(data_loader, epochs, early_stopping=0)
    '''plot training epochs'''
    n_ep = len(np_trainer.epoch_loss_history)
    ax_epoch.plot(np.linspace(0, n_ep - 1, n_ep),
                  np_trainer.epoch_loss_history,
                  c=color,
                  label=mdl)

    x_target = torch.linspace(x_range[0], x_range[1], 100)
    x_target = x_target.unsqueeze(1).unsqueeze(0)

    # plot prior
    if False:
        fig_prior, ax_prior = plt.subplots(1, 1)
        ax_prior.set_ylabel('Means of y distribution')