Ejemplo n.º 1
0
def save_results(mask, X, fname, t0, save_signal=True, save_noise=True, result_dir="results"):

    config = Config()

    if save_signal:
        _, denoised_signal = scipy.signal.istft(
            (X[..., 0] + X[..., 1] * 1j) * mask[..., 0],
            fs=config.fs,
            nperseg=config.nperseg,
            nfft=config.nfft,
            boundary='zeros',
        )  # nbt, nch, nst, nt
        denoised_signal = np.transpose(denoised_signal, [0, 3, 2, 1])  # nbt, nt, nst, nch,
    if save_noise:
        _, denoised_noise = scipy.signal.istft(
            (X[..., 0] + X[..., 1] * 1j) * mask[..., 1],
            fs=config.fs,
            nperseg=config.nperseg,
            nfft=config.nfft,
            boundary='zeros',
        )
        denoised_noise = np.transpose(denoised_noise, [0, 3, 2, 1])

    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    for i in range(len(X)):
        np.savez(
            os.path.join(result_dir, fname[i]),
            data=denoised_signal[i] if save_signal else None,
            noise=denoised_noise[i] if save_noise else None,
            t0=t0[i],
        )
Ejemplo n.º 2
0
def correct_picks(picks, true_p, true_s, tol):
    dt = Config().dt
    if len(true_p) != len(true_s):
        print("The length of true P and S pickers are not the same")
    num = len(true_p)
    TP_p = 0
    TP_s = 0
    nP_p = 0
    nP_s = 0
    nT_p = 0
    nT_s = 0
    diff_p = []
    diff_s = []
    for i in range(num):
        nT_p += len(true_p[i])
        nT_s += len(true_s[i])
        nP_p += len(picks[i][0][0])
        nP_s += len(picks[i][1][0])

        if len(true_p[i]) > 1 or len(true_s[i]) > 1:
            print(i, picks[i], true_p[i], true_s[i])
        tmp_p = np.array(picks[i][0][0]) - np.array(true_p[i])[:, np.newaxis]
        tmp_s = np.array(picks[i][1][0]) - np.array(true_s[i])[:, np.newaxis]
        TP_p += np.sum(np.abs(tmp_p) < tol / dt)
        TP_s += np.sum(np.abs(tmp_s) < tol / dt)
        diff_p.append(tmp_p[np.abs(tmp_p) < 0.5 / dt])
        diff_s.append(tmp_s[np.abs(tmp_s) < 0.5 / dt])

    return [TP_p, TP_s, nP_p, nP_s, nT_p, nT_s, diff_p, diff_s]
Ejemplo n.º 3
0
def detect_peaks_thread(i, pred, fname=None, result_dir=None, args=None):
    if args is None:
        itp = detect_peaks(pred[i, :, 0, 1],
                           mph=0.5,
                           mpd=0.5 / Config().dt,
                           show=False)
        its = detect_peaks(pred[i, :, 0, 2],
                           mph=0.5,
                           mpd=0.5 / Config().dt,
                           show=False)
    else:
        itp = detect_peaks(pred[i, :, 0, 1],
                           mph=args.tp_prob,
                           mpd=0.5 / Config().dt,
                           show=False)
        its = detect_peaks(pred[i, :, 0, 2],
                           mph=args.ts_prob,
                           mpd=0.5 / Config().dt,
                           show=False)
    prob_p = pred[i, itp, 0, 1]
    prob_s = pred[i, its, 0, 2]
    if (fname is not None) and (result_dir is not None):
        #    np.savez(os.path.join(result_dir, fname[i].decode().split('/')[-1]), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s)
        try:
            np.savez(os.path.join(result_dir, fname[i].decode()),
                     pred=pred[i],
                     itp=itp,
                     its=its,
                     prob_p=prob_p,
                     prob_s=prob_s)
        except FileNotFoundError:
            #if not os.path.exists(os.path.dirname(os.path.join(result_dir, fname[i].decode()))):
            os.makedirs(os.path.dirname(
                os.path.join(result_dir, fname[i].decode())),
                        exist_ok=True)
            np.savez(os.path.join(result_dir, fname[i].decode()),
                     pred=pred[i],
                     itp=itp,
                     its=its,
                     prob_p=prob_p,
                     prob_s=prob_s)
    return [(itp, prob_p), (its, prob_s)]
