Example #1
0
 def forward(self, x):
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     # Compute STFT
     tf_rep = self.encoder(x)
     # Estimate TF mask from STFT features : cat([re, im, mag])
     if self.is_complex:
         to_masker = take_cat(tf_rep)
     else:
         to_masker = take_mag(tf_rep)
     # LSTM masker expects a feature dimension last (not like 1D conv)
     est_masks = self.masker(to_masker.transpose(1, 2)).transpose(1, 2)
     # Apply TF mask
     if self.is_complex:
         masked_tf_rep = apply_real_mask(tf_rep, est_masks)
     else:
         masked_tf_rep = apply_mag_mask(tf_rep, est_masks)
     return masked_tf_rep
Example #2
0
def test_cat(encoder_list):
    for (enc, fb_dim) in encoder_list:
        tf_rep = enc(torch.randn(2, 1, 16000))  # [batch, freq, time]
        batch, freq, time = tf_rep.shape
        mag = transforms.take_cat(tf_rep, dim=1)
        assert mag.shape == (batch, 3 * (freq // 2), time)