예제 #1
0
def impute_using_input_decay(data, delta_ts, mask, w_input_decay,
                             b_input_decay):
    n_traj, n_tp, n_dims = data.size()

    cum_delta_ts = delta_ts.repeat(1, 1, n_dims)
    missing_index = np.where(mask.cpu().numpy() == 0)

    data_last_obsv = np.copy(data.cpu().numpy())
    for idx in range(missing_index[0].shape[0]):
        i = missing_index[0][idx]
        j = missing_index[1][idx]
        k = missing_index[2][idx]

        if j != 0 and j != (n_tp - 1):
            cum_delta_ts[i, j + 1,
                         k] = cum_delta_ts[i, j + 1, k] + cum_delta_ts[i, j, k]
        if j != 0:
            data_last_obsv[i, j, k] = data_last_obsv[i, j - 1,
                                                     k]  # last observation
    cum_delta_ts = cum_delta_ts / cum_delta_ts.max()  # normalize

    data_last_obsv = torch.Tensor(data_last_obsv).to(get_device(data))

    zeros = torch.zeros([n_traj, n_tp, n_dims]).to(get_device(data))
    decay = torch.exp(-torch.min(
        torch.max(zeros, w_input_decay * cum_delta_ts + b_input_decay), zeros +
        1000))

    data_means = torch.mean(data, 1).unsqueeze(1)

    data_imputed = data * mask + (1 - mask) * (decay * data_last_obsv +
                                               (1 - decay) * data_means)
    return data_imputed
예제 #2
0
def compute_binary_CE_loss(label_predictions, mortality_label):
    #print("Computing binary classification loss: compute_CE_loss")

    mortality_label = mortality_label.reshape(-1)

    if len(label_predictions.size()) == 1:
        label_predictions = label_predictions.unsqueeze(0)

    n_traj_samples = label_predictions.size(0)
    label_predictions = label_predictions.reshape(n_traj_samples, -1)

    idx_not_nan = 1 - torch.isnan(mortality_label)
    if len(idx_not_nan) == 0.:
        print("All are labels are NaNs!")
        ce_loss = torch.Tensor(0.).to(get_device(mortality_label))

    label_predictions = label_predictions[:, idx_not_nan]
    mortality_label = mortality_label[idx_not_nan]

    if torch.sum(mortality_label == 0.) == 0 or torch.sum(
            mortality_label == 1.) == 0:
        print(
            "Warning: all examples in a batch belong to the same class -- please increase the batch size."
        )

    assert (not torch.isnan(label_predictions).any())
    assert (not torch.isnan(mortality_label).any())

    # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
    mortality_label = mortality_label.repeat(n_traj_samples, 1)
    ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label)

    # divide by number of patients in a batch
    ce_loss = ce_loss / n_traj_samples
    return ce_loss
예제 #3
0
def compute_masked_likelihood(mu, data, mask, likelihood_func):
    # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements
    n_traj_samples, n_traj, n_timepoints, n_dims = data.size()

    res = []
    for i in range(n_traj_samples):
        for k in range(n_traj):
            for j in range(n_dims):
                data_masked = torch.masked_select(data[i, k, :, j],
                                                  mask[i, k, :, j].byte())

                #assert(torch.sum(data_masked == 0.) < 10)

                mu_masked = torch.masked_select(mu[i, k, :, j], mask[i, k, :,
                                                                     j].byte())
                log_prob = likelihood_func(mu_masked,
                                           data_masked,
                                           indices=(i, k, j))
                res.append(log_prob)
    # shape: [n_traj*n_traj_samples, 1]

    res = torch.stack(res, 0).to(get_device(data))
    res = res.reshape((n_traj_samples, n_traj, n_dims))
    # Take mean over the number of dimensions
    res = torch.mean(res, -1)  # !!!!!!!!!!! changed from sum to mean
    res = res.transpose(0, 1)
    return res
예제 #4
0
def mse(mu, data, indices=None):
    n_data_points = mu.size()[-1]

    if n_data_points > 0:
        mse = nn.MSELoss()(mu, data)
    else:
        mse = torch.zeros([1]).to(get_device(data)).squeeze()
    return mse
