Esempio n. 1
0
    def get_gaussian_likelihood(self, truth, pred_y, mask=None):
        # pred_y shape [n_traj_samples, n_traj, n_tp, n_dim]
        # truth shape  [n_traj, n_tp, n_dim]
        if mask is not None:
            mask = mask.repeat(pred_y.size(0), 1, 1, 1)

        # Compute likelihood of the data under the predictions
        log_density_data = masked_gaussian_log_density(
            pred_y, truth, obsrv_std=self.obsrv_std, mask=mask)
        log_density_data = log_density_data.permute(1, 0)

        # Compute the total density
        # Take mean over n_traj_samples
        log_density = torch.mean(log_density_data, 0)

        # shape: [n_traj]
        return log_density
Esempio n. 2
0
    def draw_one_density_plot(self,
                              ax,
                              model,
                              data_dict,
                              traj_id,
                              multiply_by_poisson=False):

        scale = 5
        cmap = add_white(plt.cm.get_cmap('Blues', 9))  # plt.cm.BuGn_r
        cmap2 = add_white(plt.cm.get_cmap('Reds', 9))  # plt.cm.BuGn_r
        #cmap = plt.cm.get_cmap('viridis')

        data = data_dict["data_to_predict"]
        time_steps = data_dict["tp_to_predict"]
        mask = data_dict["mask_predicted_data"]

        observed_data = data_dict["observed_data"]
        observed_time_steps = data_dict["observed_tp"]
        observed_mask = data_dict["observed_mask"]

        npts = 50
        xx, yy, z0_grid = get_meshgrid(npts=npts,
                                       int_y1=(-scale, scale),
                                       int_y2=(-scale, scale))
        z0_grid = z0_grid.to(get_device(data))

        if model.latent_dim > 2:
            z0_grid = torch.cat(
                (z0_grid, torch.zeros(z0_grid.size(0), model.latent_dim - 2)),
                1)

        if model.use_poisson_proc:
            n_traj, n_dims = z0_grid.size()
            # append a vector of zeros to compute the integral of lambda and also zeros for the first point of lambda
            zeros = torch.zeros([n_traj, model.input_dim + model.latent_dim
                                 ]).to(get_device(data))
            z0_grid_aug = torch.cat((z0_grid, zeros), -1)
        else:
            z0_grid_aug = z0_grid

        # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents]
        sol_y = model.diffeq_solver(z0_grid_aug.unsqueeze(0), time_steps)

        if model.use_poisson_proc:
            sol_y, log_lambda_y, int_lambda, _ = model.diffeq_solver.ode_func.extract_poisson_rate(
                sol_y)

            assert (torch.sum(int_lambda[:, :, 0, :]) == 0.)
            assert (torch.sum(int_lambda[0, 0, -1, :] <= 0) == 0.)

        pred_x = model.decoder(sol_y)

        # Plot density for one trajectory
        one_traj = data[traj_id]
        mask_one_traj = None
        if mask is not None:
            mask_one_traj = mask[traj_id].unsqueeze(0)
            mask_one_traj = mask_one_traj.repeat(npts**2, 1, 1).unsqueeze(0)

        ax.cla()

        # Plot: prior
        prior_density_grid = model.z0_prior.log_prob(
            z0_grid.unsqueeze(0)).squeeze(0)
        # Sum the density over two dimensions
        prior_density_grid = torch.sum(prior_density_grid, -1)

        # =================================================
        # Plot: p(x | y(t0))

        masked_gaussian_log_density_grid = masked_gaussian_log_density(
            pred_x,
            one_traj.repeat(npts**2, 1, 1).unsqueeze(0),
            mask=mask_one_traj,
            obsrv_std=model.obsrv_std).squeeze(-1)

        # Plot p(t | y(t0))
        if model.use_poisson_proc:
            poisson_info = {}
            poisson_info["int_lambda"] = int_lambda[:, :, -1, :]
            poisson_info["log_lambda_y"] = log_lambda_y

            poisson_log_density_grid = compute_poisson_proc_likelihood(
                one_traj.repeat(npts**2, 1, 1).unsqueeze(0),
                pred_x,
                poisson_info,
                mask=mask_one_traj)
            poisson_log_density_grid = poisson_log_density_grid.squeeze(0)

        # =================================================
        # Plot: p(x , y(t0))

        log_joint_density = prior_density_grid + masked_gaussian_log_density_grid
        if multiply_by_poisson:
            log_joint_density = log_joint_density + poisson_log_density_grid

        density_grid = torch.exp(log_joint_density)

        density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1]))
        density_grid = density_grid.cpu().numpy()

        ax.contourf(xx, yy, density_grid, cmap=cmap, alpha=1)

        # =================================================
        # Plot: q(y(t0)| x)
        #self.ax_density.set_title("Red: q(y(t0) | x)    Blue: p(x, y(t0))")
        ax.set_xlabel('z1(t0)')
        ax.set_ylabel('z2(t0)')

        data_w_mask = observed_data[traj_id].unsqueeze(0)
        if observed_mask is not None:
            data_w_mask = torch.cat(
                (data_w_mask, observed_mask[traj_id].unsqueeze(0)), -1)
        z0_mu, z0_std = model.encoder_z0(data_w_mask, observed_time_steps)

        if model.use_poisson_proc:
            z0_mu = z0_mu[:, :, :model.latent_dim]
            z0_std = z0_std[:, :, :model.latent_dim]

        q_z0 = Normal(z0_mu, z0_std)

        q_density_grid = q_z0.log_prob(z0_grid)
        # Sum the density over two dimensions
        q_density_grid = torch.sum(q_density_grid, -1)
        density_grid = torch.exp(q_density_grid)

        density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1]))
        density_grid = density_grid.cpu().numpy()

        ax.contourf(xx, yy, density_grid, cmap=cmap2, alpha=0.3)