def forward_wav(self, wav, slice_size=32000, *args, **kwargs): """Separation method for waveforms. Unfolds a full audio into slices, estimate Args: wav (torch.Tensor): waveform array/tensor. Shape: 1D, 2D or 3D tensor, time last. Return: output_cat (torch.Tensor): concatenated output tensor. [num_spks, T] """ assert not self.training, "forward_wav is only used for test mode" T = wav.size(-1) if wav.ndim == 1: wav = wav.reshape(1, wav.size(0)) assert wav.ndim == 2 # [1, T] slice_stride = slice_size // 2 # pad wav to integer multiple of slice_stride T_padded = max(int(np.ceil(T / slice_stride)), 2) * slice_stride wav = F.pad(wav, (0, T_padded - T)) slices = wav.unfold( dimension=-1, size=slice_size, step=slice_stride ) # [1, slice_nb, slice_size] slice_nb = slices.size(1) slices = slices.squeeze(0).unsqueeze(1) tf_rep = self.enc_activation(self.encoder(slices)) est_masks_list = self.masker(tf_rep) selector_input = est_masks_list[-1] # [slice_nb, bn_chan, chunk_size, n_chunks] selector_output = self.decoder_select.selector(selector_input).reshape( slice_nb, -1 ) # [slice_nb, num_decs] est_idx, _ = selector_output.argmax(-1).mode() est_spks = self.decoder_select.n_srcs[est_idx] output_wavs, _ = self.decoder_select( est_masks_list, tf_rep, ground_truth=[est_spks] * slice_nb ) # [slice_nb, 1, n_spks, slice_size] output_wavs = output_wavs.squeeze(1)[:, :est_spks, :] # TODO: overlap and add (with division) output_cat = output_wavs.new_zeros(est_spks, slice_nb * slice_size) output_cat[:, :slice_size] = output_wavs[0] start = slice_stride for i in range(1, slice_nb): end = start + slice_size overlap_prev = output_cat[:, start : start + slice_stride].unsqueeze(0) overlap_next = output_wavs[i : i + 1, :, :slice_stride] pw_losses = pairwise_neg_sisdr(overlap_next, overlap_prev) _, best_indices = PITLossWrapper.find_best_perm(pw_losses) reordered = PITLossWrapper.reorder_source(output_wavs[i : i + 1, :, :], best_indices) output_cat[:, start : start + slice_size] += reordered.squeeze(0) output_cat[:, start : start + slice_stride] /= 2 start += slice_stride return output_cat[:, :T]
def _reorder_sources( current: torch.FloatTensor, previous: torch.FloatTensor, n_src: int, window_size: int, hop_size: int, ): """ Reorder sources in current chunk to maximize correlation with previous chunk. Used for Continuous Source Separation. Standard dsp correlation is used for reordering. Args: current (:class:`torch.Tensor`): current chunk, tensor of shape (batch, n_src, window_size) previous (:class:`torch.Tensor`): previous chunk, tensor of shape (batch, n_src, window_size) n_src (:class:`int`): number of sources. window_size (:class:`int`): window_size, equal to last dimension of both current and previous. hop_size (:class:`int`): hop_size between current and previous tensors. Returns: current: """ batch, frames = current.size() current = current.reshape(-1, n_src, frames) previous = previous.reshape(-1, n_src, frames) overlap_f = window_size - hop_size pw_losses = PITLossWrapper.get_pw_losses( lambda x, y: torch.sum((x.unsqueeze(1) * y.unsqueeze(2))), current[..., :overlap_f], previous[..., -overlap_f:], ) _, perms = PITLossWrapper.find_best_perm(pw_losses, n_src) current = PITLossWrapper.reorder_source(current, n_src, perms) return current.reshape(batch, frames)