Beispiel #1
0
 def test_identity_mapping(self):
     estimated_mapping = FPA.get_identity_permutation((2, 3), axis=1)
     reference_mapping = np.asarray(
         [
             [0, 1, 2],
             [0, 1, 2]
         ]
     )
     np.testing.assert_equal(estimated_mapping, reference_mapping)
Beispiel #2
0
    def test_toy_example_embedding_based_alignment(self):
        data_dir = Path(__file__).parent
        embedding = np.load(data_dir / 'embedding.npy')
        mask = np.load(data_dir / 'ideal_binary_mask.npy')
        _, E, _ = embedding.shape
        T, K, F = mask.shape

        random_permutation = FPA.random_permutation((F, K))
        permuted_mask = FPA.apply_mapping_to_mask(mask, random_permutation)

        estimated_mask = permuted_mask
        features = FPA.extract_features(estimated_mask, embedding)
        estimated_mask, mapping = FPA.align(estimated_mask, features)

        # Allow global permutation
        mismatch = np.inf
        for global_permutation in itertools.permutations(range(K)):
            mismatch = np.minimum(np.sum(np.abs(
                estimated_mask[:, global_permutation, :] - mask
            )) / mask.size, mismatch)

        np.testing.assert_array_less(mismatch, 0.1)
def get_mask_from_cacgmm(
    ex,  # (D, T, F)
    weight_constant_axis=-1,
):  # (K, T, F)
    """

    Args:
        observation:

    Returns:

    >>> from nara_wpe.utils import stft
    >>> y = get_dataset('cv_dev93')[0]['audio_data']['observation']
    >>> Y = stft(y, size=512, shift=128)
    >>> get_mask_from_cacgmm(Y).shape
    (3, 813, 257)

    """
    Observation = ex['audio_data']['Observation']
    Observation = rearrange(Observation, 'd t f -> f t d')

    trainer = CACGMMTrainer()

    initialization: 'F, K, T' = initializer.iid.dirichlet_uniform(
        Observation,
        num_classes=3,
        permutation_free=False,
    )

    pa = DHTVPermutationAlignment.from_stft_size(512)

    affiliation = trainer.fit_predict(
        Observation,
        initialization=initialization,
        weight_constant_axis=weight_constant_axis,
        inline_permutation_aligner=pa if weight_constant_axis != -1 else None)

    mapping = pa.calculate_mapping(rearrange(affiliation, 'f k t ->k f t'))

    affiliation = rearrange(
        pa.apply_mapping(rearrange(affiliation, 'f k t ->k f t'), mapping),
        'k f t -> k t f')

    return affiliation
Beispiel #4
0
 def test_local_mapping_on_toy_example(self):
     features = np.asarray(
         [
             [1, 0, 0, 0],
             [0, 0, 2, 0],
             [1, 2, 0, 0],
         ]
     )
     centroids = np.asarray(
         [
             [1, 0, 0, 0],
             [0, 1, 0, 0],
             [0, 0, 1, 0],
         ]
     )
     estimated_mapping = FPA.get_local_mapping(
         features[:, :, None], centroids
     )[0][0, :]
     reference_mapping = [0, 2, 1]
     np.testing.assert_equal(estimated_mapping, reference_mapping)
Beispiel #5
0
    def test_parallel_get_local_mapping(self):

        def reference(mapping, features, centroids):
            F = features.shape[2]

            def get_local_mapping(x, mu):
                K = x.shape[0]
                assert K < 10, f'K = {K} seems to be too much.'
                similarity_matrix = np.einsum('ke,le->kl', x, mu)
                best_permutation = None
                best_score = -np.inf
                for permutation in list(itertools.permutations(range(K))):
                    score = np.sum(similarity_matrix[permutation, range(K)])
                    if score > best_score:
                        best_permutation = permutation
                        best_score = score
                return best_permutation, best_score

            total_score = 0
            for f in range(F):
                mapping[f, :], best_score = get_local_mapping(
                    features[:, :, f], centroids
                )
                total_score += best_score
            return mapping, total_score

        K, E, F = 2, 20, 100
        mapping = np.repeat(np.arange(K)[None, :], F, axis=0)
        features = np.random.uniform(size=(K, E, F))
        centroids = np.random.uniform(size=(K, E))

        ref_mapping, ref_total_score = reference(mapping, features, centroids)
        mapping, total_score = FPA.get_local_mapping(features, centroids)

        np.testing.assert_equal(mapping, ref_mapping)
        np.testing.assert_allclose(total_score, ref_total_score, rtol=1e-6)
