Example #1
0
def test_sisdr_and_mse(n_src, loss):
    # Unpack the triplet
    pairwise, singlesrc, multisrc, _ = loss
    # Fake targets and estimates
    targets = torch.randn(2, n_src, 10000)
    est_targets = torch.randn(2, n_src, 10000)
    # Create the 3 PIT wrappers
    pw_wrapper = PITLossWrapper(pairwise, pit_from="pw_mtx")
    wo_src_wrapper = PITLossWrapper(singlesrc, pit_from="pw_pt")
    w_src_wrapper = PITLossWrapper(multisrc, pit_from="perm_avg")

    # Circular tests on value
    assert_allclose(pw_wrapper(est_targets, targets),
                    wo_src_wrapper(est_targets, targets))
    assert_allclose(wo_src_wrapper(est_targets, targets),
                    w_src_wrapper(est_targets, targets))

    # Circular tests on returned estimates
    assert_allclose(
        pw_wrapper(est_targets, targets, return_est=True)[1],
        wo_src_wrapper(est_targets, targets, return_est=True)[1],
    )
    assert_allclose(
        wo_src_wrapper(est_targets, targets, return_est=True)[1],
        w_src_wrapper(est_targets, targets, return_est=True)[1],
    )
def test_sisdr(n_src, function_triplet):
    # Unpack the triplet
    pairwise, nosrc, nonpit = function_triplet
    # Fake targets and estimates
    targets = torch.randn(2, n_src, 10000)
    est_targets = torch.randn(2, n_src, 10000)
    # Create the 3 PIT wrappers
    pw_wrapper = PITLossWrapper(pairwise, pit_from="pw_mtx")
    wo_src_wrapper = PITLossWrapper(nosrc, pit_from="pw_pt")
    w_src_wrapper = PITLossWrapper(nonpit, pit_from="perm_avg")

    # Circular tests on value
    assert_allclose(pw_wrapper(est_targets, targets),
                    wo_src_wrapper(est_targets, targets))
    assert_allclose(wo_src_wrapper(est_targets, targets),
                    w_src_wrapper(est_targets, targets))

    # Circular tests on returned estimates
    assert_allclose(
        pw_wrapper(est_targets, targets, return_est=True)[1],
        wo_src_wrapper(est_targets, targets, return_est=True)[1],
    )
    assert_allclose(
        wo_src_wrapper(est_targets, targets, return_est=True)[1],
        w_src_wrapper(est_targets, targets, return_est=True)[1],
    )
Example #3
0
 def __init__(self, alpha=0.1):
     super().__init__()
     assert alpha >= 0, "Negative alpha values don't make sense."
     assert alpha <= 1, "Alpha values above 1 don't make sense."
     # PIT loss
     self.src_mse = PITLossWrapper(pairwise_mse, pit_from='pw_mtx')
     self.alpha = alpha
Example #4
0
def test_proximity_sinkhorn_hungarian(batch_size, n_src, beta, n_iter,
                                      function_triplet):
    time = 16000
    noise_level = 0.1
    pairwise, nosrc, nonpit = function_triplet

    # random data
    targets = torch.randn(batch_size, n_src, time) * 10  # ground truth
    noise = torch.randn(batch_size, n_src, time) * noise_level
    est_targets = (targets[:, torch.randperm(n_src), :] + noise
                   )  # reorder channels, and add small noise

    # initialize wrappers
    loss_sinkhorn = SinkPITLossWrapper(pairwise, n_iter=n_iter)
    loss_hungarian = PITLossWrapper(pairwise, pit_from="pw_mtx")

    # compute loss by sinkhorn
    loss_sinkhorn.beta = beta
    mean_loss_sinkhorn = loss_sinkhorn(est_targets, targets, return_est=False)

    # compute loss by hungarian
    mean_loss_hungarian = loss_hungarian(est_targets,
                                         targets,
                                         return_est=False)

    # compare
    assert_allclose(mean_loss_sinkhorn, mean_loss_hungarian)
Example #5
0
def test_system():
    discriminator = Discriminator()
    generator = nn.Sequential(nn.Linear(10, 10), nn.ReLU())
    opt_d = optim.Adam(discriminator.parameters(), lr=1e-3)
    opt_g = optim.Adam(generator.parameters(), lr=1e-3)
    scheduler_d = ReduceLROnPlateau(optimizer=opt_d, factor=0.5, patience=5)
    scheduler_g = ReduceLROnPlateau(optimizer=opt_g, factor=0.5, patience=5)
    g_loss = GeneratorLoss()
    d_loss = DiscriminatorLoss()
    validation_loss = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
    dataset = DummyDataset()
    loader = data.DataLoader(dataset, batch_size=4, num_workers=4)
    gan = TrainGAN(discriminator=discriminator,
                   generator=generator,
                   opt_d=opt_d,
                   opt_g=opt_g,
                   discriminator_loss=d_loss,
                   generator_loss=g_loss,
                   validation_loss=validation_loss,
                   train_loader=loader,
                   val_loader=loader,
                   scheduler_d=scheduler_d,
                   scheduler_g=scheduler_g)
    trainer = Trainer(max_epochs=1, fast_dev_run=True)
    trainer.fit(gan)
Example #6
0
def main(conf):
    model = get_model(conf)
    test_set = WhamDataset(conf['test_dir'],
                           conf['task'],
                           sample_rate=conf['sample_rate'],
                           nondefault_nsrc=conf['nondefault_nsrc'],
                           segment=None)
    loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise')
    model_device = next(model.parameters()).device
    for idx in range(len(test_set)):
        mix, sources, _ = tensors_to_device(test_set[idx], device=model_device)
        est_sources = model(mix)
        loss, reordered_sources = loss_func(sources,
                                            est_sources,
                                            return_est=True)
        mix_np = mix.data.numpy()[0]
        sources_np = sources.data.numpy()[0]
        est_sources_np = reordered_sources.data.numpy()[0]
        # Waiting for pb_bss support to compute subset of metrics.
        # We will probably want SI-SDR,  + add option for mir_eval SDR, stoi,
        # pesq
        input_metrics = InputMetrics(observation=mix_np,
                                     speech_source=sources_np,
                                     enable_si_sdr=True,
                                     sample_rate=conf["sample_rate"])
        output_metrics = OutputMetrics(speech_prediction=est_sources_np,
                                       speech_source=sources_np,
                                       enable_si_sdr=True,
                                       sample_rate=conf["sample_rate"])
