예제 #1
0
    def __call__(self, input, iteration=100, **kwargs):
        """
        Args:
            input (n_channels, n_bins, n_frames)
        Returns:
            output (n_channels, n_bins, n_frames)
        """
        self.input = input

        self._reset(**kwargs)

        loss = self.compute_negative_loglikelihood()
        self.loss.append(loss)

        for idx in range(iteration):
            self.update_once()
            loss = self.compute_negative_loglikelihood()
            self.loss.append(loss)

            if self.callback is not None:
                self.callback(self)

        self.solve_permutation()

        reference_id = self.reference_id
        X, W = input, self.demix_filter
        Y = self.separate(X, demix_filter=W)

        scale = projection_back(Y, reference=X[reference_id])
        output = Y * scale[..., np.newaxis]  # (n_sources, n_bins, n_frames)
        self.estimation = output

        return output
예제 #2
0
    def update_once(self):
        eps = self.eps

        self.update_source_model()
        self.update_space_model()

        X, W = self.input, self.demix_filter
        Y = self.separate(X, demix_filter=W)
        self.estimation = Y

        if self.normalize:
            if self.normalize == 'power':
                P = np.abs(Y)**2
                aux = np.sqrt(P.mean(axis=(1, 2)))  # (n_sources,)
                aux[aux < eps] = eps

                # Normalize
                W = W / aux[np.newaxis, :, np.newaxis]
                Y = Y / aux[:, np.newaxis, np.newaxis]

                if self.partitioning:
                    Z = self.latent
                    T = self.base
                    Zaux = Z / (aux[:, np.newaxis]**2)  # (n_sources, n_bases)
                    Zauxsum = np.sum(Zaux, axis=0)  # (n_bases,)
                    T = T * Zauxsum  # (n_bins, n_bases)
                    Z = Zaux / Zauxsum  # (n_sources, n_bases)
                    self.latent = Z
                    self.base = T
                else:
                    T = self.base
                    T = T / (aux[:, np.newaxis, np.newaxis]**2)
                    self.base = T
            elif self.normalize == 'projection-back':
                if self.partitioning:
                    raise NotImplementedError(
                        "Not support 'projection-back' based normalization for partitioninig function. Choose 'power' based normalization."
                    )
                scale = projection_back(Y, reference=X[self.reference_id])
                Y = Y * scale[..., np.newaxis]  # (n_sources, n_bins, n_frames)
                X = X.transpose(1, 0, 2)  # (n_bins, n_channels, n_frames)
                X_Hermite = X.transpose(
                    0, 2, 1).conj()  # (n_bins, n_frames, n_channels)
                W = Y.transpose(1, 0, 2) @ X_Hermite @ np.linalg.inv(
                    X @ X_Hermite)  # (n_bins, n_sources, n_channels)
            else:
                raise ValueError(
                    "Not support normalization based on {}. Choose 'power' or 'projection-back'"
                    .format(self.normalize))

            self.demix_filter = W
            self.estimation = Y