예제 #1
0
def impute_using_input_decay(data, delta_ts, mask, w_input_decay,
                             b_input_decay):
    n_traj, n_tp, n_dims = data.size()
    cum_delta_ts = delta_ts.repeat(1, 1, n_dims)
    missing_index = np.where(mask.cpu().numpy() == 0)

    cum_delta_ts, missing_index, data_last_obsv = loop(
        cum_delta_ts.cpu().numpy(), missing_index, np.copy(data.cpu().numpy()),
        n_tp)

    cum_delta_ts = torch.tensor(cum_delta_ts).to(get_device(data))
    data_last_obsv = torch.Tensor(data_last_obsv).to(get_device(data))

    cum_delta_ts = cum_delta_ts / cum_delta_ts.max()  # normalize

    zeros = torch.zeros([n_traj, n_tp, n_dims]).to(get_device(data))
    decay = torch.exp(-torch.min(
        torch.max(zeros, w_input_decay * cum_delta_ts + b_input_decay), zeros +
        1000))

    data_means = torch.mean(data, 1).unsqueeze(1)

    data_imputed = data * mask + (1 - mask) * (decay * data_last_obsv +
                                               (1 - decay) * data_means)
    return data_imputed
def impute_using_input_decay(data, delta_ts, mask, w_input_decay,
                             b_input_decay):
    n_traj, n_tp, n_dims = data.size()

    cum_delta_ts = delta_ts.repeat(1, 1, n_dims)
    missing_index = np.where(mask.cpu().numpy() == 0)

    data_last_obsv = np.copy(data.cpu().numpy())
    for idx in range(missing_index[0].shape[0]):
        i = missing_index[0][idx]
        j = missing_index[1][idx]
        k = missing_index[2][idx]

        if j != 0 and j != (n_tp - 1):
            cum_delta_ts[i, j + 1,
                         k] = cum_delta_ts[i, j + 1, k] + cum_delta_ts[i, j, k]
        if j != 0:
            data_last_obsv[i, j, k] = data_last_obsv[i, j - 1,
                                                     k]  # last observation
    cum_delta_ts = cum_delta_ts / cum_delta_ts.max()  # normalize

    data_last_obsv = torch.Tensor(data_last_obsv).to(get_device(data))

    zeros = torch.zeros([n_traj, n_tp, n_dims]).to(get_device(data))
    decay = torch.exp(-torch.min(
        torch.max(zeros, w_input_decay * cum_delta_ts + b_input_decay), zeros +
        1000))

    data_means = torch.mean(data, 1).unsqueeze(1)

    data_imputed = data * mask + (1 - mask) * (decay * data_last_obsv +
                                               (1 - decay) * data_means)
    return data_imputed
예제 #3
0
def compute_masked_likelihood(mu, data, mask, likelihood_func):
    # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements
    n_traj_samples, n_traj, n_timepoints, n_dims = data.size()

    res = []
    for i in range(n_traj_samples):
        for k in range(n_traj):
            for j in range(n_dims):
                data_masked = torch.masked_select(data[i, k, :, j],
                                                  mask[i, k, :, j].bool())

                #assert(torch.sum(data_masked == 0.) < 10)

                mu_masked = torch.masked_select(mu[i, k, :, j], mask[i, k, :,
                                                                     j].bool())
                log_prob = likelihood_func(mu_masked,
                                           data_masked,
                                           indices=(i, k, j))
                res.append(log_prob)
    # shape: [n_traj*n_traj_samples, 1]

    res = torch.stack(res, 0).to(get_device(data))
    res = res.reshape((n_traj_samples, n_traj, n_dims))
    # Take mean over the number of dimensions
    res = torch.mean(res, -1)  # !!!!!!!!!!! changed from sum to mean
    res = res.transpose(0, 1)
    return res
예제 #4
0
def compute_binary_CE_loss(label_predictions, mortality_label):
    #print("Computing binary classification loss: compute_CE_loss")

    mortality_label = mortality_label.reshape(-1)

    if len(label_predictions.size()) == 1:
        label_predictions = label_predictions.unsqueeze(0)

    n_traj_samples = label_predictions.size(0)
    label_predictions = label_predictions.reshape(n_traj_samples, -1)

    idx_not_nan = ~torch.isnan(mortality_label)
    if len(idx_not_nan) == 0.:
        print("All are labels are NaNs!")
        ce_loss = torch.Tensor(0.).to(get_device(mortality_label))

    label_predictions = label_predictions[:, idx_not_nan]
    mortality_label = mortality_label[idx_not_nan]

    if torch.sum(mortality_label == 0.) == 0 or torch.sum(
            mortality_label == 1.) == 0:
        print(
            "Warning: all examples in a batch belong to the same class -- please increase the batch size."
        )

    assert (not torch.isnan(label_predictions).any())
    assert (not torch.isnan(mortality_label).any())

    # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
    mortality_label = mortality_label.repeat(n_traj_samples, 1)
    ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label)

    # divide by number of patients in a batch
    ce_loss = ce_loss / n_traj_samples
    return ce_loss
예제 #5
0
def mse(mu, data, indices=None):
    n_data_points = mu.size()[-1]

    if n_data_points > 0:
        mse = nn.MSELoss()(mu, data)
    else:
        mse = torch.zeros([1]).to(get_device(data)).squeeze()
    return mse
def compute_multiclass_CE_loss(label_predictions, true_label, mask):
    #print("Computing multi-class classification loss: compute_multiclass_CE_loss")

    if (len(label_predictions.size()) == 3):
        label_predictions = label_predictions.unsqueeze(0)

    n_traj_samples, n_traj, n_tp, n_dims = label_predictions.size()

    # assert(not torch.isnan(label_predictions).any())
    # assert(not torch.isnan(true_label).any())

    # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
    #true_label = true_label[:,-1,:]
    true_label = true_label.repeat(n_traj_samples, 1, 1)

    label_predictions = label_predictions.reshape(
        n_traj_samples * n_traj * n_tp, n_dims)
    true_label = true_label.reshape(n_traj_samples * n_traj * n_tp,
                                    n_dims)  # 1 for meld, n_dims for others

    # choose time points with at least one measurement
    mask = torch.sum(mask, -1) > 0

    # repeat the mask for each label to mark that the label for this time point is present
    #mask = mask[:,-1].reshape(-1,1)
    pred_mask = mask.repeat(n_dims, 1, 1).permute(1, 2, 0)

    label_mask = mask
    pred_mask = pred_mask.repeat(n_traj_samples, 1, 1, 1)
    label_mask = label_mask.repeat(n_traj_samples, 1, 1, 1)

    pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp, n_dims)
    label_mask = label_mask.reshape(n_traj_samples * n_traj * n_tp, 1)

    if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1):
        assert (label_predictions.size(-1) == true_label.size(-1))
        # targets are in one-hot encoding -- convert to indices
        _, true_label = true_label.max(-1)

    res = []
    for i in range(true_label.size(0)):
        pred_masked = torch.masked_select(label_predictions[i],
                                          pred_mask[i].bool())
        labels = torch.masked_select(true_label[i], label_mask[i].bool())

        pred_masked = pred_masked.reshape(-1, n_dims)

        if (len(labels) == 0):
            continue

        ce_loss = nn.CrossEntropyLoss()(pred_masked, labels.long())
        res.append(ce_loss)

    ce_loss = torch.stack(res, 0).to(get_device(label_predictions))
    ce_loss = torch.mean(ce_loss)
    # # divide by number of patients in a batch
    # ce_loss = ce_loss / n_traj_samples
    return ce_loss