Example #7
0
def main(conf):
    # from asteroid.data.toy_data import WavSet
    # train_set = WavSet(n_ex=1000, n_src=2, ex_len=32000)
    # val_set = WavSet(n_ex=1000, n_src=2, ex_len=32000)
    # Define data pipeline
    train_set = WhamDataset(conf['data']['train_dir'],
                            conf['data']['task'],
                            sample_rate=conf['data']['sample_rate'],
                            nondefault_nsrc=conf['data']['nondefault_nsrc'])
    val_set = WhamDataset(conf['data']['valid_dir'],
                          conf['data']['task'],
                          sample_rate=conf['data']['sample_rate'],
                          nondefault_nsrc=conf['data']['nondefault_nsrc'])

    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'])
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'])
    conf['masknet'].update({'n_src': train_set.n_src})

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(conf)

    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
    # loss_class = PITLossContainer(pairwise_neg_sisdr, n_src=train_set.n_src)
    # Checkpointing callback can monitor any quantity which is returned by
    # validation step, defaults to val_loss here (see System).
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor='val_loss',
                                 mode='min',
                                 save_best_only=False)
    # New PL version will come the 7th of december / will have save_top_k
    system = System(model=model,
                    loss_func=loss_func,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    config=conf)
    trainer = pl.Trainer(max_nb_epochs=conf['training']['epochs'],
                         checkpoint_callback=checkpoint,
                         default_save_path=exp_dir,
                         gpus=conf['main_args']['gpus'],
                         distributed_backend='dp')
    trainer.fit(system)
Example #8
0
def test_negstoi_pit(n_src, sample_rate, use_vad, extended):
    ref, est = torch.randn(2, n_src, 16000), torch.randn(2, n_src, 16000)
    singlesrc_negstoi = SingleSrcNegSTOI(sample_rate=sample_rate,
                                         use_vad=use_vad,
                                         extended=extended)
    loss_func = PITLossWrapper(singlesrc_negstoi, pit_from='pw_pt')
    # Assert forward ok.
    loss_value = loss_func(est, ref)
def main(conf):
    perms = list(permutations(range(conf["train_conf"]["data"]["n_src"])))

    model_path = os.path.join(conf["exp_dir"], conf["ckpt_path"])
    if conf["ckpt_path"] == "best_model.pth":
        # serialized checkpoint
        model = getattr(asteroid, conf["model"]).from_pretrained(model_path)
    else:
        # non-serialized checkpoint, _ckpt_epoch_{i}.ckpt, keys would start with
        # "model.", which need to be removed
        model = getattr(asteroid, conf["model"])(**conf["train_conf"]["filterbank"], **conf["train_conf"]["masknet"])
        all_states = torch.load(model_path, map_location="cpu")
        state_dict = {k.split('.', 1)[1]: all_states["state_dict"][k] for k in all_states["state_dict"]}
        model.load_state_dict(state_dict)
        # model.load_state_dict(all_states["state_dict"], strict=False)

    # Handle device placement
    if conf["use_gpu"]:
        model.cuda()
    model_device = next(model.parameters()).device
    test_set = make_test_dataset(
        corpus=conf["corpus"], 
        test_dir=conf["test_dir"],
        task=conf["task"],
        sample_rate=conf["sample_rate"],
        n_src=conf["train_conf"]["data"]["n_src"],
        )
    # Used to reorder sources only
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")

    # all resulting files would be saved in eval_save_dir
    eval_save_dir = os.path.join(conf["exp_dir"], conf["out_dir"])
    os.makedirs(eval_save_dir, exist_ok=True)

    series_list = []
    torch.no_grad().__enter__()
    for idx in tqdm(range(len(test_set))):
        # Forward the network on the mixture.
        mix, sources = tensors_to_device(test_set[idx], device=model_device)
        est_sources = model(mix.unsqueeze(0))

        # When inferencing separation for multi-task training,
        # exclude the last channel. Does not effect single-task training
        # models (from_scratch, pre+FT).
        est_sources = est_sources[:, :sources.shape[0]]
        _, best_perm_idx = loss_func.find_best_perm(pairwise_neg_sisdr(est_sources, sources[None]), conf["train_conf"]["data"]["n_src"])

        utt_metrics = {}
        if hasattr(test_set, "mixture_path"):
            utt_metrics["mix_path"] = test_set.mixture_path
        utt_metrics["best_perm_idx"] = ' '.join([str(pidx) for pidx in perms[best_perm_idx[0]]])
        series_list.append(pd.Series(utt_metrics))

    # Save all metrics to the experiment folder.
    all_metrics_df = pd.DataFrame(series_list)
    all_metrics_df.to_csv(os.path.join(eval_save_dir, "best_perms.csv"))
Example #10
0
def test_negstoi_pit(n_src, sample_rate, use_vad, extended):
    ref, est = torch.randn(2, n_src, 8000), torch.randn(2, n_src, 8000)
    singlesrc_negstoi = SingleSrcNegSTOI(sample_rate=sample_rate,
                                         use_vad=use_vad,
                                         extended=extended)
    loss_func = PITLossWrapper(singlesrc_negstoi, pit_from="pw_pt")
    # Assert forward ok.
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        loss_func(est, ref)
Example #11
0
 def __init__(self, num_srcs, n_fft, hop_length, win_length, window,
              center):
     self.num_srcs = num_srcs
     self.n_fft = n_fft
     self.hop_length = hop_length
     self.win_length = win_length
     if window == 'hann':
         self.window = torch.hann_window(win_length).cuda()
     self.center = center
     self.loss = PITLossWrapper(PairwiseNegSDR("sisdr"), pit_from="pw_mtx")