Ejemplo n.º 4
0
def set_config(args, data_reader):
    config = Config()

    config.X_shape = data_reader.X_shape
    config.n_channel = config.X_shape[-1]
    config.Y_shape = data_reader.Y_shape
    config.n_class = config.Y_shape[-1]

    config.depths = args.depth
    config.filters_root = args.filters_root
    config.kernel_size = args.kernel_size
    config.pool_size = args.pool_size
    config.dilation_rate = args.dilation_rate
    config.batch_size = args.batch_size
    config.class_weights = args.class_weights
    config.loss_type = args.loss_type
    config.weight_decay = args.weight_decay
    config.optimizer = args.optimizer

    config.learning_rate = args.learning_rate
    if (args.decay_step == -1) and (args.mode == 'train'):
        config.decay_step = data_reader.num_data // args.batch_size
    else:
        config.decay_step = args.decay_step
    config.decay_rate = args.decay_rate
    config.momentum = args.momentum

    config.summary = args.summary
    config.drop_rate = args.drop_rate
    config.class_weights = args.class_weights

    return config
Ejemplo n.º 5
0
def plot_result_thread(i, epoch, preds, X, Y, figure_dir, mode="valid"):
    config = Config()
    t, noisy_signal = scipy.signal.istft(
        X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
    )
    t, ideal_denoised_signal = scipy.signal.istft(
        (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0],
        fs=config.fs,
        nperseg=config.nperseg,
        nfft=config.nfft,
        boundary='zeros',
    )
    t, denoised_signal = scipy.signal.istft(
        (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
        fs=config.fs,
        nperseg=config.nperseg,
        nfft=config.nfft,
        boundary='zeros',
    )

    plt.figure(i)
    fig_size = plt.gcf().get_size_inches()
    plt.gcf().set_size_inches(fig_size * [1.5, 1.5])
    plt.subplot(4, 2, 1)
    plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=2)
    plt.title("Noisy signal")
    plt.gca().set_xticklabels([])
    plt.subplot(4, 2, 2)
    plt.plot(t, noisy_signal, 'k', label='Noisy signal', linewidth=0.5)
    signal_ylim = plt.gca().get_ylim()
    plt.gca().set_xticklabels([])
    plt.legend(loc='lower left')
    plt.margins(x=0)

    plt.subplot(4, 2, 3)
    plt.pcolormesh(Y[i, :, :, 0], vmin=0, vmax=1)
    plt.gca().set_xticklabels([])
    plt.title("Y")
    plt.subplot(4, 2, 4)
    plt.pcolormesh(preds[i, :, :, 0], vmin=0, vmax=1)
    plt.title("$\hat{Y}$")
    plt.gca().set_xticklabels([])

    plt.subplot(4, 2, 5)
    plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], vmin=0, vmax=2)
    plt.title("Ideal denoised signal")
    plt.gca().set_xticklabels([])
    plt.subplot(4, 2, 6)
    plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=2)
    plt.title("Denoised signal")
    plt.gca().set_xticklabels([])

    plt.subplot(4, 2, 7)
    plt.plot(t, ideal_denoised_signal, 'k', label='Ideal denoised signal', linewidth=0.5)
    plt.ylim(signal_ylim)
    plt.xlabel("Time (s)")
    plt.legend(loc='lower left')
    plt.margins(x=0)
    plt.subplot(4, 2, 8)
    plt.plot(t, denoised_signal, 'k', label='Denoised signal', linewidth=0.5)
    plt.ylim(signal_ylim)
    plt.xlabel("Time (s)")
    plt.legend(loc='lower left')
    plt.margins(x=0)

    plt.tight_layout()
    plt.gcf().align_labels()
    plt.savefig(os.path.join(figure_dir, "epoch{:03d}_{:03d}_{:}.png".format(epoch, i, mode)), bbox_inches='tight')
    plt.close(i)
    return 0
