Пример #1
0
def test_sisnr_implementations(n_batches=1,
                               bs=5,
                               n_sources=1,
                               length=16000,
                               improvement=True):
    cpu_timer, gpu_timer = 0., 0.
    gpu_results = torch.zeros((n_batches, bs - 1, n_sources)).float()
    cpu_results = np.zeros((n_batches, bs - 1, n_sources))
    cpu_batches = random_batch_creator(n_batches=n_batches,
                                       bs=bs,
                                       n_sources=n_sources,
                                       length=length)
    gpu_batches = torch.from_numpy(cpu_batches)
    torch.set_printoptions(precision=8)

    for b_ind in np.arange(cpu_batches.shape[0]):
        before = time()
        cpu_results[b_ind, :, :] = \
            naive_sisnr(cpu_batches[b_ind, 1:, :, :],
                        cpu_batches[b_ind, 1:, :, :],
                        n_sources,
                        improvement=improvement)
        now = time()
        cpu_timer += now - before

    gpu_sisnr = sisdr_l.PermInvariantSISDR(batch_size=bs,
                                           n_sources=n_sources,
                                           zero_mean=True,
                                           backward_loss=False,
                                           improvement=improvement)

    for b_ind in np.arange(gpu_batches.shape[0]):
        before = time()
        gpu_results[b_ind, :, :] = \
            gpu_sisnr(gpu_batches[b_ind, 1:, :, :],
                      gpu_batches[b_ind, 1:, :, :])
        now = time()
        gpu_timer += now - before

    print("CPU")
    print(cpu_results)
    print("GPU")
    print(gpu_results.data.cpu().numpy())
    print("DIFF")
    print(np.abs(cpu_results - gpu_results.data.cpu().numpy()))
                                         hparams["fs"],
                                         hparams["bs"],
                                         hparams["n_sources"])
else:
    audio_logger = None


# define data loaders
train_gen, val_gen, tr_val_gen = dataset_specific_params.get_data_loaders(hparams)

# define the losses that are going to be used
back_loss_tr_loss_name, back_loss_tr_loss = (
    'tr_back_loss_mask_SISDR',
    sisdr_lib.PermInvariantSISDR(batch_size=hparams['bs'],
                                 n_sources=hparams['n_sources'],
                                 zero_mean=True,
                                 backward_loss=True,
                                 improvement=True))

val_losses = dict([
    ('val_SISDRi', sisdr_lib.PermInvariantSISDR(batch_size=hparams['bs'],
                                                n_sources=hparams['n_sources'],
                                                zero_mean=True,
                                                backward_loss=False,
                                                improvement=True,
                                                return_individual_results=True)),
  ])
val_loss_name = 'val_SISDRi'

tr_val_losses = dict([
    ('tr_SISDRi', sisdr_lib.PermInvariantSISDR(batch_size=hparams['bs'],