Exemple #1
0
    def forward(self, est_src, logits, src):
        """Forward
        Args:
            est_src: $(num_stages, n_src, T)
            logits: $(num_stages, num_decoders)
            src: $(n_src, T)
        """
        assert est_src.size()[1:] == src.size()
        num_stages, n_src, T = est_src.size()
        target_src = src.unsqueeze(0).repeat(num_stages, 1, 1)
        target_idx = self.n_src2idx[n_src]

        pw_losses = pairwise_neg_sisdr(est_src, target_src)
        sdr_loss, _ = PITLossWrapper.find_best_perm(pw_losses)
        pos_sdr = -sdr_loss[-1]

        cls_target = torch.LongTensor([target_idx] * num_stages).to(
            logits.device)
        cls_loss = self.cce(logits, cls_target)
        correctness = logits[-1].argmax().item() == target_idx

        coeffs = torch.Tensor([
            (c_idx + 1) * (1 / num_stages) for c_idx in range(num_stages)
        ]).to(logits.device)
        assert coeffs.size() == sdr_loss.size() == cls_loss.size()
        # use sum of SDR for each channel, not mean
        loss = torch.sum(coeffs * (sdr_loss * n_src + cls_loss * self.lamb))

        return loss, pos_sdr, correctness
def main(conf):
    perms = list(permutations(range(conf["train_conf"]["data"]["n_src"])))

    model_path = os.path.join(conf["exp_dir"], conf["ckpt_path"])
    if conf["ckpt_path"] == "best_model.pth":
        # serialized checkpoint
        model = getattr(asteroid, conf["model"]).from_pretrained(model_path)
    else:
        # non-serialized checkpoint, _ckpt_epoch_{i}.ckpt, keys would start with
        # "model.", which need to be removed
        model = getattr(asteroid, conf["model"])(**conf["train_conf"]["filterbank"], **conf["train_conf"]["masknet"])
        all_states = torch.load(model_path, map_location="cpu")
        state_dict = {k.split('.', 1)[1]: all_states["state_dict"][k] for k in all_states["state_dict"]}
        model.load_state_dict(state_dict)
        # model.load_state_dict(all_states["state_dict"], strict=False)

    # Handle device placement
    if conf["use_gpu"]:
        model.cuda()
    model_device = next(model.parameters()).device
    test_set = make_test_dataset(
        corpus=conf["corpus"], 
        test_dir=conf["test_dir"],
        task=conf["task"],
        sample_rate=conf["sample_rate"],
        n_src=conf["train_conf"]["data"]["n_src"],
        )
    # Used to reorder sources only
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")

    # all resulting files would be saved in eval_save_dir
    eval_save_dir = os.path.join(conf["exp_dir"], conf["out_dir"])
    os.makedirs(eval_save_dir, exist_ok=True)

    series_list = []
    torch.no_grad().__enter__()
    for idx in tqdm(range(len(test_set))):
        # Forward the network on the mixture.
        mix, sources = tensors_to_device(test_set[idx], device=model_device)
        est_sources = model(mix.unsqueeze(0))

        # When inferencing separation for multi-task training,
        # exclude the last channel. Does not effect single-task training
        # models (from_scratch, pre+FT).
        est_sources = est_sources[:, :sources.shape[0]]
        _, best_perm_idx = loss_func.find_best_perm(pairwise_neg_sisdr(est_sources, sources[None]), conf["train_conf"]["data"]["n_src"])

        utt_metrics = {}
        if hasattr(test_set, "mixture_path"):
            utt_metrics["mix_path"] = test_set.mixture_path
        utt_metrics["best_perm_idx"] = ' '.join([str(pidx) for pidx in perms[best_perm_idx[0]]])
        series_list.append(pd.Series(utt_metrics))

    # Save all metrics to the experiment folder.
    all_metrics_df = pd.DataFrame(series_list)
    all_metrics_df.to_csv(os.path.join(eval_save_dir, "best_perms.csv"))