Ejemplo n.º 6
0
def plot_figures(mask, X, fname, figure_dir="figures"):

    config = Config()

    # plot the last channel
    mask = mask[-1, -1, ...]  # nch, nst, nf, nt, 2 => nf, nt, 2
    X = X[-1, -1, ...]

    t1, noisy_signal = scipy.signal.istft(
        (X[..., 0] + X[..., 1] * 1j),
        fs=config.fs,
        nperseg=config.nperseg,
        nfft=config.nfft,
        boundary='zeros',
    )
    t1, denoised_signal = scipy.signal.istft(
        (X[..., 0] + X[..., 1] * 1j) * mask[..., 0],
        fs=config.fs,
        nperseg=config.nperseg,
        nfft=config.nfft,
        boundary='zeros',
    )
    t1, denoised_noise = scipy.signal.istft(
        (X[..., 0] + X[..., 1] * 1j) * mask[..., 1],
        fs=config.fs,
        nperseg=config.nperseg,
        nfft=config.nfft,
        boundary='zeros',
    )

    if not os.path.exists(figure_dir):
        os.makedirs(figure_dir)

    t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[1])
    f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[0])

    box = dict(boxstyle='round', facecolor='white', alpha=1)
    text_loc = [0.05, 0.77]

    plt.figure()
    fig_size = plt.gcf().get_size_inches()
    plt.gcf().set_size_inches(fig_size * [1, 1.2])
    vmax = np.std(np.abs(X[:, :, 0] + X[:, :, 1] * 1j)) * 1.8

    plt.subplot(311)
    plt.pcolormesh(
        t_FT,
        f_FT,
        np.abs(X[:, :, 0] + X[:, :, 1] * 1j),
        vmin=0,
        vmax=vmax,
        shading='auto',
        label='Noisy signal',
    )
    plt.gca().set_xticklabels([])
    plt.text(
        text_loc[0],
        text_loc[1],
        '(i)',
        horizontalalignment='center',
        transform=plt.gca().transAxes,
        fontsize="medium",
        fontweight="bold",
        bbox=box,
    )
    plt.subplot(312)
    plt.pcolormesh(
        t_FT,
        f_FT,
        np.abs(X[:, :, 0] + X[:, :, 1] * 1j) * mask[:, :, 0],
        vmin=0,
        vmax=vmax,
        shading='auto',
        label='Recovered signal',
    )
    plt.gca().set_xticklabels([])
    plt.ylabel("Frequency (Hz)", fontsize='large')
    plt.text(
        text_loc[0],
        text_loc[1],
        '(ii)',
        horizontalalignment='center',
        transform=plt.gca().transAxes,
        fontsize="medium",
        fontweight="bold",
        bbox=box,
    )
    plt.subplot(313)
    plt.pcolormesh(
        t_FT,
        f_FT,
        np.abs(X[:, :, 0] + X[:, :, 1] * 1j) * mask[:, :, 1],
        vmin=0,
        vmax=vmax,
        shading='auto',
        label='Recovered noise',
    )
    plt.xlabel("Time (s)", fontsize='large')
    plt.text(
        text_loc[0],
        text_loc[1],
        '(iii)',
        horizontalalignment='center',
        transform=plt.gca().transAxes,
        fontsize="medium",
        fontweight="bold",
        bbox=box,
    )

    try:
        plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png'), bbox_inches='tight')
        # plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
    except FileNotFoundError:
        os.makedirs(os.path.dirname(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png')), exist_ok=True)
        plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png'), bbox_inches='tight')
        # plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
    plt.close()

    plt.figure()
    fig_size = plt.gcf().get_size_inches()
    plt.gcf().set_size_inches(fig_size * [1, 1.2])

    ax4 = plt.subplot(311)
    plt.plot(t1, noisy_signal, 'k', label='Noisy signal', linewidth=0.5)
    plt.xlim([np.around(t1[0]), np.around(t1[-1])])
    signal_ylim = [-np.max(np.abs(noisy_signal)), np.max(np.abs(noisy_signal))]
    if signal_ylim[0] != signal_ylim[1]:
        plt.ylim(signal_ylim)
    plt.gca().set_xticklabels([])
    plt.legend(loc='lower left', fontsize='medium')
    plt.text(
        text_loc[0],
        text_loc[1],
        '(i)',
        horizontalalignment='center',
        transform=plt.gca().transAxes,
        fontsize="medium",
        fontweight="bold",
        bbox=box,
    )

    ax5 = plt.subplot(312)
    plt.plot(t1, denoised_signal, 'k', label='Recovered signal', linewidth=0.5)
    plt.xlim([np.around(t1[0]), np.around(t1[-1])])
    if signal_ylim[0] != signal_ylim[1]:
        plt.ylim(signal_ylim)
    plt.gca().set_xticklabels([])
    plt.ylabel("Amplitude", fontsize='large')
    plt.legend(loc='lower left', fontsize='medium')
    plt.text(
        text_loc[0],
        text_loc[1],
        '(ii)',
        horizontalalignment='center',
        transform=plt.gca().transAxes,
        fontsize="medium",
        fontweight="bold",
        bbox=box,
    )

    plt.subplot(313)
    plt.plot(t1, denoised_noise, 'k', label='Recovered noise', linewidth=0.5)
    plt.xlim([np.around(t1[0]), np.around(t1[-1])])
    if signal_ylim[0] != signal_ylim[1]:
        plt.ylim(signal_ylim)
    plt.xlabel("Time (s)", fontsize='large')
    plt.legend(loc='lower left', fontsize='medium')
    plt.text(
        text_loc[0],
        text_loc[1],
        '(iii)',
        horizontalalignment='center',
        transform=plt.gca().transAxes,
        fontsize="medium",
        fontweight="bold",
        bbox=box,
    )

    plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_wave.png'), bbox_inches='tight')
    # plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz')+'_wave.pdf'), bbox_inches='tight')
    plt.close()

    return
Ejemplo n.º 7
0
def postprocessing_pred(i, preds, X, fname, figure_dir=None, result_dir=None):

    if (result_dir is not None) or (figure_dir is not None):
        config = Config()

        t1, noisy_signal = scipy.signal.istft(
            (X[i, :, :, 0] + X[i, :, :, 1] * 1j),
            fs=config.fs,
            nperseg=config.nperseg,
            nfft=config.nfft,
            boundary='zeros',
        )
        t1, denoised_signal = scipy.signal.istft(
            (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
            fs=config.fs,
            nperseg=config.nperseg,
            nfft=config.nfft,
            boundary='zeros',
        )
        t1, denoised_noise = scipy.signal.istft(
            (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1],
            fs=config.fs,
            nperseg=config.nperseg,
            nfft=config.nfft,
            boundary='zeros',
        )

    if result_dir is not None:
        try:
            np.savez(
                os.path.join(result_dir, fname[i]),
                noisy_signal=noisy_signal,
                denoised_signal=denoised_signal,
                denoised_noise=denoised_noise,
                t=t1,
            )
        except FileNotFoundError:
            os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i])))
            np.savez(
                os.path.join(result_dir, fname[i]),
                noisy_signal=noisy_signal,
                denoised_signal=denoised_signal,
                denoised_noise=denoised_noise,
                t=t1,
            )

    if figure_dir is not None:

        t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[2])
        f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[1])

        box = dict(boxstyle='round', facecolor='white', alpha=1)
        text_loc = [0.05, 0.77]

        plt.figure(i)
        fig_size = plt.gcf().get_size_inches()
        plt.gcf().set_size_inches(fig_size * [1, 1.2])
        vmax = np.std(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j)) * 1.8

        plt.subplot(311)
        plt.pcolormesh(
            t_FT,
            f_FT,
            np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j),
            vmin=0,
            vmax=vmax,
            shading='auto',
            label='Noisy signal',
        )
        plt.gca().set_xticklabels([])
        plt.text(
            text_loc[0],
            text_loc[1],
            '(i)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )
        plt.subplot(312)
        plt.pcolormesh(
            t_FT,
            f_FT,
            np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
            vmin=0,
            vmax=vmax,
            shading='auto',
            label='Recovered signal',
        )
        plt.gca().set_xticklabels([])
        plt.ylabel("Frequency (Hz)", fontsize='large')
        plt.text(
            text_loc[0],
            text_loc[1],
            '(ii)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )
        plt.subplot(313)
        plt.pcolormesh(
            t_FT,
            f_FT,
            np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1],
            vmin=0,
            vmax=vmax,
            shading='auto',
            label='Recovered noise',
        )
        plt.xlabel("Time (s)", fontsize='large')
        plt.text(
            text_loc[0],
            text_loc[1],
            '(iii)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )

        try:
            plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png'), bbox_inches='tight')
            # plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
        except FileNotFoundError:
            os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png')), exist_ok=True)
            plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png'), bbox_inches='tight')
            # plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
        plt.close(i)

        plt.figure(i)
        fig_size = plt.gcf().get_size_inches()
        plt.gcf().set_size_inches(fig_size * [1, 1.2])

        ax4 = plt.subplot(311)
        plt.plot(t1, noisy_signal, 'k', label='Noisy signal', linewidth=0.5)
        plt.xlim([np.around(t1[0]), np.around(t1[-1])])
        signal_ylim = [-np.max(np.abs(noisy_signal[100:-100])), np.max(np.abs(noisy_signal[100:-100]))]
        plt.ylim(signal_ylim)
        plt.gca().set_xticklabels([])
        plt.legend(loc='lower left', fontsize='medium')
        plt.text(
            text_loc[0],
            text_loc[1],
            '(i)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )

        ax5 = plt.subplot(312)
        plt.plot(t1, denoised_signal, 'k', label='Recovered signal', linewidth=0.5)
        plt.xlim([np.around(t1[0]), np.around(t1[-1])])
        plt.ylim(signal_ylim)
        plt.gca().set_xticklabels([])
        plt.ylabel("Amplitude", fontsize='large')
        plt.legend(loc='lower left', fontsize='medium')
        plt.text(
            text_loc[0],
            text_loc[1],
            '(ii)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )

        plt.subplot(313)
        plt.plot(t1, denoised_noise, 'k', label='Recovered noise', linewidth=0.5)
        plt.xlim([np.around(t1[0]), np.around(t1[-1])])
        plt.ylim(signal_ylim)
        plt.xlabel("Time (s)", fontsize='large')
        plt.legend(loc='lower left', fontsize='medium')
        plt.text(
            text_loc[0],
            text_loc[1],
            '(iii)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )

        plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_wave.png'), bbox_inches='tight')
        # plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz')+'_wave.pdf'), bbox_inches='tight')
        plt.close(i)

    return
Ejemplo n.º 8
0
def postprocessing_test(
    i, preds, X, fname, figure_dir=None, result_dir=None, signal_FT=None, noise_FT=None, data_dir=None
):
    if (figure_dir is not None) or (result_dir is not None):
        config = Config()
        t1, noisy_signal = scipy.signal.istft(
            X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
        )
        t1, denoised_signal = scipy.signal.istft(
            (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0],
            fs=config.fs,
            nperseg=config.nperseg,
            nfft=config.nfft,
            boundary='zeros',
        )
        t1, denoised_noise = scipy.signal.istft(
            (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * (1 - preds[i, :, :, 0]),
            fs=config.fs,
            nperseg=config.nperseg,
            nfft=config.nfft,
            boundary='zeros',
        )
        t1, signal = scipy.signal.istft(
            signal_FT[i, :, :], fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
        )
        t1, noise = scipy.signal.istft(
            noise_FT[i, :, :], fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros'
        )

    if result_dir is not None:
        try:
            np.savez(
                os.path.join(result_dir, fname[i].decode()),
                preds=preds[i],
                X=X[i],
                signal_FT=signal_FT[i],
                noise_FT=noise_FT[i],
                noisy_signal=noisy_signal,
                denoised_signal=denoised_signal,
                denoised_noise=denoised_noise,
                signal=signal,
                noise=noise,
            )
        except FileNotFoundError:
            os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i].decode())), exist_ok=True)
            np.savez(
                os.path.join(result_dir, fname[i].decode()),
                preds=preds[i],
                X=X[i],
                signal_FT=signal_FT[i],
                noise_FT=noise_FT[i],
                noisy_signal=noisy_signal,
                denoised_signal=denoised_signal,
                denoised_noise=denoised_noise,
                signal=signal,
                noise=noise,
            )

    if figure_dir is not None:
        t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[2])
        f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[1])

        raw_data = None
        if data_dir is not None:
            raw_data = np.load(os.path.join(data_dir, fname[i].decode().split('/')[-1]))
            itp = raw_data['itp']
            its = raw_data['its']
            ix1 = (750 - 50) / 100
            ix2 = (750 + (its - itp) + 50) / 100
            if ix2 - ix1 > 3:
                ix2 = ix1 + 3

        box = dict(boxstyle='round', facecolor='white', alpha=1)

        text_loc = [0.05, 0.8]
        plt.figure(i)
        fig_size = plt.gcf().get_size_inches()
        plt.gcf().set_size_inches(fig_size * [1, 2])
        plt.subplot(511)
        plt.pcolormesh(t_FT, f_FT, np.abs(signal_FT[i, :, :]), vmin=0, vmax=1)
        plt.gca().set_xticklabels([])
        plt.text(
            text_loc[0],
            text_loc[1],
            '(i)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )
        plt.subplot(512)
        plt.pcolormesh(t_FT, f_FT, np.abs(noise_FT[i, :, :]), vmin=0, vmax=1)
        plt.gca().set_xticklabels([])
        plt.text(
            text_loc[0],
            text_loc[1],
            '(ii)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )
        plt.subplot(513)
        plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=1)
        plt.ylabel("Frequency (Hz)", fontsize='large')
        plt.gca().set_xticklabels([])
        plt.text(
            text_loc[0],
            text_loc[1],
            '(iii)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )
        plt.subplot(514)
        plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=1)
        plt.gca().set_xticklabels([])
        plt.text(
            text_loc[0],
            text_loc[1],
            '(iv)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )
        plt.subplot(515)
        plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1], vmin=0, vmax=1)
        plt.xlabel("Time (s)", fontsize='large')
        plt.text(
            text_loc[0],
            text_loc[1],
            '(v)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )

        try:
            plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png'), bbox_inches='tight')
            # plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
        except FileNotFoundError:
            os.makedirs(
                os.path.dirname(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png')), exist_ok=True
            )
            plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png'), bbox_inches='tight')
            # plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_FT.pdf'), bbox_inches='tight')
        plt.close(i)

        text_loc = [0.05, 0.8]
        plt.figure(i)
        fig_size = plt.gcf().get_size_inches()
        plt.gcf().set_size_inches(fig_size * [1, 2])

        ax3 = plt.subplot(513)
        plt.plot(t1, noisy_signal, 'k', linewidth=0.5, label='Noisy signal')
        plt.legend(loc='lower left', fontsize='medium')
        plt.xlim([np.around(t1[0]), np.around(t1[-1])])
        plt.ylim([-np.max(np.abs(noisy_signal)), np.max(np.abs(noisy_signal))])
        signal_ylim = [-np.max(np.abs(noisy_signal[100:-100])), np.max(np.abs(noisy_signal[100:-100]))]
        plt.ylim(signal_ylim)
        plt.ylabel("Amplitude", fontsize='large')
        plt.gca().set_xticklabels([])
        plt.text(
            text_loc[0],
            text_loc[1],
            '(iii)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )

        ax1 = plt.subplot(511)
        plt.plot(t1, signal, 'k', linewidth=0.5, label='Signal')
        plt.legend(loc='lower left', fontsize='medium')
        plt.xlim([np.around(t1[0]), np.around(t1[-1])])
        plt.ylim(signal_ylim)
        plt.gca().set_xticklabels([])
        plt.text(
            text_loc[0],
            text_loc[1],
            '(i)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )

        plt.subplot(512)
        plt.plot(t1, noise, 'k', linewidth=0.5, label='Noise')
        plt.legend(loc='lower left', fontsize='medium')
        plt.xlim([np.around(t1[0]), np.around(t1[-1])])
        plt.ylim([-np.max(np.abs(noise)), np.max(np.abs(noise))])
        noise_ylim = [-np.max(np.abs(noise[100:-100])), np.max(np.abs(noise[100:-100]))]
        plt.ylim(noise_ylim)
        plt.gca().set_xticklabels([])
        plt.text(
            text_loc[0],
            text_loc[1],
            '(ii)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )

        ax4 = plt.subplot(514)
        plt.plot(t1, denoised_signal, 'k', linewidth=0.5, label='Recovered signal')
        plt.legend(loc='lower left', fontsize='medium')
        plt.xlim([np.around(t1[0]), np.around(t1[-1])])
        plt.ylim(signal_ylim)
        plt.gca().set_xticklabels([])
        plt.text(
            text_loc[0],
            text_loc[1],
            '(iv)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )

        plt.subplot(515)
        plt.plot(t1, denoised_noise, 'k', linewidth=0.5, label='Recovered noise')
        plt.legend(loc='lower left', fontsize='medium')
        plt.xlim([np.around(t1[0]), np.around(t1[-1])])
        plt.xlabel("Time (s)", fontsize='large')
        plt.ylim(noise_ylim)
        plt.text(
            text_loc[0],
            text_loc[1],
            '(v)',
            horizontalalignment='center',
            transform=plt.gca().transAxes,
            fontsize="medium",
            fontweight="bold",
            bbox=box,
        )

        if data_dir is not None:
            axins = inset_axes(
                ax1, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.5), bbox_transform=ax1.transAxes
            )
            axins.plot(t1, signal, 'k', linewidth=0.5)
            x1, x2 = ix1, ix2
            y1 = -np.max(np.abs(signal[(t1 > ix1) & (t1 < ix2)]))
            y2 = -y1
            axins.set_xlim(x1, x2)
            axins.set_ylim(y1, y2)
            plt.xticks(visible=False)
            plt.yticks(visible=False)
            mark_inset(ax1, axins, loc1=1, loc2=3, fc="none", ec="0.5")

            axins = inset_axes(
                ax3, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.3), bbox_transform=ax3.transAxes
            )
            axins.plot(t1, noisy_signal, 'k', linewidth=0.5)
            x1, x2 = ix1, ix2
            axins.set_xlim(x1, x2)
            axins.set_ylim(y1, y2)
            plt.xticks(visible=False)
            plt.yticks(visible=False)
            mark_inset(ax3, axins, loc1=1, loc2=3, fc="none", ec="0.5")

            axins = inset_axes(
                ax4, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.5), bbox_transform=ax4.transAxes
            )
            axins.plot(t1, denoised_signal, 'k', linewidth=0.5)
            x1, x2 = ix1, ix2
            axins.set_xlim(x1, x2)
            axins.set_ylim(y1, y2)
            plt.xticks(visible=False)
            plt.yticks(visible=False)
            mark_inset(ax4, axins, loc1=1, loc2=3, fc="none", ec="0.5")

        plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_wave.png'), bbox_inches='tight')
        # plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_wave.pdf'), bbox_inches='tight')
        plt.close(i)

    return
Ejemplo n.º 9
0
def plot_result_thread(i,
                       pred,
                       X,
                       Y=None,
                       itp=None,
                       its=None,
                       itp_pred=None,
                       its_pred=None,
                       fname=None,
                       figure_dir=None):
    dt = Config().dt
    t = np.arange(0, pred.shape[1]) * dt
    box = dict(boxstyle='round', facecolor='white', alpha=1)
    text_loc = [0.05, 0.77]

    plt.figure(i)
    # fig_size = plt.gcf().get_size_inches()
    # plt.gcf().set_size_inches(fig_size*[1, 1.2])
    plt.subplot(411)
    plt.plot(t, X[i, :, 0, 0], 'k', label='E', linewidth=0.5)
    plt.autoscale(enable=True, axis='x', tight=True)
    tmp_min = np.min(X[i, :, 0, 0])
    tmp_max = np.max(X[i, :, 0, 0])
    if (itp is not None) and (its is not None):
        for j in range(len(itp[i])):
            if j == 0:
                plt.plot([itp[i][j] * dt, itp[i][j] * dt], [tmp_min, tmp_max],
                         'b',
                         label='P',
                         linewidth=0.5)
            else:
                plt.plot([itp[i][j] * dt, itp[i][j] * dt], [tmp_min, tmp_max],
                         'b',
                         linewidth=0.5)
        for j in range(len(its[i])):
            if j == 0:
                plt.plot([its[i][j] * dt, its[i][j] * dt], [tmp_min, tmp_max],
                         'r',
                         label='S',
                         linewidth=0.5)
            else:
                plt.plot([its[i][j] * dt, its[i][j] * dt], [tmp_min, tmp_max],
                         'r',
                         linewidth=0.5)
    plt.ylabel('Amplitude')
    plt.legend(loc='upper right', fontsize='small')
    plt.gca().set_xticklabels([])
    plt.text(text_loc[0],
             text_loc[1],
             '(i)',
             horizontalalignment='center',
             transform=plt.gca().transAxes,
             fontsize="small",
             fontweight="normal",
             bbox=box)
    plt.subplot(412)
    plt.plot(t, X[i, :, 0, 1], 'k', label='N', linewidth=0.5)
    plt.autoscale(enable=True, axis='x', tight=True)
    tmp_min = np.min(X[i, :, 0, 1])
    tmp_max = np.max(X[i, :, 0, 1])
    if (itp is not None) and (its is not None):
        for j in range(len(itp[i])):
            plt.plot([itp[i][j] * dt, itp[i][j] * dt], [tmp_min, tmp_max],
                     'b',
                     linewidth=0.5)
        for j in range(len(its[i])):
            plt.plot([its[i][j] * dt, its[i][j] * dt], [tmp_min, tmp_max],
                     'r',
                     linewidth=0.5)
    plt.ylabel('Amplitude')
    plt.legend(loc='upper right', fontsize='small')
    plt.gca().set_xticklabels([])
    plt.text(text_loc[0],
             text_loc[1],
             '(ii)',
             horizontalalignment='center',
             transform=plt.gca().transAxes,
             fontsize="small",
             fontweight="normal",
             bbox=box)
    plt.subplot(413)
    plt.plot(t, X[i, :, 0, 2], 'k', label='Z', linewidth=0.5)
    plt.autoscale(enable=True, axis='x', tight=True)
    tmp_min = np.min(X[i, :, 0, 2])
    tmp_max = np.max(X[i, :, 0, 2])
    if (itp is not None) and (its is not None):
        for j in range(len(itp[i])):
            plt.plot([itp[i][j] * dt, itp[i][j] * dt], [tmp_min, tmp_max],
                     'b',
                     linewidth=0.5)
        for j in range(len(its[i])):
            plt.plot([its[i][j] * dt, its[i][j] * dt], [tmp_min, tmp_max],
                     'r',
                     linewidth=0.5)
    plt.ylabel('Amplitude')
    plt.legend(loc='upper right', fontsize='small')
    plt.gca().set_xticklabels([])
    plt.text(text_loc[0],
             text_loc[1],
             '(iii)',
             horizontalalignment='center',
             transform=plt.gca().transAxes,
             fontsize="small",
             fontweight="normal",
             bbox=box)
    plt.subplot(414)
    if Y is not None:
        plt.plot(t, Y[i, :, 0, 1], 'b', label='P', linewidth=0.5)
        plt.plot(t, Y[i, :, 0, 2], 'r', label='S', linewidth=0.5)
    plt.plot(t, pred[i, :, 0, 1], '--g', label='$\hat{P}$', linewidth=0.5)
    plt.plot(t, pred[i, :, 0, 2], '-.m', label='$\hat{S}$', linewidth=0.5)
    plt.autoscale(enable=True, axis='x', tight=True)
    if (itp_pred is not None) and (its_pred is not None):
        for j in range(len(itp_pred)):
            plt.plot([itp_pred[j] * dt, itp_pred[j] * dt], [-0.1, 1.1],
                     '--g',
                     linewidth=0.5)
        for j in range(len(its_pred)):
            plt.plot([its_pred[j] * dt, its_pred[j] * dt], [-0.1, 1.1],
                     '-.m',
                     linewidth=0.5)
    plt.ylim([-0.05, 1.05])
    plt.text(text_loc[0],
             text_loc[1],
             '(iv)',
             horizontalalignment='center',
             transform=plt.gca().transAxes,
             fontsize="small",
             fontweight="normal",
             bbox=box)
    plt.legend(loc='upper right', fontsize='small')
    plt.xlabel('Time (s)')
    plt.ylabel('Probability')

    plt.tight_layout()
    plt.gcf().align_labels()

    try:
        plt.savefig(os.path.join(figure_dir,
                                 fname[i].decode().rstrip('.npz') + '.png'),
                    bbox_inches='tight')
    except FileNotFoundError:
        #if not os.path.exists(os.path.dirname(os.path.join(figure_dir, fname[i].decode()))):
        os.makedirs(os.path.dirname(os.path.join(figure_dir,
                                                 fname[i].decode())),
                    exist_ok=True)
        plt.savefig(os.path.join(figure_dir,
                                 fname[i].decode().rstrip('.npz') + '.png'),
                    bbox_inches='tight')
    #plt.savefig(os.path.join(figure_dir,
    #            fname[i].decode().split('/')[-1].rstrip('.npz')+'.png'),
    #            bbox_inches='tight')
    # plt.savefig(os.path.join(figure_dir,
    #             fname[i].decode().split('/')[-1].rstrip('.npz')+'.pdf'),
    #             bbox_inches='tight')
    plt.close(i)
    return 0