def _forward(self, x, t, temp, mask): # Transform markers and timesteps into the embedding spaces phi_x, phi_t = self.embed_x(x), self.embed_t(t) phi_xt = torch.cat([phi_x, phi_t], dim=-1) T,BS,_ = phi_x.shape ## Inference # Get the sampled value and (mean + var) latent variable # using the hidden state sequence posterior_sample_y, posterior_sample_z, posterior_logits_y, (posterior_mu_z, posterior_logvar_z) = self.encoder(phi_xt, temp) repeat_vals = (T, -1,-1) posterior_logits_y = posterior_logits_y.expand(*repeat_vals) # Create distributions for Posterior random vars posterior_dist_z = Normal(posterior_mu_z, torch.exp(posterior_logvar_z*0.5)) posterior_dist_y = Categorical(logits=posterior_logits_y) # Prior is just a Normal(0,1) dist for z and Uniform Categorical for y prior_dist_z = Normal(0.*posterior_mu_z, 1. + 0.*posterior_mu_z) prior_dist_y = Categorical(probs=1/self.cluster_dim + 0.*posterior_logits_y) ## Generative Part # Use the embedded markers and times to create another set of # hidden vectors. Can reuse the h_0 and time_marker combined computed above # Run RNN over the concatenated embedded sequence h_0 = torch.zeros(1, BS, self.hidden_dim).to(device) # Run RNN hidden_seq, _ = self.rnn(phi_xt, h_0) # Append h_0 to h_1 .. h_T hidden_seq = torch.cat([h_0, hidden_seq], dim=0) # Combine (z_t, h_t, y) form the input for the generative part concat_hzy = torch.cat([hidden_seq[:-1], posterior_sample_z, posterior_sample_y], dim=-1) phi_hzy = self.gen_pre_module(concat_hzy) mu_marker, logvar_marker = generate_marker(self, phi_hzy, None) time_log_likelihood, mu_time = compute_point_log_likelihood(self, phi_hzy, t) marker_log_likelihood = compute_marker_log_likelihood(self, x, mu_marker, logvar_marker) KL_cluster = kl_divergence(posterior_dist_y, prior_dist_y)*mask KL_z = kl_divergence(posterior_dist_z, prior_dist_z).sum(-1)*mask KL = KL_cluster.sum() + KL_z.sum() try: assert (KL >= 0) except: import pdb; pdb.set_trace() metric_dict = {} with torch.no_grad(): if self.time_loss == 'intensity': mu_time = compute_time_expectation(self, hidden_seq, t, mask)[:,:, None] get_marker_metric(self.marker_type, mu_marker, x, mask, metric_dict) get_time_metric(mu_time, t, mask, metric_dict) return time_log_likelihood, marker_log_likelihood, KL, metric_dict
def _forward(self, x, t, mask): """ Input: x : Tensor of shape TxBSxmarker_dim (if real) Tensor of shape TxBSx1(if categorical) t : Tensor of shape TxBSxtime_dim [i,:,0] represents actual time at timestep i ,\ [i,:,1] represents time gap d_i = t_i- t_{i-1} mask: Tensor of shape TxBSx1. If mask[t,i,0] =1 then that timestamp is present Output: """ # Tensor of shape (T)xBSxself.shared_output_layers[-1] _, hidden_states = self.run_forward_rnn(x, t) T, bs = x.size(0), x.size(1) # marker generation layer. Ideally it should include time gap also. # Tensor of shape TxBSx marker_dim marker_out_mu, marker_out_logvar = generate_marker( self, hidden_states, t) metric_dict = {} time_log_likelihood, mu_time = compute_point_log_likelihood( self, hidden_states, t) with torch.no_grad(): get_marker_metric(self.marker_type, marker_out_mu, x, mask, metric_dict) if self.time_loss == 'intensity': expected_t = compute_time_expectation(self, hidden_states, t, mask) time_mse = torch.abs(expected_t - t[:, :, 0])[1:, :] * mask[1:, :] else: time_mse = torch.abs(mu_time[:, :, 0] - t[:, :, 0])[1:, :] * mask[1:, :] metric_dict['time_mse'] = time_mse.sum().detach().cpu().numpy() metric_dict['time_mse_count'] = mask[ 1:, :].sum().detach().cpu().numpy() #Pad initial Time point with 0 zero_pad = torch.zeros(1, bs).to(device) time_log_likelihood = torch.cat([zero_pad, time_log_likelihood[1:, :]], dim=0) marker_log_likelihood = compute_marker_log_likelihood( self, x, marker_out_mu, marker_out_logvar) return time_log_likelihood, marker_log_likelihood, metric_dict # TxBS and TxBS
def _forward(self, x, t, mask): """ Input: x : Tensor of shape TxBSxmarker_dim (if real) Tensor of shape TxBSx1(if categorical) t : Tensor of shape TxBSx2 [i,:,0] represents actual time at timestep i ,\ [i,:,1] represents time gap d_i = t_i- t_{i-1} mask: Tensor of shape TxBS. If mask[t,i] =1 then that timestamp is present Output: """ # Tensor of shape (T)xBSxhidden_dim hs, back_hs = self.run_forward_backward_rnn(x, t, mask) mu, logvar = self.encoder(hs, back_hs) #TxBSxlatent_dim z = self.reparameterize( mu[0, :, :], logvar[0, :, :])[None, :, :] #of shape 1xBSxlatent_dim hz_embedded = self.preprocess_hidden_latent_state(hs, z) # marker generation layer. Ideally it should include time gap also. # Tensor of shape TxBSx marker_dim marker_out_mu, marker_out_logvar = generate_marker( self, hz_embedded, t) marker_log_likelihood = compute_marker_log_likelihood( self, x, marker_out_mu, marker_out_logvar) time_log_likelihood, mu_time = compute_point_log_likelihood( self, hz_embedded, t) metric_dict = {} with torch.no_grad(): if self.time_loss == 'intensity': mu_time = compute_time_expectation(self, hz_embedded, t, mask)[:, :, None] get_marker_metric(self.marker_type, marker_out_mu, x, mask, metric_dict) get_time_metric(mu_time, t, mask, metric_dict) posterior_dist = Normal(mu[0, :, :], logvar[0, :, :].exp().sqrt()) prior_dist = Normal(0, 1) kld_loss = kl_divergence(posterior_dist, prior_dist) #Shape BSxlatent_dim return time_log_likelihood, marker_log_likelihood, kld_loss, metric_dict # TxBS and TxBS
def _forward(self, x, t, temp, mask): # Transform markers and timesteps into the embedding spaces phi_x, phi_t = self.embed_x(x), self.embed_t(t) phi_xt = torch.cat([phi_x, phi_t], dim=-1) T, BS, _ = phi_x.shape ##Compute h_t Shape T+1, BS, dim # Run RNN over the concatenated embedded sequence h_0 = torch.zeros(1, BS, self.hidden_dim).to(device) # Run RNN hidden_seq, _ = self.rnn(phi_xt, h_0) # Append h_0 to h_1 .. h_T hidden_seq = torch.cat([h_0, hidden_seq], dim=0) ## Inference a_t= q([x_t, h_t], a_{t+1}) # Get the sampled value and (mean + var) latent variable # using the hidden state sequence posterior_sample_y, posterior_sample_z, posterior_logits_y, ( posterior_mu_z, posterior_logvar_z) = self.encoder(phi_xt, hidden_seq[:-1, :, :], temp, mask) # Create distributions for Posterior random vars posterior_dist_z = Normal(posterior_mu_z, torch.exp(posterior_logvar_z * 0.5)) posterior_dist_y = Categorical(logits=posterior_logits_y) # Prior is just a Normal(0,1) dist for z and Uniform Categorical for y #prior dist z is TxBSx latent_dim. T=0=> Normal(0,1) prior_mu, prior_logvar = self.prior(posterior_sample_z) ##Normal(0, 1) prior_dist_z = Normal(prior_mu, (prior_logvar * 0.5).exp()) prior_dist_y = Categorical( probs=1. / self.cluster_dim * torch.ones(1, BS, self.cluster_dim).to(device)) ## Generative Part # Use the embedded markers and times to create another set of # hidden vectors. Can reuse the h_0 and time_marker combined computed above # Combine (z_t, h_t, y) form the input for the generative part concat_hzy = torch.cat( [hidden_seq[:-1], posterior_sample_z, posterior_sample_y], dim=-1) phi_hzy = self.gen_pre_module(concat_hzy) mu_marker, logvar_marker = generate_marker(self, phi_hzy, None) time_log_likelihood, mu_time = compute_point_log_likelihood( self, phi_hzy, t) marker_log_likelihood = compute_marker_log_likelihood( self, x, mu_marker, logvar_marker) KL_cluster = kl_divergence(posterior_dist_y, prior_dist_y) KL_z = kl_divergence(posterior_dist_z, prior_dist_z).sum(-1) * mask KL = KL_cluster.sum() + KL_z.sum() try: assert (KL >= 0) except: import pdb pdb.set_trace() ##Augmented loss sl, bs = x.size(0), x.size(1) temp_ = posterior_logits_y.expand(*(sl, -1, -1)) aug_layer = self.aug_output_layer(temp_) #1xBSxc aug_x_mu = self.aug_output_x_mu(aug_layer) #1xBSxmarker_dim aug_t_mu, aug_t_logvar = self.aug_output_t_mu( aug_layer), self.aug_output_t_logvar(aug_layer) aug_t_sigma = (aug_t_logvar * 0.5).exp() + self.sigma_min aug_time_recon = Normal(aug_t_mu, aug_t_sigma) aug_ll = (aug_time_recon.log_prob( t[:, :, 0][:, :, None])).sum(dim=-1) * mask #TxBS aug_mu_ = aug_x_mu.view(-1, self.marker_dim) # x_ = x.view(-1) #TxBS aug_ml = -1 * F.cross_entropy(aug_mu_, x_, reduction='none').view( sl, bs) * mask #TxBS aug_loss = -1. * (aug_ll[1:, :] + aug_ml[1:, :]).sum() metric_dict = {"z_cluster": posterior_logits_y.detach().cpu()} with torch.no_grad(): if self.time_loss == 'intensity': mu_time = compute_time_expectation(self, hidden_seq, t, mask)[:, :, None] get_marker_metric(self.marker_type, mu_marker, x, mask, metric_dict) get_time_metric(mu_time, t, mask, metric_dict) return time_log_likelihood, marker_log_likelihood, KL, metric_dict, aug_loss