예제 #1
0
    def preprocess_latent_states(self, augmented_hidden_seq,
                                 posterior_sample_y, posterior_sample_z, T,
                                 BS):
        """
        Transforms [(h_prime, z_0, y), (h_0, z_1, y), ..., (h_{T-2}, z_{T-1}, y)] to ---> [phi_0, .., phi_{T-1}]

        :param augmented_hidden_seq: T x BS x h_dim
        [h_prime, h0, ..., h_{T-1}]
        :param posterior_sample_y: Ny x T x BS x y_dim
        [y, ..., y]
        :param posterior_sample_z: Nz x Ny x T x BS x latent_dim
        [z_0, ..., z_{T-1}]
        :return: phi_hzy_seq: Nz x Ny x T x BS x shared_output_dims[-1]
        [phi_0, ..., phi_{T-1}]
        """
        def _get_final_dim(*args):
            tensors = tuple(args)
            return tuple(tensor.shape[-1] for tensor in tensors)

        # h_dim, y_dim = _get_final_dim(augmented_hidden_seq, posterior_sample_y)
        Nz, Ny = posterior_sample_z.shape[0], posterior_sample_y.shape[0]
        concat_hzy = torch.cat([
            prepend_dims_to_tensor(augmented_hidden_seq[:-1], Nz, Ny),
            posterior_sample_z,
            prepend_dims_to_tensor(posterior_sample_y, Nz)
        ],
                               dim=-1)  # (Nz, Ny, T, BS, (h+z+y)_dims)
        phi_hzy_seq = self.preprocessing_module(
            concat_hzy)  # (Nz, Ny, T,BS,shared_output_dims[-1])
        return phi_hzy_seq
예제 #2
0
    def preprocess_latent_states(self, augmented_hidden_seq,
                                 posterior_sample_y, posterior_sample_z, T,
                                 BS):
        """
        Transforms latent states to a representation which is then used as a hidden state to compute
         densities and likelihoods of time and marker.

        - Specifically, the latent state transformation is divided into 'past latent states' and 'future latent states'
        'Past latent states' as defined as ancestors of the data_j
        'Future latent states' as defined as descendants of the data_j
        - The decoder transforms:
        x_j <- mlp( mlp(past latent states), mlp(future latent states) )

        For this decoder, the dependence is (for x_j := data_j):
        x_j <- f(h_{j-1}, y, z_j, h_j) where [h_{j-1}, y, z_j] are 'past' and [h_j] is 'future' as per the model.
        meaning x_j <- f( g1(h_{j-1}, y, z_j), g2(h_j) )

        Boundary conditions:
        x_0 = f( g1( h_prime, y, z_0 ), g2(h_0) )
        x_1 = f( g1( h_0, y, z_1 ), g2(h_1) )
        x_{T-1} = f( g1( h_{T-1}, y, z_{T-1} ), g2(h_{T-1}) )

        The job of this function is to return [phi_0, .., phi_{T-1}] so that they can be used to generate
        [data_0, .., data_{T-1}]

        :param augmented_hidden_seq: T x BS x h_dim
        [h_prime, h0, ..., h_{T-1}]
        :param posterior_sample_y: Ny x T x BS x y_dim
        [y, ..., y]
        :param posterior_sample_z: Nz x Ny x T x BS x latent_dim
        [z_0, ..., z_{T-1}]
        :return: phi_hzy_seq: Nz x Ny x T x BS x shared_output_dims[-1]
        [phi_0, ..., phi_{T-1}]
        """

        # h_dim, y_dim = _get_final_dim(augmented_hidden_seq, posterior_sample_y)
        Nz, Ny = posterior_sample_z.shape[0], posterior_sample_y.shape[0]
        expanded_augmented_hidden_seq = prepend_dims_to_tensor(
            augmented_hidden_seq, Nz, Ny)
        concat_hzy = torch.cat([
            expanded_augmented_hidden_seq[:, :, :-1], posterior_sample_z,
            prepend_dims_to_tensor(posterior_sample_y, Nz)
        ],
                               dim=-1)  # (Nz, Ny, T, BS, (h+z+y)_dims)
        combined_phi = []
        past_influence = self.preprocessing_module_past(
            concat_hzy)  # (Nz, Ny, T, BS, filtering_out_dim)
        combined_phi.append(past_influence)
        if self.is_smoothing:
            # (Nz, Ny, T, BS, smoothing_out_dim)
            future_influence = self.preprocessing_module_future(
                expanded_augmented_hidden_seq[:, :, 1:])
            combined_phi.append(future_influence)
        phi_hzy_seq = torch.cat(combined_phi, dim=-1)
        return phi_hzy_seq
