Ejemplo n.º 1
0
def test_PITSolver_tf_ce_forward(num_spk, flexible_numspk):

    batch = 2
    ncls = 100
    ref = [torch.randint(0, ncls, (batch, 10)) for spk in range(num_spk)]
    bias = [F.one_hot(y) for y in ref]
    bias = [F.pad(y, (0, ncls - y.size(-1))) for y in bias]
    inf = [torch.rand(batch, 10, ncls) + bias[spk] for spk in range(num_spk)]
    solver = PITSolver(
        FrequencyDomainCrossEntropy(),
        independent_perm=True,
        flexible_numspk=flexible_numspk,
    )

    loss, stats, others = solver(ref, inf)
    perm = others["perm"]
    correct_perm = list(range(num_spk))
    assert perm[0].equal(torch.tensor(correct_perm)), (perm, correct_perm)

    # test for independent_perm is False

    solver = PITSolver(
        FrequencyDomainCrossEntropy(),
        independent_perm=False,
        flexible_numspk=flexible_numspk,
    )
    loss, stats, others = solver(ref, inf, {"perm": perm})
Ejemplo n.º 2
0
def test_criterion_behavior_dereverb(loss_type, num_spk):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float() for _ in range(num_spk)]
    dereverb_ref = [torch.randn(2, 300).float() for _ in range(num_spk)]
    beamformer = NeuralBeamformer(
        input_dim=17,
        loss_type=loss_type,
        num_spk=num_spk,
        use_wpe=True,
        wlayers=2,
        wunits=2,
        wprojs=2,
        use_dnn_mask_for_wpe=True,
        multi_source_wpe=True,
        use_beamformer=True,
        blayers=2,
        bunits=2,
        bprojs=2,
        badim=2,
        ref_channel=0,
        use_noise_mask=False,
    )
    if loss_type == "mask_mse":
        loss_wrapper = PITSolver(
            criterion=FrequencyDomainMSE(
                compute_on_mask=True, mask_type="PSM", is_dereverb_loss=True
            )
        )
    else:
        loss_wrapper = PITSolver(criterion=SISNRLoss(is_dereverb_loss=True))
    enh_model = ESPnetEnhancementModel(
        encoder=stft_encoder,
        separator=beamformer,
        decoder=stft_decoder,
        mask_module=None,
        loss_wrappers=[loss_wrapper],
    )

    enh_model.train()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(num_spk)},
        "dereverb_ref1": dereverb_ref[0],
    }
    loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 3
0
    def __init__(
        self,
        criterion: AbsEnhLoss,
        weight=1.0,
        independent_perm=True,
    ):
        """Multi-Layer Permutation Invariant Training Solver.

        Compute the PIT loss given inferences of multiple layers and a single reference.
        It also support single inference and single reference in evaluation stage.

        Args:
            criterion (AbsEnhLoss): an instance of AbsEnhLoss
            weight (float): weight (between 0 and 1) of current loss
                for multi-task learning.
            independent_perm (bool):
                If True, PIT will be performed in forward to find the best permutation;
                If False, the permutation from the last LossWrapper output will be
                inherited.
                Note: You should be careful about the ordering of loss
                wrappers defined in the yaml config, if this argument is False.
        """
        super().__init__()
        self.criterion = criterion
        self.weight = weight
        self.independent_perm = independent_perm
        self.solver = PITSolver(criterion, weight, independent_perm)
Ejemplo n.º 4
0
def test_criterion_behavior(training):
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    enh_model = ESPnetEnhancementModel(
        encoder=stft_encoder,
        separator=rnn_separator,
        decoder=stft_decoder,
        mask_module=None,
        loss_wrappers=[PITSolver(criterion=SISNRLoss(only_for_test=True))],
    )

    if training:
        enh_model.train()
    else:
        enh_model.eval()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i]
           for i in range(2)},
    }

    if training:
        with pytest.raises(AttributeError):
            loss, stats, weight = enh_model(**kwargs)
    else:
        loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 5
0
def test_criterion_behavior_noise(encoder, decoder, separator):
    if not isinstance(encoder, STFTEncoder) and isinstance(
        separator, (DCCRNSeparator, DC_CRNSeparator)
    ):
        # skip because DCCRNSeparator and DC_CRNSeparator only work
        # for complex spectrum features
        return
    inputs = torch.randn(2, 300)
    ilens = torch.LongTensor([300, 200])
    speech_refs = [torch.randn(2, 300).float(), torch.randn(2, 300).float()]
    noise_ref = torch.randn(2, 300)
    enh_model = ESPnetEnhancementModel(
        encoder=encoder,
        separator=separator,
        decoder=decoder,
        mask_module=None,
        loss_wrappers=[PITSolver(criterion=SISNRLoss(is_noise_loss=True))],
    )

    enh_model.train()

    kwargs = {
        "speech_mix": inputs,
        "speech_mix_lengths": ilens,
        **{"speech_ref{}".format(i + 1): speech_refs[i] for i in range(2)},
        "noise_ref1": noise_ref,
    }
    loss, stats, weight = enh_model(**kwargs)
Ejemplo n.º 6
0
def test_PITSolver_forward(num_spk):

    batch = 2
    inf = [torch.rand(batch, 10, 100) for spk in range(num_spk)]
    ref = [inf[num_spk - spk - 1]
           for spk in range(num_spk)]  # reverse inf as ref
    solver = PITSolver(FrequencyDomainL1(), independent_perm=True)

    loss, stats, others = solver(ref, inf)
    perm = others["perm"]
    correct_perm = list(range(num_spk))
    correct_perm.reverse()
    assert perm[0].equal(torch.tensor(correct_perm))

    # test for independent_perm is False

    solver = PITSolver(FrequencyDomainL1(), independent_perm=False)
    loss, stats, others = solver(ref, inf, {"perm": perm})
Ejemplo n.º 7
0
    def cal_permumation(self, ref_wavs, enh_wavs, criterion="si_snr"):
        """Calculate the permutation between seaprated streams in two adjacent segments.

        Args:
            ref_wavs (List[torch.Tensor]): [(Batch, Nsamples)]
            enh_wavs (List[torch.Tensor]): [(Batch, Nsamples)]
            criterion (str): one of ("si_snr", "mse", "corr)
        Returns:
            perm (torch.Tensor): permutation for enh_wavs (Batch, num_spk)
        """

        criterion_class = {"si_snr": SISNRLoss, "mse": FrequencyDomainMSE}[criterion]

        pit_solver = PITSolver(criterion=criterion_class())

        _, _, others = pit_solver(ref_wavs, enh_wavs)
        perm = others["perm"]
        return perm
Ejemplo n.º 8
0
    bidirectional=False,
    use_cbn=False,
    kernel_size=5,
    kernel_num=[
        32,
        64,
        128,
    ],
    use_builtin_complex=True,
    use_noise_mask=False,
)
si_snr_loss = SISNRLoss()
tf_mse_loss = FrequencyDomainMSE()
tf_l1_loss = FrequencyDomainL1()

pit_wrapper = PITSolver(criterion=si_snr_loss)
fix_order_solver = FixedOrderSolver(criterion=tf_mse_loss)


@pytest.mark.parametrize(
    "encoder, decoder",
    [
        (stft_encoder, stft_decoder),
        (stft_encoder_bultin_complex, stft_decoder),
        (conv_encoder, conv_decoder),
    ],
)
@pytest.mark.parametrize(
    "separator",
    [
        rnn_separator,