示例#1
0
    def _forward(self, x, t, temp, mask):
        # Transform markers and timesteps into the embedding spaces
        phi_x, phi_t = self.embed_x(x), self.embed_t(t)
        phi_xt = torch.cat([phi_x, phi_t], dim=-1)
        T,BS,_ = phi_x.shape
                
        ## Inference
        # Get the sampled value and (mean + var) latent variable
        # using the hidden state sequence
        posterior_sample_y, posterior_sample_z, posterior_logits_y, (posterior_mu_z, posterior_logvar_z) = self.encoder(phi_xt, temp)

        repeat_vals = (T, -1,-1)
        posterior_logits_y = posterior_logits_y.expand(*repeat_vals)
        # Create distributions for Posterior random vars
        posterior_dist_z = Normal(posterior_mu_z, torch.exp(posterior_logvar_z*0.5))
        posterior_dist_y = Categorical(logits=posterior_logits_y)
        
        # Prior is just a Normal(0,1) dist for z and Uniform Categorical for y
        prior_dist_z = Normal(0.*posterior_mu_z, 1. + 0.*posterior_mu_z)
        prior_dist_y = Categorical(probs=1/self.cluster_dim + 0.*posterior_logits_y)

        ## Generative Part
        
        # Use the embedded markers and times to create another set of 
        # hidden vectors. Can reuse the h_0 and time_marker combined computed above

        # Run RNN over the concatenated embedded sequence
        h_0 = torch.zeros(1, BS, self.hidden_dim).to(device)
        # Run RNN
        hidden_seq, _ = self.rnn(phi_xt, h_0)
        # Append h_0 to h_1 .. h_T
        hidden_seq = torch.cat([h_0, hidden_seq], dim=0)
        
        # Combine (z_t, h_t, y) form the input for the generative part
        concat_hzy = torch.cat([hidden_seq[:-1], posterior_sample_z, posterior_sample_y], dim=-1)
        phi_hzy = self.gen_pre_module(concat_hzy)
        mu_marker, logvar_marker = generate_marker(self, phi_hzy, None)
        time_log_likelihood, mu_time = compute_point_log_likelihood(self, phi_hzy, t)
        marker_log_likelihood = compute_marker_log_likelihood(self, x, mu_marker, logvar_marker)
        
        KL_cluster = kl_divergence(posterior_dist_y, prior_dist_y)*mask
        KL_z = kl_divergence(posterior_dist_z, prior_dist_z).sum(-1)*mask
        KL = KL_cluster.sum() + KL_z.sum()
        try:
            assert (KL >= 0)
        except:
            import pdb; pdb.set_trace()
        metric_dict = {}
        with torch.no_grad():
            if self.time_loss == 'intensity':
                mu_time = compute_time_expectation(self, hidden_seq, t, mask)[:,:, None]
            get_marker_metric(self.marker_type, mu_marker, x, mask, metric_dict)
            get_time_metric(mu_time,  t, mask, metric_dict)
            
        return time_log_likelihood, marker_log_likelihood, KL, metric_dict
