def test_shapes(ndim): loss_dim = [random.randint(1, 4) for _ in range(ndim - 1)] t_dim = loss_dim + [10000] targets = torch.randn(t_dim) est_targets = torch.randn(t_dim) loss_func = NegSTOILoss() loss_batch = loss_func(est_targets, targets) assert loss_batch.shape == targets.shape[:-1]
def test_more_than_minusone(use_vad, extended, iteration): loss_func = NegSTOILoss(sample_rate=10000, use_vad=use_vad, extended=extended) targets = torch.randn(1, 16000) est_targets = torch.randn(1, 16000) loss_vals = loss_func(est_targets, targets) assert (loss_vals > -1).all()
def test_batchonly_equal(use_vad, extended): loss_func = NegSTOILoss(use_vad=use_vad, extended=extended) targets = torch.randn(3, 2, 16000) est_targets = torch.randn(3, 2, 16000) threed = loss_func(est_targets, targets) twod = loss_func(est_targets[:, 0], targets[:, 0]) oned = loss_func(est_targets[0, 0], targets[0, 0]) assert_allclose(oned, twod[0]) assert_allclose(twod, threed[:, 0])
def test_forward(sample_rate, use_vad, extended): loss_func = NegSTOILoss(sample_rate=sample_rate, use_vad=use_vad, extended=extended) batch_size = 3 est_targets = torch.randn(batch_size, 2 * sample_rate) targets = torch.randn(batch_size, 2 * sample_rate) loss_val = loss_func(est_targets, targets) assert loss_val.ndim == 1 assert loss_val.shape[0] == batch_size
def test_getbetter(use_vad, extended): loss_func = NegSTOILoss(use_vad=use_vad, extended=extended) targets = torch.randn(1, 16000) old_val = None for eps in [5, 2, 1, 0.5, 0.1, 0.01]: est_targets = targets + eps * torch.randn(1, 16000) new_val = loss_func(est_targets, targets).mean() # First iteration is skipped if old_val is None: continue assert new_val < old_val old_val = new_val
def make_line(clean, enh, fs): # Compute in NumPy line = [stoi(clean, enh, fs), stoi(clean, enh, fs, extended=True)] # Compute in PyTorch for use_vad in [True, False]: for extended in [True, False]: loss = NegSTOILoss(sample_rate=fs, use_vad=use_vad, extended=extended) line.append( -loss(torch.from_numpy(enh), torch.from_numpy(clean)).item()) return line
def test_backward(sample_rate, use_vad, extended): nnet = TestNet() loss_func = NegSTOILoss(sample_rate=sample_rate, use_vad=use_vad, extended=extended) batch_size = 3 mix = torch.randn(batch_size, 2 * sample_rate, requires_grad=True) targets = torch.randn(batch_size, 2 * sample_rate) est_targets = nnet(mix) loss_val = loss_func(est_targets, targets).mean() loss_val.backward() # Check that gradients exist assert nnet.conv.weight.grad is not None assert nnet.conv.bias.grad is not None