Esempio n. 1
0
def test_ebased_vad():
    mag_spec = torch.abs(torch.randn(10, 2, 65, 16))  # Need positive inputs
    batch_src_mask = transforms.ebased_vad(mag_spec)

    assert isinstance(batch_src_mask, torch.BoolTensor)
    batch_1_mask = transforms.ebased_vad(mag_spec[:, 0])
    # Assert independence of VAD output
    assert (batch_src_mask[:, 0] == batch_1_mask).all()
Esempio n. 2
0
    def forward(self, est_embeddings, target_indices, est_src=None, target_src=None, mix_spec=None):
        """

        Args:
            est_embeddings (torch.Tensor): Estimated embedding from the DC head.
            target_indices (torch.Tensor): Target indices that'll be passed to
                the DC loss.
            est_src (torch.Tensor): Estimated magnitude spectrograms (or masks).
            target_src (torch.Tensor): Target magnitude spectrograms (or masks).
            mix_spec (torch.Tensor): The magnitude spectrogram of the mixture
                from which VAD will be computed. If None, no VAD is used.

        Returns:
            torch.Tensor, the total loss, averaged over the batch.
            dict with `dc_loss` and `pit_loss` keys, unweighted losses.
        """
        if self.alpha != 0 and (est_src is None or target_src is None):
            raise ValueError(
                "Expected target and estimated spectrograms to " "compute the PIT loss, found None."
            )
        binary_mask = None
        if mix_spec is not None:
            binary_mask = ebased_vad(mix_spec)
        # Dc loss is already divided by VAD in the loss function.
        dc_loss = deep_clustering_loss(
            embedding=est_embeddings, tgt_index=target_indices, binary_mask=binary_mask
        )
        src_pit_loss = self.src_mse(est_src, target_src)
        # Equation (4) from Chimera paper.
        tot = self.alpha * dc_loss.mean() + (1 - self.alpha) * src_pit_loss
        # Return unweighted losses as well for logging.
        loss_dict = dict(dc_loss=dc_loss.mean(), pit_loss=src_pit_loss)
        return tot, loss_dict
Esempio n. 3
0
 def dc_head_separate(self, x):
     """ Cluster embeddings to produce binary masks, output waveforms """
     kmeans = KMeans(n_clusters=self.masker.n_src)
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = self.encoder(x)
     mag_spec = take_mag(tf_rep)
     proj, mask_out = self.masker(mag_spec)
     active_bins = ebased_vad(mag_spec)
     active_proj = proj[active_bins.view(1, -1)]
     #
     bin_clusters = kmeans.fit_predict(active_proj.cpu().data.numpy())
     # Create binary masks
     est_mask_list = []
     for i in range(self.masker.n_src):
         # Add ones in all inactive bins in each mask.
         mask = ~active_bins
         mask[active_bins] = torch.from_numpy(
             (bin_clusters == i)).to(mask.device)
         est_mask_list.append(mask.float())  # Need float, not bool
     # Go back to time domain
     est_masks = torch.stack(est_mask_list, dim=1)
     masked = apply_mag_mask(tf_rep, est_masks)
     wavs = pad_x_to_y(self.decoder(masked), x)
     dic_out = dict(tfrep=tf_rep,
                    mask=mask_out,
                    masked_tfrep=masked,
                    proj=proj)
     return wavs, dic_out