예제 #5
0
def compute_multiclass_CE_loss(label_predictions, true_label, mask):
    #print("Computing multi-class classification loss: compute_multiclass_CE_loss")

    if (len(label_predictions.size()) == 3):
        label_predictions = label_predictions.unsqueeze(0)

    n_traj_samples, n_traj, n_tp, n_dims = label_predictions.size()

    # assert(not torch.isnan(label_predictions).any())
    # assert(not torch.isnan(true_label).any())

    # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
    true_label = true_label.repeat(n_traj_samples, 1, 1)

    label_predictions = label_predictions.reshape(
        n_traj_samples * n_traj * n_tp, n_dims)
    true_label = true_label.reshape(n_traj_samples * n_traj * n_tp, n_dims)

    # choose time points with at least one measurement
    mask = torch.sum(mask, -1) > 0

    # repeat the mask for each label to mark that the label for this time point is present
    pred_mask = mask.repeat(n_dims, 1, 1).permute(1, 2, 0)

    label_mask = mask
    pred_mask = pred_mask.repeat(n_traj_samples, 1, 1, 1)
    label_mask = label_mask.repeat(n_traj_samples, 1, 1, 1)

    pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp, n_dims)
    label_mask = label_mask.reshape(n_traj_samples * n_traj * n_tp, 1)

    if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1):
        assert (label_predictions.size(-1) == true_label.size(-1))
        # targets are in one-hot encoding -- convert to indices
        _, true_label = true_label.max(-1)

    res = []
    for i in range(true_label.size(0)):
        pred_masked = torch.masked_select(label_predictions[i],
                                          pred_mask[i].byte())
        labels = torch.masked_select(true_label[i], label_mask[i].byte())

        pred_masked = pred_masked.reshape(-1, n_dims)

        if (len(labels) == 0):
            continue

        ce_loss = nn.CrossEntropyLoss()(pred_masked, labels.long())
        res.append(ce_loss)

    ce_loss = torch.stack(res, 0).to(get_device(label_predictions))
    ce_loss = torch.mean(ce_loss)
    # # divide by number of patients in a batch
    # ce_loss = ce_loss / n_traj_samples
    return ce_loss
예제 #6
0
def poisson_log_likelihood(masked_log_lambdas, masked_data, indices,
                           int_lambdas):
    # masked_log_lambdas and masked_data
    n_data_points = masked_data.size()[-1]

    if n_data_points > 0:
        log_prob = torch.sum(masked_log_lambdas) - int_lambdas[indices]
        #log_prob = log_prob / n_data_points
    else:
        log_prob = torch.zeros([1]).to(get_device(masked_data)).squeeze()
    return log_prob
예제 #7
0
def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices=None):
    n_data_points = mu_2d.size()[-1]

    if n_data_points > 0:
        gaussian = Independent(
            Normal(loc=mu_2d, scale=obsrv_std.repeat(n_data_points)), 1)
        log_prob = gaussian.log_prob(data_2d)
        log_prob = log_prob / n_data_points
    else:
        log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze()
    return log_prob