예제 #7
0
def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices = None):
	n_data_points = mu_2d.size()[-1]

	if n_data_points > 0:
		gaussian = Independent(Normal(loc = mu_2d, scale = obsrv_std.repeat(n_data_points)), 1)
		log_prob = gaussian.log_prob(data_2d) 
		log_prob = log_prob / n_data_points 
	else:
		log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze()
	return log_prob
예제 #8
0
def poisson_log_likelihood(masked_log_lambdas, masked_data, indices, int_lambdas):
	# masked_log_lambdas and masked_data 
	n_data_points = masked_data.size()[-1]

	if n_data_points > 0:
		log_prob = torch.sum(masked_log_lambdas) - int_lambdas[indices]
		#log_prob = log_prob / n_data_points
	else:
		log_prob = torch.zeros([1]).to(get_device(masked_data)).squeeze()
	return log_prob
예제 #9
0
def get_cum_delta_ts(data, delta_ts, mask):
    n_traj, n_tp, n_dims = data.size()

    cum_delta_ts = delta_ts.repeat(1, 1, n_dims).cpu().numpy()
    missing_index = np.where(mask.cpu().numpy() == 0)

    cum_delta_ts = loop_delta(missing_index, cum_delta_ts, n_tp)
    cum_delta_ts = torch.tensor(cum_delta_ts).to(get_device(data))
    cum_delta_ts = cum_delta_ts / cum_delta_ts.max()  # normalize

    return cum_delta_ts
예제 #10
0
    def get_ode_gradient_nn(self, t_local, y):
        if self.interpolation is None:
            raise Exception(
                'Derivative of spline interpolation should be specified before evaluating the CDE func'
            )

        # region Following the 3. method in Neural Controlled Differential Equations
        f_theta = self.cde_func(y).reshape(y.shape[:-1] + (self.latent_dim,
                                                           self.input_dim + 1))
        dx_dt = self.interpolation.derivative(t_local)
        dxt_dt = torch.cat((dx_dt, torch.tensor(1.0).to(
            utils.get_device(y)).repeat(dx_dt.shape[:-1] + (1, ))),
                           dim=-1).unsqueeze(-1)
        if len(f_theta.shape) == 4:
            dxt_dt = dxt_dt.repeat((f_theta.shape[0], 1, 1, 1))
        return (f_theta @ dxt_dt).squeeze(-1)
예제 #11
0
def compute_binary_CE_loss(label_predictions, mortality_label):
	#print("Computing binary classification loss: compute_CE_loss")

	mortality_label = mortality_label.reshape(-1)

	if len(label_predictions.size()) == 1:
		label_predictions = label_predictions.unsqueeze(0)
 
	n_traj_samples = label_predictions.size(0)
	label_predictions = label_predictions.reshape(n_traj_samples, -1)
	
	""" Nando's comment:
	The following line caused a RunTime error: Subtraction, the `-` operator, with a bool tensor
	is not supported. If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.
	idx_not_nan = 1 - torch.isnan(mortality_label) but I just used ~
	
	"""
	idx_not_nan = ~torch.isnan(mortality_label)
	if len(idx_not_nan) == 0.:
		print("All are labels are NaNs!")
		ce_loss = torch.Tensor(0.).to(get_device(mortality_label))

	label_predictions = label_predictions[:,idx_not_nan]
	mortality_label = mortality_label[idx_not_nan]

	if torch.sum(mortality_label == 0.) == 0 or torch.sum(mortality_label == 1.) == 0:
		print("Warning: all examples in a batch belong to the same class -- please increase the batch size.")

	assert(not torch.isnan(label_predictions).any())
	assert(not torch.isnan(mortality_label).any())

	# For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
	mortality_label = mortality_label.repeat(n_traj_samples, 1)
	ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label)

	# divide by number of patients in a batch
	ce_loss = ce_loss / n_traj_samples
	return ce_loss
예제 #12
0
    def get_reconstruction(self,
                           time_steps_to_predict,
                           truth,
                           truth_time_steps,
                           mask=None,
                           n_traj_samples=1,
                           run_backwards=True,
                           mode='interp'):
        """

        :param time_steps_to_predict:
        :param truth:
        :param truth_time_steps:
        :param mask:
        :param n_traj_samples:
        :param run_backwards:
        :param mode: extrap or interp
        :return:
        """

        if isinstance(self.encoder_z0, Encoder_z0_ODE_RNN) or \
                isinstance(self.encoder_z0, Encoder_z0_RNN):

            truth_w_mask = truth
            if mask is not None:
                truth_w_mask = torch.cat((truth, mask), -1)
            first_point_mu, first_point_std = self.encoder_z0(
                truth_w_mask, truth_time_steps, run_backwards=run_backwards)

            means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1)
            sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1)
            first_point_enc = utils.sample_standard_gaussian(
                means_z0, sigma_z0)

        else:
            raise Exception("Unknown encoder type {}".format(
                type(self.encoder_z0).__name__))

        first_point_std = first_point_std.abs()
        assert (torch.sum(first_point_std < 0) == 0.)

        n_traj_samples, n_traj, n_dims = first_point_enc.size()
        if self.use_poisson_proc:
            # append a vector of zeros to compute the integral of lambda
            zeros = torch.zeros([n_traj_samples, n_traj,
                                 self.input_dim]).to(utils.get_device(truth))
            first_point_enc_aug = torch.cat((first_point_enc, zeros), -1)
            means_z0_aug = torch.cat((means_z0, zeros), -1)
        else:
            first_point_enc_aug = first_point_enc
            means_z0_aug = means_z0

        assert (not torch.isnan(time_steps_to_predict).any())
        assert (not torch.isnan(first_point_enc).any())
        assert (not torch.isnan(first_point_enc_aug).any())

        # Add log prob part of prior in initial state in posterior ODE
        cum_log_prob_prior_0 = torch.zeros([n_traj_samples, n_traj,
                                            1]).to(utils.get_device(truth))
        first_point_enc_aug = torch.cat(
            (first_point_enc_aug, cum_log_prob_prior_0), -1)

        # region make cubic spline interpolation
        truth_for_interpolate = truth.clone()
        if mask is not None:  # NaturalCubicSpline requires tagging nan for unknown positions in data series.
            truth_for_interpolate[mask == 0] = torch.tensor(float('nan')).to(
                utils.get_device(truth))

        coeffs = natural_cubic_spline_coeffs(truth_time_steps,
                                             truth_for_interpolate)
        interpolation = NaturalCubicSpline(truth_time_steps, coeffs)
        self.diffeq_solver.ode_func.splines_setup(interpolation)
        # endregion

        if mode == 'interp':
            sol_y_with_logprob = self.diffeq_solver(first_point_enc_aug,
                                                    time_steps_to_predict)
            sol_y = sol_y_with_logprob[..., :-1]

            # TODO :此处 * const_for_prior_logp是因为:从随机过程对应分布得到的概率密度数值不稳定,因此算ode的时候直接/const_for_prior_logp是因为了,算完之后要乘回来
            prior_log_prob = sol_y_with_logprob[
                ..., -1] * self.diffeq_solver.ode_func.const_for_prior_logp

        elif mode == 'extrap':
            sol_y_with_logprob = self.diffeq_solver(first_point_enc_aug,
                                                    truth_time_steps)
            sol_y = sol_y_with_logprob[..., :-1]
            # TODO : 利用GaussPrior以sol_y[:,:,-1]为起点,做time_steps_to_predict上的自回归预测:
            # TODO: 方案1: 可以直接从mean采样求解ODE,不过这样就不是采样了。
            # TODO: 方案2:局部线性化(一阶泰勒展开),在t时刻算t+dt的预测分布并采样。具体公式参考
            # Chen, F., Agüero, J. C., Gilson, M., Garnier, H., & Liu, T. (2017).
            # EM-based identification of continuous-time ARMA Models from irregularly sampled data.
            # Automatica, 77, 293–301. https://doi.org/10.1016/j.automatica.2016.11.020 中的Equ.(6-8)
            raise NotImplementedError

        if self.use_poisson_proc:
            sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate(
                sol_y)

            assert (torch.sum(int_lambda[:, :, 0, :]) == 0.)
            assert (torch.sum(int_lambda[0, 0, -1, :] <= 0) == 0.)

        pred_x = self.decoder(sol_y)

        all_extra_info = {
            "first_point": (first_point_mu, first_point_std, first_point_enc),
            "latent_traj": sol_y.detach(),
            'prior_log_prob': prior_log_prob
        }

        if self.use_poisson_proc:
            # intergral of lambda from the last step of ODE Solver
            all_extra_info["int_lambda"] = int_lambda[:, :, -1, :]
            all_extra_info["log_lambda_y"] = log_lambda_y

        if self.use_binary_classif:
            if self.classif_per_tp:
                all_extra_info["label_predictions"] = self.classifier(sol_y)
            else:
                all_extra_info["label_predictions"] = self.classifier(
                    first_point_enc).squeeze(-1)

        return pred_x, all_extra_info