Example #12
0
def test_sisdr(n_src):
    targets = torch.randn(2, n_src, 32000)
    est_targets = torch.randn(2, n_src, 32000)
    pw_wrapper = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise')
    wo_src_wrapper = PITLossWrapper(nosrc_neg_sisdr, mode='wo_src')
    w_src_wrapper = PITLossWrapper(nonpit_neg_sisdr, mode='w_src')

    pw = pw_wrapper(targets, est_targets)
    wo_src = wo_src_wrapper(targets, est_targets)
    w_src = w_src_wrapper(targets, est_targets)

    assert_allclose(pw_wrapper(targets, est_targets),
                    wo_src_wrapper(targets, est_targets))
    assert_allclose(w_src_wrapper(targets, est_targets),
                    wo_src_wrapper(targets, est_targets))

    assert_allclose(pw_wrapper(targets, est_targets, return_est=True)[1],
                    wo_src_wrapper(targets, est_targets, return_est=True)[1])
    assert_allclose(w_src_wrapper(targets, est_targets, return_est=True)[1],
                    wo_src_wrapper(targets, est_targets, return_est=True)[1])
Example #13
0
def train_model_part(conf, train_part='filterbank', pretrained_filterbank=None):
    train_loader, val_loader = get_data_loaders(conf, train_part=train_part)

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(
        conf, model_part=train_part, pretrained_filterbank=pretrained_filterbank
    )
    # Define scheduler
    scheduler = None
    if conf[train_part + '_training'][train_part[0] + '_half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir, checkpoint_dir = get_encoded_paths(conf, train_part)
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(PairwiseNegSDR('sisdr', zero_mean=False),
                               pit_from='pw_mtx')
    system = SystemTwoStep(model=model, loss_func=loss_func,
                           optimizer=optimizer, train_loader=train_loader,
                           val_loader=val_loader, scheduler=scheduler,
                           config=conf, module=train_part)

    # Define callbacks
    checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss',
                                 mode='min', save_top_k=1, verbose=1)
    early_stopping = False
    if conf[train_part + '_training'][train_part[0] + '_early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss', patience=10,
                                       verbose=1)
    # Don't ask GPU if they are not available.
    if not torch.cuda.is_available():
        print('No available GPU were found, set gpus to None')
        conf['main_args']['gpus'] = None

    trainer = pl.Trainer(
        max_nb_epochs=conf[train_part + '_training'][train_part[0] + '_epochs'],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_save_path=exp_dir,
        gpus=conf['main_args']['gpus'],
        distributed_backend='dp',
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.)
    trainer.fit(system)

    with open(os.path.join(checkpoint_dir, "best_k_models.json"), "w") as file:
        json.dump(checkpoint.best_k_models, file, indent=0)
Example #14
0
def test_permreduce():
    from functools import partial
    n_src = 3
    sources = torch.randn(10, n_src, 8000)
    est_sources = torch.randn(10, n_src, 8000)
    wo_reduce = PITLossWrapper(pairwise_mse, pit_from='pw_mtx')
    w_mean_reduce = PITLossWrapper(
        pairwise_mse,
        pit_from='pw_mtx',
        # perm_reduce=partial(torch.mean, dim=-1))
        perm_reduce=lambda x: torch.mean(x, dim=-1))
    w_sum_reduce = PITLossWrapper(pairwise_mse,
                                  pit_from='pw_mtx',
                                  perm_reduce=partial(torch.sum, dim=-1))

    wo = wo_reduce(est_sources, sources)
    w_mean = w_mean_reduce(est_sources, sources)
    w_sum = w_sum_reduce(est_sources, sources)

    assert_allclose(wo, w_mean)
    assert_allclose(wo, w_sum / n_src)
Example #15
0
def test_multi_scale_spectral_PIT(n_src):
    # Test in with reduced number of STFT scales.
    filt_list = [512, 256, 32]
    # Fake targets and estimates
    targets = torch.randn(2, n_src, 8000)
    est_targets = torch.randn(2, n_src, 8000)
    # Create PITLossWrapper in 'pw_pt' mode
    pt_loss = SingleSrcMultiScaleSpectral(windows_size=filt_list,
                                          n_filters=filt_list,
                                          hops_size=filt_list)
    loss_func = PITLossWrapper(pt_loss, pit_from='pw_pt')
    # Compute the loss
    loss = loss_func(targets, est_targets)
Example #16
0
def test_permutation(perm):
    """ Construct fake target/estimates pair. Check the value and reordering."""
    n_src = len(perm)
    perm_tensor = torch.Tensor(perm)
    source_base = torch.ones(1, n_src, 10)
    sources = torch.arange(n_src).unsqueeze(-1) * source_base
    est_sources = perm_tensor.unsqueeze(-1) * source_base

    loss_func = PITLossWrapper(pairwise_mse)
    loss_value, reordered = loss_func(est_sources, sources, return_est=True)

    assert loss_value.item() == 0
    assert_allclose(sources, reordered)
def test_pmsqe_pit(n_src, sample_rate):
    # Define supported STFT
    if sample_rate == 16000:
        stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
    else:
        stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128))
    # Usage by itself
    ref, est = torch.randn(2, n_src, 16000), torch.randn(2, n_src, 16000)
    ref_spec = transforms.mag(stft(ref))
    est_spec = transforms.mag(stft(est))
    loss_func = PITLossWrapper(SingleSrcPMSQE(sample_rate=sample_rate), pit_from="pw_pt")
    # Assert forward ok.
    loss_func(est_spec, ref_spec)
