Beispiel #1
0
 def dc_head_separate(self, x):
     """ Cluster embeddings to produce binary masks, output waveforms """
     kmeans = KMeans(n_clusters=self.masker.n_src)
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = self.encoder(x)
     mag_spec = take_mag(tf_rep)
     proj, mask_out = self.masker(mag_spec)
     active_bins = ebased_vad(mag_spec)
     active_proj = proj[active_bins.view(1, -1)]
     #
     bin_clusters = kmeans.fit_predict(active_proj.cpu().data.numpy())
     # Create binary masks
     est_mask_list = []
     for i in range(self.masker.n_src):
         # Add ones in all inactive bins in each mask.
         mask = ~active_bins
         mask[active_bins] = torch.from_numpy((bin_clusters == i)).to(mask.device)
         est_mask_list.append(mask.float())  # Need float, not bool
     # Go back to time domain
     est_masks = torch.stack(est_mask_list, dim=1)
     masked = apply_mag_mask(tf_rep, est_masks)
     wavs = pad_x_to_y(self.decoder(masked), x)
     dic_out = dict(tfrep=tf_rep, mask=mask_out, masked_tfrep=masked,
                    proj=proj)
     return wavs, dic_out
Beispiel #2
0
 def forward(self, x):
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = self.encoder(x)
     est_masks = self.masker(tf_rep)
     masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
     return torch_utils.pad_x_to_y(self.decoder(masked_tf_rep), x)
Beispiel #3
0
 def forward(self, x):
     batch_size = x.shape[0]
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = self.encode(x)
     to_sep = self.bn_layer(tf_rep)
     est_masks = self.masker(to_sep.transpose(-1, -2)).transpose(-1, -2)
     est_masks = est_masks.view(batch_size, self.n_src, self.n_filters, -1)
     masked_tf_rep = tf_rep.unsqueeze(1) * est_masks
     return torch_utils.pad_x_to_y(self.decoder(masked_tf_rep), x)
Beispiel #4
0
 def separate(self, x):
     """ Separate with mask-inference head, output waveforms """
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = self.encoder(x)
     proj, mask_out = self.masker(mag(tf_rep))
     masked = apply_mag_mask(tf_rep.unsqueeze(1), mask_out)
     wavs = torch_utils.pad_x_to_y(self.decoder(masked), x)
     dic_out = dict(tfrep=tf_rep, mask=mask_out, masked_tfrep=masked, proj=proj)
     return wavs, dic_out
Beispiel #5
0
    def forward(self, x):
        """
        Forward pass of generator.
        Args:
            x: input batch (signal)
        """

        # Encode
        spec = self.encoder(x)
        mag = take_mag(spec)
        # x = nn.utils.spectral_norm(x)
        mag = torch.transpose(mag, 1, 2)
        # Compute mask
        self.LSTM.flatten_parameters()
        mask, _ = self.LSTM(mag)
        mask = self.model(mask)
        mask = torch.transpose(mask, 1, 2)
        y = apply_mag_mask(spec, mask)
        # Decode
        y = self.decoder(y)
        return torch_utils.pad_x_to_y(y, x)
Beispiel #6
0
def test_pad():
    x = torch.randn(10, 1, 16000)
    y = torch.randn(10, 1, 16234)
    padded_x = torch_utils.pad_x_to_y(x, y)
    assert padded_x.shape == y.shape
Beispiel #7
0
def test_pad_fail():
    x = torch.randn(10, 16000, 1)
    y = torch.randn(10, 16234, 1)
    with pytest.raises(NotImplementedError):
        torch_utils.pad_x_to_y(x, y, axis=1)
Beispiel #8
0
 def denoise(self, x):
     estimate_stft = self(x)
     wav = self.decoder(estimate_stft)
     return torch_utils.pad_x_to_y(wav, x)