예제 #13
0
    def compute_all_losses(self,
                           batch_dict,
                           n_tp_to_sample=None,
                           n_traj_samples=1,
                           kl_coef=1.):

        # Condition on subsampled points
        # Make predictions for all the points
        pred_x, info = self.get_reconstruction(
            batch_dict["tp_to_predict"],
            batch_dict["observed_data"],
            batch_dict["observed_tp"],
            mask=batch_dict["observed_mask"],
            n_traj_samples=n_traj_samples,
            mode=batch_dict["mode"])

        # Compute likelihood of all the points
        likelihood = self.get_gaussian_likelihood(
            batch_dict["data_to_predict"],
            pred_x,
            mask=batch_dict["mask_predicted_data"])

        mse = self.get_mse(batch_dict["data_to_predict"],
                           pred_x,
                           mask=batch_dict["mask_predicted_data"])

        ################################
        # Compute CE loss for binary classification on Physionet
        # Use only last attribute -- mortatility in the hospital
        device = utils.get_device(batch_dict["data_to_predict"])
        ce_loss = torch.Tensor([0.]).to(device)

        if (batch_dict["labels"] is not None) and self.use_binary_classif:
            if (batch_dict["labels"].size(-1) == 1) or (len(
                    batch_dict["labels"].size()) == 1):
                ce_loss = compute_binary_CE_loss(info["label_predictions"],
                                                 batch_dict["labels"])
            else:
                ce_loss = compute_multiclass_CE_loss(
                    info["label_predictions"],
                    batch_dict["labels"],
                    mask=batch_dict["mask_predicted_data"])

            if torch.isnan(ce_loss):
                print("label pred")
                print(info["label_predictions"])
                print("labels")
                print(batch_dict["labels"])
                raise Exception("CE loss is Nan!")

        pois_log_likelihood = torch.Tensor([0.]).to(
            utils.get_device(batch_dict["data_to_predict"]))
        if self.use_poisson_proc:
            pois_log_likelihood = compute_poisson_proc_likelihood(
                batch_dict["data_to_predict"],
                pred_x,
                info,
                mask=batch_dict["mask_predicted_data"])
            # Take mean over n_traj
            pois_log_likelihood = torch.mean(pois_log_likelihood, 1)

        loss = -torch.mean(likelihood)

        if self.use_poisson_proc:
            loss = loss - 0.1 * pois_log_likelihood

        if self.use_binary_classif:
            if self.train_classif_w_reconstr:
                loss = loss + ce_loss * 100
            else:
                loss = ce_loss

        # Take mean over the number of samples in a batch
        results = {}
        results["loss"] = torch.mean(loss)
        results["likelihood"] = torch.mean(likelihood).detach()
        results["mse"] = torch.mean(mse).detach()
        results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
        results["ce_loss"] = torch.mean(ce_loss).detach()
        results["kl"] = 0.
        results["kl_first_p"] = 0.
        results["std_first_p"] = 0.

        if batch_dict["labels"] is not None and self.use_binary_classif:
            results["label_predictions"] = info["label_predictions"].detach()
        return results
예제 #14
0
    def run_odernn(self,
                   data,
                   time_steps,
                   run_backwards=True,
                   save_info=False):
        # IMPORTANT: assumes that 'data' already has mask concatenated to it

        n_traj, n_tp, n_dims = data.size()
        extra_info = []

        t0 = time_steps[-1]
        if run_backwards:
            t0 = time_steps[0]

        device = get_device(data)

        prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device)
        prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device)

        prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1]

        interval_length = time_steps[-1] - time_steps[0]
        minimum_step = interval_length / 50

        # print("minimum step: {}".format(minimum_step))

        assert (not torch.isnan(data).any())
        assert (not torch.isnan(time_steps).any())

        latent_ys = []
        # Run ODE backwards and combine the y(t) estimates using gating
        time_points_iter = range(0, len(time_steps))
        if run_backwards:
            time_points_iter = reversed(time_points_iter)

        for i in time_points_iter:
            if (prev_t - t_i) < minimum_step:
                time_points = torch.stack((prev_t, t_i))
                inc = self.z0_diffeq_solver.ode_func(prev_t,
                                                     prev_y) * (t_i - prev_t)

                assert (not torch.isnan(inc).any())

                ode_sol = prev_y + inc
                ode_sol = torch.stack((prev_y, ode_sol), 2).to(device)

                assert (not torch.isnan(ode_sol).any())
            else:
                n_intermediate_tp = max(2,
                                        ((prev_t - t_i) / minimum_step).int())

                time_points = utils.linspace_vector(prev_t, t_i,
                                                    n_intermediate_tp)
                ode_sol = self.z0_diffeq_solver(prev_y, time_points)

                assert (not torch.isnan(ode_sol).any())

            if torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001:
                print(
                    "Error: first point of the ODE is not equal to initial value"
                )
                print(torch.mean(ode_sol[:, :, 0, :] - prev_y))
                exit()
            # assert(torch.mean(ode_sol[:, :, 0, :]  - prev_y) < 0.001)

            yi_ode = ode_sol[:, :, -1, :]
            xi = data[:, i, :].unsqueeze(0)

            yi, yi_std = self.GRU_update(yi_ode, prev_std, xi)

            prev_y, prev_std = yi, yi_std
            prev_t, t_i = time_steps[i], time_steps[i - 1]

            latent_ys.append(yi)

            if save_info:
                d = {
                    "yi_ode": yi_ode.detach(),  # "yi_from_data": yi_from_data,
                    "yi": yi.detach(),
                    "yi_std": yi_std.detach(),
                    "time_points": time_points.detach(),
                    "ode_sol": ode_sol.detach()
                }
                extra_info.append(d)

        latent_ys = torch.stack(latent_ys, 1)

        assert (not torch.isnan(yi).any())
        assert (not torch.isnan(yi_std).any())

        return yi, yi_std, latent_ys, extra_info
