Exemple #1
0
    def sampling_decoder(
        self,
        F,
        past_target: Tensor,
        interval_alpha_bias: Optional[Tensor] = None,
        size_alpha_bias: Optional[Tensor] = None,
        time_remaining: Optional[Tensor] = None,
    ) -> Tensor:
        """
        Sample a trajectory of interval-size pairs from the model given past target.

        Parameters
        ----------
        past_target
            (batch_size, history_length, 2) shaped tensor containing the past time series
            in an interval-size representation.
        time_remaining
            (batch_size, 1) shaped tensor containing the number of time steps that were zero
            before the start of the forecast horizon

        Returns
        -------
        samples
            samples of shape (batch_size, num_parallel_samples, 2, sequence_length)
        """
        if time_remaining is None:
            time_remaining = F.broadcast_like(
                F.zeros((1,)), past_target, lhs_axes=(0,), rhs_axes=(0,)
            )

        # if no valid points in the past_target, override time remaining with zero
        batch_sum = F.sum(past_target, axis=1).sum(-1).expand_dims(-1)
        time_remaining = F.where(
            batch_sum > 0, time_remaining, F.zeros_like(time_remaining)
        )

        repeated_past_target = past_target.repeat(
            repeats=self.num_parallel_samples, axis=0
        )  # (N * samples, T, 2)
        repeated_time_remaining = time_remaining.repeat(
            repeats=self.num_parallel_samples, axis=0
        )  # (N * samples, 1)

        interval_samples = []
        size_samples = []

        for t in range(self.prediction_length):
            cond_mean = self.mu_map(repeated_past_target, shift=False)
            cond_mean_last = F.slice_axis(
                cond_mean, axis=1, begin=-1, end=None
            ).squeeze(1)

            dist_interval, dist_size = self.distribution(
                cond_mean_last, interval_alpha_bias, size_alpha_bias
            )

            # initial samples for interval should be taken conditionally
            # we achieve this via (leaky) rejection sampling
            if t == 0:
                interval_sample = F.zeros_like(repeated_time_remaining)
                for j in range(50):
                    interval_sample = F.where(
                        interval_sample > repeated_time_remaining,
                        interval_sample,
                        dist_interval.sample() + 1,
                    )
                interval_sample = F.where(
                    interval_sample == 0,
                    repeated_time_remaining + 1,
                    interval_sample,
                )
            else:
                interval_sample = dist_interval.sample() + 1
            size_sample = dist_size.sample() + 1

            interval_samples.append(interval_sample)
            size_samples.append(size_sample)

            repeated_past_target = F.concat(
                repeated_past_target,
                F.concat(interval_sample, size_sample, dim=-1).expand_dims(1),
                dim=1,
            )

        interval_samples[0] = interval_samples[0] - repeated_time_remaining
        samples = F.concat(
            *[
                F.concat(x, y, dim=-1).expand_dims(1)
                for x, y in zip(interval_samples, size_samples)
            ],
            dim=1,
        )

        return samples.reshape(
            shape=(-1, self.num_parallel_samples) + samples.shape[1:]
        ).swapaxes(2, 3)
Exemple #2
0
    def sampling_decoder(
        self,
        F,
        static_feat: Tensor,
        past_target: Tensor,
        time_feat: Tensor,
        scale: Tensor,
        enc_out: Tensor,
    ) -> Tensor:
        """
        Computes sample paths by unrolling the LSTM starting with a initial
        input and state.

        Parameters
        ----------
        static_feat : Tensor
            static features. Shape: (batch_size, num_static_features).
        past_target : Tensor
            target history. Shape: (batch_size, history_length, 1).
        time_feat : Tensor
            time features. Shape:
            (batch_size, prediction_length, num_time_features).
        scale : Tensor
            tensor containing the scale of each element in the batch.
            Shape: (batch_size, ).
        enc_out: Tensor
            output of the encoder. Shape: (batch_size, num_cells)

        Returns
        --------
        sample_paths : Tensor
            a tensor containing sampled paths.
            Shape: (batch_size, num_sample_paths, prediction_length).
        """

        # blows-up the dimension of each tensor to batch_size *
        # self.num_parallel_samples for increasing parallelism
        repeated_past_target = past_target.repeat(
            repeats=self.num_parallel_samples, axis=0)
        repeated_time_feat = time_feat.repeat(
            repeats=self.num_parallel_samples, axis=0)
        repeated_static_feat = static_feat.repeat(
            repeats=self.num_parallel_samples, axis=0).expand_dims(axis=1)
        repeated_enc_out = enc_out.repeat(repeats=self.num_parallel_samples,
                                          axis=0).expand_dims(axis=1)
        repeated_scale = scale.repeat(repeats=self.num_parallel_samples,
                                      axis=0)

        future_samples = []

        # for each future time-units we draw new samples for this time-unit and
        # update the state
        for k in range(self.prediction_length):
            lags = self.get_lagged_subsequences(
                F=F,
                sequence=repeated_past_target,
                sequence_length=self.history_length + k,
                indices=self.shifted_lags,
                subsequences_length=1,
            )

            # (batch_size * num_samples, 1, *target_shape, num_lags)
            lags_scaled = F.broadcast_div(lags,
                                          repeated_scale.expand_dims(axis=-1))

            # from (batch_size * num_samples, 1, *target_shape, num_lags)
            # to (batch_size * num_samples, 1, prod(target_shape) * num_lags)
            input_lags = F.reshape(
                data=lags_scaled,
                shape=(-1, 1, prod(self.target_shape) * len(self.lags_seq)),
            )

            # (batch_size * num_samples, 1, prod(target_shape) * num_lags +
            # num_time_features + num_static_features)
            dec_input = F.concat(
                input_lags,
                repeated_time_feat.slice_axis(axis=1, begin=k, end=k + 1),
                repeated_static_feat,
                dim=-1,
            )

            dec_output = self.decoder(dec_input, repeated_enc_out, None, False)

            distr_args = self.proj_dist_args(dec_output)

            # compute likelihood of target given the predicted parameters
            distr = self.distr_output.distribution(distr_args,
                                                   scale=repeated_scale)

            # (batch_size * num_samples, 1, *target_shape)
            new_samples = distr.sample()

            # (batch_size * num_samples, seq_len, *target_shape)
            repeated_past_target = F.concat(repeated_past_target,
                                            new_samples,
                                            dim=1)
            future_samples.append(new_samples)

        # reset cache of the decoder
        self.decoder.cache_reset()

        # (batch_size * num_samples, prediction_length, *target_shape)
        samples = F.concat(*future_samples, dim=1)

        # (batch_size, num_samples, *target_shape, prediction_length)
        return samples.reshape(shape=((-1, self.num_parallel_samples) +
                                      self.target_shape +
                                      (self.prediction_length, )))