Exemple #3
0
 def forward_wav(self, wav, slice_size=32000, *args, **kwargs):
     """Separation method for waveforms.
     Unfolds a full audio into slices, estimate
     Args:
         wav (torch.Tensor): waveform array/tensor.
             Shape: 1D, 2D or 3D tensor, time last.
     Return:
         output_cat (torch.Tensor): concatenated output tensor.
             [num_spks, T]
     """
     assert not self.training, "forward_wav is only used for test mode"
     T = wav.size(-1)
     if wav.ndim == 1:
         wav = wav.reshape(1, wav.size(0))
     assert wav.ndim == 2  # [1, T]
     slice_stride = slice_size // 2
     # pad wav to integer multiple of slice_stride
     T_padded = max(int(np.ceil(T / slice_stride)), 2) * slice_stride
     wav = F.pad(wav, (0, T_padded - T))
     slices = wav.unfold(
         dimension=-1, size=slice_size, step=slice_stride
     )  # [1, slice_nb, slice_size]
     slice_nb = slices.size(1)
     slices = slices.squeeze(0).unsqueeze(1)
     tf_rep = self.enc_activation(self.encoder(slices))
     est_masks_list = self.masker(tf_rep)
     selector_input = est_masks_list[-1]  # [slice_nb, bn_chan, chunk_size, n_chunks]
     selector_output = self.decoder_select.selector(selector_input).reshape(
         slice_nb, -1
     )  # [slice_nb, num_decs]
     est_idx, _ = selector_output.argmax(-1).mode()
     est_spks = self.decoder_select.n_srcs[est_idx]
     output_wavs, _ = self.decoder_select(
         est_masks_list, tf_rep, ground_truth=[est_spks] * slice_nb
     )  # [slice_nb, 1, n_spks, slice_size]
     output_wavs = output_wavs.squeeze(1)[:, :est_spks, :]
     # TODO: overlap and add (with division)
     output_cat = output_wavs.new_zeros(est_spks, slice_nb * slice_size)
     output_cat[:, :slice_size] = output_wavs[0]
     start = slice_stride
     for i in range(1, slice_nb):
         end = start + slice_size
         overlap_prev = output_cat[:, start : start + slice_stride].unsqueeze(0)
         overlap_next = output_wavs[i : i + 1, :, :slice_stride]
         pw_losses = pairwise_neg_sisdr(overlap_next, overlap_prev)
         _, best_indices = PITLossWrapper.find_best_perm(pw_losses)
         reordered = PITLossWrapper.reorder_source(output_wavs[i : i + 1, :, :], best_indices)
         output_cat[:, start : start + slice_size] += reordered.squeeze(0)
         output_cat[:, start : start + slice_stride] /= 2
         start += slice_stride
     return output_cat[:, :T]
Exemple #4
0
def _reorder_sources(
    current: torch.FloatTensor,
    previous: torch.FloatTensor,
    n_src: int,
    window_size: int,
    hop_size: int,
):
    """
     Reorder sources in current chunk to maximize correlation with previous chunk.
     Used for Continuous Source Separation. Standard dsp correlation is used
     for reordering.


    Args:
        current (:class:`torch.Tensor`): current chunk, tensor
                                        of shape (batch, n_src, window_size)
        previous (:class:`torch.Tensor`): previous chunk, tensor
                                        of shape (batch, n_src, window_size)
        n_src (:class:`int`): number of sources.
        window_size (:class:`int`): window_size, equal to last dimension of
                                    both current and previous.
        hop_size (:class:`int`): hop_size between current and previous tensors.

    Returns:
        current:

    """
    batch, frames = current.size()
    current = current.reshape(-1, n_src, frames)
    previous = previous.reshape(-1, n_src, frames)

    overlap_f = window_size - hop_size
    pw_losses = PITLossWrapper.get_pw_losses(
        lambda x, y: torch.sum((x.unsqueeze(1) * y.unsqueeze(2))),
        current[..., :overlap_f],
        previous[..., -overlap_f:],
    )
    _, perms = PITLossWrapper.find_best_perm(pw_losses, n_src)
    current = PITLossWrapper.reorder_source(current, n_src, perms)
    return current.reshape(batch, frames)