Пример #1
0
def split_pair_to_vars(sample_batch_pair):
    """
    Splits the generated batch data and creates combination of pairs.
    Input argument sample_batch_pair consists of a batch_size number of
    [clean_signal, noisy_signal] pairs.

    This function creates three pytorch Variables - a clean_signal, noisy_signal pair,
    clean signal only, and noisy signal only.
    It goes through preemphasis preprocessing before converted into variable.

    Args:
        sample_batch_pair(torch.Tensor): batch of [clean_signal, noisy_signal] pairs
    Returns:
        batch_pairs_var(Variable): batch of pairs containing clean signal and noisy signal
        clean_batch_var(Variable): clean signal batch
        noisy_batch_var(Varialbe): noisy signal batch
    """
    # preemphasis
    sample_batch_pair = emph.pre_emphasis(sample_batch_pair.numpy(),
                                          emph_coeff=0.95)
    batch_pairs_var = Variable(
        torch.from_numpy(sample_batch_pair).type(
            torch.FloatTensor)).cuda()  # [40 x 2 x 16384]
    clean_batch = np.stack(
        [pair[0].reshape(1, -1) for pair in sample_batch_pair])
    clean_batch_var = Variable(
        torch.from_numpy(clean_batch).type(torch.FloatTensor)).cuda()
    noisy_batch = np.stack(
        [pair[1].reshape(1, -1) for pair in sample_batch_pair])
    noisy_batch_var = Variable(
        torch.from_numpy(noisy_batch).type(torch.FloatTensor)).cuda()
    return batch_pairs_var, clean_batch_var, noisy_batch_var
Пример #2
0
    def test_pre_emphasis(self):
        """
        Tests equality after de-emphasizing pre-emphasized signal.
        """
        rand_signal_batch = np.random.randint(low=1,
                                              high=10,
                                              size=(10, 1, 400))
        reconst_batch = emph.de_emphasis(emph.pre_emphasis(rand_signal_batch))

        # after de-emphasis, the signal must have been restored
        self.assertEqual(rand_signal_batch.shape, reconst_batch.shape)
        self.assertTrue(np.allclose(rand_signal_batch, reconst_batch))