Exemple #3
0
    def sampling_decoder(
        self,
        F,
        static_feat: Tensor,
        past_target: Tensor,
        time_feat: Tensor,
        scale: Tensor,
        begin_states: List,
    ) -> Tensor:
        """
        Computes sample paths by unrolling the LSTM starting with a initial
        input and state.

        Parameters
        ----------
        static_feat : Tensor
            static features. Shape: (batch_size, num_static_features).
        past_target : Tensor
            target history. Shape: (batch_size, history_length).
        time_feat : Tensor
            time features. Shape: (batch_size, prediction_length, num_time_features).
        scale : Tensor
            tensor containing the scale of each element in the batch. Shape: (batch_size, 1, 1).
        begin_states : List
            list of initial states for the LSTM layers.
            the shape of each tensor of the list should be (batch_size, num_cells)
        Returns
        --------
        Tensor
            A tensor containing sampled paths.
            Shape: (batch_size, num_sample_paths, prediction_length).
        """

        # blows-up the dimension of each tensor to batch_size * self.num_parallel_samples for increasing parallelism
        repeated_past_target = past_target.repeat(
            repeats=self.num_parallel_samples, axis=0)
        repeated_time_feat = time_feat.repeat(
            repeats=self.num_parallel_samples, axis=0)
        repeated_static_feat = static_feat.repeat(
            repeats=self.num_parallel_samples, axis=0).expand_dims(axis=1)
        repeated_scale = scale.repeat(repeats=self.num_parallel_samples,
                                      axis=0)
        repeated_states = [
            s.repeat(repeats=self.num_parallel_samples, axis=0)
            for s in begin_states
        ]

        future_samples = []

        # for each future time-units we draw new samples for this time-unit and update the state
        for k in range(self.prediction_length):
            # (batch_size * num_samples, 1, *target_shape, num_lags)
            lags = self.get_lagged_subsequences(
                F=F,
                sequence=repeated_past_target,
                sequence_length=self.history_length + k,
                indices=self.shifted_lags,
                subsequences_length=1,
            )

            # (batch_size * num_samples, 1, *target_shape, num_lags)
            lags_scaled = F.broadcast_div(lags,
                                          repeated_scale.expand_dims(axis=-1))

            # from (batch_size * num_samples, 1, *target_shape, num_lags)
            # to (batch_size * num_samples, 1, prod(target_shape) * num_lags)
            input_lags = F.reshape(
                data=lags_scaled,
                shape=(-1, 1, prod(self.target_shape) * len(self.lags_seq)),
            )

            # (batch_size * num_samples, 1, prod(target_shape) * num_lags + num_time_features + num_static_features)
            decoder_input = F.concat(
                input_lags,
                repeated_time_feat.slice_axis(axis=1, begin=k, end=k + 1),
                # observed_values.expand_dims(axis=1),
                repeated_static_feat,
                dim=-1,
            )

            # output shape: (batch_size * num_samples, 1, num_cells)
            # state shape: (batch_size * num_samples, num_cells)
            rnn_outputs, repeated_states = self.rnn.unroll(
                inputs=decoder_input,
                length=1,
                begin_state=repeated_states,
                layout="NTC",
                merge_outputs=True,
            )

            distr_args = self.proj_distr_args(rnn_outputs)

            # compute likelihood of target given the predicted parameters
            distr = self.distr_output.distribution(distr_args,
                                                   scale=repeated_scale)

            # (batch_size * num_samples, 1, *target_shape)
            new_samples = distr.sample(dtype=self.dtype)

            # (batch_size * num_samples, seq_len, *target_shape)
            repeated_past_target = F.concat(repeated_past_target,
                                            new_samples,
                                            dim=1)

            future_samples.append(new_samples)

        # (batch_size * num_samples, prediction_length, *target_shape)
        samples = F.concat(*future_samples, dim=1)

        # (batch_size, num_samples, prediction_length, *target_shape)
        return samples.reshape(shape=((-1, self.num_parallel_samples) +
                                      (self.prediction_length, ) +
                                      self.target_shape))