Exemplo n.º 1
0
 def forward_wav(self, wav, slice_size=32000, *args, **kwargs):
     """Separation method for waveforms.
     Unfolds a full audio into slices, estimate
     Args:
         wav (torch.Tensor): waveform array/tensor.
             Shape: 1D, 2D or 3D tensor, time last.
     Return:
         output_cat (torch.Tensor): concatenated output tensor.
             [num_spks, T]
     """
     assert not self.training, "forward_wav is only used for test mode"
     T = wav.size(-1)
     if wav.ndim == 1:
         wav = wav.reshape(1, wav.size(0))
     assert wav.ndim == 2  # [1, T]
     slice_stride = slice_size // 2
     # pad wav to integer multiple of slice_stride
     T_padded = max(int(np.ceil(T / slice_stride)), 2) * slice_stride
     wav = F.pad(wav, (0, T_padded - T))
     slices = wav.unfold(
         dimension=-1, size=slice_size, step=slice_stride
     )  # [1, slice_nb, slice_size]
     slice_nb = slices.size(1)
     slices = slices.squeeze(0).unsqueeze(1)
     tf_rep = self.enc_activation(self.encoder(slices))
     est_masks_list = self.masker(tf_rep)
     selector_input = est_masks_list[-1]  # [slice_nb, bn_chan, chunk_size, n_chunks]
     selector_output = self.decoder_select.selector(selector_input).reshape(
         slice_nb, -1
     )  # [slice_nb, num_decs]
     est_idx, _ = selector_output.argmax(-1).mode()
     est_spks = self.decoder_select.n_srcs[est_idx]
     output_wavs, _ = self.decoder_select(
         est_masks_list, tf_rep, ground_truth=[est_spks] * slice_nb
     )  # [slice_nb, 1, n_spks, slice_size]
     output_wavs = output_wavs.squeeze(1)[:, :est_spks, :]
     # TODO: overlap and add (with division)
     output_cat = output_wavs.new_zeros(est_spks, slice_nb * slice_size)
     output_cat[:, :slice_size] = output_wavs[0]
     start = slice_stride
     for i in range(1, slice_nb):
         end = start + slice_size
         overlap_prev = output_cat[:, start : start + slice_stride].unsqueeze(0)
         overlap_next = output_wavs[i : i + 1, :, :slice_stride]
         pw_losses = pairwise_neg_sisdr(overlap_next, overlap_prev)
         _, best_indices = PITLossWrapper.find_best_perm(pw_losses)
         reordered = PITLossWrapper.reorder_source(output_wavs[i : i + 1, :, :], best_indices)
         output_cat[:, start : start + slice_size] += reordered.squeeze(0)
         output_cat[:, start : start + slice_stride] /= 2
         start += slice_stride
     return output_cat[:, :T]
Exemplo n.º 2
0
def _reorder_sources(
    current: torch.FloatTensor,
    previous: torch.FloatTensor,
    n_src: int,
    window_size: int,
    hop_size: int,
):
    """
     Reorder sources in current chunk to maximize correlation with previous chunk.
     Used for Continuous Source Separation. Standard dsp correlation is used
     for reordering.


    Args:
        current (:class:`torch.Tensor`): current chunk, tensor
                                        of shape (batch, n_src, window_size)
        previous (:class:`torch.Tensor`): previous chunk, tensor
                                        of shape (batch, n_src, window_size)
        n_src (:class:`int`): number of sources.
        window_size (:class:`int`): window_size, equal to last dimension of
                                    both current and previous.
        hop_size (:class:`int`): hop_size between current and previous tensors.

    Returns:
        current:

    """
    batch, frames = current.size()
    current = current.reshape(-1, n_src, frames)
    previous = previous.reshape(-1, n_src, frames)

    overlap_f = window_size - hop_size
    pw_losses = PITLossWrapper.get_pw_losses(
        lambda x, y: torch.sum((x.unsqueeze(1) * y.unsqueeze(2))),
        current[..., :overlap_f],
        previous[..., -overlap_f:],
    )
    _, perms = PITLossWrapper.find_best_perm(pw_losses, n_src)
    current = PITLossWrapper.reorder_source(current, n_src, perms)
    return current.reshape(batch, frames)