Example #18
0
def test_wrapper(batch_size, n_src, time):
    targets = torch.randn(batch_size, n_src, time)
    est_targets = torch.randn(batch_size, n_src, time)
    for bad_loss_func in [bad_loss_func_ndim0, bad_loss_func_ndim1]:
        loss = PITLossWrapper(bad_loss_func)
        with pytest.raises(AssertionError):
            loss(est_targets, targets)
    # wo_src loss function / With and without return estimates
    loss = PITLossWrapper(good_batch_loss_func, pit_from="pw_pt")
    loss(est_targets, targets)
    loss_value, reordered_est = loss(est_targets, targets, return_est=True)
    assert reordered_est.shape == est_targets.shape

    # pairwise loss function / With and without return estimates
    loss = PITLossWrapper(good_pairwise_loss_func, pit_from="pw_mtx")
    loss(est_targets, targets)
    loss_value, reordered_est = loss(est_targets, targets, return_est=True)
    assert reordered_est.shape == est_targets.shape

    # w_src loss function / With and without return estimates
    loss = PITLossWrapper(good_batch_loss_func, pit_from="perm_avg")
    loss(est_targets, targets)
    loss_value, reordered_est = loss(est_targets, targets, return_est=True)
    assert reordered_est.shape == est_targets.shape
Example #19
0
def test_permreduce_args():
    def reduce_func(perm_losses, class_weights=None):
        # perm_losses is (batch , n_perms, n_src) for now
        if class_weights is None:
            return torch.mean(perm_losses, dim=-1)
        if class_weights.ndim == 2:
            class_weights = class_weights.unsqueeze(1)
        return torch.mean(perm_losses * class_weights, -1)

    n_src = 3
    sources = torch.randn(10, n_src, 8000)
    est_sources = torch.randn(10, n_src, 8000)
    loss_func = PITLossWrapper(pairwise_mse, pit_from="pw_mtx", perm_reduce=reduce_func)
    weights = torch.softmax(torch.randn(10, n_src), dim=-1)
    loss_func(est_sources, sources, reduce_kwargs={"class_weights": weights})
Example #20
0
def test_wrapper(batch_size, n_src, time):
    targets = torch.randn(batch_size, n_src, time)
    est_targets = torch.randn(batch_size, n_src, time)
    for bad_loss_func in [bad_loss_func_ndim0, bad_loss_func_ndim1]:
        loss = PITLossWrapper(bad_loss_func)
        with pytest.raises(AssertionError):
            loss(targets, est_targets)
    # wo_src loss function / With and without return estimates
    loss = PITLossWrapper(good_batch_loss_func, mode='wo_src')
    loss_value_no_return = loss(targets, est_targets)
    loss_value, reordered_est = loss(targets, est_targets, return_est=True)
    assert reordered_est.shape == est_targets.shape

    # pairwise loss function / With and without return estimates
    loss = PITLossWrapper(good_pairwise_loss_func, mode='pairwise')
    loss_value_no_return = loss(targets, est_targets)
    loss_value, reordered_est = loss(targets, est_targets, return_est=True)
    assert reordered_est.shape == est_targets.shape

    # w_src loss function / With and without return estimates
    loss = PITLossWrapper(good_batch_loss_func, mode='w_src')
    loss_value_no_return = loss(targets, est_targets)
    loss_value, reordered_est = loss(targets, est_targets, return_est=True)
    assert reordered_est.shape == est_targets.shape
def test_sisdr(n_src, function_triplet):
    # Unpack the triplet
    pairwise, nosrc, nonpit = function_triplet
    # Fake targets and estimates
    targets = torch.randn(2, n_src, 32000)
    est_targets = torch.randn(2, n_src, 32000)
    # Create the 3 PIT wrappers
    pw_wrapper = PITLossWrapper(pairwise, mode='pairwise')
    wo_src_wrapper = PITLossWrapper(nosrc, mode='wo_src')
    w_src_wrapper = PITLossWrapper(nonpit, mode='w_src')

    # Circular tests on value
    assert_allclose(pw_wrapper(est_targets, targets),
                    wo_src_wrapper(est_targets, targets))
    assert_allclose(wo_src_wrapper(est_targets, targets),
                    w_src_wrapper(est_targets, targets))

    # Circular tests on returned estimates
    assert_allclose(
        pw_wrapper(est_targets, targets, return_est=True)[1],
        wo_src_wrapper(est_targets, targets, return_est=True)[1])
    assert_allclose(
        wo_src_wrapper(est_targets, targets, return_est=True)[1],
        w_src_wrapper(est_targets, targets, return_est=True)[1])
Example #22
0
def _reorder_sources(
    current: torch.FloatTensor,
    previous: torch.FloatTensor,
    n_src: int,
    window_size: int,
    hop_size: int,
):
    """
     Reorder sources in current chunk to maximize correlation with previous chunk.
     Used for Continuous Source Separation. Standard dsp correlation is used
     for reordering.


    Args:
        current (:class:`torch.Tensor`): current chunk, tensor
                                        of shape (batch, n_src, window_size)
        previous (:class:`torch.Tensor`): previous chunk, tensor
                                        of shape (batch, n_src, window_size)
        n_src (:class:`int`): number of sources.
        window_size (:class:`int`): window_size, equal to last dimension of
                                    both current and previous.
        hop_size (:class:`int`): hop_size between current and previous tensors.

    Returns:
        current:

    """
    batch, frames = current.size()
    current = current.reshape(-1, n_src, frames)
    previous = previous.reshape(-1, n_src, frames)

    overlap_f = window_size - hop_size

    def reorder_func(x, y):
        x = x[..., :overlap_f]
        y = y[..., -overlap_f:]
        # Mean normalization
        x = x - x.mean(-1, keepdim=True)
        y = y - y.mean(-1, keepdim=True)
        # Negative mean Correlation
        return -torch.sum(x.unsqueeze(1) * y.unsqueeze(2), dim=-1)

    # We maximize correlation-like between previous and current.
    pit = PITLossWrapper(reorder_func)
    current = pit(current, previous, return_est=True)[1]
    return current.reshape(batch, frames)
