Beispiel #1
0
    def __call__(self, Obs, acitivity_freq, debug=False):

        initialization = np.asarray(acitivity_freq, dtype=np.float64)
        initialization = np.where(initialization == 0, 1e-10, initialization)
        initialization = initialization / np.sum(
            initialization, keepdims=True, axis=0)
        initialization = np.repeat(initialization[None, ...], 513, axis=0)

        source_active_mask = np.asarray(acitivity_freq, dtype=np.bool)
        source_active_mask = np.repeat(source_active_mask[None, ...],
                                       513,
                                       axis=0)

        cacGMM = CACGMMTrainer()

        if debug:
            learned = []
        all_affiliations = []
        F = Obs.shape[-1]
        T = Obs.T.shape[-2]
        for f in range(F):
            if self.verbose:
                if f % 50 == 0:
                    print(f'{f}/{F}')

            # T: Consider end of signal.
            # This should not be nessesary, but activity is for inear and not for
            # array.
            cur = cacGMM.fit(
                y=Obs.T[f, ...],
                initialization=initialization[f, ..., :T],
                iterations=self.iterations,
                source_activity_mask=source_active_mask[f, ..., :T],
                # return_affiliation=True,
            )

            if self.iterations_post != 0:
                if self.iterations_post != 1:
                    cur = cacGMM.fit(
                        y=Obs.T[f, ...],
                        initialization=cur,
                        iterations=self.iterations_post - 1,
                    )
                affiliation = cur.predict(Obs.T[f, ...], )
            else:
                affiliation = cur.predict(
                    Obs.T[f, ...],
                    source_activity_mask=source_active_mask[f, ..., :T])

            if debug:
                learned.append(cur)
            all_affiliations.append(affiliation)

        posterior = np.array(all_affiliations).transpose(1, 2, 0)

        if debug:
            learned = stack_parameters(learned)
            self.locals = locals()

        return posterior
Beispiel #2
0
    def test_cacgmm(self):
        samples = 10000
        weight = np.array([0.3, 0.7])
        covariance = np.array([
            [[10, 1 + 1j, 1 + 1j], [1 - 1j, 5, 1], [1 - 1j, 1, 2]],
            [[2, 0, 0], [0, 3, 0], [0, 0, 2]],
        ])
        covariance /= np.trace(covariance, axis1=-2, axis2=-1)[..., None, None]
        x = sample_cacgmm(samples, weight, covariance)

        model = CACGMMTrainer().fit(
            x,
            num_classes=2,
            covariance_norm='trace',
        )

        # Permutation invariant testing
        best_permutation = solve_permutation(model.cacg.covariance[:, :, :],
                                             covariance)

        assert_allclose(model.cacg.covariance[best_permutation, :],
                        covariance,
                        atol=0.1)

        model.weight = model.weight[best_permutation, ]
        assert model.weight[0] < model.weight[1], model.weight
        assert_allclose(model.weight, weight[:, None], atol=0.15)
Beispiel #3
0
    def test_cacgmm_independent_dimension(self):
        samples = 10000
        weight = np.array([0.3, 0.7])
        covariance = np.array(
            [
                [[10, 1 + 1j, 1 + 1j], [1 - 1j, 5, 1], [1 - 1j, 1, 2]],
                [[2, 0, 0], [0, 3, 0], [0, 0, 2]],
            ]
        )
        covariance /= np.trace(covariance, axis1=-2, axis2=-1)[..., None, None]
        x = sample_cacgmm(samples, weight, covariance)

        model = CACGMMTrainer().fit(
            x[None, ...],
            num_classes=2,
            covariance_norm='trace',
        )

        # Permutation invariant testing
        best_permutation = solve_permutation(model.cacg.covariance[0, :, :, :], covariance)

        assert_allclose(
            np.squeeze(model.weight, axis=(0, -1))[best_permutation,],
            weight,
            atol=0.15
        )
        assert_allclose(
            model.cacg.covariance[0, best_permutation, :], covariance, atol=0.1
        )

        model = CACGMMTrainer().fit(
            np.array([x, x]),
            num_classes=2,
            covariance_norm='trace',
        )

        for f in range(model.weight.shape[0]):
            # Permutation invariant testing
            best_permutation = solve_permutation(model.cacg.covariance[f, :, :, :], covariance)

            assert_allclose(
                np.squeeze(model.weight, axis=-1)[f, best_permutation,],
                weight,
                atol=0.15,
            )
            assert_allclose(
                model.cacg.covariance[f, best_permutation, :],
                covariance,
                atol=0.1,
            )
