def test_ebased_vad(): mag_spec = torch.abs(torch.randn(10, 2, 65, 16)) # Need positive inputs batch_src_mask = ebased_vad(mag_spec) assert isinstance(batch_src_mask, torch.BoolTensor) batch_1_mask = ebased_vad(mag_spec[:, 0]) # Assert independence of VAD output assert (batch_src_mask[:, 0] == batch_1_mask).all()
def forward(self, est_embeddings, target_indices, est_src=None, target_src=None, mix_spec=None): """ Args: est_embeddings (torch.Tensor): Estimated embedding from the DC head. target_indices (torch.Tensor): Target indices that'll be passed to the DC loss. est_src (torch.Tensor): Estimated magnitude spectrograms (or masks). target_src (torch.Tensor): Target magnitude spectrograms (or masks). mix_spec (torch.Tensor): The magnitude spectrogram of the mixture from which VAD will be computed. If None, no VAD is used. Returns: torch.Tensor, the total loss, averaged over the batch. dict with `dc_loss` and `pit_loss` keys, unweighted losses. """ if self.alpha != 0 and (est_src is None or target_src is None): raise ValueError( "Expected target and estimated spectrograms to " "compute the PIT loss, found None." ) binary_mask = None if mix_spec is not None: binary_mask = ebased_vad(mix_spec) # Dc loss is already divided by VAD in the loss function. dc_loss = deep_clustering_loss( embedding=est_embeddings, tgt_index=target_indices, binary_mask=binary_mask ) src_pit_loss = self.src_mse(est_src, target_src) # Equation (4) from Chimera paper. tot = self.alpha * dc_loss.mean() + (1 - self.alpha) * src_pit_loss # Return unweighted losses as well for logging. loss_dict = dict(dc_loss=dc_loss.mean(), pit_loss=src_pit_loss) return tot, loss_dict
def dc_head_separate(self, x): """ Cluster embeddings to produce binary masks, output waveforms """ kmeans = KMeans(n_clusters=self.masker.n_src) if len(x.shape) == 2: x = x.unsqueeze(1) tf_rep = self.encoder(x) mag_spec = mag(tf_rep) proj, mask_out = self.masker(mag_spec) active_bins = ebased_vad(mag_spec) active_proj = proj[active_bins.view(1, -1)] # bin_clusters = kmeans.fit_predict(active_proj.cpu().data.numpy()) # Create binary masks est_mask_list = [] for i in range(self.masker.n_src): # Add ones in all inactive bins in each mask. mask = ~active_bins mask[active_bins] = torch.from_numpy( (bin_clusters == i)).to(mask.device) est_mask_list.append(mask.float()) # Need float, not bool # Go back to time domain est_masks = torch.stack(est_mask_list, dim=1) masked = apply_mag_mask(tf_rep, est_masks) wavs = pad_x_to_y(self.decoder(masked), x) dic_out = dict(tfrep=tf_rep, mask=mask_out, masked_tfrep=masked, proj=proj) return wavs, dic_out