def impute_using_input_decay(data, delta_ts, mask, w_input_decay, b_input_decay): n_traj, n_tp, n_dims = data.size() cum_delta_ts = delta_ts.repeat(1, 1, n_dims) missing_index = np.where(mask.cpu().numpy() == 0) data_last_obsv = np.copy(data.cpu().numpy()) for idx in range(missing_index[0].shape[0]): i = missing_index[0][idx] j = missing_index[1][idx] k = missing_index[2][idx] if j != 0 and j != (n_tp - 1): cum_delta_ts[i, j + 1, k] = cum_delta_ts[i, j + 1, k] + cum_delta_ts[i, j, k] if j != 0: data_last_obsv[i, j, k] = data_last_obsv[i, j - 1, k] # last observation cum_delta_ts = cum_delta_ts / cum_delta_ts.max() # normalize data_last_obsv = torch.Tensor(data_last_obsv).to(get_device(data)) zeros = torch.zeros([n_traj, n_tp, n_dims]).to(get_device(data)) decay = torch.exp(-torch.min( torch.max(zeros, w_input_decay * cum_delta_ts + b_input_decay), zeros + 1000)) data_means = torch.mean(data, 1).unsqueeze(1) data_imputed = data * mask + (1 - mask) * (decay * data_last_obsv + (1 - decay) * data_means) return data_imputed
def compute_binary_CE_loss(label_predictions, mortality_label): #print("Computing binary classification loss: compute_CE_loss") mortality_label = mortality_label.reshape(-1) if len(label_predictions.size()) == 1: label_predictions = label_predictions.unsqueeze(0) n_traj_samples = label_predictions.size(0) label_predictions = label_predictions.reshape(n_traj_samples, -1) idx_not_nan = 1 - torch.isnan(mortality_label) if len(idx_not_nan) == 0.: print("All are labels are NaNs!") ce_loss = torch.Tensor(0.).to(get_device(mortality_label)) label_predictions = label_predictions[:, idx_not_nan] mortality_label = mortality_label[idx_not_nan] if torch.sum(mortality_label == 0.) == 0 or torch.sum( mortality_label == 1.) == 0: print( "Warning: all examples in a batch belong to the same class -- please increase the batch size." ) assert (not torch.isnan(label_predictions).any()) assert (not torch.isnan(mortality_label).any()) # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them mortality_label = mortality_label.repeat(n_traj_samples, 1) ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label) # divide by number of patients in a batch ce_loss = ce_loss / n_traj_samples return ce_loss
def compute_masked_likelihood(mu, data, mask, likelihood_func): # Compute the likelihood per patient and per attribute so that we don't priorize patients with more measurements n_traj_samples, n_traj, n_timepoints, n_dims = data.size() res = [] for i in range(n_traj_samples): for k in range(n_traj): for j in range(n_dims): data_masked = torch.masked_select(data[i, k, :, j], mask[i, k, :, j].byte()) #assert(torch.sum(data_masked == 0.) < 10) mu_masked = torch.masked_select(mu[i, k, :, j], mask[i, k, :, j].byte()) log_prob = likelihood_func(mu_masked, data_masked, indices=(i, k, j)) res.append(log_prob) # shape: [n_traj*n_traj_samples, 1] res = torch.stack(res, 0).to(get_device(data)) res = res.reshape((n_traj_samples, n_traj, n_dims)) # Take mean over the number of dimensions res = torch.mean(res, -1) # !!!!!!!!!!! changed from sum to mean res = res.transpose(0, 1) return res
def mse(mu, data, indices=None): n_data_points = mu.size()[-1] if n_data_points > 0: mse = nn.MSELoss()(mu, data) else: mse = torch.zeros([1]).to(get_device(data)).squeeze() return mse
def compute_multiclass_CE_loss(label_predictions, true_label, mask): #print("Computing multi-class classification loss: compute_multiclass_CE_loss") if (len(label_predictions.size()) == 3): label_predictions = label_predictions.unsqueeze(0) n_traj_samples, n_traj, n_tp, n_dims = label_predictions.size() # assert(not torch.isnan(label_predictions).any()) # assert(not torch.isnan(true_label).any()) # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them true_label = true_label.repeat(n_traj_samples, 1, 1) label_predictions = label_predictions.reshape( n_traj_samples * n_traj * n_tp, n_dims) true_label = true_label.reshape(n_traj_samples * n_traj * n_tp, n_dims) # choose time points with at least one measurement mask = torch.sum(mask, -1) > 0 # repeat the mask for each label to mark that the label for this time point is present pred_mask = mask.repeat(n_dims, 1, 1).permute(1, 2, 0) label_mask = mask pred_mask = pred_mask.repeat(n_traj_samples, 1, 1, 1) label_mask = label_mask.repeat(n_traj_samples, 1, 1, 1) pred_mask = pred_mask.reshape(n_traj_samples * n_traj * n_tp, n_dims) label_mask = label_mask.reshape(n_traj_samples * n_traj * n_tp, 1) if (label_predictions.size(-1) > 1) and (true_label.size(-1) > 1): assert (label_predictions.size(-1) == true_label.size(-1)) # targets are in one-hot encoding -- convert to indices _, true_label = true_label.max(-1) res = [] for i in range(true_label.size(0)): pred_masked = torch.masked_select(label_predictions[i], pred_mask[i].byte()) labels = torch.masked_select(true_label[i], label_mask[i].byte()) pred_masked = pred_masked.reshape(-1, n_dims) if (len(labels) == 0): continue ce_loss = nn.CrossEntropyLoss()(pred_masked, labels.long()) res.append(ce_loss) ce_loss = torch.stack(res, 0).to(get_device(label_predictions)) ce_loss = torch.mean(ce_loss) # # divide by number of patients in a batch # ce_loss = ce_loss / n_traj_samples return ce_loss
def poisson_log_likelihood(masked_log_lambdas, masked_data, indices, int_lambdas): # masked_log_lambdas and masked_data n_data_points = masked_data.size()[-1] if n_data_points > 0: log_prob = torch.sum(masked_log_lambdas) - int_lambdas[indices] #log_prob = log_prob / n_data_points else: log_prob = torch.zeros([1]).to(get_device(masked_data)).squeeze() return log_prob
def gaussian_log_likelihood(mu_2d, data_2d, obsrv_std, indices=None): n_data_points = mu_2d.size()[-1] if n_data_points > 0: gaussian = Independent( Normal(loc=mu_2d, scale=obsrv_std.repeat(n_data_points)), 1) log_prob = gaussian.log_prob(data_2d) log_prob = log_prob / n_data_points else: log_prob = torch.zeros([1]).to(get_device(data_2d)).squeeze() return log_prob
def draw_all_plots_one_dim(self, data_dict, model, plot_name="", save=False, experimentID=0.): data = data_dict["data_to_predict"] time_steps = data_dict["tp_to_predict"] mask = data_dict["mask_predicted_data"] observed_data = data_dict["observed_data"] observed_time_steps = data_dict["observed_tp"] observed_mask = data_dict["observed_mask"] device = get_device(time_steps) time_steps_to_predict = time_steps if isinstance(model, LatentODE): # sample at the original time points time_steps_to_predict = utils.linspace_vector( time_steps[0], time_steps[-1], 100).to(device) reconstructions, info = model.get_reconstruction(time_steps_to_predict, observed_data, observed_time_steps, mask=observed_mask, n_traj_samples=10) n_traj_to_show = 3 # plot only 10 trajectories data_for_plotting = observed_data[:n_traj_to_show] mask_for_plotting = observed_mask[:n_traj_to_show] reconstructions_for_plotting = reconstructions.mean( dim=0)[:n_traj_to_show] reconstr_std = reconstructions.std(dim=0)[:n_traj_to_show] dim_to_show = 0 max_y = max(data_for_plotting[:, :, dim_to_show].cpu().numpy().max(), reconstructions[:, :, dim_to_show].cpu().numpy().max()) min_y = min(data_for_plotting[:, :, dim_to_show].cpu().numpy().min(), reconstructions[:, :, dim_to_show].cpu().numpy().min()) ############################################ # Plot reconstructions, true postrior and approximate posterior cmap = plt.cm.get_cmap('Set1') for traj_id in range(3): # Plot observations plot_trajectories( self.ax_traj[traj_id], data_for_plotting[traj_id].unsqueeze(0), observed_time_steps, mask=mask_for_plotting[traj_id].unsqueeze(0), min_y=min_y, max_y=max_y, #title="True trajectories", marker='o', linestyle='', dim_to_show=dim_to_show, color=cmap(2)) # Plot reconstructions plot_trajectories( self.ax_traj[traj_id], reconstructions_for_plotting[traj_id].unsqueeze(0), time_steps_to_predict, min_y=min_y, max_y=max_y, title="Sample {} (data space)".format(traj_id), dim_to_show=dim_to_show, add_to_plot=True, marker='', color=cmap(3), linewidth=3) # Plot variance estimated over multiple samples from approx posterior plot_std(self.ax_traj[traj_id], reconstructions_for_plotting[traj_id].unsqueeze(0), reconstr_std[traj_id].unsqueeze(0), time_steps_to_predict, alpha=0.5, color=cmap(3)) self.set_plot_lims(self.ax_traj[traj_id], "traj_" + str(traj_id)) # Plot true posterior and approximate posterior # self.draw_one_density_plot(self.ax_density[traj_id], # model, data_dict, traj_id = traj_id, # multiply_by_poisson = False) # self.set_plot_lims(self.ax_density[traj_id], "density_" + str(traj_id)) # self.ax_density[traj_id].set_title("Sample {}: p(z0) and q(z0 | x)".format(traj_id)) ############################################ # Get several samples for the same trajectory # one_traj = data_for_plotting[:1] # first_point = one_traj[:,0] # samples_same_traj, _ = model.get_reconstruction(time_steps_to_predict, # observed_data[:1], observed_time_steps, mask = observed_mask[:1], n_traj_samples = 5) # samples_same_traj = samples_same_traj.squeeze(1) # plot_trajectories(self.ax_samples_same_traj, samples_same_traj, time_steps_to_predict, marker = '') # plot_trajectories(self.ax_samples_same_traj, one_traj, time_steps, linestyle = "", # label = "True traj", add_to_plot = True, title="Reconstructions for the same trajectory (data space)") ############################################ # Plot trajectories from prior if isinstance(model, LatentODE): torch.manual_seed(1991) np.random.seed(1991) traj_from_prior = model.sample_traj_from_prior( time_steps_to_predict, n_traj_samples=3) # Since in this case n_traj = 1, n_traj_samples -- requested number of samples from the prior, squeeze n_traj dimension traj_from_prior = traj_from_prior.squeeze(1) plot_trajectories(self.ax_traj_from_prior, traj_from_prior, time_steps_to_predict, marker='', linewidth=3) self.ax_traj_from_prior.set_title( "Samples from prior (data space)", pad=20) #self.set_plot_lims(self.ax_traj_from_prior, "traj_from_prior") ################################################ # Plot z0 # first_point_mu, first_point_std, first_point_enc = info["first_point"] # dim1 = 0 # dim2 = 1 # self.ax_z0.cla() # # first_point_enc shape: [1, n_traj, n_dims] # self.ax_z0.scatter(first_point_enc.cpu()[0,:,dim1], first_point_enc.cpu()[0,:,dim2]) # self.ax_z0.set_title("Encodings z0 of all test trajectories (latent space)") # self.ax_z0.set_xlabel('dim {}'.format(dim1)) # self.ax_z0.set_ylabel('dim {}'.format(dim2)) ################################################ # Show vector field self.ax_vector_field.cla() plot_vector_field(self.ax_vector_field, model.diffeq_solver.ode_func, model.latent_dim, device) self.ax_vector_field.set_title("Slice of vector field (latent space)", pad=20) self.set_plot_lims(self.ax_vector_field, "vector_field") #self.ax_vector_field.set_ylim((-0.5, 1.5)) ################################################ # Plot trajectories in the latent space # shape before [1, n_traj, n_tp, n_latent_dims] # Take only the first sample from approx posterior latent_traj = info["latent_traj"][0, :n_traj_to_show] # shape before permute: [1, n_tp, n_latent_dims] self.ax_latent_traj.cla() cmap = plt.cm.get_cmap('Accent') n_latent_dims = latent_traj.size(-1) custom_labels = {} for i in range(n_latent_dims): col = cmap(i) plot_trajectories(self.ax_latent_traj, latent_traj, time_steps_to_predict, title="Latent trajectories z(t) (latent space)", dim_to_show=i, color=col, marker='', add_to_plot=True, linewidth=3) custom_labels['dim ' + str(i)] = Line2D([0], [0], color=col) self.ax_latent_traj.set_ylabel("z") self.ax_latent_traj.set_title( "Latent trajectories z(t) (latent space)", pad=20) self.ax_latent_traj.legend(custom_labels.values(), custom_labels.keys(), loc='lower left') self.set_plot_lims(self.ax_latent_traj, "latent_traj") ################################################ self.fig.tight_layout() plt.draw() if save: dirname = "plots/" + str(experimentID) + "/" os.makedirs(dirname, exist_ok=True) self.fig.savefig(dirname + plot_name)
def draw_one_density_plot(self, ax, model, data_dict, traj_id, multiply_by_poisson=False): scale = 5 cmap = add_white(plt.cm.get_cmap('Blues', 9)) # plt.cm.BuGn_r cmap2 = add_white(plt.cm.get_cmap('Reds', 9)) # plt.cm.BuGn_r #cmap = plt.cm.get_cmap('viridis') data = data_dict["data_to_predict"] time_steps = data_dict["tp_to_predict"] mask = data_dict["mask_predicted_data"] observed_data = data_dict["observed_data"] observed_time_steps = data_dict["observed_tp"] observed_mask = data_dict["observed_mask"] npts = 50 xx, yy, z0_grid = get_meshgrid(npts=npts, int_y1=(-scale, scale), int_y2=(-scale, scale)) z0_grid = z0_grid.to(get_device(data)) if model.latent_dim > 2: z0_grid = torch.cat( (z0_grid, torch.zeros(z0_grid.size(0), model.latent_dim - 2)), 1) if model.use_poisson_proc: n_traj, n_dims = z0_grid.size() # append a vector of zeros to compute the integral of lambda and also zeros for the first point of lambda zeros = torch.zeros([n_traj, model.input_dim + model.latent_dim ]).to(get_device(data)) z0_grid_aug = torch.cat((z0_grid, zeros), -1) else: z0_grid_aug = z0_grid # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents] sol_y = model.diffeq_solver(z0_grid_aug.unsqueeze(0), time_steps) if model.use_poisson_proc: sol_y, log_lambda_y, int_lambda, _ = model.diffeq_solver.ode_func.extract_poisson_rate( sol_y) assert (torch.sum(int_lambda[:, :, 0, :]) == 0.) assert (torch.sum(int_lambda[0, 0, -1, :] <= 0) == 0.) pred_x = model.decoder(sol_y) # Plot density for one trajectory one_traj = data[traj_id] mask_one_traj = None if mask is not None: mask_one_traj = mask[traj_id].unsqueeze(0) mask_one_traj = mask_one_traj.repeat(npts**2, 1, 1).unsqueeze(0) ax.cla() # Plot: prior prior_density_grid = model.z0_prior.log_prob( z0_grid.unsqueeze(0)).squeeze(0) # Sum the density over two dimensions prior_density_grid = torch.sum(prior_density_grid, -1) # ================================================= # Plot: p(x | y(t0)) masked_gaussian_log_density_grid = masked_gaussian_log_density( pred_x, one_traj.repeat(npts**2, 1, 1).unsqueeze(0), mask=mask_one_traj, obsrv_std=model.obsrv_std).squeeze(-1) # Plot p(t | y(t0)) if model.use_poisson_proc: poisson_info = {} poisson_info["int_lambda"] = int_lambda[:, :, -1, :] poisson_info["log_lambda_y"] = log_lambda_y poisson_log_density_grid = compute_poisson_proc_likelihood( one_traj.repeat(npts**2, 1, 1).unsqueeze(0), pred_x, poisson_info, mask=mask_one_traj) poisson_log_density_grid = poisson_log_density_grid.squeeze(0) # ================================================= # Plot: p(x , y(t0)) log_joint_density = prior_density_grid + masked_gaussian_log_density_grid if multiply_by_poisson: log_joint_density = log_joint_density + poisson_log_density_grid density_grid = torch.exp(log_joint_density) density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1])) density_grid = density_grid.cpu().numpy() ax.contourf(xx, yy, density_grid, cmap=cmap, alpha=1) # ================================================= # Plot: q(y(t0)| x) #self.ax_density.set_title("Red: q(y(t0) | x) Blue: p(x, y(t0))") ax.set_xlabel('z1(t0)') ax.set_ylabel('z2(t0)') data_w_mask = observed_data[traj_id].unsqueeze(0) if observed_mask is not None: data_w_mask = torch.cat( (data_w_mask, observed_mask[traj_id].unsqueeze(0)), -1) z0_mu, z0_std = model.encoder_z0(data_w_mask, observed_time_steps) if model.use_poisson_proc: z0_mu = z0_mu[:, :, :model.latent_dim] z0_std = z0_std[:, :, :model.latent_dim] q_z0 = Normal(z0_mu, z0_std) q_density_grid = q_z0.log_prob(z0_grid) # Sum the density over two dimensions q_density_grid = torch.sum(q_density_grid, -1) density_grid = torch.exp(q_density_grid) density_grid = torch.reshape(density_grid, (xx.shape[0], xx.shape[1])) density_grid = density_grid.cpu().numpy() ax.contourf(xx, yy, density_grid, cmap=cmap2, alpha=0.3)
def run_rnn(inputs, delta_ts, cell, first_hidden=None, mask=None, feed_previous=False, n_steps=0, decoder=None, input_decay_params=None, feed_previous_w_prob=0., masked_update=True): if (feed_previous or feed_previous_w_prob) and decoder is None: raise Exception( "feed_previous is set to True -- please specify RNN decoder") if n_steps == 0: n_steps = inputs.size(1) if (feed_previous or feed_previous_w_prob) and mask is None: mask = torch.ones( (inputs.size(0), n_steps, inputs.size(-1))).to(get_device(inputs)) if isinstance(cell, GRUCellExpDecay): cum_delta_ts = get_cum_delta_ts(inputs, delta_ts, mask) if input_decay_params is not None: w_input_decay, b_input_decay = input_decay_params inputs = impute_using_input_decay(inputs, delta_ts, mask, w_input_decay, b_input_decay) all_hiddens = [] hidden = first_hidden if hidden is not None: all_hiddens.append(hidden) n_steps -= 1 for i in range(n_steps): delta_t = delta_ts[:, i] if i == 0: rnn_input = inputs[:, i] elif feed_previous: rnn_input = decoder(hidden) elif feed_previous_w_prob > 0: feed_prev = np.random.uniform() > feed_previous_w_prob if feed_prev: rnn_input = decoder(hidden) else: rnn_input = inputs[:, i] else: rnn_input = inputs[:, i] if mask is not None: mask_i = mask[:, i, :] rnn_input = torch.cat((rnn_input, mask_i), -1) if isinstance(cell, GRUCellExpDecay): cum_delta_t = cum_delta_ts[:, i] input_w_t = torch.cat((rnn_input, cum_delta_t), -1).squeeze(1) else: input_w_t = torch.cat((rnn_input, delta_t), -1).squeeze(1) prev_hidden = hidden hidden = cell(input_w_t, hidden) if masked_update and (mask is not None) and (prev_hidden is not None): # update only the hidden states for hidden state only if at least one feature is present for the current time point summed_mask = (torch.sum(mask_i, -1, keepdim=True) > 0).float() assert (not torch.isnan(summed_mask).any()) hidden = summed_mask * hidden + (1 - summed_mask) * prev_hidden all_hiddens.append(hidden) all_hiddens = torch.stack(all_hiddens, 0) all_hiddens = all_hiddens.permute(1, 0, 2).unsqueeze(0) return hidden, all_hiddens
def get_reconstruction(self, time_steps_to_predict, truth, truth_time_steps, mask=None, n_traj_samples=1, run_backwards=True, mode=None): if isinstance(self.encoder_z0, Encoder_z0_ODE_RNN) or \ isinstance(self.encoder_z0, Encoder_z0_RNN): truth_w_mask = truth if mask is not None: truth_w_mask = torch.cat((truth, mask), -1) first_point_mu, first_point_std = self.encoder_z0( truth_w_mask, truth_time_steps, run_backwards=run_backwards) means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1) sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1) first_point_enc = utils.sample_standard_gaussian( means_z0, sigma_z0) else: raise Exception("Unknown encoder type {}".format( type(self.encoder_z0).__name__)) first_point_std = first_point_std.abs() assert (torch.sum(first_point_std < 0) == 0.) if self.use_poisson_proc: n_traj_samples, n_traj, n_dims = first_point_enc.size() # append a vector of zeros to compute the integral of lambda zeros = torch.zeros([n_traj_samples, n_traj, self.input_dim]).to(get_device(truth)) first_point_enc_aug = torch.cat((first_point_enc, zeros), -1) means_z0_aug = torch.cat((means_z0, zeros), -1) else: first_point_enc_aug = first_point_enc means_z0_aug = means_z0 assert (not torch.isnan(time_steps_to_predict).any()) assert (not torch.isnan(first_point_enc).any()) assert (not torch.isnan(first_point_enc_aug).any()) # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents] sol_y = self.diffeq_solver(first_point_enc_aug, time_steps_to_predict) if self.use_poisson_proc: sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate( sol_y) assert (torch.sum(int_lambda[:, :, 0, :]) == 0.) assert (torch.sum(int_lambda[0, 0, -1, :] <= 0) == 0.) pred_x = self.decoder(sol_y) all_extra_info = { "first_point": (first_point_mu, first_point_std, first_point_enc), "latent_traj": sol_y.detach() } if self.use_poisson_proc: # intergral of lambda from the last step of ODE Solver all_extra_info["int_lambda"] = int_lambda[:, :, -1, :] all_extra_info["log_lambda_y"] = log_lambda_y if self.use_binary_classif: if self.classif_per_tp: all_extra_info["label_predictions"] = self.classifier(sol_y) else: all_extra_info["label_predictions"] = self.classifier( first_point_enc).squeeze(-1) return pred_x, all_extra_info
def run_odernn(self, data, time_steps, run_backwards=True, save_info=False): # IMPORTANT: assumes that 'data' already has mask concatenated to it n_traj, n_tp, n_dims = data.size() extra_info = [] t0 = time_steps[-1] if run_backwards: t0 = time_steps[0] device = get_device(data) prev_y = torch.zeros((1, n_traj, self.latent_dim)).to(device) prev_std = torch.zeros((1, n_traj, self.latent_dim)).to(device) prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1] interval_length = time_steps[-1] - time_steps[0] minimum_step = interval_length / 50 #print("minimum step: {}".format(minimum_step)) assert (not torch.isnan(data).any()) assert (not torch.isnan(time_steps).any()) latent_ys = [] # Run ODE backwards and combine the y(t) estimates using gating time_points_iter = range(0, len(time_steps)) if run_backwards: time_points_iter = reversed(time_points_iter) for i in time_points_iter: if (prev_t - t_i) < minimum_step: time_points = torch.stack((prev_t, t_i)) inc = self.z0_diffeq_solver.ode_func(prev_t, prev_y) * (t_i - prev_t) assert (not torch.isnan(inc).any()) ode_sol = prev_y + inc ode_sol = torch.stack((prev_y, ode_sol), 2).to(device) assert (not torch.isnan(ode_sol).any()) else: n_intermediate_tp = max(2, ((prev_t - t_i) / minimum_step).int()) time_points = utils.linspace_vector(prev_t, t_i, n_intermediate_tp) ode_sol = self.z0_diffeq_solver(prev_y, time_points) assert (not torch.isnan(ode_sol).any()) if torch.mean(ode_sol[:, :, 0, :] - prev_y) >= 0.001: print( "Error: first point of the ODE is not equal to initial value" ) print(torch.mean(ode_sol[:, :, 0, :] - prev_y)) exit() #assert(torch.mean(ode_sol[:, :, 0, :] - prev_y) < 0.001) yi_ode = ode_sol[:, :, -1, :] xi = data[:, i, :].unsqueeze(0) yi, yi_std = self.GRU_update(yi_ode, prev_std, xi) prev_y, prev_std = yi, yi_std prev_t, t_i = time_steps[i], time_steps[i - 1] latent_ys.append(yi) if save_info: d = { "yi_ode": yi_ode.detach(), #"yi_from_data": yi_from_data, "yi": yi.detach(), "yi_std": yi_std.detach(), "time_points": time_points.detach(), "ode_sol": ode_sol.detach() } extra_info.append(d) latent_ys = torch.stack(latent_ys, 1) assert (not torch.isnan(yi).any()) assert (not torch.isnan(yi_std).any()) return yi, yi_std, latent_ys, extra_info