Beispiel #4
0
    def initialize_spatial(self, spatial_feats):
        '''Initializes spatial model by fitting cacGMM to spatial features.

        Args:
            spatial_feats (np.array): Spatial features of shape (C,T,F)
        '''
        self.cacGMM = CACGMMTrainer()
        spatial_model = self.cacGMM.fit(
            spatial_feats.transpose(2, 1, 0),
            num_classes=self.n_classes,
            iterations=1,
            weight_constant_axis=self.wca,
            inline_permutation_aligner=self.inline_permutation_aligner,
        )
        self.spatial = spatial_model
def get_mask_from_cacgmm(
    ex,  # (D, T, F)
    weight_constant_axis=-1,
):  # (K, T, F)
    """

    Args:
        observation:

    Returns:

    >>> from nara_wpe.utils import stft
    >>> y = get_dataset('cv_dev93')[0]['audio_data']['observation']
    >>> Y = stft(y, size=512, shift=128)
    >>> get_mask_from_cacgmm(Y).shape
    (3, 813, 257)

    """
    Observation = ex['audio_data']['Observation']
    Observation = rearrange(Observation, 'd t f -> f t d')

    trainer = CACGMMTrainer()

    initialization: 'F, K, T' = initializer.iid.dirichlet_uniform(
        Observation,
        num_classes=3,
        permutation_free=False,
    )

    pa = DHTVPermutationAlignment.from_stft_size(512)

    affiliation = trainer.fit_predict(
        Observation,
        initialization=initialization,
        weight_constant_axis=weight_constant_axis,
        inline_permutation_aligner=pa if weight_constant_axis != -1 else None)

    mapping = pa.calculate_mapping(rearrange(affiliation, 'f k t ->k f t'))

    affiliation = rearrange(
        pa.apply_mapping(rearrange(affiliation, 'f k t ->k f t'), mapping),
        'k f t -> k t f')

    return affiliation
Beispiel #6
0
    def test_cacgmm_sad_init(self):
        samples = 10000
        weight = np.array([0.3, 0.7])
        num_classes, = weight.shape
        covariance = np.array([
            [[10, 1 + 1j, 1 + 1j], [1 - 1j, 5, 1], [1 - 1j, 1, 2]],
            [[2, 0, 0], [0, 3, 0], [0, 0, 2]],
        ])
        covariance /= np.trace(covariance, axis1=-2, axis2=-1)[..., None, None]
        x, labels = sample_cacgmm(samples,
                                  weight,
                                  covariance,
                                  return_label=True)

        affiliations = labels_to_one_hot(labels, num_classes, axis=-2)

        # test initialization
        model = CACGMMTrainer().fit(
            x,
            initialization=affiliations,
            covariance_norm='trace',
        )

        # test initialization with independent
        model = CACGMMTrainer().fit(
            np.array([x]),
            initialization=np.array([affiliations]),
            covariance_norm='trace',
        )

        # test initialization with independent and broadcasted initialization
        model = CACGMMTrainer().fit(
            np.array([x, x, x]),
            initialization=np.array([affiliations]),
            covariance_norm='trace',
        )

        # test initialization with independent
        model = CACGMMTrainer().fit(
            np.array([x, x]),
            initialization=np.array([affiliations, affiliations]),
            covariance_norm='trace',
        )
Beispiel #7
0
    def test_cacgmm(self):
        samples = 10000
        weight = np.array([0.3, 0.7])
        num_classes = weight.shape[0]
        labels = np.random.choice(range(num_classes),
                                  size=(samples, ),
                                  p=weight)
        covariance = np.array([
            [[10, 1 + 1j, 1 + 1j], [1 - 1j, 5, 1], [1 - 1j, 1, 2]],
            [[2, 0, 0], [0, 3, 0], [0, 0, 2]],
        ])
        covariance /= np.trace(covariance, axis1=-2, axis2=-1)[..., None, None]
        dimension = covariance.shape[-1]
        x = np.zeros((samples, dimension), dtype=np.complex128)

        for l in range(num_classes):
            cacg = ComplexAngularCentralGaussian.from_covariance(
                covariance=covariance[l, :, :])
            x[labels == l, :] = cacg.sample(size=(np.sum(labels == l), ))

        model = CACGMMTrainer().fit(
            x,
            num_classes=2,
            covariance_norm='trace',
        )

        # Permutation invariant testing
        permutations = list(itertools.permutations(range(2)))
        best_permutation, best_cost = None, np.inf
        for p in permutations:
            cost = np.linalg.norm(model.cacg.covariance[p, :] - covariance)
            if cost < best_cost:
                best_permutation, best_cost = p, cost

        assert_allclose(model.cacg.covariance[best_permutation, :],
                        covariance,
                        atol=0.1)

        model.weight = model.weight[best_permutation, ]
        assert model.weight[0] < model.weight[1], model.weight
        assert_allclose(model.weight, weight, atol=0.15)

        model = CACGMMTrainer().fit(x,
                                    num_classes=2,
                                    covariance_norm='trace',
                                    dirichlet_prior_concentration=np.inf)
        assert_allclose(model.weight, [0.5, 0.5])

        model = CACGMMTrainer().fit(
            x,
            num_classes=2,
            covariance_norm='trace',
            dirichlet_prior_concentration=1_000_000_000)
        assert_allclose(model.weight, [0.5, 0.5])