Beispiel #6
0
def trainer_on_simulated_speech_data(
        Trainer=CACGMMTrainer,
        iterations=40,
        reverberation=False,
):
    reference_channel = 0
    sample_rate = 8000

    if reverberation:
        ex = reverberation_data()
    else:
        ex = low_reverberation_data()
    observation = ex['audio_data']['observation']
    Observation = stft(observation)
    num_samples = observation.shape[-1]

    Y_mm = rearrange(Observation, 'd t f -> f t d')

    t = Trainer()
    affiliation = t.fit(
        Y_mm,
        num_classes=3,
        iterations=iterations * 2,
        weight_constant_axis=-1,
    ).predict(Y_mm)
    
    pa = DHTVPermutationAlignment.from_stft_size(512)
    affiliation_pa = pa(rearrange(affiliation, 'f k t -> k f t'))
    affiliation_pa = rearrange(affiliation_pa, 'k f t -> k t f')

    Speech_image_0_est, Speech_image_1_est, Noise_image_est = Observation[reference_channel, :, :] * affiliation_pa

    speech_image_0_est = istft(Speech_image_0_est, num_samples=num_samples)
    speech_image_1_est = istft(Speech_image_1_est, num_samples=num_samples)
    noise_image_est = istft(Noise_image_est, num_samples=num_samples)

    ###########################################################################
    # Calculate the metrics

    speech_image = ex['audio_data']['speech_image']
    noise_image = ex['audio_data']['noise_image']
    speech_source = ex['audio_data']['speech_source']

    Speech_image = stft(speech_image)
    Noise_image = stft(noise_image)

    Speech_contribution = Speech_image[:, reference_channel, None, :, :] * affiliation_pa
    Noise_contribution = Noise_image[reference_channel, :, :] * affiliation_pa

    speech_contribution = istft(Speech_contribution, num_samples=num_samples)
    noise_contribution = istft(Noise_contribution, num_samples=num_samples)

    input_metric = InputMetrics(
        observation=observation,
        speech_source=speech_source,
        speech_image=speech_image,
        noise_image=noise_image,
        sample_rate=sample_rate,
    )

    output_metric = OutputMetrics(
        speech_prediction=np.array(
            [speech_image_0_est, speech_image_1_est, noise_image_est]),
        speech_source=speech_source,
        speech_contribution=speech_contribution,
        noise_contribution=noise_contribution,
        sample_rate=sample_rate,
    )

    return {
        'invasive_sxr_sdr': output_metric.invasive_sxr['sdr'] - input_metric.invasive_sxr['sdr'][:, reference_channel],
        'mir_eval_sxr_sdr': output_metric.mir_eval['sdr'] - input_metric.mir_eval['sdr'][:, reference_channel],
    }
Beispiel #7
0
def run_inference(config):
    config['outdir'].mkdir(parents=True, exist_ok=True)

    # get device
    if config['use_gpu']:
        device = torch.device('cuda')
        # moving a tensor to GPU
        # useful at BUT cluster to prevent someone from getting the same GPU
        fake = torch.Tensor([1]).to(device)
    else:
        device = torch.device('cpu')

    # load dataset and GMM
    trans = lambda x: spec(x, **config['spectrum_conf'])
    dataset = JSONAudioMultichannelDataset(config['dataset'],
                                           transform=trans,
                                           i_split=config['i_split'],
                                           n_split=config['n_split'])
    model = read_gmm_from_h5(config['model'], device)

    # load noise model
    assert config['noise_model'].exists(
    ), f"Path {config['noise_model']} should exist."
    noise_model = get_noise_stats(config['noise_model'], None, None, None,
                                  None)

    for utt in tqdm(dataset.utts):
        # load data
        y = dataset[utt]
        logspec0 = logspec_from_spec(y[0])
        logspec0 = torch.tensor(logspec0.astype('float32')).to(device)

        # initialize permutation alignment
        pa = DHTVPermutationAlignment(stft_size=(logspec0.shape[-1] - 1) * 2,
                                      segment_start=70,
                                      segment_width=100,
                                      segment_shift=20,
                                      main_iterations=20,
                                      sub_iterations=2,
                                      similarity_metric='cos')

        # run the inference
        gmm_dolphin = GMMDolphin(model,
                                 noise_model=noise_model,
                                 device=device,
                                 inline_permutation_aligner=pa,
                                 **config['inference_conf'])
        gmm_dolphin.run(logspec0, y)

        # do final permutation alignment
        mask_perm = pa(gmm_dolphin.qd.detach().cpu().numpy().transpose(
            0, 2, 1))
        mask_perm = mask_perm.transpose(0, 2, 1)
        mask_perm = np.clip(mask_perm, 1e-6, 1 - 1e-6)

        # dump masks and q(Z): mostly for debug
        if config['dump_masks']:
            (config['outdir'] / 'masks').mkdir(exist_ok=True)
            with open(config['outdir'] / 'masks' / utt, 'wb') as f:
                pickle.dump({'mask': mask_perm}, f)
        if config['dump_qz']:
            (config['outdir'] / 'qz').mkdir(exist_ok=True)
            qz = [x.detach().cpu().numpy() for x in gmm_dolphin.qz]
            with open(config['outdir'] / 'qz' / utt, 'wb') as f:
                pickle.dump(qz, f)

        # beamforming and saving the audio
        for tgt_speaker in range(len(mask_perm)):
            target_mask = mask_perm[tgt_speaker]
            interf_mask = 1 - target_mask
            enh = Beamformer(**config['beamforming_conf'])(y, target_mask,
                                                           interf_mask)

            length = y[0].size
            s = inverse_spec(np.abs(enh), np.angle(enh), length,
                             **config['spectrum_conf'])
            s = s / np.max(np.abs(s) + 1e-6)
            sf.write(str(config['outdir'] / f'{utt}.{tgt_speaker}.wav'), s,
                     config['spectrum_conf']['sampling_freq'])