예제 #8
0
    def draw_all_plots_one_dim(self,
                               data_dict,
                               model,
                               plot_name="",
                               save=False,
                               experimentID=0.):

        data = data_dict["data_to_predict"]
        time_steps = data_dict["tp_to_predict"]
        mask = data_dict["mask_predicted_data"]

        observed_data = data_dict["observed_data"]
        observed_time_steps = data_dict["observed_tp"]
        observed_mask = data_dict["observed_mask"]

        device = get_device(time_steps)

        time_steps_to_predict = time_steps
        if isinstance(model, LatentODE):
            # sample at the original time points
            time_steps_to_predict = utils.linspace_vector(
                time_steps[0], time_steps[-1], 100).to(device)

        reconstructions, info = model.get_reconstruction(time_steps_to_predict,
                                                         observed_data,
                                                         observed_time_steps,
                                                         mask=observed_mask,
                                                         n_traj_samples=10)

        n_traj_to_show = 3
        # plot only 10 trajectories
        data_for_plotting = observed_data[:n_traj_to_show]
        mask_for_plotting = observed_mask[:n_traj_to_show]
        reconstructions_for_plotting = reconstructions.mean(
            dim=0)[:n_traj_to_show]
        reconstr_std = reconstructions.std(dim=0)[:n_traj_to_show]

        dim_to_show = 0
        max_y = max(data_for_plotting[:, :, dim_to_show].cpu().numpy().max(),
                    reconstructions[:, :, dim_to_show].cpu().numpy().max())
        min_y = min(data_for_plotting[:, :, dim_to_show].cpu().numpy().min(),
                    reconstructions[:, :, dim_to_show].cpu().numpy().min())

        ############################################
        # Plot reconstructions, true postrior and approximate posterior

        cmap = plt.cm.get_cmap('Set1')
        for traj_id in range(3):
            # Plot observations
            plot_trajectories(
                self.ax_traj[traj_id],
                data_for_plotting[traj_id].unsqueeze(0),
                observed_time_steps,
                mask=mask_for_plotting[traj_id].unsqueeze(0),
                min_y=min_y,
                max_y=max_y,  #title="True trajectories", 
                marker='o',
                linestyle='',
                dim_to_show=dim_to_show,
                color=cmap(2))
            # Plot reconstructions
            plot_trajectories(
                self.ax_traj[traj_id],
                reconstructions_for_plotting[traj_id].unsqueeze(0),
                time_steps_to_predict,
                min_y=min_y,
                max_y=max_y,
                title="Sample {} (data space)".format(traj_id),
                dim_to_show=dim_to_show,
                add_to_plot=True,
                marker='',
                color=cmap(3),
                linewidth=3)
            # Plot variance estimated over multiple samples from approx posterior
            plot_std(self.ax_traj[traj_id],
                     reconstructions_for_plotting[traj_id].unsqueeze(0),
                     reconstr_std[traj_id].unsqueeze(0),
                     time_steps_to_predict,
                     alpha=0.5,
                     color=cmap(3))
            self.set_plot_lims(self.ax_traj[traj_id], "traj_" + str(traj_id))

            # Plot true posterior and approximate posterior
            # self.draw_one_density_plot(self.ax_density[traj_id],
            # 	model, data_dict, traj_id = traj_id,
            # 	multiply_by_poisson = False)
            # self.set_plot_lims(self.ax_density[traj_id], "density_" + str(traj_id))
            # self.ax_density[traj_id].set_title("Sample {}: p(z0) and q(z0 | x)".format(traj_id))
        ############################################
        # Get several samples for the same trajectory
        # one_traj = data_for_plotting[:1]
        # first_point = one_traj[:,0]

        # samples_same_traj, _ = model.get_reconstruction(time_steps_to_predict,
        # 	observed_data[:1], observed_time_steps, mask = observed_mask[:1], n_traj_samples = 5)
        # samples_same_traj = samples_same_traj.squeeze(1)

        # plot_trajectories(self.ax_samples_same_traj, samples_same_traj, time_steps_to_predict, marker = '')
        # plot_trajectories(self.ax_samples_same_traj, one_traj, time_steps, linestyle = "",
        # 	label = "True traj", add_to_plot = True, title="Reconstructions for the same trajectory (data space)")

        ############################################
        # Plot trajectories from prior

        if isinstance(model, LatentODE):
            torch.manual_seed(1991)
            np.random.seed(1991)

            traj_from_prior = model.sample_traj_from_prior(
                time_steps_to_predict, n_traj_samples=3)
            # Since in this case n_traj = 1, n_traj_samples -- requested number of samples from the prior, squeeze n_traj dimension
            traj_from_prior = traj_from_prior.squeeze(1)

            plot_trajectories(self.ax_traj_from_prior,
                              traj_from_prior,
                              time_steps_to_predict,
                              marker='',
                              linewidth=3)
            self.ax_traj_from_prior.set_title(
                "Samples from prior (data space)", pad=20)
            #self.set_plot_lims(self.ax_traj_from_prior, "traj_from_prior")
        ################################################

        # Plot z0
        # first_point_mu, first_point_std, first_point_enc = info["first_point"]

        # dim1 = 0
        # dim2 = 1
        # self.ax_z0.cla()
        # # first_point_enc shape: [1, n_traj, n_dims]
        # self.ax_z0.scatter(first_point_enc.cpu()[0,:,dim1], first_point_enc.cpu()[0,:,dim2])
        # self.ax_z0.set_title("Encodings z0 of all test trajectories (latent space)")
        # self.ax_z0.set_xlabel('dim {}'.format(dim1))
        # self.ax_z0.set_ylabel('dim {}'.format(dim2))

        ################################################
        # Show vector field
        self.ax_vector_field.cla()
        plot_vector_field(self.ax_vector_field, model.diffeq_solver.ode_func,
                          model.latent_dim, device)
        self.ax_vector_field.set_title("Slice of vector field (latent space)",
                                       pad=20)
        self.set_plot_lims(self.ax_vector_field, "vector_field")
        #self.ax_vector_field.set_ylim((-0.5, 1.5))

        ################################################
        # Plot trajectories in the latent space

        # shape before [1, n_traj, n_tp, n_latent_dims]
        # Take only the first sample from approx posterior
        latent_traj = info["latent_traj"][0, :n_traj_to_show]
        # shape before permute: [1, n_tp, n_latent_dims]

        self.ax_latent_traj.cla()
        cmap = plt.cm.get_cmap('Accent')
        n_latent_dims = latent_traj.size(-1)

        custom_labels = {}
        for i in range(n_latent_dims):
            col = cmap(i)
            plot_trajectories(self.ax_latent_traj,
                              latent_traj,
                              time_steps_to_predict,
                              title="Latent trajectories z(t) (latent space)",
                              dim_to_show=i,
                              color=col,
                              marker='',
                              add_to_plot=True,
                              linewidth=3)
            custom_labels['dim ' + str(i)] = Line2D([0], [0], color=col)

        self.ax_latent_traj.set_ylabel("z")
        self.ax_latent_traj.set_title(
            "Latent trajectories z(t) (latent space)", pad=20)
        self.ax_latent_traj.legend(custom_labels.values(),
                                   custom_labels.keys(),
                                   loc='lower left')
        self.set_plot_lims(self.ax_latent_traj, "latent_traj")

        ################################################

        self.fig.tight_layout()
        plt.draw()

        if save:
            dirname = "plots/" + str(experimentID) + "/"
            os.makedirs(dirname, exist_ok=True)
            self.fig.savefig(dirname + plot_name)