Beispiel #8
0
class Dolphin:
    '''Base class for different implementations of Dolphin method.

    Implements functionality independent of chosen spectral model:
    * spatial model initialization and update
    * permuations needed during the inference
    * update of q(D) - this however depends on implementation of spectral model
    * overall inference algorithm

    We use the following notation in documentation:
        C: number of channels 
        S: number of speakers
        N: number of classes (N = S+1)
        T: number of frames
        F: number of frequency bins
    '''
    def initialize_spatial(self, spatial_feats):
        '''Initializes spatial model by fitting cacGMM to spatial features.

        Args:
            spatial_feats (np.array): Spatial features of shape (C,T,F)
        '''
        self.cacGMM = CACGMMTrainer()
        spatial_model = self.cacGMM.fit(
            spatial_feats.transpose(2, 1, 0),
            num_classes=self.n_classes,
            iterations=1,
            weight_constant_axis=self.wca,
            inline_permutation_aligner=self.inline_permutation_aligner,
        )
        self.spatial = spatial_model

    def permute_global(self, logspec0, spatial_feats):
        '''Flips the order of classes in q(D) according to spectral and spatial likelihoods.

        The goal is to align components (coresponding to speakers and noise) 
        between spectral and spatial model. We choose to keep the order
        in spectral model and change the order in spatial model. Due to the way
        this function is called (after initializing spatial models and 
        before update of q(Z) and q(D)), this is esentially the same as changing 
        the order in q(D).

        Args:
            logspec0 (torch.Tensor): Spectral features of shape (T,F)
            spatial_feats (np.array): Spatial features of shape (C,T,F)
        '''
        spatial_norm = normalize_observation(spatial_feats.transpose(2, 1, 0))
        spat_log_p, _ = self.spatial.cacg._log_pdf(spatial_norm[:, None, :, :])
        spat_log_p = spat_log_p.transpose(1, 2, 0)

        spectral_log_p = self._spectral_log_p(logspec0)

        perm = self._find_best_permutation(spectral_log_p.detach().numpy(),
                                           spat_log_p,
                                           idx_constant=(1, 2))
        self.qd = self.qd[perm[0, 0]]

    def _find_best_permutation(self, spectral, spatial, idx_constant):
        '''Finds the best permutation of classes in spectral and spatial model.

        Args:
            spectral (np.array): Conditional log-likelihood of spectral model 
                                 (as in Eq.19) in [1], shape (N,T,F)
            spatial (np.array): Conditional log-likelihood of spatial model 
                                (as in Eq.19) in [1], shape (N,T,F)
            idx_constant (tuple or int) indices of axis which have constant permutation
                Examples:
                    idx_constant = (1,2) 
                        -> for all time frames and frequency bins 
                           the permutation is constant 
                           (finding 1 global permutation)
                    idx_constant = 1
                        -> for all time frames the permutation is constant
                           (finding permutation for each frequency)

        Returns:
            permutations (dict): mapping tuples of time and frequency indices
                                 to the best permutation. For constant indices, 
                                 the map contains only index 0.
                Examples:
                    permutations = {(0,0) : [2, 0, 1]}
                        -> one global permutation (idx_constant = (1,2))
                        -> spectral comp. 0 corresponds to spatial comp. 2
                        -> spectral comp. 1 corresponds to spatial comp. 0
                        -> spectral comp. 2 corresponds to spatial comp. 1

        [1] Integration of variational autoencoder and spatial clustering 
            for adaptive multi-channel neural speech separation; 
            K. Zmolikova, M. Delcroix, L. Burget, T. Nakatani, J. Cernocky
        '''
        if isinstance(idx_constant, int):
            idx_constant = (idx_constant, )
        idx_constant = tuple([i + 1 for i in idx_constant])

        perm_scores = logsumexp(spectral[:, None, :, :] +
                                spatial[None, :, :, :],
                                axis=idx_constant)
        perm_scores = np.expand_dims(perm_scores, idx_constant)

        permutations = {}
        for i1, i2 in np.ndindex(perm_scores.shape[-2:]):
            idx_perm = Munkres().compute(
                make_cost_matrix(perm_scores[:, :, i1, i2]))
            idx_perm.sort(key=lambda x: x[0])
            permutations[i1, i2] = [i[1] for i in idx_perm]

        return permutations

    def update_qD(self, logspec0, spatial_feats):
        '''Updates q(D) approximate posterior.

        Follows Eq. (19) from [1]. Additionally permutes the spatial components
        to best fit the spectral components at each frequency.
        
        [1] Integration of variational autoencoder and spatial clustering 
            for adaptive multi-channel neural speech separation; 
            K. Zmolikova, M. Delcroix, L. Burget, T. Nakatani, J. Cernocky

        Args:
            logspec0 (torch.Tensor): Spectral features of shape (T,F)
            spatial_feats (np.array): Spatial features of shape (C,T,F)
        '''
        _, n_t, n_f = spatial_feats.shape
        spatial_norm = normalize_observation(spatial_feats.transpose(2, 1, 0))
        spat_log_p, _ = self.spatial.cacg._log_pdf(spatial_norm[:, None, :, :])
        spat_log_p = spat_log_p.transpose(1, 2, 0)

        spectral_log_p = self._spectral_log_p(logspec0)

        perm = self._find_best_permutation(spectral_log_p.detach().numpy(),
                                           spat_log_p,
                                           idx_constant=1)
        for f in range(n_f):
            spat_log_p[..., f] = spat_log_p[perm[0, f], :, f]

        ln_qds_unnorm = []
        for i in range(self.n_classes):
            qd1_unnorm = torch.tensor(spat_log_p[i]).to(self.device).float()
            qd1_unnorm += spectral_log_p[i]
            ln_qds_unnorm.append(qd1_unnorm)

        ln_qds_unnorm = torch.stack(ln_qds_unnorm)
        # subtract max for stability of exp (the constant does not matter)
        ln_qds_unnorm = (ln_qds_unnorm -
                         ln_qds_unnorm.max(dim=0, keepdim=True)[0])
        qds_unnorm = ln_qds_unnorm.exp()

        qd = qds_unnorm / (qds_unnorm.sum(axis=0) + 1e-6)
        qd = qd.clamp(1e-6, 1 - 1e-6)
        self.qd = qd.detach()

    def update_spatial(self, spatial_feats):
        '''Updates parameters of spatial model.

        Args:
            spatial_feats (np.array): Spatial features of shape (C,T,F)
        '''
        spatial_norm = normalize_observation(spatial_feats.transpose(2, 1, 0))
        _, quadratic_form = self.spatial.predict(spatial_feats.transpose(
            2, 1, 0),
                                                 return_quadratic_form=True)

        spec_norm = spatial_norm
        spatial_model = self.cacGMM._m_step(
            spec_norm,
            quadratic_form,
            self.qd.permute(2, 0, 1).detach().cpu().numpy(),
            saliency=None,
            hermitize=True,
            covariance_norm='eigenvalue',
            eigenvalue_floor=1e-10,
            weight_constant_axis=self.wca)
        self.spatial = spatial_model

    def run(self, logspec0, spatial_feats):
        '''Runs the overall inference algorithm.

        Args:
            logspec0 (torch.Tensor): Spectral features of shape (T,F)
            spatial_feats (np.array): Spatial features of shape (C,T,F)
        '''
        self.initialize_q(logspec0)
        self.initialize_spatial(spatial_feats)
        self.permute_global(logspec0, spatial_feats)

        for i in tqdm(range(self.n_iterations)):
            self.update_qZ(logspec0)
            self.update_qD(logspec0, spatial_feats)
            self.update_spatial(spatial_feats)

    def initialize_q(self, logspec0):
        raise NotImplementedError

    def update_qZ(self, logspec0):
        raise NotImplementedError

    def _spectral_log_p(self, logspec0):
        raise NotImplementedError