예제 #15
0
    def get_reconstruction(self,
                           time_steps_to_predict,
                           truth,
                           truth_time_steps,
                           mask=None,
                           n_traj_samples=1,
                           run_backwards=True,
                           mode=None):

        if isinstance(self.encoder_z0, Encoder_z0_ODE_RNN) or \
         isinstance(self.encoder_z0, Encoder_z0_RNN) or \
         isinstance(self.encoder_z0, TransformerLayer):

            truth_w_mask = truth
            if mask is not None:
                truth_w_mask = torch.cat((truth, mask), -1)
            first_point_mu, first_point_std = self.encoder_z0(
                truth_w_mask, truth_time_steps, run_backwards=run_backwards)

            means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1)
            sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1)
            first_point_enc = utils.sample_standard_gaussian(
                means_z0, sigma_z0)

        else:
            raise Exception("Unknown encoder type {}".format(
                type(self.encoder_z0).__name__))

        first_point_std = first_point_std.abs()
        assert (torch.sum(first_point_std < 0) == 0.)

        if self.use_poisson_proc:
            n_traj_samples, n_traj, n_dims = first_point_enc.size()
            # append a vector of zeros to compute the integral of lambda
            zeros = torch.zeros([n_traj_samples, n_traj,
                                 self.input_dim]).to(get_device(truth))
            first_point_enc_aug = torch.cat((first_point_enc, zeros), -1)
            means_z0_aug = torch.cat((means_z0, zeros), -1)
        else:
            first_point_enc_aug = first_point_enc
            means_z0_aug = means_z0

        assert (not torch.isnan(time_steps_to_predict).any())
        assert (not torch.isnan(first_point_enc).any())
        assert (not torch.isnan(first_point_enc_aug).any())

        # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents]
        sol_y = self.diffeq_solver(first_point_enc_aug, time_steps_to_predict)

        if self.use_poisson_proc:
            sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate(
                sol_y)

            assert (torch.sum(int_lambda[:, :, 0, :]) == 0.)
            assert (torch.sum(int_lambda[0, 0, -1, :] <= 0) == 0.)

        pred_x = self.decoder(sol_y)

        all_extra_info = {
            "first_point": (first_point_mu, first_point_std, first_point_enc),
            "latent_traj": sol_y.detach()
        }

        if self.use_poisson_proc:
            # intergral of lambda from the last step of ODE Solver
            all_extra_info["int_lambda"] = int_lambda[:, :, -1, :]
            all_extra_info["log_lambda_y"] = log_lambda_y

        if self.use_binary_classif:
            if self.classif_per_tp:
                all_extra_info["label_predictions"] = self.classifier(sol_y)
            else:
                all_extra_info["label_predictions"] = self.classifier(
                    first_point_enc).squeeze(-1)

        return pred_x, all_extra_info
