コード例 #1
0
ファイル: predict.py プロジェクト: wayneweiqiang/PhaseNet
def main(args):

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)

    with tf.compat.v1.name_scope('create_inputs'):

        if args.format == "mseed_array":
            data_reader = DataReader_mseed_array(
                data_dir=args.data_dir,
                data_list=args.data_list,
                stations=args.stations,
                amplitude=args.amplitude,
                highpass_filter=args.highpass_filter,
            )
        else:
            data_reader = DataReader_pred(
                format=args.format,
                data_dir=args.data_dir,
                data_list=args.data_list,
                hdf5_file=args.hdf5_file,
                hdf5_group=args.hdf5_group,
                amplitude=args.amplitude,
                highpass_filter=args.highpass_filter,
            )

        pred_fn(args, data_reader, log_dir=args.result_dir)

    return
コード例 #2
0
def main(args):
    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
    coord = tf.train.Coordinator()

    if args.mode == "train":
        with tf.compat.v1.name_scope('create_inputs'):
            data_reader = DataReader(
                data_dir=args.train_dir,
                data_list=args.train_list,
                mask_window=0.4,
                queue_size=args.batch_size * 3,
                coord=coord)
            if args.valid_list is not None:
                data_reader_valid = DataReader(
                    data_dir=args.valid_dir,
                    data_list=args.valid_list,
                    mask_window=0.4,
                    queue_size=args.batch_size * 2,
                    coord=coord)
                logging.info(
                    "Dataset size: train {}, valid {}".format(data_reader.num_data, data_reader_valid.num_data))
            else:
                data_reader_valid = None
                logging.info("Dataset size: train {}".format(data_reader.num_data))
        train_fn(args, data_reader, data_reader_valid)

    elif args.mode == "valid" or args.mode == "test":
        with tf.compat.v1.name_scope('create_inputs'):
            data_reader = DataReader_test(
                data_dir=args.data_dir,
                data_list=args.data_list,
                mask_window=0.4,
                queue_size=args.batch_size * 10,
                coord=coord)
        valid_fn(args, data_reader)

    elif args.mode == "pred":
        with tf.compat.v1.name_scope('create_inputs'):
            if args.input_mseed:
                data_reader = DataReader_mseed(
                    data_dir=args.data_dir,
                    data_list=args.data_list,
                    queue_size=args.batch_size * 10,
                    coord=coord,
                    input_length=args.input_length)
            else:
                data_reader = DataReader_pred(
                    data_dir=args.data_dir,
                    data_list=args.data_list,
                    queue_size=args.batch_size * 10,
                    coord=coord,
                    input_length=args.input_length)
        pred_fn(args, data_reader, log_dir=args.output_dir)

    else:
        print("mode should be: train, valid, test, pred or debug")

    return
コード例 #3
0
ファイル: predict.py プロジェクト: wayneweiqiang/DeepDenoiser
def main(args):

    logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)

    with tf.compat.v1.name_scope('create_inputs'):
        data_reader = DataReader_pred(format=args.format,
                                      signal_dir=args.data_dir,
                                      signal_list=args.data_list,
                                      sampling_rate=args.sampling_rate)
    logging.info("Dataset Size: {}".format(data_reader.n_signal))
    pred_fn(args, data_reader, log_dir=args.output_dir)

    return 0