def sampling_decoder( self, F, past_target: Tensor, interval_alpha_bias: Optional[Tensor] = None, size_alpha_bias: Optional[Tensor] = None, time_remaining: Optional[Tensor] = None, ) -> Tensor: """ Sample a trajectory of interval-size pairs from the model given past target. Parameters ---------- past_target (batch_size, history_length, 2) shaped tensor containing the past time series in an interval-size representation. time_remaining (batch_size, 1) shaped tensor containing the number of time steps that were zero before the start of the forecast horizon Returns ------- samples samples of shape (batch_size, num_parallel_samples, 2, sequence_length) """ if time_remaining is None: time_remaining = F.broadcast_like( F.zeros((1,)), past_target, lhs_axes=(0,), rhs_axes=(0,) ) # if no valid points in the past_target, override time remaining with zero batch_sum = F.sum(past_target, axis=1).sum(-1).expand_dims(-1) time_remaining = F.where( batch_sum > 0, time_remaining, F.zeros_like(time_remaining) ) repeated_past_target = past_target.repeat( repeats=self.num_parallel_samples, axis=0 ) # (N * samples, T, 2) repeated_time_remaining = time_remaining.repeat( repeats=self.num_parallel_samples, axis=0 ) # (N * samples, 1) interval_samples = [] size_samples = [] for t in range(self.prediction_length): cond_mean = self.mu_map(repeated_past_target, shift=False) cond_mean_last = F.slice_axis( cond_mean, axis=1, begin=-1, end=None ).squeeze(1) dist_interval, dist_size = self.distribution( cond_mean_last, interval_alpha_bias, size_alpha_bias ) # initial samples for interval should be taken conditionally # we achieve this via (leaky) rejection sampling if t == 0: interval_sample = F.zeros_like(repeated_time_remaining) for j in range(50): interval_sample = F.where( interval_sample > repeated_time_remaining, interval_sample, dist_interval.sample() + 1, ) interval_sample = F.where( interval_sample == 0, repeated_time_remaining + 1, interval_sample, ) else: interval_sample = dist_interval.sample() + 1 size_sample = dist_size.sample() + 1 interval_samples.append(interval_sample) size_samples.append(size_sample) repeated_past_target = F.concat( repeated_past_target, F.concat(interval_sample, size_sample, dim=-1).expand_dims(1), dim=1, ) interval_samples[0] = interval_samples[0] - repeated_time_remaining samples = F.concat( *[ F.concat(x, y, dim=-1).expand_dims(1) for x, y in zip(interval_samples, size_samples) ], dim=1, ) return samples.reshape( shape=(-1, self.num_parallel_samples) + samples.shape[1:] ).swapaxes(2, 3)
def sampling_decoder( self, F, static_feat: Tensor, past_target: Tensor, time_feat: Tensor, scale: Tensor, enc_out: Tensor, ) -> Tensor: """ Computes sample paths by unrolling the LSTM starting with a initial input and state. Parameters ---------- static_feat : Tensor static features. Shape: (batch_size, num_static_features). past_target : Tensor target history. Shape: (batch_size, history_length, 1). time_feat : Tensor time features. Shape: (batch_size, prediction_length, num_time_features). scale : Tensor tensor containing the scale of each element in the batch. Shape: (batch_size, ). enc_out: Tensor output of the encoder. Shape: (batch_size, num_cells) Returns -------- sample_paths : Tensor a tensor containing sampled paths. Shape: (batch_size, num_sample_paths, prediction_length). """ # blows-up the dimension of each tensor to batch_size * # self.num_parallel_samples for increasing parallelism repeated_past_target = past_target.repeat( repeats=self.num_parallel_samples, axis=0) repeated_time_feat = time_feat.repeat( repeats=self.num_parallel_samples, axis=0) repeated_static_feat = static_feat.repeat( repeats=self.num_parallel_samples, axis=0).expand_dims(axis=1) repeated_enc_out = enc_out.repeat(repeats=self.num_parallel_samples, axis=0).expand_dims(axis=1) repeated_scale = scale.repeat(repeats=self.num_parallel_samples, axis=0) future_samples = [] # for each future time-units we draw new samples for this time-unit and # update the state for k in range(self.prediction_length): lags = self.get_lagged_subsequences( F=F, sequence=repeated_past_target, sequence_length=self.history_length + k, indices=self.shifted_lags, subsequences_length=1, ) # (batch_size * num_samples, 1, *target_shape, num_lags) lags_scaled = F.broadcast_div(lags, repeated_scale.expand_dims(axis=-1)) # from (batch_size * num_samples, 1, *target_shape, num_lags) # to (batch_size * num_samples, 1, prod(target_shape) * num_lags) input_lags = F.reshape( data=lags_scaled, shape=(-1, 1, prod(self.target_shape) * len(self.lags_seq)), ) # (batch_size * num_samples, 1, prod(target_shape) * num_lags + # num_time_features + num_static_features) dec_input = F.concat( input_lags, repeated_time_feat.slice_axis(axis=1, begin=k, end=k + 1), repeated_static_feat, dim=-1, ) dec_output = self.decoder(dec_input, repeated_enc_out, None, False) distr_args = self.proj_dist_args(dec_output) # compute likelihood of target given the predicted parameters distr = self.distr_output.distribution(distr_args, scale=repeated_scale) # (batch_size * num_samples, 1, *target_shape) new_samples = distr.sample() # (batch_size * num_samples, seq_len, *target_shape) repeated_past_target = F.concat(repeated_past_target, new_samples, dim=1) future_samples.append(new_samples) # reset cache of the decoder self.decoder.cache_reset() # (batch_size * num_samples, prediction_length, *target_shape) samples = F.concat(*future_samples, dim=1) # (batch_size, num_samples, *target_shape, prediction_length) return samples.reshape(shape=((-1, self.num_parallel_samples) + self.target_shape + (self.prediction_length, )))
def sampling_decoder( self, F, static_feat: Tensor, past_target: Tensor, time_feat: Tensor, scale: Tensor, begin_states: List, ) -> Tensor: """ Computes sample paths by unrolling the LSTM starting with a initial input and state. Parameters ---------- static_feat : Tensor static features. Shape: (batch_size, num_static_features). past_target : Tensor target history. Shape: (batch_size, history_length). time_feat : Tensor time features. Shape: (batch_size, prediction_length, num_time_features). scale : Tensor tensor containing the scale of each element in the batch. Shape: (batch_size, 1, 1). begin_states : List list of initial states for the LSTM layers. the shape of each tensor of the list should be (batch_size, num_cells) Returns -------- Tensor A tensor containing sampled paths. Shape: (batch_size, num_sample_paths, prediction_length). """ # blows-up the dimension of each tensor to batch_size * self.num_parallel_samples for increasing parallelism repeated_past_target = past_target.repeat( repeats=self.num_parallel_samples, axis=0) repeated_time_feat = time_feat.repeat( repeats=self.num_parallel_samples, axis=0) repeated_static_feat = static_feat.repeat( repeats=self.num_parallel_samples, axis=0).expand_dims(axis=1) repeated_scale = scale.repeat(repeats=self.num_parallel_samples, axis=0) repeated_states = [ s.repeat(repeats=self.num_parallel_samples, axis=0) for s in begin_states ] future_samples = [] # for each future time-units we draw new samples for this time-unit and update the state for k in range(self.prediction_length): # (batch_size * num_samples, 1, *target_shape, num_lags) lags = self.get_lagged_subsequences( F=F, sequence=repeated_past_target, sequence_length=self.history_length + k, indices=self.shifted_lags, subsequences_length=1, ) # (batch_size * num_samples, 1, *target_shape, num_lags) lags_scaled = F.broadcast_div(lags, repeated_scale.expand_dims(axis=-1)) # from (batch_size * num_samples, 1, *target_shape, num_lags) # to (batch_size * num_samples, 1, prod(target_shape) * num_lags) input_lags = F.reshape( data=lags_scaled, shape=(-1, 1, prod(self.target_shape) * len(self.lags_seq)), ) # (batch_size * num_samples, 1, prod(target_shape) * num_lags + num_time_features + num_static_features) decoder_input = F.concat( input_lags, repeated_time_feat.slice_axis(axis=1, begin=k, end=k + 1), # observed_values.expand_dims(axis=1), repeated_static_feat, dim=-1, ) # output shape: (batch_size * num_samples, 1, num_cells) # state shape: (batch_size * num_samples, num_cells) rnn_outputs, repeated_states = self.rnn.unroll( inputs=decoder_input, length=1, begin_state=repeated_states, layout="NTC", merge_outputs=True, ) distr_args = self.proj_distr_args(rnn_outputs) # compute likelihood of target given the predicted parameters distr = self.distr_output.distribution(distr_args, scale=repeated_scale) # (batch_size * num_samples, 1, *target_shape) new_samples = distr.sample(dtype=self.dtype) # (batch_size * num_samples, seq_len, *target_shape) repeated_past_target = F.concat(repeated_past_target, new_samples, dim=1) future_samples.append(new_samples) # (batch_size * num_samples, prediction_length, *target_shape) samples = F.concat(*future_samples, dim=1) # (batch_size, num_samples, prediction_length, *target_shape) return samples.reshape(shape=((-1, self.num_parallel_samples) + (self.prediction_length, ) + self.target_shape))