Beispiel #8
0
 def test_inverse_permutation(self):
     permutation = np.asarray([[3, 0, 2, 1]])
     inverse = np.asarray([[1, 3, 2, 0]])
     estimated_inverse = FPA.get_inverse_permutation(permutation)
     np.testing.assert_equal(estimated_inverse, inverse)
Beispiel #9
0
def run_inference(config):
    config['outdir'].mkdir(parents=True, exist_ok=True)

    # get device
    if config['use_gpu']:
        device = torch.device('cuda')
        # moving a tensor to GPU
        # useful at BUT cluster to prevent someone from getting the same GPU
        fake = torch.Tensor([1]).to(device)
    else:
        device = torch.device('cpu')

    # load dataset
    trans = lambda x: spec(x, **config['spectrum_conf'])
    dataset = JSONAudioMultichannelDataset(config['dataset'],
                                           transform=trans,
                                           i_split=config['i_split'],
                                           n_split=config['n_split'])

    for utt in tqdm(dataset.utts):
        # load data
        y = dataset[utt]

        # initialize permutation alignment
        pa = DHTVPermutationAlignment(stft_size=(y.shape[-1] - 1) * 2,
                                      segment_start=70,
                                      segment_width=100,
                                      segment_shift=20,
                                      main_iterations=20,
                                      sub_iterations=2,
                                      similarity_metric='cos')

        # load masks for the utterance
        try:
            m1 = sio.loadmat(config['maskdir'] / f'{utt}.spk1.mat')['mask']
            m2 = sio.loadmat(config['maskdir'] / f'{utt}.spk2.mat')['mask']
            m3 = sio.loadmat(config['maskdir'] / f'{utt}.spk3.mat')['mask']
        except sio.matlab.miobase.MatReadError:
            raise KeyError(f'Empty mask for utt {utt}')
        mask = np.stack((m1, m2, m3))

        # run the inference
        pit_dolphin = PITDolphin(mask,
                                 device=device,
                                 inline_permutation_aligner=pa,
                                 **config['inference_conf'])
        pit_dolphin.run(None, y)

        # do final permutation alignment
        mask_perm = pa(pit_dolphin.qd.detach().cpu().numpy().transpose(
            0, 2, 1))
        mask_perm = mask_perm.transpose(0, 2, 1)
        mask_perm = np.clip(mask_perm, 1e-6, 1 - 1e-6)

        # dump masks: mostly for debug
        if config['dump_masks']:
            (config['outdir'] / 'masks').mkdir(exist_ok=True)
            with open(config['outdir'] / 'masks' / utt, 'wb') as f:
                pickle.dump({'mask': mask_perm}, f)

        # beamforming and saving the audio
        for tgt_speaker in range(len(mask_perm)):
            target_mask = mask_perm[tgt_speaker]
            interf_mask = 1 - target_mask
            enh = Beamformer(**config['beamforming_conf'])(y, target_mask,
                                                           interf_mask)

            length = y[0].size
            s = inverse_spec(np.abs(enh), np.angle(enh), length,
                             **config['spectrum_conf'])
            s = s / np.max(np.abs(s) + 1e-6)
            sf.write(str(config['outdir'] / f'{utt}.{tgt_speaker}.wav'), s,
                     config['spectrum_conf']['sampling_freq'])