예제 #16
0
    def draw_all_plots_one_dim(self,
                               data_dict,
                               model,
                               plot_name="",
                               save=False,
                               experimentID=0.):

        data = data_dict["data_to_predict"]
        time_steps = data_dict["tp_to_predict"]
        mask = data_dict["mask_predicted_data"]

        observed_data = data_dict["observed_data"]
        observed_time_steps = data_dict["observed_tp"]
        observed_mask = data_dict["observed_mask"]

        device = get_device(time_steps)

        time_steps_to_predict = time_steps
        if isinstance(model, LatentODE):
            # sample at the original time points
            time_steps_to_predict = utils.linspace_vector(
                time_steps[0], time_steps[-1], 100).to(device)

        reconstructions, info = model.get_reconstruction(time_steps_to_predict,
                                                         observed_data,
                                                         observed_time_steps,
                                                         mask=observed_mask,
                                                         n_traj_samples=10)

        n_traj_to_show = 3
        # plot only 10 trajectories
        data_for_plotting = observed_data[:n_traj_to_show]
        mask_for_plotting = observed_mask[:n_traj_to_show]
        reconstructions_for_plotting = reconstructions.mean(
            dim=0)[:n_traj_to_show]
        reconstr_std = reconstructions.std(dim=0)[:n_traj_to_show]

        dim_to_show = 0
        max_y = max(data_for_plotting[:, :, dim_to_show].cpu().numpy().max(),
                    reconstructions[:, :, dim_to_show].cpu().numpy().max())
        min_y = min(data_for_plotting[:, :, dim_to_show].cpu().numpy().min(),
                    reconstructions[:, :, dim_to_show].cpu().numpy().min())

        ############################################
        # Plot reconstructions, true postrior and approximate posterior

        cmap = plt.cm.get_cmap('Set1')
        for traj_id in range(3):
            # Plot observations
            plot_trajectories(
                self.ax_traj[traj_id],
                data_for_plotting[traj_id].unsqueeze(0),
                observed_time_steps,
                mask=mask_for_plotting[traj_id].unsqueeze(0),
                min_y=min_y,
                max_y=max_y,  #title="True trajectories", 
                marker='o',
                linestyle='',
                dim_to_show=dim_to_show,
                color=cmap(2))
            # Plot reconstructions
            plot_trajectories(
                self.ax_traj[traj_id],
                reconstructions_for_plotting[traj_id].unsqueeze(0),
                time_steps_to_predict,
                min_y=min_y,
                max_y=max_y,
                title="Sample {} (data space)".format(traj_id),
                dim_to_show=dim_to_show,
                add_to_plot=True,
                marker='',
                color=cmap(3),
                linewidth=3)
            # Plot variance estimated over multiple samples from approx posterior
            plot_std(self.ax_traj[traj_id],
                     reconstructions_for_plotting[traj_id].unsqueeze(0),
                     reconstr_std[traj_id].unsqueeze(0),
                     time_steps_to_predict,
                     alpha=0.5,
                     color=cmap(3))
            self.set_plot_lims(self.ax_traj[traj_id], "traj_" + str(traj_id))

            # Plot true posterior and approximate posterior
            # self.draw_one_density_plot(self.ax_density[traj_id],
            # 	model, data_dict, traj_id = traj_id,
            # 	multiply_by_poisson = False)
            # self.set_plot_lims(self.ax_density[traj_id], "density_" + str(traj_id))
            # self.ax_density[traj_id].set_title("Sample {}: p(z0) and q(z0 | x)".format(traj_id))
        ############################################
        # Get several samples for the same trajectory
        # one_traj = data_for_plotting[:1]
        # first_point = one_traj[:,0]

        # samples_same_traj, _ = model.get_reconstruction(time_steps_to_predict,
        # 	observed_data[:1], observed_time_steps, mask = observed_mask[:1], n_traj_samples = 5)
        # samples_same_traj = samples_same_traj.squeeze(1)

        # plot_trajectories(self.ax_samples_same_traj, samples_same_traj, time_steps_to_predict, marker = '')
        # plot_trajectories(self.ax_samples_same_traj, one_traj, time_steps, linestyle = "",
        # 	label = "True traj", add_to_plot = True, title="Reconstructions for the same trajectory (data space)")

        ############################################
        # Plot trajectories from prior

        if isinstance(model, LatentODE):
            torch.manual_seed(1991)
            np.random.seed(1991)

            traj_from_prior = model.sample_traj_from_prior(
                time_steps_to_predict, n_traj_samples=3)
            # Since in this case n_traj = 1, n_traj_samples -- requested number of samples from the prior, squeeze n_traj dimension
            traj_from_prior = traj_from_prior.squeeze(1)

            plot_trajectories(self.ax_traj_from_prior,
                              traj_from_prior,
                              time_steps_to_predict,
                              marker='',
                              linewidth=3)
            self.ax_traj_from_prior.set_title(
                "Samples from prior (data space)", pad=20)
            #self.set_plot_lims(self.ax_traj_from_prior, "traj_from_prior")
        ################################################

        # Plot z0
        # first_point_mu, first_point_std, first_point_enc = info["first_point"]

        # dim1 = 0
        # dim2 = 1
        # self.ax_z0.cla()
        # # first_point_enc shape: [1, n_traj, n_dims]
        # self.ax_z0.scatter(first_point_enc.cpu()[0,:,dim1], first_point_enc.cpu()[0,:,dim2])
        # self.ax_z0.set_title("Encodings z0 of all test trajectories (latent space)")
        # self.ax_z0.set_xlabel('dim {}'.format(dim1))
        # self.ax_z0.set_ylabel('dim {}'.format(dim2))

        ################################################
        # Show vector field
        self.ax_vector_field.cla()
        plot_vector_field(self.ax_vector_field, model.diffeq_solver.ode_func,
                          model.latent_dim, device)
        self.ax_vector_field.set_title("Slice of vector field (latent space)",
                                       pad=20)
        self.set_plot_lims(self.ax_vector_field, "vector_field")
        #self.ax_vector_field.set_ylim((-0.5, 1.5))

        ################################################
        # Plot trajectories in the latent space

        # shape before [1, n_traj, n_tp, n_latent_dims]
        # Take only the first sample from approx posterior
        latent_traj = info["latent_traj"][0, :n_traj_to_show]
        # shape before permute: [1, n_tp, n_latent_dims]

        self.ax_latent_traj.cla()
        cmap = plt.cm.get_cmap('Accent')
        n_latent_dims = latent_traj.size(-1)

        custom_labels = {}
        for i in range(n_latent_dims):
            col = cmap(i)
            plot_trajectories(self.ax_latent_traj,
                              latent_traj,
                              time_steps_to_predict,
                              title="Latent trajectories z(t) (latent space)",
                              dim_to_show=i,
                              color=col,
                              marker='',
                              add_to_plot=True,
                              linewidth=3)
            custom_labels['dim ' + str(i)] = Line2D([0], [0], color=col)

        self.ax_latent_traj.set_ylabel("z")
        self.ax_latent_traj.set_title(
            "Latent trajectories z(t) (latent space)", pad=20)
        self.ax_latent_traj.legend(custom_labels.values(),
                                   custom_labels.keys(),
                                   loc='lower left')
        self.set_plot_lims(self.ax_latent_traj, "latent_traj")

        ################################################

        self.fig.tight_layout()
        plt.draw()

        if save:
            dirname = "plots/" + str(experimentID) + "/"
            os.makedirs(dirname, exist_ok=True)
            self.fig.savefig(dirname + plot_name)
예제 #17
0
    def draw_one_density_plot(self,
                              ax,
                              model,
                              data_dict,
                              traj_id,
                              multiply_by_poisson=False):

        scale = 5
        cmap = add_white(plt.cm.get_cmap('Blues', 9))  # plt.cm.BuGn_r
        cmap2 = add_white(plt.cm.get_cmap('Reds', 9))  # plt.cm.BuGn_r
        #cmap = plt.cm.get_cmap('viridis')

        data = data_dict["data_to_predict"]
        time_steps = data_dict["tp_to_predict"]
        mask = data_dict["mask_predicted_data"]

        observed_data = data_dict["observed_data"]
        observed_time_steps = data_dict["observed_tp"]
        observed_mask = data_dict["observed_mask"]

        npts = 50
        xx, yy, z0_grid = get_meshgrid(npts=npts,
                                       int_y1=(-scale, scale),
                                       int_y2=(-scale, scale))
        z0_grid = z0_grid.to(get_device(data))

        if model.latent_dim > 2:
            z0_grid = torch.cat(
                (z0_grid, torch.zeros(z0_grid.size(0), model.latent_dim - 2)),
                1)

        if model.use_poisson_proc:
            n_traj, n_dims = z0_grid.size()
            # append a vector of zeros to compute the integral of lambda and also zeros for the first point of lambda
            zeros = torch.zeros([n_traj, model.input_dim + model.latent_dim
                                 ]).to(get_device(data))
            z0_grid_aug = torch.cat((z0_grid, zeros), -1)
        else:
            z0_grid_aug = z0_grid

        # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents]
        sol_y = model.diffeq_solver(z0_grid_aug.unsqueeze(0), time_steps)

        if model.use_poisson_proc:
            sol_y, log_lambda_y, int_lambda, _ = model.diffeq_solver.ode_func.extract_poisson_rate(
                sol_y)

            assert (torch.sum(int_lambda[:, :, 0, :]) == 0.)
            assert (torch.sum(int_lambda[0, 0, -1, :] <= 0) == 0.)

        pred_x = model.decoder(sol_y)

        # Plot density for one trajectory
        one_traj = data[traj_id]
        mask_one_traj = None
        if mask is not None:
            mask_one_traj = mask[traj_id].unsqueeze(0)
            mask_one_traj = mask_one_traj.repeat(npts**2, 1, 1).unsqueeze(0)

        ax.cla()

        # Plot: prior
        prior_density_grid = model.z0_prior.log_prob(
            z0_grid.unsqueeze(0)).squeeze(0)
        # Sum the density over two dimensions
        prior_density_grid = torch.sum(prior_density_grid, -1)

        # =================================================
        # Plot: p(x | y(t0))

        masked_gaussian_log_density_grid = masked_gaussian_log_density(
            pred_x,
            one_traj.repeat(npts**2, 1, 1).unsqueeze(0),
            mask=mask_one_traj,
            obsrv_std=model.obsrv_std).squeeze(-1)

        # Plot p(t | y(t0))
        if model.use_poisson_proc:
            poisson_info = {}
            poisson_info["int_lambda"] = int_lambda[:, :, -1, :]
            poisson_info["log_lambda_y"] = log_lambda_y

            poisson_log_density_grid = compute_poisson_proc_likelihood(
                one_traj.repeat(npts**2, 1, 1).unsqueeze(0),
                pred_x,
                poisson_info,
                mask=mask_one_traj)
            poisson_log_density_grid = poisson_log_density_grid.squeeze(0)

        # =================================================
        # Plot: p(x , y(t0))

        log_joint_density = prior_density_grid + masked_gaussian_log_density_grid
        if multiply_by_poisson:
            log_joint_density = log_joint_density + poisson_log_density_grid

        density_grid = torch.exp(log_joint_density)

        density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1]))
        density_grid = density_grid.cpu().numpy()

        ax.contourf(xx, yy, density_grid, cmap=cmap, alpha=1)

        # =================================================
        # Plot: q(y(t0)| x)
        #self.ax_density.set_title("Red: q(y(t0) | x)    Blue: p(x, y(t0))")
        ax.set_xlabel('z1(t0)')
        ax.set_ylabel('z2(t0)')

        data_w_mask = observed_data[traj_id].unsqueeze(0)
        if observed_mask is not None:
            data_w_mask = torch.cat(
                (data_w_mask, observed_mask[traj_id].unsqueeze(0)), -1)
        z0_mu, z0_std = model.encoder_z0(data_w_mask, observed_time_steps)

        if model.use_poisson_proc:
            z0_mu = z0_mu[:, :, :model.latent_dim]
            z0_std = z0_std[:, :, :model.latent_dim]

        q_z0 = Normal(z0_mu, z0_std)

        q_density_grid = q_z0.log_prob(z0_grid)
        # Sum the density over two dimensions
        q_density_grid = torch.sum(q_density_grid, -1)
        density_grid = torch.exp(q_density_grid)

        density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1]))
        density_grid = density_grid.cpu().numpy()

        ax.contourf(xx, yy, density_grid, cmap=cmap2, alpha=0.3)
