Exemple #1
0
 def forward(self, x):
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     # Compute STFT
     tf_rep = self.encoder(x)
     # Estimate TF mask from STFT features : cat([re, im, mag])
     if self.is_complex:
         to_masker = magreim(tf_rep)
     else:
         to_masker = mag(tf_rep)
     # LSTM masker expects a feature dimension last (not like 1D conv)
     est_masks = self.masker(to_masker.transpose(1, 2)).transpose(1, 2)
     # Apply TF mask
     if self.is_complex:
         masked_tf_rep = apply_real_mask(tf_rep, est_masks)
     else:
         masked_tf_rep = apply_mag_mask(tf_rep, est_masks)
     return masked_tf_rep
Exemple #2
0
    def forward_masker(self, tf_rep):
        """Estimates masks based on time-frequency representations.

        Args:
            tf_rep (torch.Tensor): Time-frequency representation in
                (batch, freq, seq).

        Returns:
            torch.Tensor: Estimated masks in (batch, freq, seq).
        """
        masker_input = tf_rep
        if self.input_type == "mag":
            masker_input = mag(masker_input)
        elif self.input_type == "cat":
            masker_input = magreim(masker_input)
        est_masks = self.masker(masker_input)
        if self.output_type == "mag":
            est_masks = est_masks.repeat(1, 2, 1)
        return est_masks
Exemple #3
0
def test_cat(encoder_list):
    for (enc, fb_dim) in encoder_list:
        tf_rep = enc(torch.randn(2, 1, 16000))  # [batch, freq, time]
        batch, freq, time = tf_rep.shape
        mag = transforms.magreim(tf_rep, dim=1)
        assert mag.shape == (batch, 3 * (freq // 2), time)