def get_gaussian_likelihood(self, truth, pred_y, mask=None): # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim] # truth shape [n_traj, n_tp, n_dim] if mask is not None: mask = mask.repeat(pred_y.size(0), 1, 1, 1) # Compute likelihood of the data under the predictions log_density_data = masked_gaussian_log_density( pred_y, truth, obsrv_std=self.obsrv_std, mask=mask) log_density_data = log_density_data.permute(1, 0) # Compute the total density # Take mean over n_traj_samples log_density = torch.mean(log_density_data, 0) # shape: [n_traj] return log_density
def draw_one_density_plot(self, ax, model, data_dict, traj_id, multiply_by_poisson=False): scale = 5 cmap = add_white(plt.cm.get_cmap('Blues', 9)) # plt.cm.BuGn_r cmap2 = add_white(plt.cm.get_cmap('Reds', 9)) # plt.cm.BuGn_r #cmap = plt.cm.get_cmap('viridis') data = data_dict["data_to_predict"] time_steps = data_dict["tp_to_predict"] mask = data_dict["mask_predicted_data"] observed_data = data_dict["observed_data"] observed_time_steps = data_dict["observed_tp"] observed_mask = data_dict["observed_mask"] npts = 50 xx, yy, z0_grid = get_meshgrid(npts=npts, int_y1=(-scale, scale), int_y2=(-scale, scale)) z0_grid = z0_grid.to(get_device(data)) if model.latent_dim > 2: z0_grid = torch.cat( (z0_grid, torch.zeros(z0_grid.size(0), model.latent_dim - 2)), 1) if model.use_poisson_proc: n_traj, n_dims = z0_grid.size() # append a vector of zeros to compute the integral of lambda and also zeros for the first point of lambda zeros = torch.zeros([n_traj, model.input_dim + model.latent_dim ]).to(get_device(data)) z0_grid_aug = torch.cat((z0_grid, zeros), -1) else: z0_grid_aug = z0_grid # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents] sol_y = model.diffeq_solver(z0_grid_aug.unsqueeze(0), time_steps) if model.use_poisson_proc: sol_y, log_lambda_y, int_lambda, _ = model.diffeq_solver.ode_func.extract_poisson_rate( sol_y) assert (torch.sum(int_lambda[:, :, 0, :]) == 0.) assert (torch.sum(int_lambda[0, 0, -1, :] <= 0) == 0.) pred_x = model.decoder(sol_y) # Plot density for one trajectory one_traj = data[traj_id] mask_one_traj = None if mask is not None: mask_one_traj = mask[traj_id].unsqueeze(0) mask_one_traj = mask_one_traj.repeat(npts**2, 1, 1).unsqueeze(0) ax.cla() # Plot: prior prior_density_grid = model.z0_prior.log_prob( z0_grid.unsqueeze(0)).squeeze(0) # Sum the density over two dimensions prior_density_grid = torch.sum(prior_density_grid, -1) # ================================================= # Plot: p(x | y(t0)) masked_gaussian_log_density_grid = masked_gaussian_log_density( pred_x, one_traj.repeat(npts**2, 1, 1).unsqueeze(0), mask=mask_one_traj, obsrv_std=model.obsrv_std).squeeze(-1) # Plot p(t | y(t0)) if model.use_poisson_proc: poisson_info = {} poisson_info["int_lambda"] = int_lambda[:, :, -1, :] poisson_info["log_lambda_y"] = log_lambda_y poisson_log_density_grid = compute_poisson_proc_likelihood( one_traj.repeat(npts**2, 1, 1).unsqueeze(0), pred_x, poisson_info, mask=mask_one_traj) poisson_log_density_grid = poisson_log_density_grid.squeeze(0) # ================================================= # Plot: p(x , y(t0)) log_joint_density = prior_density_grid + masked_gaussian_log_density_grid if multiply_by_poisson: log_joint_density = log_joint_density + poisson_log_density_grid density_grid = torch.exp(log_joint_density) density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1])) density_grid = density_grid.cpu().numpy() ax.contourf(xx, yy, density_grid, cmap=cmap, alpha=1) # ================================================= # Plot: q(y(t0)| x) #self.ax_density.set_title("Red: q(y(t0) | x) Blue: p(x, y(t0))") ax.set_xlabel('z1(t0)') ax.set_ylabel('z2(t0)') data_w_mask = observed_data[traj_id].unsqueeze(0) if observed_mask is not None: data_w_mask = torch.cat( (data_w_mask, observed_mask[traj_id].unsqueeze(0)), -1) z0_mu, z0_std = model.encoder_z0(data_w_mask, observed_time_steps) if model.use_poisson_proc: z0_mu = z0_mu[:, :, :model.latent_dim] z0_std = z0_std[:, :, :model.latent_dim] q_z0 = Normal(z0_mu, z0_std) q_density_grid = q_z0.log_prob(z0_grid) # Sum the density over two dimensions q_density_grid = torch.sum(q_density_grid, -1) density_grid = torch.exp(q_density_grid) density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1])) density_grid = density_grid.cpu().numpy() ax.contourf(xx, yy, density_grid, cmap=cmap2, alpha=0.3)