def _index_tensor(x: Tensor, item: Any) -> Tensor: """""" squeeze: List[int] = [] if not isinstance(item, tuple): item = (item, ) saw_ellipsis = False for i, item_i in enumerate(item): axis = i - len(item) if saw_ellipsis else i if isinstance(item_i, int): if item_i != -1: x = x.slice_axis(axis=axis, begin=item_i, end=item_i + 1) else: x = x.slice_axis(axis=axis, begin=-1, end=None) squeeze.append(axis) elif item_i == slice(None): continue elif item_i == Ellipsis: saw_ellipsis = True continue elif isinstance(item_i, slice): assert item_i.step is None start = item_i.start if item_i.start is not None else 0 x = x.slice_axis(axis=axis, begin=start, end=item_i.stop) else: raise RuntimeError(f"invalid indexing item: {item}") if len(squeeze): x = x.squeeze(axis=tuple(squeeze)) return x
def get_gp_params( self, F, past_target: Tensor, past_time_feat: Tensor, feat_static_cat: Tensor, ) -> Tuple: """ This function returns the GP hyper-parameters for the model. Parameters ---------- F A module that can either refer to the Symbol API or the NDArray API in MXNet. past_target Training time series values of shape (batch_size, context_length). past_time_feat Training features of shape (batch_size, context_length, num_features). feat_static_cat Time series indices of shape (batch_size, 1). Returns ------- Tuple Tuple of kernel hyper-parameters of length num_hyperparams. Each is a Tensor of shape (batch_size, 1, 1). Model noise sigma. Tensor of shape (batch_size, 1, 1). """ output = self.embedding(feat_static_cat.squeeze() ) # Shape (batch_size, num_hyperparams + 1) kernel_args = self.proj_kernel_args(output) sigma = softplus( F, output.slice_axis( # sigma is the last hyper-parameter axis=1, begin=self.num_hyperparams, end=self.num_hyperparams + 1, ), ) if self.params_scaling: scalings = self.kernel_output.gp_params_scaling( F, past_target, past_time_feat) sigma = F.broadcast_mul(sigma, scalings[self.num_hyperparams]) kernel_args = (F.broadcast_mul(kernel_arg, scaling) for kernel_arg, scaling in zip( kernel_args, scalings[0:self.num_hyperparams])) min_value = 1e-5 max_value = 1e8 kernel_args = (kernel_arg.clip(min_value, max_value).expand_dims(axis=2) for kernel_arg in kernel_args) sigma = sigma.clip(min_value, max_value).expand_dims(axis=2) return kernel_args, sigma
def emission_coeff( self, feature: Tensor # (batch_size, time_length, 1) ) -> Tensor: F = getF(feature) _emission_coeff = F.ones(shape=(1, 1, 1, self.latent_dim())) # get the right shape: (batch_size, time_length, obs_dim, latent_dim) zeros = _broadcast_param( feature.squeeze(axis=2), axes=[2, 3], sizes=[1, self.latent_dim()], ) return _emission_coeff.broadcast_like(zeros)
def transition_coeff( self, feature: Tensor # (batch_size, time_length, 1) ) -> Tensor: F = getF(feature) _transition_coeff = (F.eye( self.latent_dim()).expand_dims(axis=0).expand_dims(axis=0)) # get the right shape: (batch_size, time_length, latent_dim, latent_dim) zeros = _broadcast_param( feature.squeeze(axis=2), axes=[2, 3], sizes=[self.latent_dim(), self.latent_dim()], ) return _transition_coeff.broadcast_like(zeros)