def generate_training_data(pick_list, dataset, database, chunk_size): """ Generate TFrecords from database. :param pick_list: List of picks from Pick SQL query. :param str dataset: Output directory name. :param str database: SQL database. :param int chunk_size: Number of data stores in TFRecord. """ config = utils.get_config() dataset_dir = os.path.join(config['DATASET_ROOT'], dataset) utils.make_dirs(dataset_dir) total_batch = int(len(pick_list) / chunk_size) batch_picks = utils.batch(pick_list, size=chunk_size) for index, picks in enumerate(batch_picks): example_list = utils.parallel(picks, func=get_example_list, database=database) flatten = itertools.chain.from_iterable flat_list = list(flatten(flatten(example_list))) file_name = f'{index:0>5}.tfrecord' save_file = os.path.join(dataset_dir, file_name) io.write_tfrecord(flat_list, save_file) print(f'output {file_name} / {total_batch}')
def save_loss(loss_buffer, title, save_dir): make_dirs(save_dir) file_path = os.path.join(save_dir, f'{title}.log') loss_buffer = np.asarray(loss_buffer) with open(file_path, 'ab') as f: np.savetxt(f, loss_buffer)
def plot_loss(log_file, save_dir=None): loss = [] with open(log_file, 'r') as f: for line in f.readlines(): line = line.split(' ') loss.append(line) file_name = os.path.basename(log_file).split('.') loss = np.asarray(loss).astype(np.float32) fig = plt.figure(figsize=(8, 4)) ax = fig.add_subplot(111) ax.plot(loss[:, 0], label='train') ax.plot(loss[:, 1], label='validation') plt.xlabel('Steps') plt.ylabel('Loss') ax.legend() plt.title(f'{file_name[0]} loss') if save_dir: make_dirs(save_dir) plt.savefig(os.path.join(save_dir, f'{file_name[0]}.png')) plt.close() else: plt.show()
def plot_dataset(instance, title=None, save_dir=None): """ Plot trace and label. :param instance: :param title: :param save_dir: """ if title is None: title = f'{instance.starttime}_{instance.id[:-3]}' subplot = len(instance.channel) + 1 fig = plt.figure(figsize=(8, subplot * 2)) # plot label ax = fig.add_subplot(subplot, 1, subplot) ax.set_ylim([-0.05, 1.05]) threshold = 0.5 ax.hlines(threshold, 0, 30, lw=1, linestyles='--') peak_flag = [] for i, label in enumerate(['label', 'predict']): for j, phase in enumerate(instance.phase[0:2]): color = color_palette(j, i) ax.plot(get_time_array(instance), getattr(instance, label)[-1, :, j], color=color, label=f'{phase} {label}') peaks, _ = find_peaks(getattr(instance, label)[-1, :, j], distance=100, height=threshold) peak_flag.append(peaks) ax.legend() peak_flag = np.reshape(peak_flag, [2, 2]) # plot trace lines_shape = [':', '-'] for i, chan in enumerate(instance.channel): ax = fig.add_subplot(subplot, 1, i + 1) ax.set_ylim([-1.05, 1.05]) if i == 0: plt.title(title[0:-2]) trace = instance.trace[-1, :, i] ax.plot(get_time_array(instance), trace, "k-", label=chan) for j, phase in enumerate(['label', 'predict']): for k, peak in enumerate(peak_flag[j]): color = color_palette(k, j) ax.vlines(peak_flag[j, k] / 100, -1.05, 1.05, color, lines_shape[j]) ax.legend(loc=1) if save_dir: utils.make_dirs(save_dir) plt.savefig(os.path.join(save_dir, f'{title}.png')) plt.close() else: plt.show()
def get_model_dir(model_instance, remove=False): config = utils.get_config() save_model_path = os.path.join(config['MODELS_ROOT'], model_instance) if remove: shutil.rmtree(save_model_path, ignore_errors=True) utils.make_dirs(save_model_path) save_history_path = os.path.join(save_model_path, "history") utils.make_dirs(save_history_path) return save_model_path, save_history_path
def save_loss(loss_buffer, title, save_dir): """ Write history loss into a log file. :param loss_buffer: Loss history. :param str title: Log file name. :param str save_dir: Output directory. """ utils.make_dirs(save_dir) file_path = os.path.join(save_dir, f'{title}.log') loss_buffer = np.asarray(loss_buffer) with open(file_path, 'ab') as f: np.savetxt(f, loss_buffer)
def plot_error_distribution(time_residuals, save_dir=None): bins = np.linspace(-0.5, 0.5, 100) plt.hist(time_residuals, bins=bins) plt.xticks(np.arange(-0.5, 0.51, step=0.1)) plt.xlabel("Time residuals (sec)") plt.ylabel("Counts") plt.title("Error Distribution") if save_dir: make_dirs(save_dir) plt.savefig(os.path.join(save_dir, f'error_distribution.png')) plt.close() else: plt.show()
def plot_snr_distribution(pick_snr, save_dir=None): sns.set() bins = np.linspace(-1, 10, 55) plt.hist(pick_snr, bins=bins) plt.xticks(np.arange(-1, 11, step=1)) plt.xlabel("Signal to Noise Ratio (log10)") plt.ylabel("Counts") plt.title("SNR Distribution") if save_dir: make_dirs(save_dir) plt.savefig(os.path.join(save_dir, f'error_distribution.png')) plt.close() else: plt.show()
def write_training_dataset(pick_list, geom, dataset, pickset): config = get_config() dataset_dir = os.path.join(config['DATASET_ROOT'], dataset) make_dirs(dataset_dir) pick_time_key = [] for pick in pick_list: pick_time_key.append(pick.time) par = partial(_write_picked_stream, pick_list=pick_list, pick_time_key=pick_time_key, geom=geom, pickset=pickset) example_list = parallel(par, pick_list) station = pick_list[0].waveform_id.station_code file_name = '{}.tfrecord'.format(station) save_file = os.path.join(dataset_dir, file_name) write_tfrecord(example_list, save_file)
from seisnn.model.settings import model, optimizer ap = argparse.ArgumentParser() ap.add_argument('-i', '--input', required=True, help='input dataset', type=str) ap.add_argument('-o', '--output', required=True, help='output dataset', type=str) ap.add_argument('-m', '--model', required=True, help='model', type=str) args = ap.parse_args() config = get_config() MODEL_PATH = os.path.join(config['MODELS_ROOT'], args.model) make_dirs(MODEL_PATH) OUTPUT_DATASET = os.path.join(config['DATASET_ROOT'], args.output) make_dirs(OUTPUT_DATASET) INPUT_DATASET = os.path.join(config['DATASET_ROOT'], args.input) dataset = read_dataset(INPUT_DATASET) ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager(ckpt, MODEL_PATH, max_to_keep=100) if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint) last_epoch = len(ckpt_manager.checkpoints) print(f'Latest checkpoint epoch {last_epoch} restored!!')
from seisnn.example_proto import batch_iterator ap = argparse.ArgumentParser() ap.add_argument('-d', '--dataset', required=True, help='dataset', type=str) ap.add_argument('-p', '--pre_train', required=True, help='pre-train model', type=str) ap.add_argument('-m', '--model', required=True, help='save model', type=str) args = ap.parse_args() config = get_config() SAVE_MODEL_PATH = os.path.join(config['MODELS_ROOT'], args.model) make_dirs(SAVE_MODEL_PATH) SAVE_HISTORY_PATH = os.path.join(SAVE_MODEL_PATH, "history") make_dirs(SAVE_HISTORY_PATH) dataset_dir = os.path.join(config['DATASET_ROOT'], args.dataset) dataset = read_dataset(dataset_dir).skip(1000) val = next(iter(dataset.batch(1))) val_trace = val['trace'][:, :, :, 0, tf.newaxis] val_pdf = val['pdf'][:, :, :, 0, tf.newaxis] ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager(ckpt, SAVE_MODEL_PATH, max_to_keep=100)
GEOM_ROOT = os.path.join(WORKSPACE, 'geom') MODELS_ROOT = os.path.join(WORKSPACE, 'models') config = { 'WORKSPACE': WORKSPACE, 'SDS_ROOT': SDS_ROOT, 'SFILE_ROOT': SFILE_ROOT, 'DATASET_ROOT': DATASET_ROOT, 'SQL_ROOT': SQL_ROOT, 'CATALOG_ROOT': CATALOG_ROOT, 'GEOM_ROOT': GEOM_ROOT, 'MODELS_ROOT': MODELS_ROOT, } if __name__ == '__main__': path_list = [ DATASET_ROOT, SQL_ROOT, CATALOG_ROOT, MODELS_ROOT, GEOM_ROOT, ] for d in path_list: utils.make_dirs(d) path = os.path.join(os.path.expanduser("~"), 'config.yaml') with open(path, 'w') as file: yaml.dump(config, file, sort_keys=False) print('SeisNN initialized.')
def plot_dataset(feature, snr=False, enlarge=False, xlim=None, title=None, save_dir=None): if title is None: title = f'{feature["starttime"]}_{feature["id"][:-3]}' if feature['pick_time']: first_pick_time = UTCDateTime(feature['pick_time'][-1]) - UTCDateTime( feature['starttime']) else: first_pick_time = 1 subplot = len(feature['channel']) + 1 fig = plt.figure(figsize=(8, subplot * 2)) for i, chan in enumerate(feature['channel']): ax = fig.add_subplot(subplot, 1, i + 1) plt.title(title + chan) if xlim: plt.xlim(xlim) if enlarge: plt.xlim((first_pick_time - 1, first_pick_time + 2)) trace = feature['trace'][-1, :, i] ax.plot(get_time_array(feature), trace, "k-", label=chan) y_min, y_max = ax.get_ylim() if feature['pick_time']: label_set = set() pick_type = ['manual', 'predict'] for i in range(len(feature['pick_time'])): pick_set = feature['pick_set'][i] pick_phase = feature['pick_phase'][i] phase_color = feature['phase'].index(pick_phase) type_color = pick_type.index(pick_set) color = color_palette(type_color, 1) label = pick_set + " " + pick_phase pick_time = UTCDateTime(feature['pick_time'][i]) - UTCDateTime( feature['starttime']) if not label in label_set: ax.vlines(pick_time, y_min, y_max, color=color, lw=1, label=label) label_set.add(label) else: ax.vlines(pick_time, y_min, y_max, color=color, lw=1) if snr and pick_set == 'manual': try: index = int(pick_time / feature['delta']) noise = trace[index - 100:index] signal = trace[index:index + 100] snr = signal_to_noise_ratio(signal, noise) if not snr == float("inf"): ax.text(pick_time, y_max - 0.1, f'SNR: {snr:.2f}') except IndexError: pass ax.legend(loc=1) ax = fig.add_subplot(subplot, 1, subplot) ax.set_ylim([-0.05, 1.05]) for i in range(feature['pdf'].shape[2]): if feature['phase'][i]: color = color_palette(i, 1) ax.plot(get_time_array(feature), feature['pdf'][-1, :, i], color=color, label=feature['phase'][i]) ax.legend() else: label_only = [Line2D([0], [0], color="#AAAAAA", lw=2)] ax.legend(label_only, ['No phase data']) threshold = 0.5 ax.hlines(threshold, 0, 30, lw=1, linestyles='--') if xlim: plt.xlim(xlim) if enlarge: plt.xlim((first_pick_time - 1, first_pick_time + 2)) if save_dir: make_dirs(save_dir) plt.savefig(os.path.join(save_dir, f'{title}.png')) plt.close() else: plt.show()
CATALOG_ROOT = os.path.join(WORKSPACE, 'catalog') GEOM_ROOT = os.path.join(WORKSPACE, 'geom') MODELS_ROOT = os.path.join(WORKSPACE, 'models') config = { 'WORKSPACE': WORKSPACE, 'SDS_ROOT': SDS_ROOT, 'SFILE_ROOT': SFILE_ROOT, 'TFRECORD_ROOT': TFRECORD_ROOT, 'DATABASE_ROOT': DATABASE_ROOT, 'CATALOG_ROOT': CATALOG_ROOT, 'GEOM_ROOT': GEOM_ROOT, 'MODELS_ROOT': MODELS_ROOT, } # mkdir for all folders and store into config.yaml if __name__ == '__main__': for d in [ TFRECORD_ROOT, DATABASE_ROOT, CATALOG_ROOT, MODELS_ROOT, GEOM_ROOT ]: make_dirs(d) with open(os.path.join(os.path.expanduser("~"), '.bashrc'), 'w') as file: file.write('export PYTHONPATH=/SeisNN:$PYTHONPATH') with open(os.path.join(os.path.expanduser("~"), 'config.yaml'), 'w') as file: yaml.dump(config, file, sort_keys=False) print('SeisNN initialized.')