예제 #3
0
    def _get_encoded_z(self, sample_y, reverse_hidden_seq, T, BS, Nz):
        """

        The conditional independencies imply the following relation for z:
        z_j = f( z_{j-1}, y, a_j )
        where a_j is defined as the first reverse rnn hidden state that has seen x_j

        Boundary conditions:
        z_{T-1} = f( z_{T-2}, y, a_{T-1} )
        z_1 = f(z_0, y, a_1)
        z_0 = f(z_prime, y, a_0)
        where
        z_prime is a special latent state made of only zeros, and not learned
        a_0 is g(h_prime, x_0, a_1) and h_prime is also only zeros, and not learned

        :param Nz:
        :param sample_y: Ny x T x BS x y_dim
        :param reverse_hidden_seq: T x BS x h_dim
        [a_0, ..., a_{T-1}]
        :return: dist_z_seq, sample_z_seq
        dist_z_seq: Nz x Ny x T x BS x z_dim
        sample_z_seq: Nz x Ny x T x BS x z_dim
        """
        Ny = self.num_posterior_samples
        # Ancestral sampling
        _dists_z, _samples_z = [], []
        _prev_z = torch.zeros(Nz, Ny, 1, BS, self.latent_dim).to(device)
        for time_idx in range(T - 1, -1, -1):
            concat_ayz = torch.cat(
                [
                    prepend_dims_to_tensor(
                        reverse_hidden_seq[time_idx:time_idx + 1], Nz, Ny),
                    _prev_z,
                    prepend_dims_to_tensor(sample_y[:, time_idx:time_idx + 1]),
                ],
                dim=-1)  # (Nz, Ny, 1, BS, (latent+cluster+hidden_dim))
            _dist_z = self.z_module(concat_ayz)  # (Nz, Ny, 1, BS, z_dim)
            _dists_z.append(_dist_z)
            sample_z_ = _dist_z.rsample()  # (Nz, Ny, 1, BS, z_dim)
            _samples_z.append(sample_z_)
            _prev_z = sample_z_

        dist_z_seq = self._concat_z_distributions(_dists_z, dim=2)
        sample_z_seq = torch.cat(_samples_z, dim=2)  # (Nz, Ny, 1, BS, z_dim)

        assert_shape("Z samples", sample_z_seq.shape,
                     (Nz, Ny, T, BS, self.latent_dim))
        return dist_z_seq, sample_z_seq
예제 #4
0
    def compute_metrics(self, predicted_times: torch.Tensor,
                        event_times: torch.Tensor, marker_logits: torch.Tensor,
                        marker: torch.Tensor, mask: torch.Tensor):
        """
        Input:
            :param predicted_times: Nz x Ny x T x BS x 1  or T x BS x 1 ;
            pt_j : predicted time of event j
            :param event_times : T x BS x 1
            t_j : actual time of event j
            :param marker_logits : * x T x BS x marker_dim;
            :param marker : T x BS
            :param mask: T x BS
        Output:
            metric_dict: dict
        """
        metric_dict = {}
        has_latent_samples = predicted_times.dim() == 5
        if not has_latent_samples:
            predicted_times = prepend_dims_to_tensor(predicted_times, 1, 1)
            marker_logits = prepend_dims_to_tensor(marker_logits, 1, 1)
        Nz, Ny = predicted_times.shape[:2]
        num_total_events = Nz * Ny * mask[1:, :].sum().detach().cpu().numpy()
        with torch.no_grad():
            # note: both t[0] and pt[0] are 0 by convention
            time_squared_error = torch.pow(
                (predicted_times - event_times) * mask.unsqueeze(-1),
                2.)  # (Nz, Ny, T, BS, 1)
            metric_dict['time_mse'] = time_squared_error.sum().cpu().numpy()
            # note: because 0th timestamps are zero always, reducing count to ensure accuracy stays unbiased
            metric_dict['time_mse_count'] = num_total_events

            if self.marker_type == "categorical":
                predicted = torch.argmax(marker_logits,
                                         dim=-1)  # (Nz, Ny, T, BS)
                correct_predictions = (predicted
                                       == marker) * mask  # (Nz, Ny, T, BS)
                correct_predictions = correct_predictions[:, :,
                                                          1:]  # Keep only the predictions from 2nd timestamp

                # count how many correct predictions we made
                metric_dict['marker_acc'] = correct_predictions.sum().cpu(
                ).numpy()
                # count how many predictions we made
                metric_dict['marker_acc_count'] = num_total_events
            else:
                raise NotImplementedError

        return metric_dict
