Пример #1
0
def test_mag_mask(encoder_list):
    """ Assert identity mask works. """
    for (enc, fb_dim) in encoder_list:
        tf_rep = enc(torch.randn(2, 1, 8000))  # [batch, freq, time]
        id_mag_mask = torch.ones((1, fb_dim // 2, 1))
        masked = transforms.apply_mag_mask(tf_rep, id_mag_mask, dim=1)
        assert_allclose(masked, tf_rep)
Пример #2
0
 def dc_head_separate(self, x):
     """ Cluster embeddings to produce binary masks, output waveforms """
     kmeans = KMeans(n_clusters=self.masker.n_src)
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = self.encoder(x)
     mag_spec = take_mag(tf_rep)
     proj, mask_out = self.masker(mag_spec)
     active_bins = ebased_vad(mag_spec)
     active_proj = proj[active_bins.view(1, -1)]
     #
     bin_clusters = kmeans.fit_predict(active_proj.cpu().data.numpy())
     # Create binary masks
     est_mask_list = []
     for i in range(self.masker.n_src):
         # Add ones in all inactive bins in each mask.
         mask = ~active_bins
         mask[active_bins] = torch.from_numpy(
             (bin_clusters == i)).to(mask.device)
         est_mask_list.append(mask.float())  # Need float, not bool
     # Go back to time domain
     est_masks = torch.stack(est_mask_list, dim=1)
     masked = apply_mag_mask(tf_rep, est_masks)
     wavs = pad_x_to_y(self.decoder(masked), x)
     dic_out = dict(tfrep=tf_rep,
                    mask=mask_out,
                    masked_tfrep=masked,
                    proj=proj)
     return wavs, dic_out
Пример #3
0
 def separate(self, x):
     """ Separate with mask-inference head, output waveforms """
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = self.encoder(x)
     proj, mask_out = self.masker(mag(tf_rep))
     masked = apply_mag_mask(tf_rep.unsqueeze(1), mask_out)
     wavs = torch_utils.pad_x_to_y(self.decoder(masked), x)
     dic_out = dict(tfrep=tf_rep, mask=mask_out, masked_tfrep=masked, proj=proj)
     return wavs, dic_out
Пример #4
0
 def forward(self, x):
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     # Compute STFT
     tf_rep = self.encoder(x)
     # Estimate TF mask from STFT features : cat([re, im, mag])
     if self.is_complex:
         to_masker = take_cat(tf_rep)
     else:
         to_masker = take_mag(tf_rep)
     # LSTM masker expects a feature dimension last (not like 1D conv)
     est_masks = self.masker(to_masker.transpose(1, 2)).transpose(1, 2)
     # Apply TF mask
     if self.is_complex:
         masked_tf_rep = apply_real_mask(tf_rep, est_masks)
     else:
         masked_tf_rep = apply_mag_mask(tf_rep, est_masks)
     return masked_tf_rep
Пример #5
0
    def forward(self, x):
        """
        Forward pass of generator.
        Args:
            x: input batch (signal)
        """

        # Encode
        spec = self.encoder(x)
        mag = take_mag(spec)
        # x = nn.utils.spectral_norm(x)
        mag = torch.transpose(mag, 1, 2)
        # Compute mask
        self.LSTM.flatten_parameters()
        mask, _ = self.LSTM(mag)
        mask = self.model(mask)
        mask = torch.transpose(mask, 1, 2)
        y = apply_mag_mask(spec, mask)
        # Decode
        y = self.decoder(y)
        return torch_utils.pad_x_to_y(y, x)