def get_reconstruction(self, batch_en, batch_de, batch_g, n_traj_samples=1, run_backwards=True): #Encoder: first_point_mu, first_point_std = self.encoder_z0( batch_en.x, batch_en.edge_attr, batch_en.edge_index, batch_en.pos, batch_en.edge_same, batch_en.batch, batch_en.y) # [num_ball,10] means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1) #[3,num_ball,10] sigmas_z0 = first_point_std.repeat(n_traj_samples, 1, 1) #[3,num_ball,10] first_point_enc = utils.sample_standard_gaussian( means_z0, sigmas_z0) #[3,num_ball,10] first_point_std = first_point_std.abs() time_steps_to_predict = batch_de["time_steps"] assert (torch.sum(first_point_std < 0) == 0.) assert (not torch.isnan(time_steps_to_predict).any()) assert (not torch.isnan(first_point_enc).any()) # ODE: Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents] sol_y = self.diffeq_solver(first_point_enc, time_steps_to_predict, batch_g) # Decoder: pred_x = self.decoder(sol_y) all_extra_info = { "first_point": (torch.unsqueeze(first_point_mu, 0), torch.unsqueeze(first_point_std, 0), first_point_enc), "latent_traj": sol_y.detach() } return pred_x, all_extra_info, None
def get_reconstruction(self, time_steps_to_predict, truth, truth_time_steps, mask=None, n_traj_samples=1, run_backwards=True, mode=None): if isinstance(self.encoder_z0, Encoder_z0_ODE_RNN) or \ isinstance(self.encoder_z0, Encoder_z0_RNN) or \ isinstance(self.encoder_z0, TransformerLayer): truth_w_mask = truth if mask is not None: truth_w_mask = torch.cat((truth, mask), -1) first_point_mu, first_point_std = self.encoder_z0( truth_w_mask, truth_time_steps, run_backwards=run_backwards) means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1) sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1) first_point_enc = utils.sample_standard_gaussian( means_z0, sigma_z0) else: raise Exception("Unknown encoder type {}".format( type(self.encoder_z0).__name__)) first_point_std = first_point_std.abs() assert (torch.sum(first_point_std < 0) == 0.) if self.use_poisson_proc: n_traj_samples, n_traj, n_dims = first_point_enc.size() # append a vector of zeros to compute the integral of lambda zeros = torch.zeros([n_traj_samples, n_traj, self.input_dim]).to(get_device(truth)) first_point_enc_aug = torch.cat((first_point_enc, zeros), -1) means_z0_aug = torch.cat((means_z0, zeros), -1) else: first_point_enc_aug = first_point_enc means_z0_aug = means_z0 assert (not torch.isnan(time_steps_to_predict).any()) assert (not torch.isnan(first_point_enc).any()) assert (not torch.isnan(first_point_enc_aug).any()) # Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents] sol_y = self.diffeq_solver(first_point_enc_aug, time_steps_to_predict) if self.use_poisson_proc: sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate( sol_y) assert (torch.sum(int_lambda[:, :, 0, :]) == 0.) assert (torch.sum(int_lambda[0, 0, -1, :] <= 0) == 0.) pred_x = self.decoder(sol_y) all_extra_info = { "first_point": (first_point_mu, first_point_std, first_point_enc), "latent_traj": sol_y.detach() } if self.use_poisson_proc: # intergral of lambda from the last step of ODE Solver all_extra_info["int_lambda"] = int_lambda[:, :, -1, :] all_extra_info["log_lambda_y"] = log_lambda_y if self.use_binary_classif: if self.classif_per_tp: all_extra_info["label_predictions"] = self.classifier(sol_y) else: all_extra_info["label_predictions"] = self.classifier( first_point_enc).squeeze(-1) return pred_x, all_extra_info
def get_reconstruction(self, time_steps_to_predict, data, truth_time_steps, mask=None, n_traj_samples=1, mode=None): assert (mask is not None) batch_size = data.size(0) zero_delta_t = torch.Tensor([0.]).to(self.device) # run encoder backwards run_backwards = bool(time_steps_to_predict[0] < truth_time_steps[-1]) if run_backwards: # Look at data in the reverse order: from later points to the first data = utils.reverse(data) mask = utils.reverse(mask) delta_ts = truth_time_steps[1:] - truth_time_steps[:-1] if run_backwards: # we are going backwards in time delta_ts = utils.reverse(delta_ts) delta_ts = torch.cat((delta_ts, zero_delta_t)) if len(delta_ts.size()) == 1: # delta_ts are shared for all trajectories in a batch assert (data.size(1) == delta_ts.size(0)) delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size, 1, 1)) input_decay_params = None if self.input_space_decay: input_decay_params = (self.w_input_decay, self.b_input_decay) hidden_state, _ = run_rnn(data, delta_ts, cell=self.rnn_cell_enc, mask=mask, input_decay_params=input_decay_params) z0_mean, z0_std = utils.split_last_dim(self.z0_net(hidden_state)) z0_std = z0_std.abs() z0_sample = utils.sample_standard_gaussian(z0_mean, z0_std) # Decoder # # # # # # # # # # # # # # # # # # # # delta_ts = torch.cat( (zero_delta_t, time_steps_to_predict[1:] - time_steps_to_predict[:-1])) if len(delta_ts.size()) == 1: delta_ts = delta_ts.unsqueeze(-1).repeat((batch_size, 1, 1)) _, all_hiddens = run_rnn(data, delta_ts, cell=self.rnn_cell_dec, first_hidden=z0_sample, feed_previous=True, n_steps=time_steps_to_predict.size(0), decoder=self.decoder, input_decay_params=input_decay_params) outputs = self.decoder(all_hiddens) # Shift outputs for computing the loss -- we should compare the first output to the second data point, etc. first_point = data[:, 0, :] outputs = utils.shift_outputs(outputs, first_point) extra_info = { "first_point": (z0_mean.unsqueeze(0), z0_std.unsqueeze(0), z0_sample.unsqueeze(0)) } if self.use_binary_classif: if self.classif_per_tp: extra_info["label_predictions"] = self.classifier(all_hiddens) else: extra_info["label_predictions"] = self.classifier( z0_mean).reshape(1, -1) # outputs shape: [n_traj_samples, n_traj, n_tp, n_dims] return outputs, extra_info
def run_latent_ctfp_model(args, encoder, aug_model, values, times, vars, masks, evaluation=False): """ Functions for running the latent ctfp model Parameters: args: arguments returned from parse_arguments encoder: ode_rnn model as encoder aug_model: ctfp model as decoder values: observations, a 3-D tensor of shape batchsize x max_length x input_size times: observation time stampes, a 3-D tensor of shape batchsize x max_length x 1 vars: Difference between consequtive observation time stampes. 2-D tensor of size batch_size x length masks: a 2-D binary tensor of shape batchsize x max_length showing whehter the position is observation or padded dummy variables evluation (bool): whether to run the latent ctfp model in the evaluation mode. Return IWAE if set to true. Return both IWAE and training loss if set to false Returns: Return IWAE if evaluation set to true. Return both IWAE and training loss if evaluation set to false. """ if evaluation: num_iwae_samples = args.niwae_test batch_size = args.test_batch_size else: num_iwae_samples = args.num_iwae_samples batch_size = args.batch_size data_batches = create_separate_batches(values, times, masks) mean_list, stdv_list = [], [] for item in data_batches: z_mean, z_stdv = encoder(item[0], item[1]) mean_list.append(z_mean) stdv_list.append(z_stdv) means = torch.cat(mean_list, dim=1) stdvs = torch.cat(stdv_list, dim=1) # Sample latent variables repeat_times = [1] * len(means.shape) repeat_times[0] = num_iwae_samples means = means.repeat(*repeat_times) stdvs = stdvs.repeat(*repeat_times) latent = sample_standard_gaussian(means, stdvs) ## Decode latent latent_sequence = latent.view(-1, args.latent_size).unsqueeze(1) max_length = times.shape[1] latent_sequence = latent_sequence.repeat(1, max_length, 1) time_to_cat = times.repeat(num_iwae_samples, 1, 1) times = torch.cat([latent_sequence, time_to_cat], -1) ## run flow forward to get augmented dimensions values = values.repeat(num_iwae_samples, 1, 1) aux = torch.cat([torch.zeros_like(values), times], dim=2) aux = aux.view(-1, aux.shape[2]) aux, _ = aug_model(aux, torch.zeros(aux.shape[0], 1).to(aux), reverse=True) aux = aux[:, args.effective_shape:] ## run flow backward if args.activation == "exp": transform_values, transform_logdet = log_jaco(values) elif args.activation == "softplus": transform_values, transform_logdet = inversoft_jaco(values) elif args.activation == "identity": transform_values = values transform_logdet = torch.sum(torch.zeros_like(values), dim=2) else: raise NotImplementedError aug_values = torch.cat( [transform_values.view(-1, transform_values.shape[2]), aux], dim=1) base_values, flow_logdet = aug_model( aug_values, torch.zeros(aug_values.shape[0], 1).to(aug_values)) base_values = base_values[:, :args.effective_shape] base_values = base_values.view(values.shape[0], -1, args.effective_shape) ## flow_logdet and transform_logdet are both of size length*batch_size x length flow_logdet = flow_logdet.sum(-1).view(num_iwae_samples * batch_size, -1) transform_logdet = transform_logdet.view(num_iwae_samples * batch_size, -1) if len(vars.shape) == 2: vars_unsqueed = vars.unsqueeze(-1) else: vars_unsqueed = vars ll = compute_ll( flow_logdet + transform_logdet, base_values, vars_unsqueed.repeat(num_iwae_samples, 1, 1), masks.repeat(num_iwae_samples, 1), ) ll = ll.view(num_iwae_samples, batch_size) ## Reconstruction log likelihood ## Compute KL divergence and compute IWAE posterior = torch.distributions.Normal(means[:1], stdvs[:1]) prior = torch.distributions.Normal(torch.zeros_like(means[:1]), torch.ones_like(stdvs[:1])) # kl_latent = kl_divergence(posterior, prior).sum(-1) prior_z = prior.log_prob(latent).sum(-1) posterior_z = posterior.log_prob(latent).sum(-1) weights = ll + prior_z - posterior_z loss = -torch.logsumexp(weights, 0) + np.log(num_iwae_samples) if evaluation: return torch.sum(loss) / torch.sum(masks) loss = torch.sum(loss) / (batch_size * max_length) loss_training = -torch.sum(F.softmax(weights, 0).detach() * weights) / ( batch_size * max_length) return loss, loss_training
def get_reconstruction(self, time_steps_to_predict, truth, truth_time_steps, mask=None, n_traj_samples=1, run_backwards=True, mode='interp'): """ :param time_steps_to_predict: :param truth: :param truth_time_steps: :param mask: :param n_traj_samples: :param run_backwards: :param mode: extrap or interp :return: """ if isinstance(self.encoder_z0, Encoder_z0_ODE_RNN) or \ isinstance(self.encoder_z0, Encoder_z0_RNN): truth_w_mask = truth if mask is not None: truth_w_mask = torch.cat((truth, mask), -1) first_point_mu, first_point_std = self.encoder_z0( truth_w_mask, truth_time_steps, run_backwards=run_backwards) means_z0 = first_point_mu.repeat(n_traj_samples, 1, 1) sigma_z0 = first_point_std.repeat(n_traj_samples, 1, 1) first_point_enc = utils.sample_standard_gaussian( means_z0, sigma_z0) else: raise Exception("Unknown encoder type {}".format( type(self.encoder_z0).__name__)) first_point_std = first_point_std.abs() assert (torch.sum(first_point_std < 0) == 0.) n_traj_samples, n_traj, n_dims = first_point_enc.size() if self.use_poisson_proc: # append a vector of zeros to compute the integral of lambda zeros = torch.zeros([n_traj_samples, n_traj, self.input_dim]).to(utils.get_device(truth)) first_point_enc_aug = torch.cat((first_point_enc, zeros), -1) means_z0_aug = torch.cat((means_z0, zeros), -1) else: first_point_enc_aug = first_point_enc means_z0_aug = means_z0 assert (not torch.isnan(time_steps_to_predict).any()) assert (not torch.isnan(first_point_enc).any()) assert (not torch.isnan(first_point_enc_aug).any()) # Add log prob part of prior in initial state in posterior ODE cum_log_prob_prior_0 = torch.zeros([n_traj_samples, n_traj, 1]).to(utils.get_device(truth)) first_point_enc_aug = torch.cat( (first_point_enc_aug, cum_log_prob_prior_0), -1) # region make cubic spline interpolation truth_for_interpolate = truth.clone() if mask is not None: # NaturalCubicSpline requires tagging nan for unknown positions in data series. truth_for_interpolate[mask == 0] = torch.tensor(float('nan')).to( utils.get_device(truth)) coeffs = natural_cubic_spline_coeffs(truth_time_steps, truth_for_interpolate) interpolation = NaturalCubicSpline(truth_time_steps, coeffs) self.diffeq_solver.ode_func.splines_setup(interpolation) # endregion if mode == 'interp': sol_y_with_logprob = self.diffeq_solver(first_point_enc_aug, time_steps_to_predict) sol_y = sol_y_with_logprob[..., :-1] # TODO :此处 * const_for_prior_logp是因为:从随机过程对应分布得到的概率密度数值不稳定,因此算ode的时候直接/const_for_prior_logp是因为了,算完之后要乘回来 prior_log_prob = sol_y_with_logprob[ ..., -1] * self.diffeq_solver.ode_func.const_for_prior_logp elif mode == 'extrap': sol_y_with_logprob = self.diffeq_solver(first_point_enc_aug, truth_time_steps) sol_y = sol_y_with_logprob[..., :-1] # TODO : 利用GaussPrior以sol_y[:,:,-1]为起点,做time_steps_to_predict上的自回归预测: # TODO: 方案1: 可以直接从mean采样求解ODE,不过这样就不是采样了。 # TODO: 方案2:局部线性化(一阶泰勒展开),在t时刻算t+dt的预测分布并采样。具体公式参考 # Chen, F., Agüero, J. C., Gilson, M., Garnier, H., & Liu, T. (2017). # EM-based identification of continuous-time ARMA Models from irregularly sampled data. # Automatica, 77, 293–301. https://doi.org/10.1016/j.automatica.2016.11.020 中的Equ.(6-8) raise NotImplementedError if self.use_poisson_proc: sol_y, log_lambda_y, int_lambda, _ = self.diffeq_solver.ode_func.extract_poisson_rate( sol_y) assert (torch.sum(int_lambda[:, :, 0, :]) == 0.) assert (torch.sum(int_lambda[0, 0, -1, :] <= 0) == 0.) pred_x = self.decoder(sol_y) all_extra_info = { "first_point": (first_point_mu, first_point_std, first_point_enc), "latent_traj": sol_y.detach(), 'prior_log_prob': prior_log_prob } if self.use_poisson_proc: # intergral of lambda from the last step of ODE Solver all_extra_info["int_lambda"] = int_lambda[:, :, -1, :] all_extra_info["log_lambda_y"] = log_lambda_y if self.use_binary_classif: if self.classif_per_tp: all_extra_info["label_predictions"] = self.classifier(sol_y) else: all_extra_info["label_predictions"] = self.classifier( first_point_enc).squeeze(-1) return pred_x, all_extra_info