예제 #9
0
    def draw_one_density_plot(self,
                              ax,
                              model,
                              data_dict,
                              traj_id,
                              multiply_by_poisson=False):

        scale = 5
        cmap = add_white(plt.cm.get_cmap('Blues', 9))  # plt.cm.BuGn_r
        cmap2 = add_white(plt.cm.get_cmap('Reds', 9))  # plt.cm.BuGn_r
        #cmap = plt.cm.get_cmap('viridis')

        data = data_dict["data_to_predict"]
        time_steps = data_dict["tp_to_predict"]
        mask = data_dict["mask_predicted_data"]

        observed_data = data_dict["observed_data"]
        observed_time_steps = data_dict["observed_tp"]
        observed_mask = data_dict["observed_mask"]

        npts = 50
        xx, yy, z0_grid = get_meshgrid(npts=npts,
                                       int_y1=(-scale, scale),
                                       int_y2=(-scale, scale))
        z0_grid = z0_grid.to(get_device(data))

        if model.latent_dim > 2:
            z0_grid = torch.cat(
                (z0_grid, torch.zeros(z0_grid.size(0), model.latent_dim - 2)),
                1)

        if model.use_poisson_proc:
            n_traj, n_dims = z0_grid.size()
            # append a vector of zeros to compute the integral of lambda and also zeros for the first point of lambda
            zeros = torch.zeros([n_traj, model.input_dim + model.latent_dim
                                 ]).to(get_device(data))
            z0_grid_aug = torch.cat((z0_grid, zeros), -1)
        else:
            z0_grid_aug = z0_grid

        # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents]
        sol_y = model.diffeq_solver(z0_grid_aug.unsqueeze(0), time_steps)

        if model.use_poisson_proc:
            sol_y, log_lambda_y, int_lambda, _ = model.diffeq_solver.ode_func.extract_poisson_rate(
                sol_y)

            assert (torch.sum(int_lambda[:, :, 0, :]) == 0.)
            assert (torch.sum(int_lambda[0, 0, -1, :] <= 0) == 0.)

        pred_x = model.decoder(sol_y)

        # Plot density for one trajectory
        one_traj = data[traj_id]
        mask_one_traj = None
        if mask is not None:
            mask_one_traj = mask[traj_id].unsqueeze(0)
            mask_one_traj = mask_one_traj.repeat(npts**2, 1, 1).unsqueeze(0)

        ax.cla()

        # Plot: prior
        prior_density_grid = model.z0_prior.log_prob(
            z0_grid.unsqueeze(0)).squeeze(0)
        # Sum the density over two dimensions
        prior_density_grid = torch.sum(prior_density_grid, -1)

        # =================================================
        # Plot: p(x | y(t0))

        masked_gaussian_log_density_grid = masked_gaussian_log_density(
            pred_x,
            one_traj.repeat(npts**2, 1, 1).unsqueeze(0),
            mask=mask_one_traj,
            obsrv_std=model.obsrv_std).squeeze(-1)

        # Plot p(t | y(t0))
        if model.use_poisson_proc:
            poisson_info = {}
            poisson_info["int_lambda"] = int_lambda[:, :, -1, :]
            poisson_info["log_lambda_y"] = log_lambda_y

            poisson_log_density_grid = compute_poisson_proc_likelihood(
                one_traj.repeat(npts**2, 1, 1).unsqueeze(0),
                pred_x,
                poisson_info,
                mask=mask_one_traj)
            poisson_log_density_grid = poisson_log_density_grid.squeeze(0)

        # =================================================
        # Plot: p(x , y(t0))

        log_joint_density = prior_density_grid + masked_gaussian_log_density_grid
        if multiply_by_poisson:
            log_joint_density = log_joint_density + poisson_log_density_grid

        density_grid = torch.exp(log_joint_density)

        density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1]))
        density_grid = density_grid.cpu().numpy()

        ax.contourf(xx, yy, density_grid, cmap=cmap, alpha=1)

        # =================================================
        # Plot: q(y(t0)| x)
        #self.ax_density.set_title("Red: q(y(t0) | x)    Blue: p(x, y(t0))")
        ax.set_xlabel('z1(t0)')
        ax.set_ylabel('z2(t0)')

        data_w_mask = observed_data[traj_id].unsqueeze(0)
        if observed_mask is not None:
            data_w_mask = torch.cat(
                (data_w_mask, observed_mask[traj_id].unsqueeze(0)), -1)
        z0_mu, z0_std = model.encoder_z0(data_w_mask, observed_time_steps)

        if model.use_poisson_proc:
            z0_mu = z0_mu[:, :, :model.latent_dim]
            z0_std = z0_std[:, :, :model.latent_dim]

        q_z0 = Normal(z0_mu, z0_std)

        q_density_grid = q_z0.log_prob(z0_grid)
        # Sum the density over two dimensions
        q_density_grid = torch.sum(q_density_grid, -1)
        density_grid = torch.exp(q_density_grid)

        density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1]))
        density_grid = density_grid.cpu().numpy()

        ax.contourf(xx, yy, density_grid, cmap=cmap2, alpha=0.3)
