示例#1
0
def generate_training_data(pick_list, dataset, database, chunk_size):
    """
    Generate TFrecords from database.

    :param pick_list: List of picks from Pick SQL query.
    :param str dataset: Output directory name.
    :param str database: SQL database.
    :param int chunk_size: Number of data stores in TFRecord.
    """
    config = utils.get_config()
    dataset_dir = os.path.join(config['DATASET_ROOT'], dataset)
    utils.make_dirs(dataset_dir)

    total_batch = int(len(pick_list) / chunk_size)
    batch_picks = utils.batch(pick_list, size=chunk_size)
    for index, picks in enumerate(batch_picks):
        example_list = utils.parallel(picks,
                                      func=get_example_list,
                                      database=database)
        flatten = itertools.chain.from_iterable
        flat_list = list(flatten(flatten(example_list)))

        file_name = f'{index:0>5}.tfrecord'
        save_file = os.path.join(dataset_dir, file_name)
        io.write_tfrecord(flat_list, save_file)
        print(f'output {file_name} / {total_batch}')
示例#2
0
文件: logger.py 项目: jyyjqq/SeisNN
def save_loss(loss_buffer, title, save_dir):

    make_dirs(save_dir)
    file_path = os.path.join(save_dir, f'{title}.log')
    loss_buffer = np.asarray(loss_buffer)
    with open(file_path, 'ab') as f:
        np.savetxt(f, loss_buffer)
示例#3
0
def plot_loss(log_file, save_dir=None):
    loss = []
    with open(log_file, 'r') as f:
        for line in f.readlines():
            line = line.split(' ')
            loss.append(line)

    file_name = os.path.basename(log_file).split('.')
    loss = np.asarray(loss).astype(np.float32)

    fig = plt.figure(figsize=(8, 4))
    ax = fig.add_subplot(111)

    ax.plot(loss[:, 0], label='train')
    ax.plot(loss[:, 1], label='validation')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    ax.legend()
    plt.title(f'{file_name[0]} loss')

    if save_dir:
        make_dirs(save_dir)
        plt.savefig(os.path.join(save_dir, f'{file_name[0]}.png'))
        plt.close()
    else:
        plt.show()
示例#4
0
文件: plot.py 项目: zhandyg/SeisNN
def plot_dataset(instance, title=None, save_dir=None):
    """
    Plot trace and label.

    :param instance:
    :param title:
    :param save_dir:
    """
    if title is None:
        title = f'{instance.starttime}_{instance.id[:-3]}'

    subplot = len(instance.channel) + 1
    fig = plt.figure(figsize=(8, subplot * 2))

    # plot label
    ax = fig.add_subplot(subplot, 1, subplot)
    ax.set_ylim([-0.05, 1.05])
    threshold = 0.5
    ax.hlines(threshold, 0, 30, lw=1, linestyles='--')
    peak_flag = []
    for i, label in enumerate(['label', 'predict']):
        for j, phase in enumerate(instance.phase[0:2]):
            color = color_palette(j, i)
            ax.plot(get_time_array(instance),
                    getattr(instance, label)[-1, :, j],
                    color=color,
                    label=f'{phase} {label}')
            peaks, _ = find_peaks(getattr(instance, label)[-1, :, j],
                                  distance=100,
                                  height=threshold)
            peak_flag.append(peaks)
            ax.legend()
    peak_flag = np.reshape(peak_flag, [2, 2])

    # plot trace
    lines_shape = [':', '-']
    for i, chan in enumerate(instance.channel):
        ax = fig.add_subplot(subplot, 1, i + 1)
        ax.set_ylim([-1.05, 1.05])
        if i == 0:
            plt.title(title[0:-2])
        trace = instance.trace[-1, :, i]
        ax.plot(get_time_array(instance), trace, "k-", label=chan)
        for j, phase in enumerate(['label', 'predict']):
            for k, peak in enumerate(peak_flag[j]):
                color = color_palette(k, j)
                ax.vlines(peak_flag[j, k] / 100, -1.05, 1.05, color,
                          lines_shape[j])
        ax.legend(loc=1)

    if save_dir:
        utils.make_dirs(save_dir)
        plt.savefig(os.path.join(save_dir, f'{title}.png'))
        plt.close()
    else:
        plt.show()
示例#5
0
    def get_model_dir(model_instance, remove=False):
        config = utils.get_config()
        save_model_path = os.path.join(config['MODELS_ROOT'], model_instance)

        if remove:
            shutil.rmtree(save_model_path, ignore_errors=True)
        utils.make_dirs(save_model_path)

        save_history_path = os.path.join(save_model_path, "history")
        utils.make_dirs(save_history_path)

        return save_model_path, save_history_path
示例#6
0
def save_loss(loss_buffer, title, save_dir):
    """
    Write history loss into a log file.

    :param loss_buffer: Loss history.
    :param str title: Log file name.
    :param str save_dir: Output directory.
    """
    utils.make_dirs(save_dir)
    file_path = os.path.join(save_dir, f'{title}.log')
    loss_buffer = np.asarray(loss_buffer)
    with open(file_path, 'ab') as f:
        np.savetxt(f, loss_buffer)
