コード例 #1
0
ファイル: gcacgmm.py プロジェクト: shaigue/DL_project_2019
    def _predict(
        self,
        observation,
        embedding,
        affiliation_eps=0.,
        inline_permutation_alignment=False,
    ):
        """

        Args:
            observation: Shape (F, T, D)
            embedding: Shape (F, T, E)

        Returns:
            affiliation: Shape (F, K, T)
            quadratic_form: Shape (F, K, T)

        """
        F, T, D = observation.shape
        _, _, E = embedding.shape

        observation_ = observation[..., None, :, :]
        cacg_log_pdf, quadratic_form = self.cacg._log_pdf(
            np.swapaxes(observation_, -1, -2))

        embedding_ = np.reshape(embedding, (1, F * T, E))
        gaussian_log_pdf = self.gaussian.log_pdf(embedding_)
        num_classes = gaussian_log_pdf.shape[0]
        gaussian_log_pdf = np.transpose(
            np.reshape(gaussian_log_pdf, (num_classes, F, T)), (1, 0, 2))

        if inline_permutation_alignment:
            affiliation \
                = log_pdf_to_affiliation_for_integration_models_with_inline_pa(
                    weight=unsqueeze(self.weight, self.weight_constant_axis),
                    spatial_log_pdf=self.spatial_weight * cacg_log_pdf,
                    spectral_log_pdf=self.spectral_weight * gaussian_log_pdf,
                    affiliation_eps=affiliation_eps,
                )
        else:
            affiliation = log_pdf_to_affiliation(
                weight=unsqueeze(self.weight, self.weight_constant_axis),
                log_pdf=(self.spatial_weight * cacg_log_pdf +
                         self.spectral_weight * gaussian_log_pdf),
                affiliation_eps=affiliation_eps,
            )

        return affiliation, quadratic_form
コード例 #2
0
    def _predict(self, observation, embedding):
        F, T, D = observation.shape
        _, _, E = embedding.shape

        observation_ = observation[..., None, :, :]
        cacg_log_pdf, quadratic_form = self.cacg._log_pdf(
            np.swapaxes(observation_, -1, -2))

        embedding_ = np.reshape(embedding, (1, F * T, E))
        vmf_log_pdf = self.vmf.log_pdf(embedding_)
        num_classes = vmf_log_pdf.shape[0]
        vmf_log_pdf = np.transpose(
            np.reshape(vmf_log_pdf, (num_classes, F, T)), (1, 0, 2))

        affiliation = (
            unsqueeze(np.log(self.weight), self.weight_constant_axis) +
            self.spatial_weight * cacg_log_pdf +
            self.spectral_weight * vmf_log_pdf)
        affiliation -= np.max(affiliation, axis=-2, keepdims=True)
        np.exp(affiliation, out=affiliation)
        denominator = np.maximum(
            np.einsum("...kt->...t", affiliation)[..., None, :],
            np.finfo(affiliation.dtype).tiny,
        )
        affiliation /= denominator
        return affiliation, quadratic_form