예제 #10
0
def run_rnn(inputs,
            delta_ts,
            cell,
            first_hidden=None,
            mask=None,
            feed_previous=False,
            n_steps=0,
            decoder=None,
            input_decay_params=None,
            feed_previous_w_prob=0.,
            masked_update=True):
    if (feed_previous or feed_previous_w_prob) and decoder is None:
        raise Exception(
            "feed_previous is set to True -- please specify RNN decoder")

    if n_steps == 0:
        n_steps = inputs.size(1)

    if (feed_previous or feed_previous_w_prob) and mask is None:
        mask = torch.ones(
            (inputs.size(0), n_steps, inputs.size(-1))).to(get_device(inputs))

    if isinstance(cell, GRUCellExpDecay):
        cum_delta_ts = get_cum_delta_ts(inputs, delta_ts, mask)

    if input_decay_params is not None:
        w_input_decay, b_input_decay = input_decay_params
        inputs = impute_using_input_decay(inputs, delta_ts, mask,
                                          w_input_decay, b_input_decay)

    all_hiddens = []
    hidden = first_hidden

    if hidden is not None:
        all_hiddens.append(hidden)
        n_steps -= 1

    for i in range(n_steps):
        delta_t = delta_ts[:, i]
        if i == 0:
            rnn_input = inputs[:, i]
        elif feed_previous:
            rnn_input = decoder(hidden)
        elif feed_previous_w_prob > 0:
            feed_prev = np.random.uniform() > feed_previous_w_prob
            if feed_prev:
                rnn_input = decoder(hidden)
            else:
                rnn_input = inputs[:, i]
        else:
            rnn_input = inputs[:, i]

        if mask is not None:
            mask_i = mask[:, i, :]
            rnn_input = torch.cat((rnn_input, mask_i), -1)

        if isinstance(cell, GRUCellExpDecay):
            cum_delta_t = cum_delta_ts[:, i]
            input_w_t = torch.cat((rnn_input, cum_delta_t), -1).squeeze(1)
        else:
            input_w_t = torch.cat((rnn_input, delta_t), -1).squeeze(1)

        prev_hidden = hidden
        hidden = cell(input_w_t, hidden)

        if masked_update and (mask is not None) and (prev_hidden is not None):
            # update only the hidden states for hidden state only if at least one feature is present for the current time point
            summed_mask = (torch.sum(mask_i, -1, keepdim=True) > 0).float()
            assert (not torch.isnan(summed_mask).any())
            hidden = summed_mask * hidden + (1 - summed_mask) * prev_hidden

        all_hiddens.append(hidden)

    all_hiddens = torch.stack(all_hiddens, 0)
    all_hiddens = all_hiddens.permute(1, 0, 2).unsqueeze(0)
    return hidden, all_hiddens
