Ejemplo n.º 1
0
def compute_jacobian_realization(fr_module, num_samples, signal_dim, num_freq,
                                 min_sep, fixed_freq, snr):
    clean_signals, signal_frs, signal_num_freqs = gen_signal(
        num_samples=num_samples,
        signal_dim=signal_dim,
        num_freq=num_freq,
        min_sep=min_sep,
        fixed_freq=fixed_freq)
    noisy_signals = noise_torch(torch.as_tensor(clean_signals), snr,
                                'gaussian')
    jacobian_realization = []
    jacobian_realization_clean = []
    for idx in range(clean_signals.shape[0]):
        clean_signal, signal_fr, signal_num_freq = clean_signals[
            idx], signal_frs[idx], signal_num_freqs[idx]
        noisy_signal = noisy_signals[idx].cpu().numpy()
        clean_signal_t = clean_signal[0] + clean_signal[1] * 1j
        clean_signal_fft = np.fft.fft(clean_signal_t, n=1000)
        clean_signal_fft = np.fft.fftshift(clean_signal_fft)
        noisy_signal_t = noisy_signal[0] + 1j * noisy_signal[1]
        noisy_signal_fft = np.fft.fft(noisy_signal_t, n=1000)
        noisy_signal_fft = np.fft.fftshift(noisy_signal_fft)
        noisy_signal = torch.as_tensor(noisy_signal).unsqueeze(dim=0)
        jacobian, inputs, outputs = compute_jacobian_and_bias(
            noisy_signal, fr_module)
        clean_signal = torch.as_tensor(clean_signal).unsqueeze(dim=0)
        jacobian_clean, inputs, outputs = compute_jacobian_and_bias(
            clean_signal, fr_module)
        jacobian_realization.append(jacobian)
        jacobian_realization_clean.append(jacobian_clean)
    return jacobian_realization, jacobian_realization_clean
Ejemplo n.º 2
0
def train_frequency_representation(args, fr_module, fr_optimizer, fr_criterion,
                                   fr_scheduler, train_loader, val_loader,
                                   xgrid, epoch, tb_writer):
    """
    Train the frequency-representation module for one epoch
    """
    epoch_start_time = time.time()
    fr_module.train()
    loss_train_fr = 0
    for batch_idx, (clean_signal, target_fr, freq) in enumerate(train_loader):
        if args.use_cuda:
            clean_signal, target_fr = clean_signal.cuda(), target_fr.cuda()
        snr = np.random.uniform(args.snrl, args.snrh + 1)
        noisy_signal = noise_torch(clean_signal, snr, args.noise)
        fr_optimizer.zero_grad()
        output_fr = fr_module(noisy_signal)
        loss_fr = fr_criterion(output_fr, target_fr)
        loss_fr.backward()
        fr_optimizer.step()
        loss_train_fr += loss_fr.data.item()

    fr_module.eval()
    loss_val_fr, fnr_val = 0, 0
    for batch_idx, (noisy_signal, _, target_fr, freq) in enumerate(val_loader):
        if args.use_cuda:
            noisy_signal, target_fr = noisy_signal.cuda(), target_fr.cuda()
        with torch.no_grad():
            output_fr = fr_module(noisy_signal)
        loss_fr = fr_criterion(output_fr, target_fr)
        loss_val_fr += loss_fr.data.item()
        nfreq = (freq >= -0.5).sum(dim=1)
        f_hat = fr.find_freq(output_fr.cpu().detach().numpy(), nfreq, xgrid)
        fnr_val += fnr(f_hat, freq.cpu().numpy(), args.signal_dim)

    loss_train_fr /= (args.n_training * (args.snrh - args.snrl + 1))
    loss_val_fr /= (args.n_validation * (args.snrh - args.snrl + 1))
    fnr_val *= 100 / (args.n_validation * (args.snrh - args.snrl + 1))

    tb_writer.add_scalar('fr_l2_training', loss_train_fr, epoch)
    tb_writer.add_scalar('fr_l2_validation', loss_val_fr, epoch)
    tb_writer.add_scalar('fr_FNR', fnr_val, epoch)

    fr_scheduler.step(loss_val_fr)
    logger.info(
        "Epochs: %d / %d, Time: %.1f, FR training L2 loss %.2f, FR validation L2 loss %.2f, FNR %.2f %%",
        epoch, args.n_epochs_fr + args.n_epochs_fc,
        time.time() - epoch_start_time, loss_train_fr, loss_val_fr, fnr_val)
