def __call__(self, input, iteration=100, **kwargs): """ Args: input (n_channels, n_bins, n_frames) Returns: output (n_channels, n_bins, n_frames) """ self.input = input self._reset(**kwargs) loss = self.compute_negative_loglikelihood() self.loss.append(loss) for idx in range(iteration): self.update_once() loss = self.compute_negative_loglikelihood() self.loss.append(loss) if self.callback is not None: self.callback(self) self.solve_permutation() reference_id = self.reference_id X, W = input, self.demix_filter Y = self.separate(X, demix_filter=W) scale = projection_back(Y, reference=X[reference_id]) output = Y * scale[..., np.newaxis] # (n_sources, n_bins, n_frames) self.estimation = output return output
def update_once(self): eps = self.eps self.update_source_model() self.update_space_model() X, W = self.input, self.demix_filter Y = self.separate(X, demix_filter=W) self.estimation = Y if self.normalize: if self.normalize == 'power': P = np.abs(Y)**2 aux = np.sqrt(P.mean(axis=(1, 2))) # (n_sources,) aux[aux < eps] = eps # Normalize W = W / aux[np.newaxis, :, np.newaxis] Y = Y / aux[:, np.newaxis, np.newaxis] if self.partitioning: Z = self.latent T = self.base Zaux = Z / (aux[:, np.newaxis]**2) # (n_sources, n_bases) Zauxsum = np.sum(Zaux, axis=0) # (n_bases,) T = T * Zauxsum # (n_bins, n_bases) Z = Zaux / Zauxsum # (n_sources, n_bases) self.latent = Z self.base = T else: T = self.base T = T / (aux[:, np.newaxis, np.newaxis]**2) self.base = T elif self.normalize == 'projection-back': if self.partitioning: raise NotImplementedError( "Not support 'projection-back' based normalization for partitioninig function. Choose 'power' based normalization." ) scale = projection_back(Y, reference=X[self.reference_id]) Y = Y * scale[..., np.newaxis] # (n_sources, n_bins, n_frames) X = X.transpose(1, 0, 2) # (n_bins, n_channels, n_frames) X_Hermite = X.transpose( 0, 2, 1).conj() # (n_bins, n_frames, n_channels) W = Y.transpose(1, 0, 2) @ X_Hermite @ np.linalg.inv( X @ X_Hermite) # (n_bins, n_sources, n_channels) else: raise ValueError( "Not support normalization based on {}. Choose 'power' or 'projection-back'" .format(self.normalize)) self.demix_filter = W self.estimation = Y