예제 #11
0
    def get_reconstruction(self,
                           time_steps_to_predict,
                           truth,
                           truth_time_steps,
                           mask=None,
                           n_traj_samples=1,
                           run_backwards=True,
                           mode=None):

        if isinstance(self.encoder_z0, Encoder_z0_ODE_RNN) or \
         isinstance(self.encoder_z0, Encoder_z0_RNN):

            truth_w_mask = truth
            if mask is not None:
                truth_w_mask = torch.cat((truth, mask), -1)
            first_point_mu, first_point_std = self.encoder_z0(
                truth_w_mask, truth_time_steps, run_backwards=run_backwards)

            means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1)
            sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1)
            first_point_enc = utils.sample_standard_gaussian(
                means_z0, sigma_z0)

        else:
            raise Exception("Unknown encoder type {}".format(
                type(self.encoder_z0).__name__))

        first_point_std = first_point_std.abs()
        assert (torch.sum(first_point_std < 0) == 0.)

        if self.use_poisson_proc:
            n_traj_samples, n_traj, n_dims = first_point_enc.size()
            # append a vector of zeros to compute the integral of lambda
            zeros = torch.zeros([n_traj_samples, n_traj,
                                 self.input_dim]).to(get_device(truth))
            first_point_enc_aug = torch.cat((first_point_enc, zeros), -1)
            means_z0_aug = torch.cat((means_z0, zeros), -1)
        else:
            first_point_enc_aug = first_point_enc
            means_z0_aug = means_z0

        assert (not torch.isnan(time_steps_to_predict).any())
        assert (not torch.isnan(first_point_enc).any())
        assert (not torch.isnan(first_point_enc_aug).any())

        # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents]
        sol_y = self.diffeq_solver(first_point_enc_aug, time_steps_to_predict)

        if self.use_poisson_proc:
            sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate(
                sol_y)

            assert (torch.sum(int_lambda[:, :, 0, :]) == 0.)
            assert (torch.sum(int_lambda[0, 0, -1, :] <= 0) == 0.)

        pred_x = self.decoder(sol_y)

        all_extra_info = {
            "first_point": (first_point_mu, first_point_std, first_point_enc),
            "latent_traj": sol_y.detach()
        }

        if self.use_poisson_proc:
            # intergral of lambda from the last step of ODE Solver
            all_extra_info["int_lambda"] = int_lambda[:, :, -1, :]
            all_extra_info["log_lambda_y"] = log_lambda_y

        if self.use_binary_classif:
            if self.classif_per_tp:
                all_extra_info["label_predictions"] = self.classifier(sol_y)
            else:
                all_extra_info["label_predictions"] = self.classifier(
                    first_point_enc).squeeze(-1)

        return pred_x, all_extra_info