示例#7
0
def plot_error_distribution(time_residuals, save_dir=None):
    bins = np.linspace(-0.5, 0.5, 100)
    plt.hist(time_residuals, bins=bins)
    plt.xticks(np.arange(-0.5, 0.51, step=0.1))
    plt.xlabel("Time residuals (sec)")
    plt.ylabel("Counts")
    plt.title("Error Distribution")

    if save_dir:
        make_dirs(save_dir)
        plt.savefig(os.path.join(save_dir, f'error_distribution.png'))
        plt.close()
    else:
        plt.show()
示例#8
0
def plot_snr_distribution(pick_snr, save_dir=None):
    sns.set()
    bins = np.linspace(-1, 10, 55)
    plt.hist(pick_snr, bins=bins)
    plt.xticks(np.arange(-1, 11, step=1))
    plt.xlabel("Signal to Noise Ratio (log10)")
    plt.ylabel("Counts")
    plt.title("SNR Distribution")

    if save_dir:
        make_dirs(save_dir)
        plt.savefig(os.path.join(save_dir, f'error_distribution.png'))
        plt.close()
    else:
        plt.show()
示例#9
0
文件: io.py 项目: jyyjqq/SeisNN
def write_training_dataset(pick_list, geom, dataset, pickset):
    config = get_config()
    dataset_dir = os.path.join(config['DATASET_ROOT'], dataset)
    make_dirs(dataset_dir)

    pick_time_key = []
    for pick in pick_list:
        pick_time_key.append(pick.time)

    par = partial(_write_picked_stream,
                  pick_list=pick_list,
                  pick_time_key=pick_time_key,
                  geom=geom,
                  pickset=pickset)

    example_list = parallel(par, pick_list)

    station = pick_list[0].waveform_id.station_code
    file_name = '{}.tfrecord'.format(station)
    save_file = os.path.join(dataset_dir, file_name)

    write_tfrecord(example_list, save_file)
示例#10
0
文件: predict.py 项目: zhandyg/SeisNN
from seisnn.model.settings import model, optimizer

ap = argparse.ArgumentParser()
ap.add_argument('-i', '--input', required=True, help='input dataset', type=str)
ap.add_argument('-o',
                '--output',
                required=True,
                help='output dataset',
                type=str)
ap.add_argument('-m', '--model', required=True, help='model', type=str)
args = ap.parse_args()

config = get_config()

MODEL_PATH = os.path.join(config['MODELS_ROOT'], args.model)
make_dirs(MODEL_PATH)

OUTPUT_DATASET = os.path.join(config['DATASET_ROOT'], args.output)
make_dirs(OUTPUT_DATASET)

INPUT_DATASET = os.path.join(config['DATASET_ROOT'], args.input)
dataset = read_dataset(INPUT_DATASET)

ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, MODEL_PATH, max_to_keep=100)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    last_epoch = len(ckpt_manager.checkpoints)
    print(f'Latest checkpoint epoch {last_epoch} restored!!')
示例#11
0
文件: training.py 项目: jyyjqq/SeisNN
from seisnn.example_proto import batch_iterator

ap = argparse.ArgumentParser()
ap.add_argument('-d', '--dataset', required=True, help='dataset', type=str)
ap.add_argument('-p',
                '--pre_train',
                required=True,
                help='pre-train model',
                type=str)
ap.add_argument('-m', '--model', required=True, help='save model', type=str)
args = ap.parse_args()

config = get_config()

SAVE_MODEL_PATH = os.path.join(config['MODELS_ROOT'], args.model)
make_dirs(SAVE_MODEL_PATH)
SAVE_HISTORY_PATH = os.path.join(SAVE_MODEL_PATH, "history")
make_dirs(SAVE_HISTORY_PATH)

dataset_dir = os.path.join(config['DATASET_ROOT'], args.dataset)
dataset = read_dataset(dataset_dir).skip(1000)

val = next(iter(dataset.batch(1)))
val_trace = val['trace'][:, :, :, 0, tf.newaxis]
val_pdf = val['pdf'][:, :, :, 0, tf.newaxis]

ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt,
                                          SAVE_MODEL_PATH,
                                          max_to_keep=100)
示例#12
0
文件: config.py 项目: zhandyg/SeisNN
GEOM_ROOT = os.path.join(WORKSPACE, 'geom')
MODELS_ROOT = os.path.join(WORKSPACE, 'models')

config = {
    'WORKSPACE': WORKSPACE,
    'SDS_ROOT': SDS_ROOT,
    'SFILE_ROOT': SFILE_ROOT,
    'DATASET_ROOT': DATASET_ROOT,
    'SQL_ROOT': SQL_ROOT,
    'CATALOG_ROOT': CATALOG_ROOT,
    'GEOM_ROOT': GEOM_ROOT,
    'MODELS_ROOT': MODELS_ROOT,
}

