Exemple #1
0
def test_shapes(ndim):
    loss_dim = [random.randint(1, 4) for _ in range(ndim - 1)]
    t_dim = loss_dim + [10000]
    targets = torch.randn(t_dim)
    est_targets = torch.randn(t_dim)
    loss_func = NegSTOILoss()
    loss_batch = loss_func(est_targets, targets)
    assert loss_batch.shape == targets.shape[:-1]
def test_more_than_minusone(use_vad, extended, iteration):
    loss_func = NegSTOILoss(sample_rate=10000,
                            use_vad=use_vad,
                            extended=extended)
    targets = torch.randn(1, 16000)
    est_targets = torch.randn(1, 16000)
    loss_vals = loss_func(est_targets, targets)
    assert (loss_vals > -1).all()
Exemple #3
0
def test_batchonly_equal(use_vad, extended):
    loss_func = NegSTOILoss(use_vad=use_vad, extended=extended)
    targets = torch.randn(3, 2, 16000)
    est_targets = torch.randn(3, 2, 16000)
    threed = loss_func(est_targets, targets)
    twod = loss_func(est_targets[:, 0], targets[:, 0])
    oned = loss_func(est_targets[0, 0], targets[0, 0])
    assert_allclose(oned, twod[0])
    assert_allclose(twod, threed[:, 0])
Exemple #4
0
def test_forward(sample_rate, use_vad, extended):
    loss_func = NegSTOILoss(sample_rate=sample_rate,
                            use_vad=use_vad,
                            extended=extended)
    batch_size = 3
    est_targets = torch.randn(batch_size, 2 * sample_rate)
    targets = torch.randn(batch_size, 2 * sample_rate)
    loss_val = loss_func(est_targets, targets)
    assert loss_val.ndim == 1
    assert loss_val.shape[0] == batch_size
Exemple #5
0
def test_getbetter(use_vad, extended):
    loss_func = NegSTOILoss(use_vad=use_vad, extended=extended)
    targets = torch.randn(1, 16000)
    old_val = None
    for eps in [5, 2, 1, 0.5, 0.1, 0.01]:
        est_targets = targets + eps * torch.randn(1, 16000)
        new_val = loss_func(est_targets, targets).mean()
        # First iteration is skipped
        if old_val is None:
            continue
        assert new_val < old_val
        old_val = new_val
Exemple #6
0
def make_line(clean, enh, fs):
    # Compute in NumPy
    line = [stoi(clean, enh, fs), stoi(clean, enh, fs, extended=True)]
    # Compute in PyTorch
    for use_vad in [True, False]:
        for extended in [True, False]:
            loss = NegSTOILoss(sample_rate=fs,
                               use_vad=use_vad,
                               extended=extended)
            line.append(
                -loss(torch.from_numpy(enh), torch.from_numpy(clean)).item())
    return line
Exemple #7
0
def test_backward(sample_rate, use_vad, extended):
    nnet = TestNet()
    loss_func = NegSTOILoss(sample_rate=sample_rate,
                            use_vad=use_vad,
                            extended=extended)
    batch_size = 3
    mix = torch.randn(batch_size, 2 * sample_rate, requires_grad=True)
    targets = torch.randn(batch_size, 2 * sample_rate)
    est_targets = nnet(mix)
    loss_val = loss_func(est_targets, targets).mean()
    loss_val.backward()
    # Check that gradients exist
    assert nnet.conv.weight.grad is not None
    assert nnet.conv.bias.grad is not None