Ejemplo n.º 3
0
def train_frequency_counting(args, fr_module, fc_module, fc_optimizer,
                             fc_criterion, fc_scheduler, train_loader,
                             val_loader, epoch, tb_writer):
    """
    Train the frequency-counting module for one epoch
    """
    epoch_start_time = time.time()
    fr_module.eval()
    fc_module.train()
    loss_train_fc, acc_train_fc = 0, 0
    for batch_idx, (clean_signal, target_fr, freq) in enumerate(train_loader):
        if args.use_cuda:
            clean_signal, target_fr, freq = clean_signal.cuda(
            ), target_fr.cuda(), freq.cuda()
        snr = np.random.uniform(args.snrl, args.snrh + 1)

        noisy_signal = noise_torch(clean_signal, snr, args.noise)
        nfreq = (freq >= -0.5).sum(dim=1)
        if args.use_cuda:
            nfreq = nfreq.cuda()
        if args.fc_module_type == 'classification':
            nfreq = nfreq - 1
        with torch.no_grad():
            output_fr = fr_module(noisy_signal)
            output_fr = output_fr.detach()
        output_fc = fc_module(output_fr)
        if args.fc_module_type == 'regression':
            output_fc = output_fc.view(output_fc.size(0))

            nfreq = nfreq.float()
        loss_fc = fc_criterion(output_fc, nfreq)
        if args.fc_module_type == 'classification':
            estimate = output_fc.max(1)[1]
        else:
            estimate = torch.round(output_fc)
        acc_train_fc += estimate.eq(nfreq).sum().cpu().item()
        loss_train_fc += loss_fc.data.item()

        fc_optimizer.zero_grad()
        loss_fc.backward()
        fc_optimizer.step()

    loss_train_fc /= args.n_training
    acc_train_fc *= 100 / args.n_training

    fc_module.eval()

    loss_val_fc = 0
    acc_val_fc = 0
    for batch_idx, (noisy_signal, _, target_fr, freq) in enumerate(val_loader):
        if args.use_cuda:
            noisy_signal, target_fr = noisy_signal.cuda(), target_fr.cuda()
        nfreq = (freq >= -0.5).sum(dim=1)
        if args.use_cuda:
            nfreq = nfreq.cuda()
        if args.fc_module_type == 'classification':
            nfreq = nfreq - 1
        with torch.no_grad():
            output_fr = fr_module(noisy_signal)
            output_fc = fc_module(output_fr)
        if args.fc_module_type == 'regression':
            output_fc = output_fc.view(output_fc.size(0))
            nfreq = nfreq.float()
        loss_fc = fc_criterion(output_fc, nfreq)
        if args.fc_module_type == 'regression':
            estimate = torch.round(output_fc)
        elif args.fc_module_type == 'classification':
            estimate = torch.argmax(output_fc, dim=1)

        acc_val_fc += estimate.eq(nfreq).sum().item()
        loss_val_fc += loss_fc.data.item()

    loss_val_fc /= args.n_validation
    acc_val_fc *= 100 / args.n_validation

    fc_scheduler.step(acc_val_fc)

    tb_writer.add_scalar('fc_loss_training', loss_train_fc,
                         epoch - args.n_epochs_fr)
    tb_writer.add_scalar('fc_loss_validation', loss_val_fc,
                         epoch - args.n_epochs_fr)
    tb_writer.add_scalar('fc_accuracy_training', acc_train_fc,
                         epoch - args.n_epochs_fr)
    tb_writer.add_scalar('fc_accuracy_validation', acc_val_fc,
                         epoch - args.n_epochs_fr)

    logger.info(
        "Epochs: %d / %d, Time: %.1f, Training fc loss: %.2f, Vadidation fc loss: %.2f, "
        "Training accuracy: %.2f %%, Validation accuracy: %.2f %%", epoch,
        args.n_epochs_fr + args.n_epochs_fc,
        time.time() - epoch_start_time, loss_train_fc, loss_val_fc,
        acc_train_fc, acc_val_fc)
