def fit(self, y, saliency=None, min_concentration=1e-10, max_concentration=500) -> VonMisesFisher: """ Fits a von Mises Fisher distribution. Broadcasting (for sources) has to be done outside this function. Args: y: Observations with shape (..., N, D) saliency: Either None or weights with shape (..., N) min_concentration: max_concentration: """ assert np.isrealobj(y), y.dtype y = y / np.maximum(np.linalg.norm(y, axis=-1, keepdims=True), np.finfo(y.dtype).tiny) if saliency is not None: assert is_broadcast_compatible(y.shape[:-1], saliency.shape), ( y.shape, saliency.shape, ) return self._fit( y, saliency=saliency, min_concentration=min_concentration, max_concentration=max_concentration, )
def _fit( self, y, saliency, quadratic_form, hermitize=True, covariance_norm='eigenvalue', eigenvalue_floor=1e-10, ) -> ComplexAngularCentralGaussian: """Single step of the fit function. In general, needs iterations. Note: y shape is (..., D, N) and not (..., N, D) like in fit Args: y: Assumed to have unit length. Shape (..., D, N), e.g. (1, D, N) for mixture models saliency: Shape (..., N), e.g. (K, N) for mixture models quadratic_form: (..., N), e.g. (K, N) for mixture models hermitize: eigenvalue_floor: Returns: """ assert np.iscomplexobj(y), y.dtype assert is_broadcast_compatible( y.shape[:-2], quadratic_form.shape[:-1] ), (y.shape, quadratic_form.shape) D = y.shape[-2] *independent, N = quadratic_form.shape if saliency is None: saliency = 1 denominator = N else: assert y.ndim == saliency.ndim + 1, (y.shape, saliency.ndim) denominator = np.einsum('...n->...', saliency)[..., None, None] covariance = D * np.einsum( '...n,...dn,...Dn->...dD', (saliency / quadratic_form), y, y.conj(), optimize='greedy', ) covariance /= denominator assert covariance.shape == (*independent, D, D), covariance.shape if hermitize: covariance = force_hermitian(covariance) return ComplexAngularCentralGaussian.from_covariance( covariance, eigenvalue_floor=eigenvalue_floor, covariance_norm=covariance_norm, )
def fit(self, y, saliency=None, covariance_type="full"): """ Args: y: Shape (..., N, D) saliency: Importance weighting for each observation, shape (..., N) covariance_type: Either 'full', 'diagonal', or 'spherical' Returns: """ assert np.isrealobj(y), y.dtype if saliency is not None: assert is_broadcast_compatible(y.shape[:-1], saliency.shape), (y.shape, saliency.shape) return self._fit(y, saliency=saliency, covariance_type=covariance_type)
def fit(self, y, saliency=None) -> ComplexWatson: assert np.iscomplexobj(y), y.dtype assert y.shape[-1] > 1 y = y / np.maximum(np.linalg.norm(y, axis=-1, keepdims=True), np.finfo(y.dtype).tiny) if saliency is not None: assert is_broadcast_compatible(y.shape[:-1], saliency.shape), ( y.shape, saliency.shape, ) if self.dimension is None: self.dimension = y.shape[-1] else: assert self.dimension == y.shape[-1], ( "You initialized the trainer with a different dimension than " "you are using to fit a model. Use a new trainer, when you " "change the dimension.") return self._fit(y, saliency=saliency)
def _log_pdf(self, y): """Gets used by. e.g. the cACGMM. TODO: quadratic_form might be useful by itself Note: y shape is (..., D, N) and not (..., N, D) like in log_pdf Args: y: Normalized observations with shape (..., D, N). Returns: Affiliations with shape (..., K, N) and quadratic format with the same shape. """ *independent, D, T = y.shape assert is_broadcast_compatible( [*independent, D, D], self.covariance_eigenvectors.shape ), (y.shape, self.covariance_eigenvectors.shape) quadratic_form = np.maximum( np.abs( np.einsum( # '...dt,...kde,...ke,...kge,...gt->...kt', '...dt,...de,...e,...ge,...gt->...t', y.conj(), self.covariance_eigenvectors, 1 / self.covariance_eigenvalues, self.covariance_eigenvectors.conj(), y, optimize='optimal', ) ), np.finfo(y.dtype).tiny, ) log_pdf = -D * np.log(quadratic_form) log_pdf -= self.log_determinant[..., None] return log_pdf, quadratic_form
def check_false(self, *shapes): self.assertFalse(is_broadcast_compatible(*shapes), msg=shapes)
def check_true(self, *shapes): self.assertTrue(is_broadcast_compatible(*shapes), msg=shapes)