Example #23
0
def main(conf):
    set_trace()
    test_set = WSJ2mixDataset(conf['data']['tt_wav_len_list'],
                              conf['data']['wav_base_path'] + '/tt',
                              sample_rate=conf['data']['sample_rate'])
    test_loader = DataLoader(test_set,
                             shuffle=True,
                             batch_size=1,
                             num_workers=conf['data']['num_workers'],
                             drop_last=False)
    istft = fb.Decoder(fb.STFTFB(**conf['filterbank']))
    exp_dir = conf['main_args']['exp_dir']
    model_path = os.path.join(exp_dir, 'checkpoints/_ckpt_epoch_0.ckpt')
    model = load_best_model(conf, model_path)
    pit_loss = PITLossWrapper(pairwise_mse, mode='pairwise')

    system = DcSystem(model, None, None, None, config=conf)

    # Randomly choose the indexes of sentences to save.
    exp_dir = conf['main_args']['exp_dir']
    exp_save_dir = os.path.join(exp_dir, 'examples/')
    n_save = conf['main_args']['n_save_ex']
    if n_save == -1:
        n_save = len(test_set)
    save_idx = random.sample(range(len(test_set)), n_save)
    series_list = []
    torch.no_grad().__enter__()

    for batch in test_loader:
        batch = [ele.type(torch.float32) for ele in batch]
        inputs, targets, masks = system.unpack_data(batch)
        est_targets = system(inputs)
        mix_stft = system.enc(inputs.unsqueeze(1))
        min_loss, min_idx = pit_loss.best_perm_from_perm_avg_loss(\
                pairwise_mse, est_targets[1], masks)
        for sidx in min_idx:
            src_stft = mix_stft * est_targets[1][sidx]
            src_sig = istft(src_stft)
Example #24
0
def main(conf):
    train_set = WhamDataset(
        conf["data"]["train_dir"],
        conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        segment=conf["data"]["segment"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )
    val_set = WhamDataset(
        conf["data"]["valid_dir"],
        conf["data"]["task"],
        sample_rate=conf["data"]["sample_rate"],
        nondefault_nsrc=conf["data"]["nondefault_nsrc"],
    )

    train_loader = DataLoader(
        train_set,
        shuffle=True,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    val_loader = DataLoader(
        val_set,
        shuffle=False,
        batch_size=conf["training"]["batch_size"],
        num_workers=conf["training"]["num_workers"],
        drop_last=True,
    )
    # Update number of source values (It depends on the task)
    conf["masknet"].update({"n_src": train_set.n_src})

    model = DPTNet(**conf["filterbank"], **conf["masknet"])
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    from asteroid.engine.schedulers import DPTNetScheduler

    schedulers = {
        "scheduler":
        DPTNetScheduler(optimizer,
                        len(train_loader) // conf["training"]["batch_size"],
                        64),
        "interval":
        "step",
    }

    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf["main_args"]["exp_dir"]
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, "conf.yml")
    with open(conf_path, "w") as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
    system = System(
        model=model,
        loss_func=loss_func,
        optimizer=optimizer,
        scheduler=schedulers,
        train_loader=train_loader,
        val_loader=val_loader,
        config=conf,
    )

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, "checkpoints/")
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor="val_loss",
                                 mode="min",
                                 save_top_k=5,
                                 verbose=True)
    early_stopping = False
    if conf["training"]["early_stop"]:
        early_stopping = EarlyStopping(monitor="val_loss",
                                       patience=30,
                                       verbose=True)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend="ddp",
        gradient_clip_val=conf["training"]["gradient_clipping"],
    )
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    state_dict = torch.load(checkpoint.best_model_path)
    system.load_state_dict(state_dict=state_dict["state_dict"])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
Example #25
0
def main(conf):
    train_set = LibriMix(csv_dir=conf['data']['train_dir'],
                         task=conf['data']['task'],
                         sample_rate=conf['data']['sample_rate'],
                         n_src=conf['data']['n_src'],
                         segment=conf['data']['segment'])

    val_set = LibriMix(csv_dir=conf['data']['valid_dir'],
                       task=conf['data']['task'],
                       sample_rate=conf['data']['sample_rate'],
                       n_src=conf['data']['n_src'],
                       segment=conf['data']['segment'])

    train_loader = DataLoader(train_set, shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)

    val_loader = DataLoader(val_set, shuffle=True,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)
    conf['masknet'].update({'n_src': conf['data']['n_src']})

    # Define model and optimizer in a local function (defined in the recipe).
    # Two advantages to this : re-instantiating the model and optimizer
    # for retraining and evaluating is straight-forward.
    model, optimizer = make_model_and_optimizer(conf)
    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, mode='pairwise')
    system = System(model=model, loss_func=loss_func, optimizer=optimizer,
                    train_loader=train_loader, val_loader=val_loader,
                    scheduler=scheduler, config=conf)

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss',
                                 mode='min', save_top_k=5, verbose=1)
    early_stopping = False
    if conf['training']['early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss', patience=10,
                                       verbose=1)

    # Don't ask GPU if they are not available.
    if not torch.cuda.is_available():
        print('No available GPU were found, set gpus to None')
        conf['main_args']['gpus'] = None
    trainer = pl.Trainer(max_epochs=conf['training']['epochs'],
                         checkpoint_callback=checkpoint,
                         early_stop_callback=early_stopping,
                         default_save_path=exp_dir,
                         gpus=conf['main_args']['gpus'],
                         distributed_backend='dp',
                         train_percent_check=1.0,  # Useful for fast experiment
                         gradient_clip_val=5.)
    trainer.fit(system)

    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(checkpoint.best_k_models, f, indent=0)