예제 #12
0
    def run_odernn(self,
                   data,
                   time_steps,
                   run_backwards=True,
                   save_info=False):
        # IMPORTANT: assumes that 'data' already has mask concatenated to it

        n_traj, n_tp, n_dims = data.size()
        extra_info = []

        t0 = time_steps[-1]
        if run_backwards:
            t0 = time_steps[0]

        device = get_device(data)

        prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device)
        prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device)

        prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1]

        interval_length = time_steps[-1] - time_steps[0]
        minimum_step = interval_length / 50

        #print("minimum step: {}".format(minimum_step))

        assert (not torch.isnan(data).any())
        assert (not torch.isnan(time_steps).any())

        latent_ys = []
        # Run ODE backwards and combine the y(t) estimates using gating
        time_points_iter = range(0, len(time_steps))
        if run_backwards:
            time_points_iter = reversed(time_points_iter)

        for i in time_points_iter:
            if (prev_t - t_i) < minimum_step:
                time_points = torch.stack((prev_t, t_i))
                inc = self.z0_diffeq_solver.ode_func(prev_t,
                                                     prev_y) * (t_i - prev_t)

                assert (not torch.isnan(inc).any())

                ode_sol = prev_y + inc
                ode_sol = torch.stack((prev_y, ode_sol), 2).to(device)

                assert (not torch.isnan(ode_sol).any())
            else:
                n_intermediate_tp = max(2,
                                        ((prev_t - t_i) / minimum_step).int())

                time_points = utils.linspace_vector(prev_t, t_i,
                                                    n_intermediate_tp)
                ode_sol = self.z0_diffeq_solver(prev_y, time_points)

                assert (not torch.isnan(ode_sol).any())

            if torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001:
                print(
                    "Error: first point of the ODE is not equal to initial value"
                )
                print(torch.mean(ode_sol[:, :, 0, :] - prev_y))
                exit()
            #assert(torch.mean(ode_sol[:, :, 0, :]  - prev_y) < 0.001)

            yi_ode = ode_sol[:, :, -1, :]
            xi = data[:, i, :].unsqueeze(0)

            yi, yi_std = self.GRU_update(yi_ode, prev_std, xi)

            prev_y, prev_std = yi, yi_std
            prev_t, t_i = time_steps[i], time_steps[i - 1]

            latent_ys.append(yi)

            if save_info:
                d = {
                    "yi_ode": yi_ode.detach(),  #"yi_from_data": yi_from_data,
                    "yi": yi.detach(),
                    "yi_std": yi_std.detach(),
                    "time_points": time_points.detach(),
                    "ode_sol": ode_sol.detach()
                }
                extra_info.append(d)

        latent_ys = torch.stack(latent_ys, 1)

        assert (not torch.isnan(yi).any())
        assert (not torch.isnan(yi_std).any())

        return yi, yi_std, latent_ys, extra_info