Ejemplo n.º 4
0
def generate_report_plots(num_samples=2,
                          signal_dim=50,
                          min_sep=1.,
                          snr=30,
                          fixed_freq=[0.4],
                          save=False):
    if not os.path.exists('./plots'):
        os.mkdir('./plots')
    save_dir = './plots'
    num_freq = len(fixed_freq)
    for signal_idx in range(num_samples):
        clean_signals, signal_frs, signal_num_freqs = gen_signal(
            num_samples=num_samples,
            signal_dim=signal_dim,
            num_freq=num_freq,
            min_sep=min_sep,
            fixed_freq=fixed_freq)
        noisy_signals = noise_torch(torch.as_tensor(clean_signals), snr,
                                    'gaussian')

        clean_signal, signal_fr, signal_num_freq = clean_signals[
            signal_idx], signal_frs[signal_idx], signal_num_freqs[signal_idx]
        noisy_signal = noisy_signals[signal_idx].cpu().numpy()
        clean_signal_t = clean_signal[0] + clean_signal[1] * 1j
        clean_signal_fft = np.fft.fft(clean_signal_t, n=1000)
        clean_signal_fft = np.fft.fftshift(clean_signal_fft)
        noisy_signal_t = noisy_signal[0] + 1j * noisy_signal[1]
        noisy_signal_fft = np.fft.fft(noisy_signal_t, n=1000)
        noisy_signal_fft = np.fft.fftshift(noisy_signal_fft)
        if signal_idx == 0:  # plot clean first
            noisy_signal = torch.as_tensor(clean_signal).unsqueeze(dim=0)
        else:
            noisy_signal = torch.as_tensor(noisy_signal).unsqueeze(dim=0)

        jacobian, inputs, outputs = compute_jacobian_and_bias(
            noisy_signal, fr_module)
        fft_filter = jacobian[0] - 1j * jacobian[1]

        # plot 1
        fig1, ax = plt.subplots(3, 1, figsize=(15, 6))
        xgrid = np.linspace(-0.5, 0.5, fr_module.fr_size, endpoint=False)
        ax[0].plot(xgrid, outputs[0])
        ax[0].set_xticks(np.arange(-0.5, 0.5, 0.2))
        ylim = ax[0].get_ylim()
        for i in range(signal_fr.shape[0]):
            ax[0].vlines(signal_fr[i],
                         ymin=ylim[0],
                         ymax=ylim[1],
                         label='target:{:4.2f}'.format(signal_fr[i]))
        ax[0].legend()
        ax[0].set_xlim(-0.5, 0.5)
        ax[1].plot(xgrid, np.abs(clean_signal_fft), label='clean fft')
        ax[1].plot(xgrid, np.abs(noisy_signal_fft), '--', label='noisy fft')
        ax[1].set_xlim(-0.5, 0.5)
        ylim = ax[1].get_ylim()
        for i in range(signal_fr.shape[0]):
            ax[1].vlines(signal_fr[i],
                         ymin=ylim[0],
                         ymax=ylim[1],
                         label='target:{:4.2f}'.format(signal_fr[i]))
        ax[1].legend()
        im = ax[2].imshow(np.abs(fft_filter))
        ax[2].set_aspect(2.2)
        ax[2].set_ylabel('jacobian', fontsize=13)

        fig1.subplots_adjust(right=0.9)
        cbar_ax = fig1.add_axes([0.91, 0.15, 0.01, 0.7])
        fig1.colorbar(im, cax=cbar_ax)
        plt.show()
        # plot 2
        fig2, ax = plt.subplots(3, 1, figsize=(15, 6))
        xgrid = np.linspace(-0.5, 0.5, fr_module.fr_size, endpoint=False)
        ax[0].plot(xgrid, outputs[0])
        ax[0].set_xticks(np.arange(-0.5, 0.5, 0.2))
        ylim = ax[0].get_ylim()
        for i in range(signal_fr.shape[0]):
            ax[0].vlines(signal_fr[i],
                         ymin=ylim[0],
                         ymax=ylim[1],
                         label='target:{:4.2f}'.format(signal_fr[i]))
        ax[0].legend()

        ax[1].plot(xgrid, np.abs(clean_signal_fft), label='clean fft')
        ax[1].plot(xgrid, np.abs(noisy_signal_fft), '--', label='noisy fft')
        ylim = ax[1].get_ylim()
        for i in range(signal_fr.shape[0]):
            ax[1].vlines(signal_fr[i],
                         ymin=ylim[0],
                         ymax=ylim[1],
                         label='target:{:4.2f}'.format(signal_fr[i]))
        ax[1].legend()
        fft_filter_norm = fft_filter * np.conjugate(fft_filter)
        fft_filter_norm = fft_filter_norm.T.sum(axis=1).real
        ax[2].plot(xgrid, fft_filter_norm)
        ax[0].set_xlim(-0.5, 0.5)
        ax[1].set_xlim(-0.5, 0.5)
        ax[2].set_xlim(-0.5, 0.5)

        fig2.subplots_adjust(right=0.9)
        plt.show()

        # plot. 3
        indices = find_neariest_idx(signal_fr, xgrid)
        target_filter = fft_filter.T[indices]
        # time domain
        fig3, ax = plt.subplots(target_filter.shape[0], 2)
        for idx, filt in enumerate(target_filter):
            if target_filter.shape[0] > 1:
                ax[idx, 0].plot(filt.real)
                ax[idx, 1].plot(filt.imag)
                ax[idx, 0].set_title('Real Part of Freq:{:4.2f};'.format(
                    signal_fr[idx]))
                ax[idx, 1].set_title('Complex Part of Freq:{:4.2f};'.format(
                    signal_fr[idx]))
            else:
                ax[0].plot(filt.real)
                ax[1].plot(filt.imag)
                ax[0].set_title('Real Part of Freq:{:4.2f};'.format(
                    signal_fr[idx]))
                ax[1].set_title('Complex Part of Freq:{:4.2f};'.format(
                    signal_fr[idx]))
        plt.tight_layout()
        plt.show()

        # fft domian of signals
        fig4, ax = plt.subplots(target_filter.shape[0], 1, dpi=300)
        for idx, filt in enumerate(target_filter):
            filt_fft = np.fft.fft(filt, n=1000)
            filt_fft = np.fft.fftshift(filt_fft)
            magnitude = np.abs(filt_fft)
            if target_filter.shape[0] > 1:
                ax[idx].plot(xgrid, magnitude)
                ax[idx].plot(signal_fr[idx], magnitude[indices[idx]], '*')
                ax[idx].plot(-signal_fr[idx], magnitude[999 - indices[idx]],
                             '*')
                ax[idx].set_title('Freq:{:4.2f}'.format(signal_fr[idx]))
            else:
                ax.plot(xgrid, magnitude)
                ax.plot(signal_fr[idx], magnitude[indices[idx]], '*')
                ax.plot(-signal_fr[idx], magnitude[999 - indices[idx]], '*')
                ax.set_title('Freq:{:4.2f}'.format(signal_fr[idx]))
        plt.tight_layout()
        plt.show()

        if save:
            file_name = 'output_clean_{}.pdf' if signal_idx == 0 else 'output_{}.pdf'
            pdf = matplotlib.backends.backend_pdf.PdfPages(
                "./plots/" + file_name.format(signal_idx))
            pdf.savefig(fig1)
            pdf.savefig(fig2)
            pdf.savefig(fig3)
            pdf.savefig(fig4)
            pdf.close()
Ejemplo n.º 5
0
            .format(args.output_dir))
    elif not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    with open(os.path.join(args.output_dir, 'data.args'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    np.random.seed(args.numpy_seed)
    torch.manual_seed(args.torch_seed)

    s, f, nfreq = gen_signal(num_samples=args.n_test,
                             signal_dim=args.signal_dimension,
                             num_freq=args.max_freq,
                             min_sep=args.minimum_separation,
                             distance=args.distance,
                             amplitude=args.amplitude,
                             floor_amplitude=args.floor_amplitude,
                             variable_num_freq=True)

    np.save(os.path.join(args.output_dir, 'infdB'), s)
    np.save(os.path.join(args.output_dir, 'f'), f)

    eval_snrs = [np.exp(np.log(10) * float(x) / 10) for x in args.dB]

    for k, snr in enumerate(eval_snrs):
        noisy_signals = noise.noise_torch(torch.tensor(s), snr,
                                          'gaussian').cpu()
        np.save(
            os.path.join(args.output_dir, '{}dB'.format(float(args.dB[k]))),
            noisy_signals)