Example #26
0
def main(conf):
    model_path = os.path.join(conf['exp_dir'], 'best_model.pth')
    model = DPRNNTasNet.from_pretrained(model_path)
    # Handle device placement
    if conf['use_gpu']:
        model.cuda()
    model_device = next(model.parameters()).device
    test_set = WhamDataset(conf['test_dir'], conf['task'],
                           sample_rate=conf['sample_rate'],
                           nondefault_nsrc=model.masker.n_src,
                           segment=None)  # Uses all segment length
    # Used to reorder sources only
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')

    # Randomly choose the indexes of sentences to save.
    ex_save_dir = os.path.join(conf['exp_dir'], 'examples/')
    if conf['n_save_ex'] == -1:
        conf['n_save_ex'] = len(test_set)
    save_idx = random.sample(range(len(test_set)), conf['n_save_ex'])
    series_list = []
    torch.no_grad().__enter__()
    for idx in tqdm(range(len(test_set))):
        # Forward the network on the mixture.
        mix, sources = tensors_to_device(test_set[idx], device=model_device)
        est_sources = model(mix[None, None])
        loss, reordered_sources = loss_func(est_sources, sources[None],
                                            return_est=True)
        mix_np = mix[None].cpu().data.numpy()
        sources_np = sources.squeeze().cpu().data.numpy()
        est_sources_np = reordered_sources.squeeze().cpu().data.numpy()
        utt_metrics = get_metrics(mix_np, sources_np, est_sources_np,
                                  sample_rate=conf['sample_rate'])
        utt_metrics['mix_path'] = test_set.mix[idx][0]
        series_list.append(pd.Series(utt_metrics))

        # Save some examples in a folder. Wav files and metrics as text.
        if idx in save_idx:
            local_save_dir = os.path.join(ex_save_dir, 'ex_{}/'.format(idx))
            os.makedirs(local_save_dir, exist_ok=True)
            sf.write(local_save_dir + "mixture.wav", mix_np[0],
                     conf['sample_rate'])
            # Loop over the sources and estimates
            for src_idx, src in enumerate(sources_np):
                sf.write(local_save_dir + "s{}.wav".format(src_idx+1), src,
                         conf['sample_rate'])
            for src_idx, est_src in enumerate(est_sources_np):
                sf.write(local_save_dir + "s{}_estimate.wav".format(src_idx+1),
                         est_src, conf['sample_rate'])
            # Write local metrics to the example folder.
            with open(local_save_dir + 'metrics.json', 'w') as f:
                json.dump(utt_metrics, f, indent=0)

    # Save all metrics to the experiment folder.
    all_metrics_df = pd.DataFrame(series_list)
    all_metrics_df.to_csv(os.path.join(conf['exp_dir'], 'all_metrics.csv'))

    # Print and save summary metrics
    final_results = {}
    for metric_name in compute_metrics:
        input_metric_name = 'input_' + metric_name
        ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name]
        final_results[metric_name] = all_metrics_df[metric_name].mean()
        final_results[metric_name + '_imp'] = ldf.mean()
    print('Overall metrics :')
    pprint(final_results)
    with open(os.path.join(conf['exp_dir'], 'final_metrics.json'), 'w') as f:
        json.dump(final_results, f, indent=0)
    model_dict = torch.load(model_path, map_location='cpu')

    publishable = save_publishable(
        os.path.join(conf['exp_dir'], 'publish_dir'), model_dict,
        metrics=final_results, train_conf=train_conf
    )
Example #27
0
def main(conf):
    train_set = WhamDataset(conf['data']['train_dir'],
                            conf['data']['task'],
                            sample_rate=conf['data']['sample_rate'],
                            nondefault_nsrc=conf['data']['nondefault_nsrc'])
    val_set = WhamDataset(conf['data']['valid_dir'],
                          conf['data']['task'],
                          sample_rate=conf['data']['sample_rate'],
                          nondefault_nsrc=conf['data']['nondefault_nsrc'])

    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=conf['training']['batch_size'],
                              num_workers=conf['training']['num_workers'],
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=conf['training']['batch_size'],
                            num_workers=conf['training']['num_workers'],
                            drop_last=True)
    # Update number of source values (It depends on the task)
    conf['masknet'].update({'n_src': train_set.n_src})

    # Define model and optimizer
    model = ConvTasNet(**conf['filterbank'], **conf['masknet'])
    optimizer = make_optimizer(model.parameters(), **conf['optim'])
    # Define scheduler
    scheduler = None
    if conf['training']['half_lr']:
        scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                      factor=0.5,
                                      patience=5)
    # Just after instantiating, save the args. Easy loading in the future.
    exp_dir = conf['main_args']['exp_dir']
    os.makedirs(exp_dir, exist_ok=True)
    conf_path = os.path.join(exp_dir, 'conf.yml')
    with open(conf_path, 'w') as outfile:
        yaml.safe_dump(conf, outfile)

    # Define Loss function.
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')
    system = System(model=model,
                    loss_func=loss_func,
                    optimizer=optimizer,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    scheduler=scheduler,
                    config=conf)

    # Define callbacks
    checkpoint_dir = os.path.join(exp_dir, 'checkpoints/')
    checkpoint = ModelCheckpoint(checkpoint_dir,
                                 monitor='val_loss',
                                 mode='min',
                                 save_top_k=5,
                                 verbose=1)
    early_stopping = False
    if conf['training']['early_stop']:
        early_stopping = EarlyStopping(monitor='val_loss',
                                       patience=10,
                                       verbose=1)

    # Don't ask GPU if they are not available.
    gpus = -1 if torch.cuda.is_available() else None
    trainer = pl.Trainer(
        max_epochs=conf['training']['epochs'],
        checkpoint_callback=checkpoint,
        early_stop_callback=early_stopping,
        default_save_path=exp_dir,
        gpus=gpus,
        distributed_backend='dp',
        train_percent_check=1.0,  # Useful for fast experiment
        gradient_clip_val=5.)
    trainer.fit(system)

    best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
    with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
        json.dump(best_k, f, indent=0)

    # Save best model (next PL version will make this easier)
    best_path = [b for b, v in best_k.items() if v == min(best_k.values())][0]
    state_dict = torch.load(best_path)
    system.load_state_dict(state_dict=state_dict['state_dict'])
    system.cpu()

    to_save = system.model.serialize()
    to_save.update(train_set.get_infos())
    torch.save(to_save, os.path.join(exp_dir, 'best_model.pth'))
