Esempio n. 1
0
def test_griffinlim(fb_config, feed_istft, feed_angle):
    stft = Encoder(STFTFB(**fb_config))
    istft = None if not feed_istft else Decoder(STFTFB(**fb_config))
    wav = torch.randn(2, 1, 8000)
    spec = stft(wav)
    tf_mask = torch.sigmoid(torch.randn_like(spec))
    masked_spec = spec * tf_mask
    mag = transforms.mag(masked_spec, -2)
    angles = None if not feed_angle else transforms.angle(masked_spec, -2)
    griffin_lim(mag, stft, angles=angles, istft_dec=istft, n_iter=3)
Esempio n. 2
0
    def apply_masks(self, tf_rep, est_masks):
        if self.target == "TMS":
            ang = angle(tf_rep)
            return from_magphase(est_masks, ang)

        elif self.target == "TCS":
            return est_masks

        else:
            return apply_complex_mask(tf_rep, est_masks)
Esempio n. 3
0
def test_angle_mag_recompostion(dim):
    """ Test complex --> (mag, angle) --> complex conversions"""
    max_tested_ndim = 4
    # Random tensor shape
    tensor_shape = [random.randint(1, 10) for _ in range(max_tested_ndim)]
    # Make sure complex dimension has even shape
    tensor_shape[dim] = 2 * tensor_shape[dim]
    complex_tensor = torch.randn(tensor_shape)
    phase = transforms.angle(complex_tensor, dim=dim)
    mag = transforms.mag(complex_tensor, dim=dim)
    tensor_back = transforms.from_magphase(mag, phase, dim=dim)
    assert_allclose(complex_tensor, tensor_back)
Esempio n. 4
0
def test_misi(fb_config, feed_istft, feed_angle):
    stft = Encoder(STFTFB(**fb_config))
    istft = None if not feed_istft else Decoder(STFTFB(**fb_config))
    n_src = 3
    # Create mixture
    wav = torch.randn(2, 1, 8000)
    spec = stft(wav).unsqueeze(1)
    # Create n_src masks on mixture spec and apply them
    shape = list(spec.shape)
    shape[1] *= n_src
    tf_mask = torch.sigmoid(torch.randn(*shape))
    masked_specs = spec * tf_mask
    # Separate mag and angle.
    mag = transforms.mag(masked_specs, -2)
    angles = None if not feed_angle else transforms.angle(masked_specs, -2)
    est_wavs = misi(wav, mag, stft, angles=angles, istft_dec=istft, n_iter=2)
    # We actually don't know the last dim because ISTFT(STFT()) cuts the end
    assert est_wavs.shape[:-1] == (2, n_src)
Esempio n. 5
0
 def apply_masks(self, tf_rep, est_masks):
     ang = angle(tf_rep)
     return from_magphase(est_masks, ang)
Esempio n. 6
0
 def forward_masker(self, tf_rep):
     output, _, _ = self.forward_vae_mu_logvar(torch.pow(mag(tf_rep), 2))
     return from_magphase(torch.sqrt(output), angle(tf_rep))