예제 #18
0
def compute_multiclass_CE_loss(label_predictions, true_label, mask):
	#print("Computing multi-class classification loss: compute_multiclass_CE_loss")
	def CXE(predicted, target):
		focal = False
		if focal:
			return -(1-predicted)**2*(target * torch.log(predicted)).sum(dim=1).mean()
		else:
			return -(target * torch.log(predicted)).sum(dim=1).mean()
	
	n_tp = 1
	n_traj_samples = 1
	
	crop_set = False
	RNN = False
	
	if (len(label_predictions.size()) == 3):
		label_predictions = label_predictions.unsqueeze(0)
	
	if (len(true_label.size()) == 2) and (len(label_predictions.size()) == 2):
		n_traj, n_dims = true_label.size()
		RNN = True
		crop_set = True
		
	elif (len(true_label.size()) == 2):
		n_traj_samples, _, n_traj, n_dims = label_predictions.size()
		crop_set = True
	else:
		n_traj_samples, n_traj, n_tp, n_dims = label_predictions.size()
		crop_set = False

	# assert(not torch.isnan(label_predictions).any())
	# assert(not torch.isnan(true_label).any())

	# For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them
	true_label = true_label.repeat(n_traj_samples, 1, 1)

	label_predictions = label_predictions.reshape(n_traj_samples * n_traj * n_tp, n_dims)
	true_label = true_label.reshape(n_traj_samples * n_traj * n_tp, n_dims)

	# choose time points with at least one measurement
	mask = torch.sum(mask, -1) > 0

	# repeat the mask for each label to mark that the label for this time point is present
	if crop_set or RNN:
		mask[:,:] = True
		
		pred_mask = mask[:,0]
		label_mask = mask[:,0]
		
	else:
		pred_mask = mask.repeat(n_dims, 1,1).permute(1,2,0)

		label_mask = mask
		pred_mask = pred_mask.repeat(n_traj_samples,1,1,1)
		label_mask = label_mask.repeat(n_traj_samples,1,1,1)
	
		pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp,  n_dims)
		label_mask = label_mask.reshape(n_traj_samples * n_traj * n_tp, 1)

	if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1):
		assert(label_predictions.size(-1) == true_label.size(-1))
		# targets are in one-hot encoding -- convert to indices
		_, true_label_hard = true_label.max(-1)
		
		#Nando's comment: but what if i want soft labels for my cross entropy? use the labels_soft variable => solved!
		
	
	vectorized = True
	if not vectorized:

		res = []
		for i in range(true_label_hard.size(0)):
			pred_masked = torch.masked_select(label_predictions[i], pred_mask[i].bool()) #byte()
			labels_hard = torch.masked_select(true_label_hard[i], label_mask[i].bool()) #byte()
			labels_soft = torch.masked_select(true_label[i], label_mask[i].bool()) #byte()
			
			
			pred_masked = pred_masked.reshape(-1, n_dims)

			if (not crop_set):
				if (len(labels_hard) == 0):
					continue

			#ce_loss = nn.CrossEntropyLoss()(pred_masked, labels_hard.long())
			ce_loss = CXE(pred_masked, labels_soft)
			res.append(ce_loss)

		pdb.set_trace()
		
		ce_loss = torch.stack(res, 0).to(get_device(label_predictions))
		ce_loss = torch.mean(ce_loss)
		# # divide by number of patients in a batch
		# ce_loss = ce_loss / n_traj_samples

	else: #Nando's alternative:

		# use a very small number to avoid numerical problems
		eps = 1e-10
		focal=False
		if focal:
			ce_loss = -( (1-label_predictions)**0.5 * true_label * torch.log(label_predictions + eps)).sum(dim=1).mean()
		else: 
			ce_loss = -( true_label * torch.log(label_predictions + eps)).sum(dim=1).mean()

	return ce_loss
