def train(self, data_loader, epochs, early_stopping=None): """Training procedure for DKL on arbitrary function""" print('std') for epoch in range(epochs): epoch_loss = 0. for i, data in enumerate(data_loader): t = time.time() # Zero backprop gradients self.optimizer.zero_grad() # Get output from model x, y = data # add , num_points x_context, y_context, x_target, y_target = context_target_split(x[0:1], y[0:1], self.num_context, self.num_target) self.model.set_train_data(inputs=x_context, targets=y_context.view(-1), strict=False) #with gpytorch.settings.use_toeplitz(False), gpytorch.settings.fast_pred_var(False): output = self.model(x_context) if any(torch.isnan(output.stddev.view(-1))): print('nan at epoch ', epoch) continue # Calc loss and backprop derivatives loss = -self.mll(output, y_context.view(-1)) # .sum() loss.backward() self.optimizer.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(data_loader) if epoch % self.print_freq == 0 or epoch == epochs - 1: print("Epoch: {}, Avg_loss: {}".format(epoch, avg_loss)) self.epoch_loss_history.append(avg_loss) if early_stopping is not None: if avg_loss < early_stopping: break
def train_rl(self, data_loader, epochs, early_stopping=None): """ Train MKI model as a part if IMeL """ for epoch in range(epochs): epoch_loss = 0. for i, data in enumerate(data_loader): # Zero backprop gradients self.optimizer.zero_grad() # Get output from model x, y, num_points = data # add , num_points # divide context (N-1) and target (1) x_context, y_context, x_target, y_target = context_target_split(x[:,:num_points,:], y[:,:num_points,:], num_points.item()-1, 1) all_x_context, all_y_context = merge_context(context_points_list) prediction = self.model(x_context, y_context, x_target) if torch.isnan(prediction): prediction = self.model(x_context, y_context, x_target) loss = self._loss(y_target, prediction) loss.backward() self.optimizer.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(data_loader) if epoch % self.print_freq == 0 or epoch == epochs-1 : print("Epoch: {}, Avg_loss: {}, W_sum: {}".format(epoch, avg_loss, self.model.interpolator.W.sum().item())) self.epoch_loss_history.append(avg_loss) if early_stopping is not None: if avg_loss < early_stopping: break
def train_rl_ctx(self, data_loader, epochs, early_stopping=None): """Training module for DKL as part of IMeL. Condition GP on context set """ for epoch in range(epochs): epoch_loss = 0. for i, data in enumerate(data_loader): t = time.time() # Zero backprop gradients self.optimizer.zero_grad() # Get output from model x, y, num_points = data # add , num_points # divide context (N-1) and target (1) x_context, y_context, _, _ = context_target_split(x[:,:num_points,:], y[:,:num_points,:], num_points.item()-1, 1) self.model.set_train_data(inputs=x_context, targets=y_context.view(-1), strict=False) predictions = self.model(x_context) loss = -self.mll(predictions, y_context.view(-1)) loss.backward() self.optimizer.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(data_loader) print('epoch %d - Loss: %.3f lengthscale: %.9f outpuscale: %.3f noise: %.3f' % (epoch, loss.item(), self.model.covar_module.base_kernel.base_kernel.lengthscale.item(), self.model.covar_module.base_kernel.outputscale.item(), self.model.likelihood.noise.item())) if epoch % self.print_freq == 0 or epoch == epochs-1: print("Epoch: {}, Avg_loss: {}".format(epoch, avg_loss)) #plot_posterior_2d(data_loader, self.model, 'training '+str(epoch), self.args) self.epoch_loss_history.append(avg_loss) if early_stopping is not None: if avg_loss < early_stopping: break
def train(self, data_loader, epochs, early_stopping=None): """ Trains Neural Process. Parameters ---------- dataloader : torch.utils.DataLoader instance epochs : int Number of epochs to train for. """ # compute the episode-specific context sets one_out_list = [] episode_fixed_list = [ep for _, ep in enumerate(data_loader)] for i in range(len(episode_fixed_list)): context_list = [] if len(episode_fixed_list) == 1: context_list = [ep for ep in episode_fixed_list] else: for j, ep in enumerate(episode_fixed_list): if j != i: context_list.append(ep) #context_list = [ep for j, ep in enumerate(data_loader) if j != i] all_context_points = merge_context(context_list) one_out_list.append(all_context_points) for epoch in range(epochs): epoch_loss = 0. for i in range(len(data_loader)): self.optimizer.zero_grad() all_context_points = one_out_list[i] data = episode_fixed_list[i] x, y, num_points = data num_target = min(num_points.item(), self.num_target) x_context, y_context = all_context_points _, _, x_target, y_target = context_target_split(x[:, :num_points, :], y[:, :num_points, :], 0, num_target) p_y_pred, q_target, q_context = \ self.neural_process(x_context, y_context, x_target, y_target) loss = self._loss(p_y_pred, y_target, q_target, q_context) loss.backward() self.optimizer.step() epoch_loss += loss.item() self.steps += 1 avg_loss = epoch_loss / len(data_loader) if epoch % self.print_freq == 0 or epoch == epochs-1: print("Epoch loo: {}, Avg_loss: {}".format(epoch, avg_loss)) self.epoch_loss_history.append(avg_loss) if early_stopping is not None: if avg_loss < early_stopping: break
def plot_posterior(data_loader, model, id, args, title='Posterior', num_func=4): plt.ylabel('Predicted y distribution') colors = ['r', 'b', 'g', 'y', 'b', 'g', 'y'] for j in range(num_func): plt.figure(j) plt.xlabel('x') x, y = data_loader.dataset.data[j] x = x.unsqueeze(0) y = y.unsqueeze(0) x_context, y_context, x_target, y_target = context_target_split( x[0:1], y[0:1], args.num_context, args.num_target) #plt.title(title) model.set_train_data(x_context.squeeze(0), y_context.squeeze(0).squeeze(-1), strict=False) model.training = False with torch.no_grad(), gpytorch.settings.use_toeplitz( False), gpytorch.settings.fast_pred_var(): p_y_pred = model(x[0:1].squeeze(0)) model.training = True # Extract mean of distribution mu = p_y_pred.loc.detach().cpu().numpy() stdv = p_y_pred.stddev.detach().cpu().numpy() plt.plot(x[0:1].cpu().numpy()[0].squeeze(-1), mu, alpha=0.9, c=colors[j], label='Mean') plt.fill_between(x[0:1].cpu().numpy()[0].squeeze(-1), mu - stdv, mu + stdv, color=colors[j], alpha=0.1, label='stdev') plt.plot(x[0].cpu().numpy(), y[0].cpu().numpy(), alpha=0.5, c='k', label='Real function') plt.scatter(x_context[0].cpu().numpy(), y_context[0].cpu().numpy(), c=colors[j], label='Context points') plt.legend() plt.savefig(args.directory_path + '/' + id + str(j)) plt.close()
def plot_posterior_2d(data_loader, model, id, args): for batch in data_loader: break # Use batch to create random set of context points x, y = batch # , real_len x_context, y_context, _, _ = context_target_split(x[0:1], y[0:1], args.num_context // 2, args.num_target) x, X1, X2, x1, x2 = create_plot_grid(args.extent, args, size=args.grid_size) fig = plt.figure(figsize=(20, 8)) # figsize=plt.figaspect(1.5) #fig.suptitle(id, fontsize=20) #fig.tight_layout() ax_real = fig.add_subplot(131, projection='3d') ax_real.plot_surface(X1, X2, y.reshape(X1.shape).cpu().numpy(), cmap='viridis') ax_real.set_title('Real function') ax_context = fig.add_subplot(132, projection='3d') ax_context.scatter(x_context[0, :, 0].detach().cpu().numpy(), x_context[0, :, 1].detach().cpu().numpy(), y_context[0, :, 0].detach().cpu().numpy(), c=y_context[0, :, 0].detach().cpu().numpy(), cmap='viridis', vmin=-1., vmax=1., s=8) ax_context.set_title('Context points') with torch.no_grad(): mu = model(x_context, y_context, x[0:1]) ax_mean = fig.add_subplot(133, projection='3d') # Extract mean of distribution ax_mean.plot_surface(X1, X2, mu.cpu().view(X1.shape).numpy(), cmap='viridis') ax_mean.set_title('Posterior estimate') for ax in [ax_mean, ax_context, ax_real]: ax.set_xlabel('x1') ax.set_ylabel('x2') ax.set_zlabel('y') plt.savefig(args.directory_path + '/posteriior' + id, dpi=350) #plt.show() plt.close(fig) return
def train_rl_loo(self, data_loader, epochs, early_stopping=None): """ train MKI as part of IMeL by leave-one-out procedure """ one_out_list = [] episode_fixed_list = [ep for _, ep in enumerate(data_loader)] for i in range(len(episode_fixed_list)): context_list = [] if len(episode_fixed_list) == 1: context_list = [ep for ep in episode_fixed_list] else: for j, ep in enumerate(episode_fixed_list): if j != i: context_list.append(ep) # context_list = [ep for j, ep in enumerate(data_loader) if j != i] #all_context_points = merge_context(context_list) one_out_list.append(context_list) loader_len = len(data_loader) for epoch in range(epochs): epoch_loss = 0. for i, data in enumerate(data_loader): tt0 = time.time() # Zero backprop gradients self.optimizer.zero_grad() # Get output from model all_context_points = one_out_list[i] data = episode_fixed_list[i] x, y, num_points = data x_context, y_context = get_random_context(all_context_points, self.num_context) num_target = min(self.num_target, num_points.item()) _, _, x_target, y_target = context_target_split(x[:, :num_points, :], y[:, :num_points, :], 0, num_target) tp0 = time.time() prediction = self.model(x_context, y_context, x_target) tp1 = time.time() loss = self._loss(y_target, prediction) loss.backward() tb0 = time.time() self.optimizer.step() tb1 = time.time() if epoch % self.print_freq == 0 or epoch == epochs-1: epoch_loss += loss.item() avg_loss = epoch_loss / loader_len tt1 = time.time() if epoch % self.print_freq == 0 or epoch == epochs-1 : print("Epoch: {}, Avg_loss: {}".format(epoch, avg_loss)) #print('tot-forw: ', tp0-tt1, ' forw-back:', tb0-tp1, ' back-nw:', tb1-tb0) self.epoch_loss_history.append(avg_loss) if early_stopping is not None: if avg_loss < early_stopping: break
def train_rl(self, data_loader, epochs, early_stopping=None): """Training module for DKL as part of IMeL """ for epoch in range(epochs): epoch_loss = 0. for i, data in enumerate(data_loader): t = time.time() # Zero backprop gradients self.optimizer.zero_grad() # Get output from model x, y, num_points = data # add , num_points # divide context (N-1) and target (1) x_context, y_context, x_target, y_target = context_target_split(x[:,:num_points,:], y[:,:num_points,:], num_points.item()-1, 1) self.model.set_train_data(inputs=x_context, targets=y_context.view(-1), strict=False) self.model.eval() self.model.likelihood.eval() with gpytorch.settings.use_toeplitz(False): predictions = self.model(x_target) self.model.train() self.model.likelihood.train() loss = -self.mll(predictions, y_target.view(-1)) if torch.isnan(loss): print(loss) self.model.eval() self.model.likelihood.eval() s = self.model(x_target) loss.backward() self.optimizer.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(data_loader) print('epoch %d - Loss: %.3f lengthscale: %.3f noise: %.3f' % (epoch, loss.item(), self.model.covar_module.base_kernel.lengthscale.item(), self.model.likelihood.noise.item())) if epoch % self.print_freq == 0 or epoch == epochs-1: print("Epoch: {}, Avg_loss: {}".format(epoch, avg_loss)) self.epoch_loss_history.append(avg_loss) if early_stopping is not None: if avg_loss < early_stopping: break
anp = False args.directory_path += id os.mkdir(args.directory_path) # Create dataset dataset = MultiGPData(mean, kernel, num_samples=args.num_tot_samples, amplitude_range=x_range, num_points=200) data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) test_dataset = MultiGPData(mean, kernel, num_samples=args.batch_size, amplitude_range=x_range, num_points=200) test_data_loader = DataLoader(test_dataset, batch_size=1, shuffle=True) for data_init in data_loader: break x_init, y_init = data_init x_init, y_init, _, _ = context_target_split(x_init[0:1], y_init[0:1], args.num_context, args.num_target) print('dataset created', x_init.size()) # create model likelihood = gpytorch.likelihoods.GaussianLikelihood().to(device) model_dkl = GPRegressionModel(x_init, y_init.squeeze(0).squeeze(-1), likelihood, args.h_dim_dkl, args.z_dim_dkl, name_id='DKL').to(device) if anp: model_np = AttentiveNeuralProcess(args.x_dim, args.y_dim, args.r_dim_np, args.z_dim_np, args.h_dim_np, args.a_dim_np, use_self_att=True, fixed_sigma=None).to(device) else: model_np = NeuralProcess(args.x_dim, args.y_dim, args.r_dim_np, args.z_dim_np, args.h_dim_np, fixed_sigma=None).to(device) optimizer_dkl = torch.optim.Adam([ {'params': model_dkl.feature_extractor.parameters()}, {'params': model_dkl.covar_module.parameters()},
amplitude_range=x_range) data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) #all_dataset = [dataset.data[0][0], dataset.data[0][1]] #for func in dataset.data[1:]: # all_train_dataset = [torch.cat([all_dataset[0], func[0]], dim=0), torch.cat([all_dataset[1], func[1]], dim=0)] #train_x, train_y = all_train_dataset test_dataset = MultiGPData(mean, kernel, num_samples=10, amplitude_range=x_range) test_data_loader = DataLoader(test_dataset, batch_size=1, shuffle=True) for data_init in data_loader: break x_init, y_init = data_init x_init, y_init, _, _ = context_target_split(x_init[0:1], y_init[0:1], args.num_context, args.num_target) #x_init, y_init = dataset.data[0] print('dataset created') # create model likelihood = gpytorch.likelihoods.GaussianLikelihood().to( device) # noise_constraint=gpytorch.constraints.GreaterThan(1e-3) model = GPRegressionModel(x_init, y_init.squeeze(0).squeeze(-1), likelihood, args.h_dim, args.z_dim, name_id=id, scaling=args.scaling, grid_size=100).to(device) optimizer = torch.optim.Adam([{
amplitude_range=x_range) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Extract a batch from data_loader test_x_context_l = [] test_y_context_l = [] test_x_l = [] test_y_l = [] for e, batch in enumerate(data_loader): # Use batch to create random set of context points test_x, test_y = batch test_x_l.append(test_x) test_y_l.append(test_y) num_context_t = 21 num_target_t = 21 test_x_context, test_y_context, _, _ = context_target_split( test_x[0:1], test_y[0:1], num_context_t, num_target_t) test_x_context_l.append(test_x_context) test_y_context_l.append(test_y_context) # Visualize data samples fig_data, ax_data = plt.subplots(1, 1) #ax_data.set_title('Multiple samples from the GP with ' + kernel[0] + ' kernel') for i in range(64): x, y = dataset[i * (num_tot_samples // 64)] ax_data.plot(x.cpu().numpy(), y.cpu().numpy(), c='k', alpha=0.5) ax_data.set_xlabel('x') ax_data.set_ylabel('y') ax_data.set_xlim(x_range[0], x_range[1]) fig_data.savefig(plots_path + '-'.join(kernel) + '_data', dpi=250) plt.close(fig_data)
plt.savefig(plots_path + '-'.join(kernel) + '_prior_' + id) plt.close() # Extract a batch from data_loader plt.figure(4) colors = ['r', 'b', 'g', 'y'] for j in range(2): for batch in data_loader: break # Use batch to create random set of context points x, y = batch num_context_t = randint(*num_context) num_target_t = randint(*num_target) x_context, y_context, _, _ = context_target_split(x[0:1], y[0:1], num_context_t, num_target_t) neuralprocess.training = False plt.title(mdl + ' Posterior') for i in range(4): # Neural process returns distribution over y_target p_y_pred = neuralprocess(x_context, y_context, x_target) # Extract mean of distribution plt.xlabel('x') plt.ylabel('means of y distribution') mu = p_y_pred.loc.detach() plt.plot(x_target.cpu().numpy()[0], mu.cpu().numpy()[0], alpha=0.3,
def plot_posterior_2d(data_loader, model, id, args): for n, batch in enumerate(data_loader): # Use batch to create random set of context points x, y = batch # , real_len x_context, y_context, _, _ = context_target_split( x[0:1], y[0:1], args.num_context, args.num_target) x, X1, X2, x1, x2 = create_plot_grid(args.extent, args, size=args.grid_size) fig = plt.figure(figsize=(20, 8)) # figsize=plt.figaspect(1.5) fig.suptitle(id, fontsize=20) #fig.tight_layout() ax_real = fig.add_subplot(131, projection='3d') ax_real.plot_surface(X1, X2, y.reshape(X1.shape).cpu().numpy(), cmap='viridis') ax_real.set_title('Real function') ax_context = fig.add_subplot(132, projection='3d') ax_context.scatter(x_context[0, :, 0].detach().cpu().numpy(), x_context[0, :, 1].detach().cpu().numpy(), y_context[0, :, 0].detach().cpu().numpy(), c=y_context[0, :, 0].detach().cpu().numpy(), cmap='viridis', vmin=-1., vmax=1., s=8) ax_context.set_title('Context points') model.set_train_data(x_context.squeeze(0), y_context.squeeze(0).squeeze(-1), strict=False) model.training = False with torch.no_grad(): p_y_pred = model(x[0:1]) mu = p_y_pred.mean.reshape(X1.shape).cpu() mu[torch.isnan(mu)] = 0. mu = mu.numpy() sigma = p_y_pred.stddev.reshape(X1.shape).cpu() sigma[torch.isnan(sigma)] = 0. sigma = sigma.detach().numpy() std_h = mu + sigma std_l = mu - sigma model.training = True max_mu = std_h.max() min_mu = std_l.min() ax_mean = fig.add_subplot(133, projection='3d') i = 0 for y_slice in x2: ax_mean.add_collection3d(plt.fill_between(x1, std_l[i, :], std_h[i, :], color='lightseagreen', alpha=0.04, label='stdev'), zs=y_slice, zdir='y') i += 1 # Extract mean of distribution ax_mean.plot_surface(X1, X2, mu, cmap='viridis', label='mean') for ax in [ax_mean, ax_context, ax_real]: ax.set_zlim(min_mu, max_mu) ax.set_xlabel('x1') ax.set_ylabel('x2') ax.set_zlabel('y') ax_mean.set_title('Posterior estimate') ax_mean.set_xlim(args.extent[0], args.extent[1]) ax_mean.set_ylim(args.extent[2], args.extent[3]) plt.savefig(args.directory_path + '/posteriior' + id + str(n), dpi=250) #plt.show() plt.close(fig) return
def img_plot(): grid = torch.zeros(grid_size, len(grid_bounds)) for i in range(len(grid_bounds)): grid_diff = float(grid_bounds[i][1] - grid_bounds[i][0]) / (grid_size - 2) grid[:, i] = torch.linspace(grid_bounds[i][0] - grid_diff, grid_bounds[i][1] + grid_diff, grid_size) x = gpytorch.utils.grid.create_data_from_grid(grid) x = x.unsqueeze(0) if not use_attention: mu_list = [] for i in range(4): z_sample = torch.randn((1, z_dim)) z_sample = z_sample.unsqueeze(1).repeat(1, x.size()[1], 1) mu, _ = neuralprocess.xz_to_y(x, z_sample) mu_list.append(mu) f2, axarr2 = plt.subplots(2, 2) axarr2[0, 0].imshow(mu_list[0].view(-1, grid_size).detach().cpu().numpy(), extent=extent) axarr2[0, 1].imshow(mu_list[1].view(-1, grid_size).detach().cpu().numpy(), extent=extent) axarr2[1, 0].imshow(mu_list[2].view(-1, grid_size).detach().cpu().numpy(), extent=extent) axarr2[1, 1].imshow(mu_list[3].view(-1, grid_size).detach().cpu().numpy(), extent=extent) f2.suptitle('Samples from trained prior') for ax in axarr2.flat: ax.set(xlabel='x1', ylabel='x2') ax.label_outer() plt.savefig(plots_path + kernel + 'prior' + id) plt.show() plt.close(f2) # Extract a batch from data_loader for n, batch in enumerate(data_loader): # Use batch to create random set of context points x, y = batch num_context_t = randint(*num_context) num_target_t = randint(*num_target) x_context, y_context, _, _ = context_target_split( x[0:1], y[0:1], num_context_t, num_target_t) neuralprocess.training = False f3, axarr3 = plt.subplots(2, 2) axarr3[0, 1].scatter(x_context[0].detach().cpu().numpy()[:, 0], x_context[0].detach().cpu().numpy()[:, 1], cmap='viridis', c=y_context[0].detach().cpu().numpy()[:, 0], s=1) axarr3[0, 0].imshow(y[0].view(-1, grid_size).detach().cpu().numpy()[::-1], extent=extent) axarr3[0, 0].set_title('Real function') axarr3[0, 1].set_xlim(extent[0], extent[1]) axarr3[0, 1].set_ylim(extent[2], extent[3]) axarr3[0, 1].set_aspect('equal') axarr3[0, 1].set_title('Context points') mu_list = [] for i in range(2): # Neural process returns distribution over y_target p_y_pred = neuralprocess(x_context, y_context, x[0].unsqueeze(0)) # Extract mean of distribution mu_list.append(p_y_pred.loc.detach()) axarr3[1, 0].imshow(mu_list[0].view( -1, grid_size).detach().cpu().numpy()[::-1], extent=extent) axarr3[1, 0].set_title('Posterior estimate_1') axarr3[1, 1].imshow(mu_list[1].view( -1, grid_size).detach().cpu().numpy()[::-1], extent=extent) axarr3[1, 1].set_title('Posterior estimate_2') plt.savefig(plots_path + kernel + ' posteriior' + id) #plt.show() plt.close(f3)