Ejemplo n.º 1
0
def train_frequency_representation(args, fr_module, fr_optimizer, fr_criterion,
                                   fr_scheduler, train_loader, val_loader,
                                   xgrid, epoch, tb_writer):
    """
    Train the frequency-representation module for one epoch
    """
    epoch_start_time = time.time()
    fr_module.train()
    loss_train_fr = 0
    for batch_idx, (clean_signal, target_fr, freq) in enumerate(train_loader):
        if args.use_cuda:
            clean_signal, target_fr = clean_signal.cuda(), target_fr.cuda()
        snr = np.random.uniform(args.snrl, args.snrh + 1)
        noisy_signal = noise_torch(clean_signal, snr, args.noise)
        fr_optimizer.zero_grad()
        output_fr = fr_module(noisy_signal)
        loss_fr = fr_criterion(output_fr, target_fr)
        loss_fr.backward()
        fr_optimizer.step()
        loss_train_fr += loss_fr.data.item()

    fr_module.eval()
    loss_val_fr, fnr_val = 0, 0
    for batch_idx, (noisy_signal, _, target_fr, freq) in enumerate(val_loader):
        if args.use_cuda:
            noisy_signal, target_fr = noisy_signal.cuda(), target_fr.cuda()
        with torch.no_grad():
            output_fr = fr_module(noisy_signal)
        loss_fr = fr_criterion(output_fr, target_fr)
        loss_val_fr += loss_fr.data.item()
        nfreq = (freq >= -0.5).sum(dim=1)
        f_hat = fr.find_freq(output_fr.cpu().detach().numpy(), nfreq, xgrid)
        fnr_val += fnr(f_hat, freq.cpu().numpy(), args.signal_dim)

    loss_train_fr /= (args.n_training * (args.snrh - args.snrl + 1))
    loss_val_fr /= (args.n_validation * (args.snrh - args.snrl + 1))
    fnr_val *= 100 / (args.n_validation * (args.snrh - args.snrl + 1))

    tb_writer.add_scalar('fr_l2_training', loss_train_fr, epoch)
    tb_writer.add_scalar('fr_l2_validation', loss_val_fr, epoch)
    tb_writer.add_scalar('fr_FNR', fnr_val, epoch)

    fr_scheduler.step(loss_val_fr)
    logger.info(
        "Epochs: %d / %d, Time: %.1f, FR training L2 loss %.2f, FR validation L2 loss %.2f, FNR %.2f %%",
        epoch, args.n_epochs_fr + args.n_epochs_fc,
        time.time() - epoch_start_time, loss_train_fr, loss_val_fr, fnr_val)
Ejemplo n.º 2
0
    for k in range(len(dB)):

        data_path = os.path.join(args.data_dir, str(dB[k]) + 'dB.npy')
        if not os.path.exists(data_path):
            warnings.warn('{:.1f}dB data not in data directory.'.format(dB[k]))

        noisy_signals = np.load(data_path)
        noisy_signals = torch.tensor(noisy_signals)

        with torch.no_grad():

            # Evaluate FNR of the frequency-representation module
            model_fr_torch = fr_module(noisy_signals)
            model_fr = model_fr_torch.cpu().numpy()
            f_model = fr.find_freq(model_fr, nfreq, xgrid)
            model_fnr_arr.append(100 * loss.fnr(f_model, f, signal_dim) / num_test)

            # Evaluate accuracy of the frequency-counting module
            if args.fc_path is not None:
                model_fc = fc_module(model_fr_torch)
                model_fc = model_fc.view(model_fc.size(0))
                model_estimate = torch.round(model_fc).cpu().numpy()
                model_err = 1 - (model_estimate == nfreq).sum() / num_test
                fc_acc.append(100 * model_err)
                f_model_fc = fr.find_freq(model_fr, model_estimate, xgrid, 50)
                model_chamfer.append(loss.chamfer(f_model_fc, f) / num_test)

            # Evalute FNR of the PSnet
            if psnet is not None:
                psnet_fr_torch = psnet(noisy_signals)
Ejemplo n.º 3
0
    nfreq = np.sum(f >= -0.5, axis=1)

    for k in range(len(dB)):
        data_path = os.path.join(args.data_dir, str(dB[k]) + 'dB.npy')
        if not os.path.exists(data_path):
            warnings.warn('{:.1f}dB data not in data directory.'.format(dB[k]))

        noisy_signals = np.load(data_path)
        noisy_signals = torch.tensor(noisy_signals).to(device)

        with torch.no_grad():
            # Evaluate FNR of the frequency-representation module
            model_fr_torch = fr_module1(noisy_signals)
            model_fr = model_fr_torch.cpu().numpy()
            f_model = fr.find_freq(model_fr, nfreq, xgrid)
            model1_fnr_arr.append(100 * loss.fnr(f_model, f, signal_dim) / num_test)
            # get fr for the second model [bias free]
            model_fr_torch = fr_module2(noisy_signals)
            model_fr = model_fr_torch.cpu().numpy()
            f_model = fr.find_freq(model_fr, nfreq, xgrid)
            model2_fnr_arr.append(100 * loss.fnr(f_model, f, signal_dim) / num_test)

    fig, ax = plt.subplots()
    ax.grid(linestyle='--', linewidth=0.5)
    ax.plot(dB, model1_fnr_arr, label='DF', marker='d', c=palette[3])
    ax.plot(dB, model2_fnr_arr, label='DF_NB', marker='d', c=palette[4])
    ax.set_xlabel('SNR (dB)')
    ax.set_ylabel('FNR (\%)')
    ax.legend()
    plt.savefig(os.path.join(args.output_dir, 'fnr.png'), bbox_inches='tight', pad_inches=0.0)