예제 #19
0
    def compute_all_losses(self, batch_dict, n_traj_samples=1, kl_coef=1.):
        # Condition on subsampled points
        # Make predictions for all the points
        pred_y, info = self.get_reconstruction(
            batch_dict["tp_to_predict"],
            batch_dict["observed_data"],
            batch_dict["observed_tp"],
            mask=batch_dict["observed_mask"],
            n_traj_samples=n_traj_samples,
            mode=batch_dict["mode"])

        #print("get_reconstruction done -- computing likelihood")
        fp_mu, fp_std, fp_enc = info["first_point"]
        fp_std = fp_std.abs()
        fp_distr = Normal(fp_mu, fp_std)

        assert (torch.sum(fp_std < 0) == 0.)

        kldiv_z0 = kl_divergence(fp_distr, self.z0_prior)

        if torch.isnan(kldiv_z0).any():
            print(fp_mu)
            print(fp_std)
            raise Exception("kldiv_z0 is Nan!")

        # Mean over number of latent dimensions
        # kldiv_z0 shape: [n_traj_samples, n_traj, n_latent_dims] if prior is a mixture of gaussians (KL is estimated)
        # kldiv_z0 shape: [1, n_traj, n_latent_dims] if prior is a standard gaussian (KL is computed exactly)
        # shape after: [n_traj_samples]
        kldiv_z0 = torch.mean(kldiv_z0, (1, 2))

        # Compute likelihood of all the points
        rec_likelihood = self.get_gaussian_likelihood(
            batch_dict["data_to_predict"],
            pred_y,
            mask=batch_dict["mask_predicted_data"])

        mse = self.get_mse(batch_dict["data_to_predict"],
                           pred_y,
                           mask=batch_dict["mask_predicted_data"])

        pois_log_likelihood = torch.Tensor([0.]).to(
            utils.get_device(batch_dict["data_to_predict"]))
        if self.use_poisson_proc:
            pois_log_likelihood = compute_poisson_proc_likelihood(
                batch_dict["data_to_predict"],
                pred_y,
                info,
                mask=batch_dict["mask_predicted_data"])
            # Take mean over n_traj
            pois_log_likelihood = torch.mean(pois_log_likelihood, 1)

        ################################
        # Compute CE loss for binary classification on Physionet
        device = utils.get_device(batch_dict["data_to_predict"])
        ce_loss = torch.Tensor([0.]).to(device)
        if (batch_dict["labels"] is not None) and self.use_binary_classif:

            if (batch_dict["labels"].size(-1) == 1) or (len(
                    batch_dict["labels"].size()) == 1):
                ce_loss = compute_binary_CE_loss(info["label_predictions"],
                                                 batch_dict["labels"])
            else:
                ce_loss = compute_multiclass_CE_loss(
                    info["label_predictions"],
                    batch_dict["labels"],
                    mask=batch_dict["mask_predicted_data"])

        # IWAE loss
        loss = -torch.logsumexp(rec_likelihood - kl_coef * kldiv_z0, 0)
        if torch.isnan(loss):
            loss = -torch.mean(rec_likelihood - kl_coef * kldiv_z0, 0)

        if self.use_poisson_proc:
            loss = loss - 0.1 * pois_log_likelihood

        if self.use_binary_classif:
            if self.train_classif_w_reconstr:
                loss = loss + ce_loss * 100
            else:
                loss = ce_loss

        results = {}
        results["loss"] = torch.mean(loss)
        results["likelihood"] = torch.mean(rec_likelihood).detach()
        results["mse"] = torch.mean(mse).detach()
        results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
        results["ce_loss"] = torch.mean(ce_loss).detach()
        results["kl_first_p"] = torch.mean(kldiv_z0).detach()
        results["std_first_p"] = torch.mean(fp_std).detach()

        if batch_dict["labels"] is not None and self.use_binary_classif:
            results["label_predictions"] = info["label_predictions"].detach()

        return results
