def main(args): # setup training cfg = config.Config() win_points_len = 100 *int(cfg.win_len) if args.model=='DetNet': num_steps = 1 data_shape = [cfg.cnn_bsize, num_steps, win_points_len, cfg.num_chns] elif args.model=='PpkNet': step_len = int(100*cfg.step_len) step_stride = int(100*cfg.step_stride) num_steps = -(step_len/step_stride-1) + win_points_len/step_stride data_shape = [cfg.rnn_bsize, num_steps, step_len, cfg.num_chns] else: print 'false model name!' # get training and validation set if args.model=='DetNet': train_samples = get_det_samples('train', data_shape) valid_samples = get_det_samples('valid', data_shape) elif args.model=='PpkNet': train_samples = get_ppk_samples('train', data_shape) valid_samples = get_ppk_samples('valid', data_shape) inputs = [train_samples, valid_samples] # get model ckpt_dir = os.path.join(args.ckpt_dir, args.model) if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) if args.model=='DetNet': model = models.DetNet(inputs, ckpt_dir) elif args.model=='PpkNet': model = models.PpkNet(inputs, ckpt_dir) # train BaseModel(model).train(args.resume)
def pick(self, streams): """ run PpkNet """ picks = [] data_batch = self.fetch_data(streams, self.num_steps, self.step_len, self.step_stride) data_holder = tf.placeholder(tf.float32, shape=data_batch.shape) inputs = [{'data': data_holder}, {'data': data_holder}] with tf.Session() as sess: # set up PpkNet model model = models.PpkNet(inputs, self.ckpt_dir) BaseModel(model).load(sess, self.ckpt_step) to_fetch = model.layers['pred_class'] # run PpkNet feed_dict = { inputs[1]['data']: data_batch, model.is_training: False } run_time_start = time.time() pred_classes = sess.run(to_fetch, feed_dict) ppk_time = time.time() - run_time_start # decode to sec for pred_class in pred_classes: pred_p = np.where(pred_class == 1)[0] if len(pred_p) > 0: tp = self.step_len/2 if pred_p[0]==0 \ else self.step_len + self.step_stride * (pred_p[0]-0.5) tp /= self.samp_rate pred_class[0:pred_p[0]] = 0 else: tp = -1 pred_s = np.where(pred_class == 2)[0] if len(pred_s) > 0: ts = self.step_len/2 if pred_s[0]==0 \ else self.step_len + self.step_stride * (pred_s[0]-0.5) ts /= self.samp_rate else: ts = -1 picks.append([tp, ts]) tf.reset_default_graph() return picks
def run_ppk(self, stream, det_list): """ run PpkNet to ppk the detected events """ with tf.Session() as sess: # set up PpkNet model step_point_len = int(100 * self.step_len) step_point_stride = int(100 * self.step_stride) inputs = { 'data': tf.placeholder(tf.float32, shape=(1, self.num_steps, step_point_len, 3)) } inputs = [inputs, inputs] model = models.PpkNet(inputs, self.rnn_ckpt_dir) BaseModel(model).load(sess, self.cnn_ckpt_step) to_fetch = model.layers['pred_class'] run_time_start = time.time() num_events = 0 old_t1 = det_list[0][0] for idx, det in enumerate(det_list): t0, t1, det_prob = det[0], det[1], det[2] # pick the time windows with P in first half # if (1) no consecutive picks # or (2.1) next win is event # (2.2) and with higher pred_prob new_idx = min(idx + 1, len(det_list) - 1) if t0 < old_t1 \ or (t1 > det_list[new_idx][0] \ and det_list[idx][2] < det_list[new_idx][2]): continue else: # run PpkNet st = self.preprocess(stream.slice(t0, t1)) feed_dict = { inputs[1]['data']: self.fetch_data(st, self.num_steps, step_point_len, step_point_stride), model.is_training: False } pred_class = sess.run(to_fetch, feed_dict)[0] # decode to relative time (sec) to win_t0 pred_p = np.where(pred_class == 1)[0] pred_s = np.where(pred_class == 2)[0] if len(pred_p) > 0: if pred_p[0] == 0: tp = t0 + self.step_len / 2 else: tp = t0 + self.step_len + self.step_stride * ( pred_p[0] - 0.5) else: tp = -1 if len(pred_s) > 0: if pred_s[0] == 0: ts = t0 + self.step_len / 2 else: ts = t0 + self.step_len + self.step_stride * ( pred_s[0] - 0.5) else: ts = -1 print 'picked phase time: tp={}, ts={}'.format(tp, ts) self.out_file.write(unicode('{},{},{}\n'.\ format(stream[0].stats.station, tp, ts))) num_events += 1 old_t1 = t1 # if picked print "Picked {} events".format(num_events) print "PpkNet Run time: ", time.time() - run_time_start tf.reset_default_graph() return