if __name__ == '__main__':
    path_list = [
        DATASET_ROOT,
        SQL_ROOT,
        CATALOG_ROOT,
        MODELS_ROOT,
        GEOM_ROOT,
    ]
    for d in path_list:
        utils.make_dirs(d)

    path = os.path.join(os.path.expanduser("~"), 'config.yaml')
    with open(path, 'w') as file:
        yaml.dump(config, file, sort_keys=False)

    print('SeisNN initialized.')
示例#13
0
def plot_dataset(feature,
                 snr=False,
                 enlarge=False,
                 xlim=None,
                 title=None,
                 save_dir=None):
    if title is None:
        title = f'{feature["starttime"]}_{feature["id"][:-3]}'
    if feature['pick_time']:
        first_pick_time = UTCDateTime(feature['pick_time'][-1]) - UTCDateTime(
            feature['starttime'])
    else:
        first_pick_time = 1

    subplot = len(feature['channel']) + 1
    fig = plt.figure(figsize=(8, subplot * 2))
    for i, chan in enumerate(feature['channel']):
        ax = fig.add_subplot(subplot, 1, i + 1)
        plt.title(title + chan)

        if xlim:
            plt.xlim(xlim)
        if enlarge:
            plt.xlim((first_pick_time - 1, first_pick_time + 2))
        trace = feature['trace'][-1, :, i]
        ax.plot(get_time_array(feature), trace, "k-", label=chan)
        y_min, y_max = ax.get_ylim()

        if feature['pick_time']:
            label_set = set()
            pick_type = ['manual', 'predict']

            for i in range(len(feature['pick_time'])):
                pick_set = feature['pick_set'][i]
                pick_phase = feature['pick_phase'][i]
                phase_color = feature['phase'].index(pick_phase)
                type_color = pick_type.index(pick_set)

                color = color_palette(type_color, 1)
                label = pick_set + " " + pick_phase

                pick_time = UTCDateTime(feature['pick_time'][i]) - UTCDateTime(
                    feature['starttime'])
                if not label in label_set:
                    ax.vlines(pick_time,
                              y_min,
                              y_max,
                              color=color,
                              lw=1,
                              label=label)
                    label_set.add(label)
                else:
                    ax.vlines(pick_time, y_min, y_max, color=color, lw=1)

                if snr and pick_set == 'manual':
                    try:
                        index = int(pick_time / feature['delta'])
                        noise = trace[index - 100:index]
                        signal = trace[index:index + 100]
                        snr = signal_to_noise_ratio(signal, noise)
                        if not snr == float("inf"):
                            ax.text(pick_time, y_max - 0.1, f'SNR: {snr:.2f}')
                    except IndexError:
                        pass
        ax.legend(loc=1)

    ax = fig.add_subplot(subplot, 1, subplot)
    ax.set_ylim([-0.05, 1.05])

    for i in range(feature['pdf'].shape[2]):
        if feature['phase'][i]:
            color = color_palette(i, 1)
            ax.plot(get_time_array(feature),
                    feature['pdf'][-1, :, i],
                    color=color,
                    label=feature['phase'][i])
            ax.legend()

        else:
            label_only = [Line2D([0], [0], color="#AAAAAA", lw=2)]
            ax.legend(label_only, ['No phase data'])

    threshold = 0.5
    ax.hlines(threshold, 0, 30, lw=1, linestyles='--')

    if xlim:
        plt.xlim(xlim)
    if enlarge:
        plt.xlim((first_pick_time - 1, first_pick_time + 2))

    if save_dir:
        make_dirs(save_dir)
        plt.savefig(os.path.join(save_dir, f'{title}.png'))
        plt.close()
    else:
        plt.show()
示例#14
0
文件: config.py 项目: jyyjqq/SeisNN
CATALOG_ROOT = os.path.join(WORKSPACE, 'catalog')
GEOM_ROOT = os.path.join(WORKSPACE, 'geom')
MODELS_ROOT = os.path.join(WORKSPACE, 'models')

config = {
    'WORKSPACE': WORKSPACE,
    'SDS_ROOT': SDS_ROOT,
    'SFILE_ROOT': SFILE_ROOT,
    'TFRECORD_ROOT': TFRECORD_ROOT,
    'DATABASE_ROOT': DATABASE_ROOT,
    'CATALOG_ROOT': CATALOG_ROOT,
    'GEOM_ROOT': GEOM_ROOT,
    'MODELS_ROOT': MODELS_ROOT,
}

# mkdir for all folders and store into config.yaml
if __name__ == '__main__':
    for d in [
            TFRECORD_ROOT, DATABASE_ROOT, CATALOG_ROOT, MODELS_ROOT, GEOM_ROOT
    ]:
        make_dirs(d)

    with open(os.path.join(os.path.expanduser("~"), '.bashrc'), 'w') as file:
        file.write('export PYTHONPATH=/SeisNN:$PYTHONPATH')

    with open(os.path.join(os.path.expanduser("~"), 'config.yaml'),
              'w') as file:
        yaml.dump(config, file, sort_keys=False)

    print('SeisNN initialized.')