def fit(cls, F, samples: Tensor, rank: int = 0) -> Distribution: """ Returns an instance of `LowrankMultivariateGaussian` after fitting parameters to the given data. Only the special case of `rank` = 0 is supported at the moment. Parameters ---------- F samples Tensor of shape (num_samples, batch_size, seq_len, target_dim) rank Rank of W Returns ------- Distribution instance of type `LowrankMultivariateGaussian`. """ # TODO: Implement it for the general case: `rank` > 0 assert rank == 0, "Fit is not only implemented for the case rank = 0!" # Compute mean and variances mu = samples.mean(axis=0) var = F.square(samples - samples.mean(axis=0)).mean(axis=0) return cls(dim=samples.shape[-1], rank=rank, mu=mu, D=var)
def weighted_average( F, x: Tensor, weights: Optional[Tensor] = None, axis: Optional[int] = None ) -> Tensor: """ Computes the weighted average of a given tensor across a given axis, masking values associated with weight zero, meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. Parameters ---------- F The function space to use. x Input tensor, of which the average must be computed. weights Weights tensor, of the same shape as `x`. axis The axis along which to average `x` Returns ------- Tensor: The tensor with values averaged along the specified `axis`. """ if weights is not None: weighted_tensor = F.where( condition=weights, x=x * weights, y=F.zeros_like(x) ) sum_weights = F.maximum(1.0, weights.sum(axis=axis)) return weighted_tensor.sum(axis=axis) / sum_weights else: return x.mean(axis=axis)
def weighted_average( F, x: Tensor, weights: Optional[Tensor] = None, axis: Optional[int] = None, include_zeros_in_denominator=False, ) -> Tensor: """ Computes the weighted average of a given tensor across a given axis, masking values associated with weight zero, meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. Parameters ---------- F The function space to use. x Input tensor, of which the average must be computed. weights Weights tensor, of the same shape as `x`. axis The axis along which to average `x` include_zeros_in_denominator Include zeros in the denominator. Can be useful for sparse time series because the loss can be dominated by few observed examples. Returns ------- Tensor: The tensor with values averaged along the specified `axis`. """ if weights is not None: weighted_tensor = F.where( condition=weights, x=x * weights, y=F.zeros_like(x) ) if include_zeros_in_denominator: sum_weights = F.maximum(1.0, F.ones_like(weights).sum(axis=axis)) else: sum_weights = F.maximum(1.0, weights.sum(axis=axis)) return weighted_tensor.sum(axis=axis) / sum_weights else: return x.mean(axis=axis)