Example #28
0
def main(conf):
    model_path = os.path.join(conf["exp_dir"], "best_model.pth")
    #model = ConvTasNet.from_pretrained(model_path)
    model = DCUNet.from_pretrained(model_path)
    # Handle device placement
    if conf["use_gpu"]:
        model.cuda()
    model_device = next(model.parameters()).device

    test_set = BBCSODataset(
        conf["json_dir"],
        conf["n_src"],
        conf["sample_rate"],
        conf["batch_size"],
        220500,
        train = False
    )
    # Uses all segment length
    # Used to reorder sources only
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")

    # Randomly choose the indexes of sentences to save.
    ex_save_dir = os.path.join(conf["exp_dir"], "examples/")
    if conf["n_save_ex"] == -1:
        conf["n_save_ex"] = len(test_set)
    save_idx = random.sample(range(len(test_set)), conf["n_save_ex"])
    series_list = []
    torch.no_grad().__enter__()
    for idx in tqdm(range(len(test_set))):
        # Forward the network on the mixture.
        mix, sources = tensors_to_device(test_set[idx], device=model_device)
        mix = mix.unsqueeze(0)
        sources = sources.unsqueeze(0)
        est_sources = model(mix)
        loss, reordered_sources = loss_func(est_sources, sources, return_est=True)
        #mix_np = mix.squeeze(0).cpu().data.numpy()
        mix_np = mix.cpu().data.numpy()
        sources_np = sources.squeeze(0).cpu().data.numpy()
        est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy()
        utt_metrics = get_metrics(
            mix_np,
            sources_np,
            est_sources_np,
            sample_rate=conf["sample_rate"],
            metrics_list=compute_metrics,
        )
        #utt_metrics["mix_path"] = test_set.mix[idx][0]
        series_list.append(pd.Series(utt_metrics))
        
        # Save some examples in a folder. Wav files and metrics as text.
        if idx in save_idx:
            local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx))
            os.makedirs(local_save_dir, exist_ok=True)
            #print(mix_np.shape)
            sf.write(local_save_dir + "mixture.wav", np.swapaxes(mix_np,0,1), conf["sample_rate"])
            # Loop over the sources and estimates
            for src_idx, src in enumerate(sources_np):
                sf.write(local_save_dir + "s{}.wav".format(src_idx + 1), src, conf["sample_rate"])
            for src_idx, est_src in enumerate(est_sources_np):
                est_src *= np.max(np.abs(mix_np)) / np.max(np.abs(est_src))
                sf.write(
                    local_save_dir + "s{}_estimate.wav".format(src_idx + 1),
                    est_src,
                    conf["sample_rate"],
                )
            # Write local metrics to the example folder.
            with open(local_save_dir + "metrics.json", "w") as f:
                json.dump(utt_metrics, f, indent=0)

    # Save all metrics to the experiment folder.
    all_metrics_df = pd.DataFrame(series_list)
    all_metrics_df.to_csv(os.path.join(conf["exp_dir"], "all_metrics.csv"))

    # Print and save summary metrics
    final_results = {}
    for metric_name in compute_metrics:
        input_metric_name = "input_" + metric_name
        ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name]
        final_results[metric_name] = all_metrics_df[metric_name].mean()
        final_results[metric_name + "_imp"] = ldf.mean()
    print("Overall metrics :")
    pprint(final_results)
    with open(os.path.join(conf["exp_dir"], "final_metrics.json"), "w") as f:
        json.dump(final_results, f, indent=0)

    model_dict = torch.load(model_path, map_location="cpu")
    os.makedirs(os.path.join(conf["exp_dir"], "publish_dir"), exist_ok=True)
    publishable = save_publishable(
        os.path.join(conf["exp_dir"], "publish_dir"),
        model_dict,
        metrics=final_results,
        train_conf=train_conf,
    )