예제 #5
0
    def _get_predicted_times(self, phi_hzy_seq, event_times, BS):
        """
        Predicts the next event for each time step using Numerical Integration

        :param phi_hzy_seq: Nz x Ny x T x BS x phi_dim
        [phi_0, .., phi_{T-1}] = f[(h_prime, z_0, y), (h_0, z_1, y), ..., (h_{T-2}, z_{T-1}, y)]
        :param event_times: T x BS x 1
        [t_0, ..., t_{T-1}]
        :param BS: int, Batch size
        :return: predicted_times = Nz x Ny x T x BS x 1
        [t'_0, ..., t'_{T-1}]
        """
        Nz, Ny = phi_hzy_seq.shape[:2]
        expanded_event_times = prepend_dims_to_tensor(event_times[:-1], Nz,
                                                      Ny)  # Nz, Ny, T-1, BS, 1
        # import pdb; pdb.set_trace()
        with torch.no_grad():
            # The pairs of (h,t) should be (h_j, t_j) where h_j has information about t_j
            # don't need the predicted timestamp after the last observed event at time T
            next_event_times = self.marked_point_process_net.get_next_event_times(
                phi_hzy_seq[:, :, 1:], expanded_event_times)  # (*, T-1, BS, 1)
            predicted_times = torch.cat(
                [torch.zeros(Nz, Ny, 1, BS, 1).to(device), next_event_times],
                dim=2)
        return predicted_times
예제 #6
0
    def _compute_log_likelihoods(
            self, phi_hzy_seq: torch.Tensor, time_intervals: torch.Tensor,
            marker_seq: torch.Tensor,
            predicted_marker_dist: torch.distributions.Distribution, T: int,
            BS: int):
        """
        :param phi_hzy_seq: Nz x Ny x T x BS x shared_output_dims[-1]
        [phi_0, ..., phi_{T-1}]
        :param time_intervals: T x BS x 1
        [i_0, ..., i_{T-1}]
        :param marker_seq: T x BS
        [x_0, ..., x_{T-1}]
        :param predicted_marker_dist: distribution with logits of shape Nz x Ny x T x BS x x_dim
        [fx_0, ..., fx_{T-1}]

        :return marker_log_likelihood: T x BS
        :return time_log_likelihood: T x BS


        Computing Time Log Likelihood:
        Relationship between ll and (h,t,i)
        `logf*(t_{j+1}) = g(h_j, z_{j+1}, y, (t_{j+1}-t_j)) = g(phi_j, i_{j+1})`
        which implies
        (first timestep) `logf*(t0) = g(h', z_0, y, t_0) := g(phi_0, i_0)`
        (last timestep) `logf*(t_{T-1}) = g(phi_{T-2}, i_{T-1})`

        Boundary Conditions:
        logf*(t0) is the likelihood of the first event but it's not based on past
        information, so we don't use it in likelihood computation (forward function)

        Finally, Expectation is taken wrt posterior samples by taking the mean along first two dimensions

        """
        Nz, Ny = phi_hzy_seq.shape[:2]
        expanded_time_intervals = prepend_dims_to_tensor(
            time_intervals, Nz, Ny)
        time_log_likelihood = self.marked_point_process_net.get_point_log_density(
            phi_hzy_seq, expanded_time_intervals)
        marker_log_likelihood = self.marked_point_process_net.get_marker_log_prob(
            marker_seq, predicted_marker_dist)
        time_log_likelihood_expectation = time_log_likelihood.mean((0, 1))
        marker_log_likelihood_expectation = marker_log_likelihood.mean((0, 1))
        assert_shape("time log likelihood",
                     time_log_likelihood_expectation.shape, (
                         T,
                         BS,
                     ))
        assert_shape("marker log likelihood",
                     marker_log_likelihood_expectation.shape, (T, BS))
        return marker_log_likelihood_expectation, time_log_likelihood_expectation
