예제 #1
0
    def hybrid_forward(
        self,
        F,
        past_target: Tensor,
        past_valid_length: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        """
        Draw forward samples from the model. At each step, we sample an
        inter-event time and feed it into the RNN to obtain the parameters for
        the next distribution over the inter-event time.

        Parameters
        ----------
        F
            MXNet backend.
        past_target
            Tensor with past observations.
            Shape: (batch_size, context_length, target_dim). Has to comply
            with :code:`self.context_interval_length`.
        past_valid_length
            The `valid_length` or number of valid entries in the past_target
            Tensor. Shape: (batch_size,)

        Returns
        -------
        sampled_target: Tensor
            Predicted inter-event times and marks.
            Shape: (samples, batch_size, max_prediction_length, target_dim).
        sampled_valid_length: Tensor
            The number of valid entries in the time axis of each sample.
            Shape (samples, batch_size)
        """
        # Variable-length generation (while t < t_max) is a potential problem
        if F is mx.sym:
            raise ValueError(
                "The DeepTPP model currently doesn't support hybridization.")

        assert (
            past_target.shape[-1] == 2
        ), "TPP data should have two target_dim, interarrival times and marks"

        batch_size = past_target.shape[0]

        # condition the prediction network on the past events
        past_ia_times, past_marks = F.split(past_target,
                                            num_outputs=2,
                                            axis=-1)
        past_valid_length = past_valid_length.reshape(-1).astype(
            past_ia_times.dtype)

        if self.apply_log_to_rnn_inputs:
            past_ia_times_input = past_ia_times.clip(1e-8, np.inf).log()
        else:
            past_ia_times_input = past_ia_times
        rnn_input = F.concat(
            past_ia_times_input,
            self.embedding(past_marks.squeeze(axis=-1)),
            dim=-1,
        )
        rnn_output = self.rnn(rnn_input)  # (N, T, H)
        rnn_init_state = F.zeros([batch_size, 1, self.rnn_hidden_size])

        past_history_emb = F.concat(rnn_init_state, rnn_output,
                                    dim=1)  # (N, T + 1, H)

        # Select the history embedding after the last event in the past
        indices = F.stack(F.arange(batch_size), past_valid_length)
        history_emb = F.gather_nd(past_history_emb, indices)  # (N, H)

        num_total_samples = self.num_parallel_samples * batch_size
        history_emb = history_emb.expand_dims(0).repeat(
            self.num_parallel_samples, axis=0)  # (S, N, H)
        history_emb = history_emb.reshape(
            [num_total_samples, self.rnn_hidden_size])  # (S * N, H)

        sampled_ia_times_list: List[nd.NDArray] = []
        sampled_marks_list: List[nd.NDArray] = []
        arrival_times = F.zeros([num_total_samples])

        # Time from the last observed event until the past interval end
        past_time_elapsed = past_ia_times.squeeze(axis=-1).sum(-1)
        past_time_remaining = self.interval_length - past_time_elapsed  # (N)
        past_time_remaining_repeat = (
            past_time_remaining.expand_dims(0).repeat(
                self.num_parallel_samples,
                axis=0).reshape([num_total_samples]))  # (S * N)

        first_step = True
        while F.sum(arrival_times < self.prediction_interval_length) > 0:
            # Sample the next inter-arrival time
            time_distr_args = self.time_distr_args_proj(history_emb)
            time_distr = self.time_distr_output.distribution(
                time_distr_args,
                scale=self.output_scale,
            )
            if first_step:
                # Time from the last event until the next event
                next_ia_times = time_distr.sample(
                    lower_bound=past_time_remaining_repeat)
                # Time from the prediction interval start until the next event
                clipped_ia_times = next_ia_times - past_time_remaining_repeat
                sampled_ia_times_list.append(clipped_ia_times)
                arrival_times = arrival_times + clipped_ia_times
                first_step = False
            else:
                next_ia_times = time_distr.sample()
                sampled_ia_times_list.append(next_ia_times)
                arrival_times = arrival_times + next_ia_times

            # Sample the next marks
            if self.num_marks > 1:
                mark_distr_args = self.mark_distr_args_proj(history_emb)
                next_marks = self.mark_distr_output.distribution(
                    mark_distr_args).sample()
            else:
                next_marks = F.zeros([num_total_samples])

            sampled_marks_list.append(next_marks)

            # Pass the generated ia_times & marks into the RNN to obtain
            # the next history embedding
            if self.apply_log_to_rnn_inputs:
                next_ia_times_input = next_ia_times.clip(1e-8, np.inf).log()
            else:
                next_ia_times_input = next_ia_times
            rnn_input = F.concat(
                next_ia_times_input.expand_dims(-1),
                self.embedding(next_marks),
                dim=-1,
            ).expand_dims(1)

            history_emb = self.rnn(rnn_input).squeeze(axis=1)  # (S * N, C)

        sampled_ia_times = F.stack(*sampled_ia_times_list, axis=-1)
        sampled_marks = F.stack(*sampled_marks_list, axis=-1).astype("float32")

        sampled_valid_length = F.sum(
            F.cumsum(sampled_ia_times, axis=1) <
            self.prediction_interval_length,
            axis=-1,
        )

        def _mask(x, sequence_length):
            return F.SequenceMask(
                data=x,
                sequence_length=sequence_length,
                axis=1,
                use_sequence_length=True,
            )

        sampled_ia_times = _mask(sampled_ia_times, sampled_valid_length)
        sampled_marks = _mask(sampled_marks, sampled_valid_length)

        sampled_ia_times = sampled_ia_times.reshape(
            [self.num_parallel_samples, batch_size, -1])
        sampled_marks = sampled_marks.reshape(
            [self.num_parallel_samples, batch_size, -1])
        sampled_valid_length = sampled_valid_length.reshape(
            [self.num_parallel_samples, batch_size])
        sampled_target = F.stack(sampled_ia_times, sampled_marks, axis=-1)
        return sampled_target, sampled_valid_length
예제 #2
0
    def hybrid_forward(
        self,
        F,
        target: Tensor,
        valid_length: Tensor,
        **kwargs,
    ) -> Tensor:
        """
        Computes the negative log likelihood loss for the given sequences.

        As the model is trained on past (resp. future) or context
        (resp. prediction) "intervals" as opposed to fixed-length "sequences",
        the number of data points available varies across observations. To
        account for this, data is made available to the training network as a
        "ragged" tensor. The number of valid entries in each sequence is
        provided in a separate variable, :code:`xxx_valid_length`.

        Parameters
        ----------
        F
            MXNet backend.
        target
            Tensor with observations.
            Shape: (batch_size, past_max_sequence_length, target_dim).
        valid_length
            The `valid_length` or number of valid entries in the past_target
            Tensor. Shape: (batch_size,)

        Returns
        -------
        Tensor
            Loss tensor. Shape: (batch_size,).
        """
        if F is mx.sym:
            raise ValueError(
                "The DeepTPP model currently doesn't support hybridization.")

        batch_size = target.shape[0]
        # IMPORTANT: We add an additional zero at the end of each sequence
        # It will be used to store the time until the end of the interval
        target = F.concat(target, F.zeros((batch_size, 1, 2)), dim=1)
        # (N, T + 1, 2)

        ia_times, marks = F.split(target, num_outputs=2,
                                  axis=-1)  # inter-arrival times, marks
        marks = marks.squeeze(axis=-1)  # (N, T + 1)

        valid_length = valid_length.reshape(-1).astype(
            ia_times.dtype)  # make sure shape is (batch_size,)

        if self.apply_log_to_rnn_inputs:
            ia_times_input = ia_times.clip(1e-8, np.inf).log()
        else:
            ia_times_input = ia_times
        rnn_input = F.concat(ia_times_input, self.embedding(marks), dim=-1)
        rnn_output = self.rnn(rnn_input)  # (N, T + 1, H)

        rnn_init_state = F.zeros([batch_size, 1, self.rnn_hidden_size])
        history_emb = F.slice_axis(
            F.concat(rnn_init_state, rnn_output, dim=1),
            axis=1,
            begin=0,
            end=-1,
        )  # (N, T + 1, H)

        # Augment ia_times by adding the time remaining until interval_length
        # Afterwards, each row of ia_times will sum up to interval_length
        ia_times = ia_times.squeeze(axis=-1)  # (N, T + 1)
        time_remaining = self.interval_length - ia_times.sum(-1)  # (N)
        # Equivalent to ia_times[F.arange(N), valid_length] = time_remaining
        indices = F.stack(F.arange(batch_size), valid_length)
        time_remaining_tensor = F.scatter_nd(time_remaining, indices,
                                             ia_times.shape)
        ia_times_aug = ia_times + time_remaining_tensor

        time_distr_args = self.time_distr_args_proj(history_emb)
        time_distr = self.time_distr_output.distribution(
            time_distr_args, scale=self.output_scale)
        log_intensity = time_distr.log_intensity(ia_times_aug)  # (N, T + 1)
        log_survival = time_distr.log_survival(ia_times_aug)  # (N, T + 1)

        if self.num_marks > 1:
            mark_distr_args = self.mark_distr_args_proj(history_emb)
            mark_distr = self.mark_distr_output.distribution(mark_distr_args)
            log_intensity = log_intensity + mark_distr.log_prob(marks)

        def _mask(x, sequence_length):
            return F.SequenceMask(
                data=x,
                sequence_length=sequence_length,
                axis=1,
                use_sequence_length=True,
            )

        log_likelihood = F.sum(
            (_mask(log_intensity, valid_length) +
             _mask(log_survival, valid_length + 1)),
            axis=-1,
        )  # (N)

        return -log_likelihood