示例#2
0
    def _forward(self, x, t, mask):
        """
            Input: 
                x   : Tensor of shape TxBSxmarker_dim (if real)
                     Tensor of shape TxBSx1(if categorical)
                t   : Tensor of shape TxBSxtime_dim [i,:,0] represents actual time at timestep i ,\
                    [i,:,1] represents time gap d_i = t_i- t_{i-1}
                mask: Tensor of shape TxBSx1. If mask[t,i,0] =1 then that timestamp is present
            Output:

        """

        # Tensor of shape (T)xBSxself.shared_output_layers[-1]
        _, hidden_states = self.run_forward_rnn(x, t)
        T, bs = x.size(0), x.size(1)

        # marker generation layer. Ideally it should include time gap also.
        # Tensor of shape TxBSx marker_dim
        marker_out_mu, marker_out_logvar = generate_marker(
            self, hidden_states, t)

        metric_dict = {}
        time_log_likelihood, mu_time = compute_point_log_likelihood(
            self, hidden_states, t)
        with torch.no_grad():
            get_marker_metric(self.marker_type, marker_out_mu, x, mask,
                              metric_dict)
            if self.time_loss == 'intensity':
                expected_t = compute_time_expectation(self, hidden_states, t,
                                                      mask)
                time_mse = torch.abs(expected_t -
                                     t[:, :, 0])[1:, :] * mask[1:, :]
            else:
                time_mse = torch.abs(mu_time[:, :, 0] -
                                     t[:, :, 0])[1:, :] * mask[1:, :]
            metric_dict['time_mse'] = time_mse.sum().detach().cpu().numpy()
            metric_dict['time_mse_count'] = mask[
                1:, :].sum().detach().cpu().numpy()

        #Pad initial Time point with 0
        zero_pad = torch.zeros(1, bs).to(device)
        time_log_likelihood = torch.cat([zero_pad, time_log_likelihood[1:, :]],
                                        dim=0)
        marker_log_likelihood = compute_marker_log_likelihood(
            self, x, marker_out_mu, marker_out_logvar)

        return time_log_likelihood, marker_log_likelihood, metric_dict  # TxBS and TxBS
示例#3
0
    def _forward(self, x, t, mask):
        """
            Input: 
                x   : Tensor of shape TxBSxmarker_dim (if real)
                     Tensor of shape TxBSx1(if categorical)
                t   : Tensor of shape TxBSx2 [i,:,0] represents actual time at timestep i ,\
                    [i,:,1] represents time gap d_i = t_i- t_{i-1}
                mask: Tensor of shape TxBS. If mask[t,i] =1 then that timestamp is present
            Output:

        """

        # Tensor of shape (T)xBSxhidden_dim
        hs, back_hs = self.run_forward_backward_rnn(x, t, mask)

        mu, logvar = self.encoder(hs, back_hs)  #TxBSxlatent_dim
        z = self.reparameterize(
            mu[0, :, :],
            logvar[0, :, :])[None, :, :]  #of shape 1xBSxlatent_dim

        hz_embedded = self.preprocess_hidden_latent_state(hs, z)
        # marker generation layer. Ideally it should include time gap also.
        # Tensor of shape TxBSx marker_dim
        marker_out_mu, marker_out_logvar = generate_marker(
            self, hz_embedded, t)
        marker_log_likelihood = compute_marker_log_likelihood(
            self, x, marker_out_mu, marker_out_logvar)

        time_log_likelihood, mu_time = compute_point_log_likelihood(
            self, hz_embedded, t)
        metric_dict = {}
        with torch.no_grad():
            if self.time_loss == 'intensity':
                mu_time = compute_time_expectation(self, hz_embedded, t,
                                                   mask)[:, :, None]
            get_marker_metric(self.marker_type, marker_out_mu, x, mask,
                              metric_dict)
            get_time_metric(mu_time, t, mask, metric_dict)

        posterior_dist = Normal(mu[0, :, :], logvar[0, :, :].exp().sqrt())
        prior_dist = Normal(0, 1)
        kld_loss = kl_divergence(posterior_dist,
                                 prior_dist)  #Shape BSxlatent_dim

        return time_log_likelihood, marker_log_likelihood, kld_loss, metric_dict  # TxBS and TxBS
