def forward(self, x): if len(x.shape) == 2: x = x.unsqueeze(1) # Compute STFT tf_rep = self.encoder(x) # Estimate TF mask from STFT features : cat([re, im, mag]) if self.is_complex: to_masker = magreim(tf_rep) else: to_masker = mag(tf_rep) # LSTM masker expects a feature dimension last (not like 1D conv) est_masks = self.masker(to_masker.transpose(1, 2)).transpose(1, 2) # Apply TF mask if self.is_complex: masked_tf_rep = apply_real_mask(tf_rep, est_masks) else: masked_tf_rep = apply_mag_mask(tf_rep, est_masks) return masked_tf_rep
def forward_masker(self, tf_rep): """Estimates masks based on time-frequency representations. Args: tf_rep (torch.Tensor): Time-frequency representation in (batch, freq, seq). Returns: torch.Tensor: Estimated masks in (batch, freq, seq). """ masker_input = tf_rep if self.input_type == "mag": masker_input = mag(masker_input) elif self.input_type == "cat": masker_input = magreim(masker_input) est_masks = self.masker(masker_input) if self.output_type == "mag": est_masks = est_masks.repeat(1, 2, 1) return est_masks
def test_cat(encoder_list): for (enc, fb_dim) in encoder_list: tf_rep = enc(torch.randn(2, 1, 16000)) # [batch, freq, time] batch, freq, time = tf_rep.shape mag = transforms.magreim(tf_rep, dim=1) assert mag.shape == (batch, 3 * (freq // 2), time)