예제 #20
0
    def run_odernn(self,
                   data,
                   time_steps,
                   run_backwards=True,
                   save_info=False,
                   testing=False):
        # IMPORTANT: assumes that 'data' already has mask concatenated to it

        n_traj, n_tp, n_dims = data.size()
        extra_info = []
        save_latents = 10 if testing else 0

        device = get_device(data)

        # Initialize the hidden state with noise
        if self.RNNcell == 'lstm':
            # make some noise
            prev_h = torch.zeros(
                (1, n_traj,
                 self.latent_dim // 2)).data.normal_(0, 0.0001).to(device)
            prev_h_std = torch.zeros(
                (1, n_traj,
                 self.latent_dim // 2)).data.normal_(0, 0.0001).to(device)

            ci = torch.zeros(
                (1, n_traj,
                 self.latent_dim // 2)).data.normal_(0, 0.0001).to(device)
            ci_std = torch.zeros(
                (1, n_traj,
                 self.latent_dim // 2)).data.normal_(0, 0.0001).to(device)

            #concatinate cell state and hidden state
            prev_y = torch.cat([prev_h, ci], -1)
            prev_std = torch.cat([prev_h_std, ci_std], -1)
        else:
            # make some noise
            prev_y = torch.zeros(
                (1, n_traj, self.latent_dim)).data.normal_(0,
                                                           0.0001).to(device)
            prev_std = torch.zeros(
                (1, n_traj, self.latent_dim)).data.normal_(0,
                                                           0.0001).to(device)

        #prev_t, t_i = time_steps[-1] + 0.01,  time_steps[-1] # original
        #prev_t = time_steps[0] - 0.00001 # new2
        t_i = time_steps[0] - 0.00001  # new

        interval_length = time_steps[-1] - time_steps[0]
        minimum_step = interval_length / 200  # maybe have to modify minimum time step # original
        #minimum_step = interval_length / 100 # maybe have to modify minimum time step # new

        #print("minimum step: {}".format(minimum_step))

        assert (not torch.isnan(data).any())
        assert (not torch.isnan(time_steps).any())

        latent_ys = []
        firststep = True

        # Run ODE backwards and combine the y(t) estimates using gating
        time_points_iter = range(0, len(time_steps))
        if run_backwards:
            time_points_iter = reversed(time_points_iter)

        # Get positional encoding
        # position_encodings = utils.get_sinusoid_encoding_table(time_steps*2*math.pi, d_hid=2) # note: 2*pi is already included? [0,1] is enough
        position_encodings = utils.get_sinusoid_encoding_table(
            time_steps.cpu().numpy(), d_hid=2)

        ### Check if experimental is on/off ###
        experimental = False

        for i in time_points_iter:

            # move time step to the next interval

            #t_i = time_steps[i]							# new2
            prev_t = time_steps[i]  # new

            n_intermediate_tp = self.n_intermediate_tp  # get steps in between, minimum is 2
            if save_latents != 0:
                n_intermediate_tp = max(
                    2, ((prev_t - t_i) / minimum_step
                        ).int())  # get more steps in between for testing

            time_points = utils.linspace_vector(prev_t, t_i, n_intermediate_tp)

            if experimental:
                time_points = time_points.flip(0)

            #Include inplementationin case of no ODE function
            if self.use_ODE:

                if abs(prev_t - t_i) < minimum_step:
                    #short integration, linear approximation with the gradient
                    time_points = torch.stack((prev_t, t_i))
                    inc = self.z0_diffeq_solver.ode_func(
                        prev_t, prev_y) * (t_i - prev_t)

                    assert (not torch.isnan(inc).any())

                    ode_sol = prev_y + inc
                    ode_sol = torch.stack((prev_y, ode_sol), 2).to(device)

                    assert (not torch.isnan(ode_sol).any())

                else:
                    #complete Integration using differential equation solver
                    ode_sol = self.z0_diffeq_solver(prev_y, time_points)

                    assert (not torch.isnan(ode_sol).any())

                if torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001:
                    print(
                        "Error: first point of the ODE is not equal to initial value"
                    )
                    print(torch.mean(ode_sol[:, :, 0, :] - prev_y))
                    exit()

                yi_ode = ode_sol[:, :, -1, :]

                xi = data[:, i, :].unsqueeze(0)

            else:

                # skipping ODE function and assign directly
                yi_ode = prev_y
                time_points = time_points[-1]

                # extract the mask for the current (single) time step
                single_mask = data[:, i, self.input_dim // 2]
                delta_ts = (prev_t - t_i).repeat(1, n_traj, 1).float()
                delta_ts[:, ~single_mask.bool(), :] = 0

                if self.nornnimputation:
                    delta_ts[:, :, :] = 0

                features = data[:, i, :self.input_dim // 2].unsqueeze(0)

                if not self.use_pos_encod:
                    new_mask = single_mask.unsqueeze(0).unsqueeze(2).repeat(
                        1, 1, self.input_dim // 2 + 1)
                    xi = torch.cat([features, delta_ts, new_mask], -1)
                else:
                    pos_encod = position_encodings[i].repeat(1, n_traj,
                                                             1).float()
                    pos_encod[:, ~single_mask.bool(), :] = 0
                    new_mask = single_mask.unsqueeze(0).unsqueeze(2).repeat(
                        1, 1, self.input_dim // 2 + 2)
                    xi = torch.cat([features, pos_encod, new_mask], -1)
                #creating new data including delta ts plus mask, concaninate the delta t for pure RNN

            if self.RNNcell == 'lstm':
                # In case of LSTM update, we have to take special care of the variables for the hidden and cell state
                h_i_ode = yi_ode[:, :, :self.latent_dim // 2]
                c_i_ode = yi_ode[:, :, self.latent_dim // 2:]
                h_c_lstm = (h_i_ode, c_i_ode)

                # actually this is a LSTM update here:
                outi, yi_std = self.RNN_update(h_c_lstm, prev_std, xi)
                # the RNN cell is a LSTM and outi:=(yi,ci), we only need h as latent dim
                h_i_, c_i_ = outi[0], outi[1]
                yi = torch.cat([h_i_, c_i_], -1)
                yi_out = h_i_

                if not self.use_ODE:
                    ode_sol = yi_out.unsqueeze(2)
                    time_points = time_points.unsqueeze(0)
            else:
                # GRU-unit or any other RNN cell: the output is directly the hidden state
                yi_ode, prev_std = self.RNN_update(yi_ode, prev_std, xi)
                yi, yi_std = yi_ode, prev_std
                yi_out = yi

                if not self.use_ODE:
                    ode_sol = yi_ode.unsqueeze(2)
                    time_points = time_points.unsqueeze(0)

            prev_y, prev_std = yi, yi_std
            #prev_t, t_i = time_steps[i],  time_steps[i-1]	# original
            #prev_t = time_steps[i]								# new2
            t_i = time_steps[i]  # new

            latent_ys.append(yi_out)
            if save_info or save_latents:
                if self.use_ODE:
                    #ODE-RNN case
                    ODE_flags = (xi[:, :, self.latent_dim:].sum(
                        (0, 2)) == 0).cpu().detach().int().numpy(
                        )  # zero: RNN-update, one: ODE-update
                    marker = np.ones((n_traj, n_intermediate_tp))
                    marker[:, -1] = ODE_flags

                    if not firststep:
                        #marker[:,0] = old_ODE_flags
                        pass
                    else:
                        firststep = False
                    old_ODE_flags = ODE_flags
                else:
                    #RNN case
                    marker = (
                        (xi[:, :, (self.latent_dim + 1):].sum(
                            (0, 2)) == 0).cpu().detach().int().numpy() * 2
                    )[:, np.newaxis]  # zero: RNN-update, two: No update at all

                d = {
                    "yi_ode": yi_ode[:, :save_latents].cpu().detach(
                    ),  #"yi_from_data": yi_from_data,
                    "yi":
                    yi_out[:, :save_latents].cpu().detach()[:, :save_latents],
                    "yi_std": yi_std[:, :save_latents].cpu().detach(),
                    "time_points": time_points.cpu().detach().double(),
                    "ode_sol":
                    ode_sol[:, :save_latents].cpu().detach().double(),
                    "marker": marker[:save_latents]
                }
            """
			if save_info or testing:
				d = {"yi_ode": yi_ode.detach()[:,:20], #"yi_from_data": yi_from_data,
					 "yi": yi_out.detach()[:,:20], "yi_std": yi_std.detach()[:,:20], 
					 "time_points": time_points.detach(),
					 "ode_sol": ode_sol.detach()[:,:20]
				}
				extra_info.append(d)
			"""

        latent_ys = torch.stack(latent_ys, 1)

        #BatchNormalization for the outputs
        if self.use_BN:

            # only apply BN to the RNN converted outputs (observed times), the one that are further used...
            # Experimental: for selective BN of outputs
            fancy_BN = False
            if fancy_BN:
                # not faster due to non-contigious data
                obs_mask = data[:, :, self.input_dim // 2].permute(1, 0)
                latent_ys[:, obs_mask.bool()] = self.output_bn(
                    latent_ys[:, obs_mask.bool()].permute(0, 2,
                                                          1)).permute(0, 2, 1)
            else:
                latent_ys = self.output_bn(latent_ys.squeeze().permute(
                    0, 2, 1)).permute(0, 2, 1).unsqueeze(0)  #orig

        assert (not torch.isnan(yi).any())
        assert (not torch.isnan(yi_std).any())

        return yi, yi_std, latent_ys, extra_info
def run_rnn(inputs,
            delta_ts,
            cell,
            first_hidden=None,
            mask=None,
            feed_previous=False,
            n_steps=0,
            decoder=None,
            input_decay_params=None,
            feed_previous_w_prob=0.,
            masked_update=True):
    if (feed_previous or feed_previous_w_prob) and decoder is None:
        raise Exception(
            "feed_previous is set to True -- please specify RNN decoder")

    if n_steps == 0:
        n_steps = inputs.size(1)

    if (feed_previous or feed_previous_w_prob) and mask is None:
        mask = torch.ones(
            (inputs.size(0), n_steps, inputs.size(-1))).to(get_device(inputs))

    if isinstance(cell, GRUCellExpDecay):
        cum_delta_ts = get_cum_delta_ts(inputs, delta_ts, mask)

    if input_decay_params is not None:
        w_input_decay, b_input_decay = input_decay_params
        inputs = impute_using_input_decay(inputs, delta_ts, mask,
                                          w_input_decay, b_input_decay)

    all_hiddens = []
    hidden = first_hidden

    if hidden is not None:
        all_hiddens.append(hidden)
        n_steps -= 1

    for i in range(n_steps):
        delta_t = delta_ts[:, i]
        if i == 0:
            rnn_input = inputs[:, i]
        elif feed_previous:
            rnn_input = decoder(hidden)
        elif feed_previous_w_prob > 0:
            feed_prev = np.random.uniform() > feed_previous_w_prob
            if feed_prev:
                rnn_input = decoder(hidden)
            else:
                rnn_input = inputs[:, i]
        else:
            rnn_input = inputs[:, i]

        if mask is not None:
            mask_i = mask[:, i, :]
            rnn_input = torch.cat((rnn_input, mask_i), -1)

        if isinstance(cell, GRUCellExpDecay):
            cum_delta_t = cum_delta_ts[:, i]
            input_w_t = torch.cat((rnn_input, cum_delta_t), -1).squeeze(1)
        else:
            input_w_t = torch.cat((rnn_input, delta_t), -1).squeeze(1)

        prev_hidden = hidden
        hidden = cell(input_w_t, hidden)

        if masked_update and (mask is not None) and (prev_hidden is not None):
            # update only the hidden states for hidden state only if at least one feature is present for the current time point
            summed_mask = (torch.sum(mask_i, -1, keepdim=True) > 0).float()
            assert (not torch.isnan(summed_mask).any())
            hidden = summed_mask * hidden + (1 - summed_mask) * prev_hidden

        all_hiddens.append(hidden)

    all_hiddens = torch.stack(all_hiddens, 0)
    all_hiddens = all_hiddens.permute(1, 0, 2).unsqueeze(0)
    return hidden, all_hiddens