def _predict( self, observation, embedding, affiliation_eps=0., inline_permutation_alignment=False, ): """ Args: observation: Shape (F, T, D) embedding: Shape (F, T, E) Returns: affiliation: Shape (F, K, T) quadratic_form: Shape (F, K, T) """ F, T, D = observation.shape _, _, E = embedding.shape observation_ = observation[..., None, :, :] cacg_log_pdf, quadratic_form = self.cacg._log_pdf( np.swapaxes(observation_, -1, -2)) embedding_ = np.reshape(embedding, (1, F * T, E)) gaussian_log_pdf = self.gaussian.log_pdf(embedding_) num_classes = gaussian_log_pdf.shape[0] gaussian_log_pdf = np.transpose( np.reshape(gaussian_log_pdf, (num_classes, F, T)), (1, 0, 2)) if inline_permutation_alignment: affiliation \ = log_pdf_to_affiliation_for_integration_models_with_inline_pa( weight=unsqueeze(self.weight, self.weight_constant_axis), spatial_log_pdf=self.spatial_weight * cacg_log_pdf, spectral_log_pdf=self.spectral_weight * gaussian_log_pdf, affiliation_eps=affiliation_eps, ) else: affiliation = log_pdf_to_affiliation( weight=unsqueeze(self.weight, self.weight_constant_axis), log_pdf=(self.spatial_weight * cacg_log_pdf + self.spectral_weight * gaussian_log_pdf), affiliation_eps=affiliation_eps, ) return affiliation, quadratic_form
def _predict(self, observation, embedding): F, T, D = observation.shape _, _, E = embedding.shape observation_ = observation[..., None, :, :] cacg_log_pdf, quadratic_form = self.cacg._log_pdf( np.swapaxes(observation_, -1, -2)) embedding_ = np.reshape(embedding, (1, F * T, E)) vmf_log_pdf = self.vmf.log_pdf(embedding_) num_classes = vmf_log_pdf.shape[0] vmf_log_pdf = np.transpose( np.reshape(vmf_log_pdf, (num_classes, F, T)), (1, 0, 2)) affiliation = ( unsqueeze(np.log(self.weight), self.weight_constant_axis) + self.spatial_weight * cacg_log_pdf + self.spectral_weight * vmf_log_pdf) affiliation -= np.max(affiliation, axis=-2, keepdims=True) np.exp(affiliation, out=affiliation) denominator = np.maximum( np.einsum("...kt->...t", affiliation)[..., None, :], np.finfo(affiliation.dtype).tiny, ) affiliation /= denominator return affiliation, quadratic_form