예제 #7
0
    def _get_encoded_z(self, augmented_hidden_seq, data, sample_y, Nz):
        """Encoder for z - continuous latent state
        Computes z_t <- f(y, h_{t-1}, data_t)

        :param Nz: int
        :param augmented_hidden_seq: T x BS x h_dim
        [h_prime, h_0, ..., h_{T-1}]
        :param data: T x BS x embed_dim
        [data_0, ..., data_{T-1}]
        :param sample_y: Ny x T x BS x y_dim
        This is embedded data sequence which has information for both marker and time.
        :return sample_z : Nz x Ny x T x BS x latent_dim
        Sample [z_0, ..., z_{T-1}] for each batch
        """
        Ny = self.num_posterior_samples
        concat_hxty = torch.cat([
            prepend_dims_to_tensor(augmented_hidden_seq[:-1], Ny),
            prepend_dims_to_tensor(data, Ny), sample_y
        ],
                                dim=-1)  # (N, T, BS, ..)
        dist_z = self.z_module(concat_hxty)  # (Ny,T,BS,latent_dim)
        sample_z = dist_z.rsample((Nz, ))  # (Nz, Ny, T,BS,latent_dim)
        return dist_z, sample_z
예제 #8
0
    def _compute_log_likelihoods(
            self, phi_hzy_seq: torch.Tensor, time_intervals: torch.Tensor,
            marker_seq: torch.Tensor,
            predicted_marker_dist: torch.distributions.Distribution, T: int,
            BS: int):
        """
        Computes Marker and Time Log Likelihoods of the observed data:

        # Relationship between ll and (h,t,i):
        TODO: Modify this
        `logf*(t_{j+1}) = g(h_j, z_{j+1}, y, (t_{j+1}-t_j)) = g(phi_j, i_{j+1})`
        which implies
        (first timestep) `logf*(t0) = g(h', z_0, y, t_0) := g(phi_0, i_0)`
        (last timestep) `logf*(t_{T-1}) = g(phi_{T-2}, i_{T-1})`

        For this decoder, the dependence is (for t_j := data_j):
        t_j <- f(h_{j-1}, y, z_j, h_j)
        where [h_{j-1}, y, z_j] are 'past' and [h_j] is 'future' as per the model.

        Therefore the general expression for time LL is:
         `logf*(t_j) <- f( g1(h_{j-1}, y, z_j), g2(h_j) ) `

        ## Boundary conditions:
        (first timestep) `logf*(t0) = f( g1( h_prime, y, z_0 ), g2(h_0) ) `
        (second timestep) `logf*(t_1) = f( g1( h_0, y, z_1 ), g2(h_1) ) `
        (last timestep) `logf*(t_1) = f( g1( h_{T-1}, y, z_{T-1} ), g2(h_{T-1}) ) `

        ## Filtering and Smoothing:
        When we're in smoothing mode, the likelihood computation does have information about the future, which ...
        can be deactivated by setting self.is_smoothing = False. In that case, there would be no function `g2` ...
        in the above computation.

        Finally, Expectation is taken wrt posterior samples by taking the mean along first two dimensions

        :param phi_hzy_seq: Nz x Ny x T x BS x shared_output_dims[-1]
        [phi_0, ..., phi_{T-1}]
        :param time_intervals: T x BS x 1
        [i_0, ..., i_{T-1}]
        :param marker_seq: T x BS
        [x_0, ..., x_{T-1}]
        :param predicted_marker_dist: distribution with logits of shape Nz x Ny x T x BS x x_dim
        [fx_0, ..., fx_{T-1}]

        :return marker_log_likelihood: T x BS
        :return time_log_likelihood: T x BS
        """
        Nz, Ny = phi_hzy_seq.shape[:2]
        expanded_time_intervals = prepend_dims_to_tensor(
            time_intervals, Nz, Ny)
        time_log_likelihood = self.marked_point_process_net.get_point_log_density(
            phi_hzy_seq, expanded_time_intervals)
        marker_log_likelihood = self.marked_point_process_net.get_marker_log_prob(
            marker_seq, predicted_marker_dist)
        time_log_likelihood_expectation = time_log_likelihood.mean((0, 1))
        marker_log_likelihood_expectation = marker_log_likelihood.mean((0, 1))
        assert_shape("time log likelihood",
                     time_log_likelihood_expectation.shape, (
                         T,
                         BS,
                     ))
        assert_shape("marker log likelihood",
                     marker_log_likelihood_expectation.shape, (T, BS))
        return marker_log_likelihood_expectation, time_log_likelihood_expectation