Example #29
0
def main(conf):
    # Make the model
    model, _ = make_model_and_optimizer(conf['train_conf'])
    # Load best model
    with open(os.path.join(conf['exp_dir'], 'best_k_models.json'), "r") as f:
        best_k = json.load(f)
    best_model_path = min(best_k, key=best_k.get)
    # Load checkpoint
    checkpoint = torch.load(best_model_path, map_location='cpu')
    state = checkpoint['state_dict']
    state_copy = state.copy()
    # Remove unwanted keys
    for keys, values in state.items():
        if keys.startswith('loss'):
            del state_copy[keys]
            print(keys)
    model = torch_utils.load_state_dict_in(state_copy, model)
    # Handle device placement
    if conf['use_gpu']:
        model.cuda()
    model_device = next(model.parameterss()).device
    test_set = LibriMix(csv_dir=conf['test_dir'],
                        task=conf['task'],
                        sample_rate=conf['sample_rate'],
                        n_src=conf['train_conf']['data']['n_src'],
                        segment=None)  # Uses all segment length
    # Used to reorder sources only
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx')

    # Randomly choose the indexes of sentences to save.
    eval_save_dir = os.path.join(conf['exp_dir'], conf['out_dir'])
    ex_save_dir = os.path.join(eval_save_dir, 'examples/')
    if conf['n_save_ex'] == -1:
        conf['n_save_ex'] = len(test_set)
    save_idx = random.sample(range(len(test_set)), conf['n_save_ex'])
    series_list = []
    torch.no_grad().__enter__()
    for idx in tqdm(range(len(test_set))):
        # Forward the network on the mixture.
        mix, sources = tensors_to_device(test_set[idx], device=model_device)
        est_sources = model(mix.unsqueeze(0))
        loss, reordered_sources = loss_func(est_sources,
                                            sources[None],
                                            return_est=True)
        mix_np = mix.cpu().data.numpy()
        sources_np = sources.squeeze().cpu().data.numpy()
        est_sources_np = reordered_sources.squeeze().cpu().data.numpy()
        # For each utterance, we get a dictionary with the mixture path,
        # the input and output metrics
        utt_metrics = get_metrics(mix_np,
                                  sources_np,
                                  est_sources_np,
                                  sample_rate=conf['sample_rate'])
        utt_metrics['mix_path'] = test_set.mixture_path
        series_list.append(pd.Series(utt_metrics))

        # Save some examples in a folder. Wav files and metrics as text.
        if idx in save_idx:
            local_save_dir = os.path.join(ex_save_dir, 'ex_{}/'.format(idx))
            os.makedirs(local_save_dir, exist_ok=True)
            sf.write(local_save_dir + "mixture.wav", mix_np,
                     conf['sample_rate'])
            # Loop over the sources and estimates
            for src_idx, src in enumerate(sources_np):
                sf.write(local_save_dir + "s{}.wav".format(src_idx), src,
                         conf['sample_rate'])
            for src_idx, est_src in enumerate(est_sources_np):
                sf.write(local_save_dir + "s{}_estimate.wav".format(src_idx),
                         est_src, conf['sample_rate'])
            # Write local metrics to the example folder.
            with open(local_save_dir + 'metrics.json', 'w') as f:
                json.dump(utt_metrics, f, indent=0)

    # Save all metrics to the experiment folder.
    all_metrics_df = pd.DataFrame(series_list)
    all_metrics_df.to_csv(os.path.join(eval_save_dir, 'all_metrics.csv'))

    # Print and save summary metrics
    final_results = {}
    for metric_name in compute_metrics:
        input_metric_name = 'input_' + metric_name
        ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name]
        final_results[metric_name] = all_metrics_df[metric_name].mean()
        final_results[metric_name + '_imp'] = ldf.mean()
    print('Overall metrics :')
    pprint(final_results)
    with open(os.path.join(eval_save_dir, 'final_metrics.json'), 'w') as f:
        json.dump(final_results, f, indent=0)
Example #30
0
def main(conf):
    model = load_best_model(conf["train_conf"], conf["exp_dir"])
    # Handle device placement
    if conf["use_gpu"]:
        model.cuda()
    model_device = next(model.parameters()).device
    test_set = WhamRDataset(
        conf["test_dir"],
        conf["task"],
        sample_rate=conf["sample_rate"],
        nondefault_nsrc=model.n_src,
        segment=None,
    )  # Uses all segment length
    # Used to reorder sources only
    loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")

    # Randomly choose the indexes of sentences to save.
    ex_save_dir = os.path.join(conf["exp_dir"], "examples/")
    if conf["n_save_ex"] == -1:
        conf["n_save_ex"] = len(test_set)
    save_idx = random.sample(range(len(test_set)), conf["n_save_ex"])
    series_list = []
    torch.no_grad().__enter__()
    for idx in tqdm(range(len(test_set))):
        # Forward the network on the mixture.
        mix, sources = tensors_to_device(test_set[idx], device=model_device)
        est_sources = model(mix[None, None])
        loss, reordered_sources = loss_func(est_sources, sources[None], return_est=True)
        mix_np = mix[None].cpu().data.numpy()
        sources_np = sources.cpu().data.numpy()
        est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy()
        utt_metrics = get_metrics(
            mix_np,
            sources_np,
            est_sources_np,
            sample_rate=conf["sample_rate"],
            metrics_list=compute_metrics,
        )
        utt_metrics["mix_path"] = test_set.mix[idx][0]
        series_list.append(pd.Series(utt_metrics))

        # Save some examples in a folder. Wav files and metrics as text.
        if idx in save_idx:
            local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx))
            os.makedirs(local_save_dir, exist_ok=True)
            sf.write(local_save_dir + "mixture.wav", mix_np[0], conf["sample_rate"])
            # Loop over the sources and estimates
            for src_idx, src in enumerate(sources_np):
                sf.write(local_save_dir + "s{}.wav".format(src_idx + 1), src, conf["sample_rate"])
            for src_idx, est_src in enumerate(est_sources_np):
                sf.write(
                    local_save_dir + "s{}_estimate.wav".format(src_idx + 1),
                    est_src,
                    conf["sample_rate"],
                )
            # Write local metrics to the example folder.
            with open(local_save_dir + "metrics.json", "w") as f:
                json.dump(utt_metrics, f, indent=0)

    # Save all metrics to the experiment folder.
    all_metrics_df = pd.DataFrame(series_list)
    all_metrics_df.to_csv(os.path.join(conf["exp_dir"], "all_metrics.csv"))

    # Print and save summary metrics
    final_results = {}
    for metric_name in compute_metrics:
        input_metric_name = "input_" + metric_name
        ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name]
        final_results[metric_name] = all_metrics_df[metric_name].mean()
        final_results[metric_name + "_imp"] = ldf.mean()
    print("Overall metrics :")
    pprint(final_results)
    with open(os.path.join(conf["exp_dir"], "final_metrics.json"), "w") as f:
        json.dump(final_results, f, indent=0)