示例#4
0
    def _forward(self, x, t, temp, mask):
        # Transform markers and timesteps into the embedding spaces
        phi_x, phi_t = self.embed_x(x), self.embed_t(t)
        phi_xt = torch.cat([phi_x, phi_t], dim=-1)
        T, BS, _ = phi_x.shape

        ##Compute h_t Shape T+1, BS, dim
        # Run RNN over the concatenated embedded sequence
        h_0 = torch.zeros(1, BS, self.hidden_dim).to(device)
        # Run RNN
        hidden_seq, _ = self.rnn(phi_xt, h_0)
        # Append h_0 to h_1 .. h_T
        hidden_seq = torch.cat([h_0, hidden_seq], dim=0)

        ## Inference a_t= q([x_t, h_t], a_{t+1})
        # Get the sampled value and (mean + var) latent variable
        # using the hidden state sequence
        posterior_sample_y, posterior_sample_z, posterior_logits_y, (
            posterior_mu_z,
            posterior_logvar_z) = self.encoder(phi_xt, hidden_seq[:-1, :, :],
                                               temp, mask)

        # Create distributions for Posterior random vars
        posterior_dist_z = Normal(posterior_mu_z,
                                  torch.exp(posterior_logvar_z * 0.5))
        posterior_dist_y = Categorical(logits=posterior_logits_y)

        # Prior is just a Normal(0,1) dist for z and Uniform Categorical for y
        #prior dist z is TxBSx latent_dim. T=0=> Normal(0,1)
        prior_mu, prior_logvar = self.prior(posterior_sample_z)  ##Normal(0, 1)
        prior_dist_z = Normal(prior_mu, (prior_logvar * 0.5).exp())

        prior_dist_y = Categorical(
            probs=1. / self.cluster_dim *
            torch.ones(1, BS, self.cluster_dim).to(device))

        ## Generative Part

        # Use the embedded markers and times to create another set of
        # hidden vectors. Can reuse the h_0 and time_marker combined computed above

        # Combine (z_t, h_t, y) form the input for the generative part
        concat_hzy = torch.cat(
            [hidden_seq[:-1], posterior_sample_z, posterior_sample_y], dim=-1)
        phi_hzy = self.gen_pre_module(concat_hzy)
        mu_marker, logvar_marker = generate_marker(self, phi_hzy, None)
        time_log_likelihood, mu_time = compute_point_log_likelihood(
            self, phi_hzy, t)
        marker_log_likelihood = compute_marker_log_likelihood(
            self, x, mu_marker, logvar_marker)

        KL_cluster = kl_divergence(posterior_dist_y, prior_dist_y)
        KL_z = kl_divergence(posterior_dist_z, prior_dist_z).sum(-1) * mask
        KL = KL_cluster.sum() + KL_z.sum()
        try:
            assert (KL >= 0)
        except:
            import pdb
            pdb.set_trace()

        ##Augmented loss
        sl, bs = x.size(0), x.size(1)
        temp_ = posterior_logits_y.expand(*(sl, -1, -1))
        aug_layer = self.aug_output_layer(temp_)  #1xBSxc
        aug_x_mu = self.aug_output_x_mu(aug_layer)  #1xBSxmarker_dim
        aug_t_mu, aug_t_logvar = self.aug_output_t_mu(
            aug_layer), self.aug_output_t_logvar(aug_layer)
        aug_t_sigma = (aug_t_logvar * 0.5).exp() + self.sigma_min
        aug_time_recon = Normal(aug_t_mu, aug_t_sigma)
        aug_ll = (aug_time_recon.log_prob(
            t[:, :, 0][:, :, None])).sum(dim=-1) * mask  #TxBS
        aug_mu_ = aug_x_mu.view(-1, self.marker_dim)  #
        x_ = x.view(-1)  #TxBS
        aug_ml = -1 * F.cross_entropy(aug_mu_, x_, reduction='none').view(
            sl, bs) * mask  #TxBS
        aug_loss = -1. * (aug_ll[1:, :] + aug_ml[1:, :]).sum()

        metric_dict = {"z_cluster": posterior_logits_y.detach().cpu()}
        with torch.no_grad():
            if self.time_loss == 'intensity':
                mu_time = compute_time_expectation(self, hidden_seq, t,
                                                   mask)[:, :, None]
            get_marker_metric(self.marker_type, mu_marker, x, mask,
                              metric_dict)
            get_time_metric(mu_time, t, mask, metric_dict)

        return time_log_likelihood, marker_log_likelihood, KL, metric_dict, aug_loss