Exemple #1
0
    def train(self, passes, new_training=True):
        with tf.Session() as sess:
            global_step=tf.Variable(0, trainable=False)
            #learning_rate = tf.train.exponential_decay(0.001, global_step, 200, 0.8, staircase=True)
            training = tf.train.AdamOptimizer(self.config.learning_rate).minimize(self.loss)

            if new_training:
                saver, global_step = Model.start_new_session(sess)
            else:
                saver, global_step = Model.continue_previous_session(sess,
                                                                     model_file='cnn',
                                                                     ckpt_file= self.config.root + '/event_detect/saver/cnn/checkpoint')

            sess.run(tf.local_variables_initializer())
            self.train_writer.add_graph(sess.graph, global_step=global_step)

            test_restlt=[]

            for step in range(1 + global_step, 1 + passes + global_step):
                input, target = self.reader.get_cnn_batch_data('train')
                #print(input.shape)
                summary, _, acc = sess.run([self.merged, training, self.metrics['accuracy']],
                                           feed_dict={self.layer['input']: input,
                                                      self.layer['target']: target})
                self.train_writer.add_summary(summary, step)

                if step % 10 == 0:
                    loss = sess.run(self.loss,
                                    feed_dict={self.layer['input']: input,
                                               self.layer['target']: target})
                    test_restlt.append(loss)

                    print("gobal_step {}, training_loss {}, accuracy {}".format(step, loss, acc))

                if step % 100 == 0:
                     test_x, text_y = self.reader.get_cnn_batch_data('test')
                     acc, recall, precision = sess.run([self.metrics['accuracy'],
                                                        self.metrics['recall'],
                                                        self.metrics['precision']],
                                                       feed_dict={self.layer['input']: test_x,
                                                                  self.layer['target']: text_y})


                     print("test: accuracy {}, recall {}, precision {}".format(acc, recall, precision))
                     saver.save(sess, self.config.root + '/event_detect/saver/cnn/cnn', global_step=step)
                     print('checkpoint saved')
                     #print(sess.run([self.layer['class_prob']], feed_dict={self.layer['input']: input}))

            print(test_restlt)
Exemple #2
0
    def train(self, passes, new_training=True):
        with tf.Session() as sess:
            training = tf.train.AdamOptimizer(1e-3).minimize(self.loss)
            if new_training:
                saver, global_step = Model.start_new_session(sess)
            else:
                saver, global_step = Model.continue_previous_session(
                    sess, model_file='cnn', ckpt_file='saver/cnn/checkpoint')
            sess.run(tf.local_variables_initializer())
            self.train_writer.add_graph(sess.graph, global_step=global_step)

            for step in range(1 + global_step, 1 + passes + global_step):
                input, target = self.reader.get_cnn_batch_data('train')

                summary, _, acc = sess.run(
                    [self.merged, training, self.metrics['accuracy']],
                    feed_dict={
                        self.layer['input']: input,
                        self.layer['target']: target
                    })
                self.train_writer.add_summary(summary, step)

                if step % 10 == 0:
                    loss = sess.run(self.loss,
                                    feed_dict={
                                        self.layer['input']: input,
                                        self.layer['target']: target
                                    })
                    print(
                        "gobal_step {}, training_loss {}, accuracy {}".format(
                            step, loss, acc))

                if step % 100 == 0:
                    test_x, text_y = self.reader.get_cnn_batch_data('test')
                    acc, recall, precision = sess.run([
                        self.metrics['accuracy'], self.metrics['recall'],
                        self.metrics['precision']
                    ],
                                                      feed_dict={
                                                          self.layer['input']:
                                                          test_x,
                                                          self.layer['target']:
                                                          text_y
                                                      })
                    print("test: accuracy {}, recall {}, precision {}".format(
                        acc, recall, precision))
                    saver.save(sess, 'saver/cnn/cnn', global_step=step)
                    print('checkpoint saved')
Exemple #3
0
 def get_emb(self):
     sess_config = tf.ConfigProto()
     sess_config.gpu_options.allow_growth = True
     with tf.Session(config=sess_config) as sess:
         saver, global_step = Model.continue_previous_session(sess,
                                                              model_file='model/saver/{}'.format(self.saveFile),
                                                              ckpt_file='model/saver/{}/checkpoint'.format(self.saveFile))
         ids_set = np.array(range(self.g.nodes_num))
         emb_set = tf.nn.embedding_lookup(self.layer['emb'], ids_set)
         sup_emb_set = tf.nn.embedding_lookup(self.layer['sup_emb'], ids_set)
         [emb, sup_emb] = sess.run([emb_set, sup_emb_set])
         emb = np.concatenate([emb, sup_emb], axis=1)
     return emb
Exemple #4
0
def main(model_name, new_scan=False, preprocess=True):
    config = Config()
    plot = config.plot
    cut = config.cut
    bandpass = config.bandpass
    resample = config.resample

    # read data folders
    file_list = os.listdir(config.root + '/data/after')
    file_list.sort()
    if new_scan == True:
        print('start new scan!')
        start_point = 0
        event_num = 0

        try:
            os.system('rm -rf %s/event_detect/detect_result/cut/*' %
                      config.root)
            os.system('rm -rf %s/event_detect/detect_result/png/*' %
                      config.root)
            os.system('rm -rf %s/event_detect/detect_result/png2/*' %
                      config.root)
            os.system('rm -rf %s/event_detect/detect_result/cnn/*.csv' %
                      config.root)
        except:
            pass
        # file_list_len = len(file_list)
    else:
        with open(config.root + '/event_detect/detect_result/' + model_name +
                  '/checkpoint') as file:
            start_point = int(file.readline())
            event_num = int(file.readline())
            file_list = file_list[start_point:]
            # file_list_len = len(file_list)
            print('restart from {}'.format(file_list[0]))

    # load CNN model
    if model_name == 'cnn':
        from cnn import CNN
        import tensorflow as tf
        from tflib.models import Model

        model = CNN()
        # sess = tf.Session(config=tf.ConfigProto(device_count={"CPU":20},inter_op_parallelism_threads=0,intra_op_parallelism_threads=0))
        sess = tf.Session()
        saver, global_step = Model.continue_previous_session(
            sess,
            model_file='cnn',
            ckpt_file=config.root + '/event_detect/saver/cnn/checkpoint')

    # read group info
    group = []
    with open(config.root + '/config/group_info', 'r') as f:
        for line in f.readlines():
            if line != '\n':
                if line[0] == '#':
                    group.append([])
                else:
                    group[-1].append(line.split()[0])
    # read data & detect eq
    for file in file_list:
        sac_file_name = [[], [], []]
        all_group_sta_num = [0] * len(group)
        path = os.path.join(config.root + '/data/after', file)
        begin = datetime.datetime.now()
        group_E = [[] for _ in range(len(group))]
        group_N = [[] for _ in range(len(group))]
        group_Z = [[] for _ in range(len(group))]
        print('Start reading data: %s.' % file)
        for i in range(len(group)):
            for sta in group[i]:
                if len(glob.glob(path + '/' + '*' + sta + '.*')) == 3:
                    all_group_sta_num[i] += 1
                    sacfile_E = glob.glob(path + '/' + '*' + sta + '.*' +
                                          'E')[0]
                    sacfile_N = glob.glob(path + '/' + '*' + sta + '.*' +
                                          'N')[0]
                    sacfile_Z = glob.glob(path + '/' + '*' + sta + '.*' +
                                          'Z')[0]
                    sac_file_name[0].append(sacfile_E.split('/')[-1])
                    sac_file_name[1].append(sacfile_N.split('/')[-1])
                    sac_file_name[2].append(sacfile_Z.split('/')[-1])
                    group_E[i].append(obspy.read(sacfile_E))
                    group_N[i].append(obspy.read(sacfile_N))
                    group_Z[i].append(obspy.read(sacfile_Z))
        flatten_group_E = [st for each_group in group_E for st in each_group]
        flatten_group_N = [st for each_group in group_N for st in each_group]
        flatten_group_Z = [st for each_group in group_Z for st in each_group]
        st_E = reduce(lambda st1, st2: st1 + st2, flatten_group_E)
        st_N = reduce(lambda st1, st2: st1 + st2, flatten_group_N)
        st_Z = reduce(lambda st1, st2: st1 + st2, flatten_group_Z)
        st_all = st_E + st_N + st_Z
        all_sta_num = len(flatten_group_Z)
        if resample:
            st_all = st_all.resample(sampling_rate=resample)
        if bandpass:
            st_all = st_all.filter('bandpass',
                                   freqmin=bandpass[0],
                                   freqmax=bandpass[1],
                                   corners=4,
                                   zerophase=True)
        endtime = st_all[0].stats.endtime

        start_flag = -1
        end_flag = -1
        event_list = []
        confidence_total = {}
        start_total = []
        end_total = []
        pos_num_total = []
        samples = 1.0 / st_all[0].stats.delta
        # npts = st_all[0].stats.npts
        print('Finish reading data.')

        print('Start detection.')
        for windowed_st in st_all.slide(window_length=(config.winsize - 1) /
                                        samples,
                                        step=config.winlag / samples):
            cur_sta = 0
            len_group_conf = 0
            group_class, group_conf = [], []
            # windowed_E = windowed_st[:all_sta_num]
            # windowed_N = windowed_st[all_sta_num:2*all_sta_num]
            # windowed_Z = windowed_st[2*all_sta_num:]
            start = len(windowed_st) / 3 * 2
            end = len(windowed_st)
            group_max_conf = 0
            for i in range(len(group)):
                data_input = [[], [], []]
                group_sta_num = all_group_sta_num[i]
                if group_sta_num > 0:
                    for j in range(cur_sta, cur_sta + group_sta_num):
                        if len(windowed_st[j].data) < config.winsize:
                            windowed_st[j].data = np.concatenate([
                                windowed_st[j].data,
                                np.zeros(config.winsize -
                                         len(windowed_st[j].data))
                            ])
                        data_input[0].append(
                            windowed_st[j].data[:config.winsize])
                        # print(j, windowed_st[j])
                    for j in range(all_sta_num + cur_sta,
                                   all_sta_num + cur_sta + group_sta_num):
                        if len(windowed_st[j].data) < config.winsize:
                            windowed_st[j].data = np.concatenate([
                                windowed_st[j].data,
                                np.zeros(config.winsize -
                                         len(windowed_st[j].data))
                            ])
                        data_input[1].append(
                            windowed_st[j].data[:config.winsize])
                        # print(j, windowed_st[j])
                    for j in range(2 * all_sta_num + cur_sta,
                                   2 * all_sta_num + cur_sta + group_sta_num):
                        if len(windowed_st[j].data) < config.winsize:
                            windowed_st[j].data = np.concatenate([
                                windowed_st[j].data,
                                np.zeros(config.winsize -
                                         len(windowed_st[j].data))
                            ])
                        data_input[2].append(
                            windowed_st[j].data[:config.winsize])
                        # print(j, windowed_st[j])
                    plot_b = 2 * all_sta_num + cur_sta
                    plot_e = 2 * all_sta_num + cur_sta + group_sta_num
                    cur_sta += group_sta_num

                    if preprocess:
                        for i in range(3):
                            for j in range(group_sta_num):
                                data_input[i][j] = data_preprocess(
                                    data_input[i][j])
                    data_input = np.array(data_input)

                    if len(data_input[0][0]) < config.winsize:
                        concat = np.zeros([
                            3, group_sta_num,
                            config.winsize - len(data_input[0][0])
                        ])
                        data_input = np.concatenate([data_input, concat],
                                                    axis=2)
                    else:
                        data_input = data_input[:, :, :config.winsize]
                    data_input = data_input.transpose((1, 2, 0))

                    j = 0
                    while j < len(data_input):
                        if np.max(data_input[j]) == 0 or np.isnan(
                                np.max(data_input[j])):
                            data_input = np.delete(data_input, j, axis=0)
                        else:
                            j += 1

                    if len(data_input) >= 3:
                        len_group_conf += 1
                        class_pred, confidence = model.classify(
                            sess=sess, input_=[data_input])
                        group_class.append(class_pred)
                        group_conf.append(confidence[0])
                        if confidence[0] > group_max_conf:
                            start = plot_b
                            end = plot_e
                            group_max_conf = confidence[0]
                    else:
                        group_class.append(0)
                        group_conf.append(0)
                else:
                    group_class.append(0)
                    group_conf.append(0)

            # consider the result of multiple groups
            pos_num = 0
            for each in group_class:
                if each == 1:
                    pos_num += 1
            if pos_num >= config.group_num_thrd:
                class_pred = 1
            else:
                class_pred = 0

            confidence = sum(
                group_conf) / len_group_conf if len_group_conf else 0

            # calculate the window range
            if class_pred == 1:
                confidence_total[confidence] = [group_max_conf, start, end]
                start_total.append(windowed_st[0].stats.starttime)
                end_total.append(windowed_st[0].stats.endtime)
                pos_num_total.append(pos_num)

                if start_flag == -1:
                    start_flag = windowed_st[0].stats.starttime
                    end_flag = windowed_st[0].stats.endtime
                else:
                    end_flag = windowed_st[0].stats.endtime
            print("{} {} {} {} {:.8f} {:.8f}".format(class_pred,start_flag,end_flag, \
                windowed_st[0].stats.starttime,confidence, group_max_conf))

            if class_pred == 0 and start_flag != -1:  #end_flag < windowed_st[0].stats.starttime:
                confidence = max(list(confidence_total.keys()))
                # for j in range(len(confidence_total)):
                #     if confidence == confidence_total[j]:
                #         break
                # start_local = start_total[j]
                # end_local = end_total[j]
                # event = [file, start_flag, end_flag,
                #          confidence, start_local, end_local]
                event_num += 1
                group_max_conf = confidence_total[confidence][0]
                start = confidence_total[confidence][1]
                end = confidence_total[confidence][2]
                event = [event_num, file, start_flag, end_flag, confidence, \
                    max(pos_num_total), start, end, group_max_conf]

                confidence_total = {}
                start_total = []
                end_total = []
                pos_num_total = []

                event_list.append(event)
                #print(event_list)

                start_flag = -1
                end_flag = -1

            if class_pred == 1 and end_flag + config.winlag / samples >= endtime:
                confidence = max(list(confidence_total.keys()))
                # for j in range(len(confidence_total)):
                #     if confidence == confidence_total[j]:
                #         break
                # start_local = start_total[j]
                # end_local = end_total[j]
                # event = [file.split('/')[-2], start_flag, endtime,
                #          confidence, start_total, end_total]
                event_num += 1
                group_max_conf = confidence_total[confidence][0]
                start = confidence_total[confidence][1]
                end = confidence_total[confidence][2]
                event = [event_num, file, start_flag, endtime, confidence, \
                    max(pos_num_total), start, end, group_max_conf]

                event_list.append(event)
                start_flag = -1
                end_flag = -1

        if event_list:
            new_event_list = [event_list[0]]
            for i in range(1, len(event_list)):
                if event_list[i][1] > new_event_list[-1][1] and \
                event_list[i][1] < new_event_list[-1][1]+1000/(config.resample if config.resample else 200):
                    # if event_list[i][1] > new_event_list[-1][1] and event_list[i][1] < new_event_list[-1][2]:
                    new_event_list[-1][2] = event_list[i][2]
                else:
                    new_event_list.append(event_list[i])
        else:
            new_event_list = []

        # write event list
        if len(event_list) != 0:
            with open(config.root + '/event_detect/detect_result/' +
                      model_name + '/events_list.csv',
                      mode='a',
                      newline='') as f:
                csvwriter = csv.writer(f)
                for event in event_list:
                    csvwriter.writerow(event)
                f.close()

        if plot:
            print('Plot detected events.')
            for event in new_event_list:
                plot_traces = st_Z
                event_num, _, start_flag, end_flag, confidence, pos_num, start, end, group_max_conf = event
                name = config.root + '/event_detect/detect_result/png/' \
                        + str(int(event_num)) + '_' + str(confidence)[:4] + '.png'
                plot_traces.plot(starttime=start_flag,
                                 endtime=end_flag,
                                 size=(800, 800),
                                 automerge=False,
                                 equal_scale=False,
                                 linewidth=0.8,
                                 outfile=name)

                plot_traces2 = st_all[start:end]
                name2 = config.root + '/event_detect/detect_result/png2/' \
                        + str(int(event_num)) + '_' + str(group_max_conf)[:4] + '.png'
                plot_traces2.plot(starttime=start_flag,
                                  endtime=end_flag,
                                  size=(800, 800),
                                  automerge=False,
                                  equal_scale=False,
                                  linewidth=0.8,
                                  outfile=name2)

        ## cut use Obspy, processed data
        # if cut:
        #     print('Cut detected events.')
        #     for event in new_event_list:
        #         event_num, _, start_flag, end_flag, confidence, pos_num, start, end, group_max_conf = event
        #         slice_E = st_E.slice(start_flag, end_flag)
        #         slice_N = st_N.slice(start_flag, end_flag)
        #         slice_Z = st_Z.slice(start_flag, end_flag)
        #         save_path = config.root + '/event_detect/detect_result/cut/' \
        #                 + str(int(event_num)) + '_' + str(confidence)[:4]
        #         os.system('mkdir %s'%save_path)
        #         for i in range(len(slice_E)):
        #             slice_E[i].write(save_path+'/'+sac_file_name[0][i], format='SAC')
        #             slice_N[i].write(save_path+'/'+sac_file_name[1][i], format='SAC')
        #             slice_Z[i].write(save_path+'/'+sac_file_name[2][i], format='SAC')

        ## cut use SAC, raw data
        if cut:
            print('Cut detected events.')
            for event in new_event_list:
                event_num, _, start_flag, end_flag, confidence, pos_num, start, end, group_max_conf = event
                save_path = config.root + '/event_detect/detect_result/cut/' \
                    + str(int(event_num)) + '_' + str(confidence)[:4] + '/'
                os.system('mkdir %s' % save_path)
                cut_b = 60 * 60 * int(start_flag.hour) + 60 * int(
                    start_flag.minute) + float(start_flag.second)
                cut_e = 60 * 60 * int(end_flag.hour) + 60 * int(
                    end_flag.minute) + float(end_flag.second)
                ## SAC
                os.putenv("SAC_DISPLAY_COPYRIGHT", "0")
                p = subprocess.Popen(['sac'], stdin=subprocess.PIPE)

                s = ''

                s += "cut %s %s \n" % (cut_b, cut_e)
                s += "r %s/* \n" % (config.root + '/data/after/' + file)
                s += "w dir %s over \n" % (save_path)
                s += "quit \n"

                p.communicate(s.encode())

        start_point += 1
        with open(config.root + '/event_detect/detect_result/' + model_name +
                  '/checkpoint',
                  mode='w') as f:
            f.write(str(start_point) + '\n')
            f.write(str(event_num))
            end = datetime.datetime.now()
            print('{} completed, num {}, time {}.'.format(
                file, start_point, end - begin))
            print('Checkpoint saved.')
def main(model_name, new_scan=False, preprocess=True):
    reader = Reader()
    config = Config()

    reader.aftername.sort()

    if new_scan == True:
        print('start new scan!')
        file_list = reader.aftername

        start_point = 0
    else:
        with open('detect_result/' + model_name + '/checkpoint') as file:
            start_point = int(file.readline())
            file_list = reader.aftername[start_point:]
            print('restart from {}'.format(file_list[0]))

    if model_name == 'cnn':
        from event_detect.cnn import CNN
        import tensorflow as tf
        from tflib.models import Model

        model = CNN()
        sess = tf.Session()
        saver, global_step = Model.continue_previous_session(
            sess, model_file='cnn', ckpt_file='saver/cnn/checkpoint')

    if model_name == 'cldnn':
        from event_detect.cldnn import CLDNN
        import tensorflow as tf
        from tflib.models import Model

        model = CLDNN()
        sess = tf.Session()
        saver, global_step = Model.continue_previous_session(
            sess, model_file='cldnn', ckpt_file='saver/cldnn/checkpoint')

    for file in file_list:
        begin = datetime.datetime.now()
        traces = obspy.read(file[0])
        traces = traces + obspy.read(file[1])
        traces = traces + obspy.read(file[2])

        if not (traces[0].stats.starttime == traces[1].stats.starttime
                and traces[0].stats.starttime == traces[2].stats.starttime):
            starttime = max([
                traces[0].stats.starttime, traces[1].stats.starttime,
                traces[2].stats.starttime
            ])
            for j in range(3):
                traces[j] = traces[j].slice(starttime=starttime)

        if not (traces[0].stats.endtime == traces[1].stats.endtime
                and traces[0].stats.endtime == traces[2].stats.endtime):
            endtime = min([
                traces[0].stats.endtime, traces[1].stats.endtime,
                traces[2].stats.endtime
            ])
            for j in range(3):
                traces[j] = traces[j].slice(endtime=endtime)

        start_flag = -1
        end_flag = -1
        event_list = []

        for windowed_st in traces.slide(window_length=(config.winsize - 1) /
                                        100.0,
                                        step=config.winlag / 100.0):
            data_input = []
            for j in range(3):
                data_input.append(windowed_st[j].data)

            if model_name == 'cnn':
                # raw_data = [data_preprocess(d, 'bandpass', False) for d in data_input]
                data_input = np.array(data_input).T
                if preprocess:
                    # data_input = sklearn.preprocessing.minmax_scale(data_input)

                    data_mean = np.mean(data_input, axis=0)
                    data_input = np.absolute(data_input - data_mean)
                    data_input = data_input / (np.max(data_input, axis=0) +
                                               np.array([1, 1, 1]))
                data_input = np.array([np.array([data_input])])
            elif model_name == 'cldnn':
                # raw_data = [data_preprocess(d, 'bandpass', False) for d in data_input]

                data_input = [data_preprocess(d) for d in data_input]
                data_input = np.array(data_input).T
                data_input = np.array([data_input])

            class_pred, confidence = model.classify(sess=sess,
                                                    input_=data_input)
            if class_pred == 1:

                # plt.subplot(3, 1, 1)
                # plt.plot(raw_data[0])
                # plt.subplot(3, 1, 2)
                # plt.plot(raw_data[1])
                # plt.subplot(3, 1, 3)
                # plt.plot(raw_data[2])
                # plt.show()

                if start_flag == -1:
                    start_flag = windowed_st[0].stats.starttime
                    end_flag = windowed_st[0].stats.endtime
                else:
                    end_flag = windowed_st[0].stats.endtime

            if class_pred == 0 and start_flag != -1 and end_flag < windowed_st[
                    0].stats.starttime:
                event = [
                    file[0].split('\\')[-1][:-4], start_flag, end_flag,
                    confidence
                ]

                # print(event)

                event_list.append(event)
                start_flag = -1
                end_flag = -1

        if len(event_list) != 0:
            with open('detect_result/' + model_name + '/events_test.csv',
                      mode='a',
                      newline='') as f:
                csvwriter = csv.writer(f)
                for event in event_list:
                    csvwriter.writerow(event)
                f.close()

        start_point += 1
        with open('detect_result/' + model_name + '/checkpoint',
                  mode='w') as f:
            f.write(str(start_point))
            end = datetime.datetime.now()
            print('{} scanned, num {}, time {}.'.format(
                file[0].split('\\')[-1][:-4], start_point, end - begin))
            print('checkpoint saved.')
Exemple #6
0
 def start_sess(self, sess):
     saver, global_step = Model.continue_previous_session(sess,
                                                          model_file='cblstm',
                                                          ckpt_file='eqpickup/saver/cblstm/checkpoint')
Exemple #7
0
    def train(self, passes, new_training=True):
        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True
        with tf.Session(config=sess_config) as sess:
            if new_training:
                saver, global_step = Model.start_new_session(sess)
            else:
                saver, global_step = Model.continue_previous_session(sess,
                                                                     model_file='cblstm',
                                                                     ckpt_file='eqpickup/saver/cblstm/checkpoint')

            self.train_writer.add_graph(sess.graph, global_step=global_step)

            for step in range(1 + global_step, 1 + passes + global_step):
                with tf.variable_scope('Train'):
                    input_, targets = self.reader.get_cblstm_batch_data('train',
                                                                        self.reader.pre_data_generator,
                                                                        self.reader.pre_validation_data)
                    input_, seq_len = self.data_padding_preprocess(input_, 'input')
                    targets, _ = self.data_padding_preprocess(targets, 'targets')
                    _, train_summary, loss, pred_seq = sess.run(
                        [self.train_op, self.train_merged, self.loss, self.layer['pred_seq']],
                        feed_dict={self.layer['input']: input_,
                                   self.layer['targets']: targets,
                                   self.layer['seq_len']: seq_len,
                                   self.layer['keep_prob']: self.config.dl_tradition_model_config.cblstm_keep_prob})
                    self.train_writer.add_summary(train_summary, step)

                    train_p_err, train_p_err_max, train_s_err, train_s_err_max = get_p_s_error(pred_seq, targets,
                                                                                               seq_len)
                    train_acc = get_acc(pred_seq, targets, seq_len)

                    [train_metrics_summary] = sess.run(
                        [self.train_metrics_merged],
                        feed_dict={self.train_metrics['acc']: train_acc,
                                   self.train_metrics['p_error']: train_p_err,
                                   self.train_metrics['p_error_max']: train_p_err_max,
                                   self.train_metrics['s_error']: train_s_err,
                                   self.train_metrics['s_error_max']: train_s_err_max})
                    self.train_writer.add_summary(train_metrics_summary, step)
                    print("gobal_step {},"
                          " training_loss {},"
                          " accuracy {},"
                          " p_error {},"
                          " p_err_max {},"
                          " s_error {},"
                          " s_err_max {}.".format(step, loss, train_acc, train_p_err, train_p_err_max, train_s_err,
                                                  train_s_err_max))

                if step % 50 == 0:
                    with tf.variable_scope('Test', reuse=True):
                        test_input, test_targets = self.reader.get_cblstm_batch_data('test',
                                                                                     self.reader.pre_data_generator,
                                                                                     self.reader.pre_validation_data)
                        test_input, test_seq_len = self.data_padding_preprocess(test_input, 'input')
                        test_targets, _ = self.data_padding_preprocess(test_targets, 'targets')
                        [test_pred_seq] = sess.run([self.layer['pred_seq']],
                                                   feed_dict={self.layer['input']: test_input,
                                                              self.layer['seq_len']: test_seq_len,
                                                              self.layer['keep_prob']: 1.0})
                        test_p_err, test_p_err_max, test_s_err, test_s_err_max = get_p_s_error(test_pred_seq,
                                                                                               test_targets,
                                                                                               test_seq_len)
                        test_acc = get_acc(test_pred_seq,
                                           test_targets,
                                           test_seq_len)
                        [test_metrics_summary] = sess.run(
                            [self.test_metrics_merged],
                            feed_dict={self.test_metrics['acc']: test_acc,
                                       self.test_metrics['p_error']: test_p_err,
                                       self.test_metrics['p_error_max']: test_p_err_max,
                                       self.test_metrics['s_error']: test_s_err,
                                       self.test_metrics['s_error_max']: test_s_err_max})
                        self.train_writer.add_summary(test_metrics_summary, step)
                        print("test_acc {}, "
                              "test_p_err {},"
                              "test_p_err_max {},"
                              "test_s_err {},"
                              "test_s_err_max {}.".format(test_acc, test_p_err, test_p_err_max, test_s_err,
                                                          test_s_err_max))

                if step % 50 == 0:
                    saver.save(sess, 'eqpickup/saver/cblstm/cblstm', global_step=step)
                    print('checkpoint saved')
Exemple #8
0
        if get_pred_seq:
            return p_index, s_index, class_prob, pred_seq
        else:
            return p_index, s_index, class_prob


if __name__ == '__main__':
    # cblstm = CBLSTM()
    # cblstm.train(10000, False)

    ### test acc
    cblstm = CBLSTM()
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.allow_growth = True
    sess = tf.Session(config=sess_config)
    saver, global_step = Model.continue_previous_session(
        sess, model_file='cblstm', ckpt_file='saver/cblstm/checkpoint')

    acc_ave = 0
    p_err_ave = 0
    s_err_ave = 0
    p_err_max = []
    s_err_max = []
    for i in range(20):
        test_input, test_targets = cblstm.reader.get_birnn_batch_data('test')
        test_targets, test_seq_len = cblstm.data_padding_preprocess(
            test_targets, 'targets')
        events_p_index, events_s_index, _, pred_seq = cblstm.pickup_p_s(
            sess, test_input, get_pred_seq=True)
        tmp_acc = get_acc(pred_seq, test_targets, test_seq_len)
        acc_ave += tmp_acc
        tmp_p_error, tmp_p_err_max, tmp_s_error, tmp_s_err_max = get_p_s_error(
Exemple #9
0
 def start_sess(self, sess):
     saver, global_step = Model.continue_previous_session(sess,
                                                          model_file='cnn',
                                                          ckpt_file='eqdetector/saver/cnn/checkpoint')
def main(model_name, new_scan=False, preprocess=True):
    reader = Reader()
    config = Config()
    plot = config.plot
    bandpass = config.bandpass
    resample = config.resample

    confidence0=[]
    plot_num = 0
    reader.aftername.sort()

    if new_scan == True:
        print('start new scan!')
        file_list = reader.aftername

        start_point = 0
    else:
        with open(config.root + '/event_detect/detect_result/' + model_name + '/checkpoint') as file:
            start_point = int(file.readline())
            file_list = reader.aftername[start_point:]
            print('restart from {}'.format(file_list[0]))

    if model_name == 'cnn':
        from cnn import CNN
        import tensorflow as tf
        from tflib.models import Model

        model = CNN()
        sess = tf.Session(config=tf.ConfigProto(device_count={"CPU":20},inter_op_parallelism_threads=0,intra_op_parallelism_threads=0))
        saver, global_step = Model.continue_previous_session(sess,
                                                             model_file='cnn',
                                                             ckpt_file=config.root + '/event_detect/saver/cnn/checkpoint')
    file_list_len = len(file_list)
    # print(file_list_len)

    try:
        os.system('rm -rf %s/event_detect/detect_result/png/*'%config.root)
        os.system('rm -rf %s/event_detect/detect_result/cnn/*.csv'%config.root)
    except:
        pass

    for file in file_list:
        file=np.array(file)
        #print(file)
        #file=file.T
        #np.random.shuffle(file)  #random
        #file=file.T

        #print(file,'\n')
        begin = datetime.datetime.now()
        if plot:
            plot_traces = obspy.read(file[2][0]) #Z component
        sta_num = len(file[0])
        trace_len = []

        for i in range(3):
            for j in range(sta_num):
                trace_len.append(obspy.read(file[i][j])[0].stats.npts)
        max_len = max(trace_len)

        for i in range(3):
            for j in range(sta_num):        # station number
                each_tr = obspy.read(file[i][j])
                if each_tr[0].stats.npts < max_len:
                    zero = np.zeros(max_len-each_tr[0].stats.npts)
                    each_tr[0].data = np.concatenate([each_tr[0].data,zero])
                if i==j==0:
                    traces = each_tr
                else:
                    traces=traces + each_tr
                if i == 2:
                    if j == 0:
                        pass
                    else:
                        plot_traces = plot_traces + each_tr

        if plot:
            if resample:
                plot_traces = plot_traces.resample(sampling_rate=resample)
            plot_traces = plot_traces.filter('bandpass',freqmin=bandpass[0],freqmax=bandpass[1],corners=4,zerophase=True)
        
        if resample:
            traces = traces.resample(sampling_rate=resample)
        traces = traces.filter('bandpass',freqmin=bandpass[0],freqmax=bandpass[1],corners=4,zerophase=True)
        starttime = traces[0].stats.starttime;
        endtime = traces[0].stats.endtime;
        #print(traces)

        start_flag = -1
        end_flag = -1
        event_list = []
        confidence_total=[]
        start_total=[]
        end_total=[]
        samples_trace= 1.0/traces[0].stats.delta;
        npts = traces[0].stats.npts


        for windowed_st in traces.slide(window_length=(config.winsize-1)/samples_trace,
                                        step=config.winlag / samples_trace):
            data_input = [[],[],[]]

            for j in range(sta_num):
                if len(windowed_st[j].data) < config.winsize:
                    windowed_st[j].data = np.concatenate([windowed_st[j].data,np.zeros(config.winsize-len(windowed_st[j].data))])
                data_input[0].append(windowed_st[j].data[:config.winsize])
            for j in range(sta_num,2*sta_num):
                if len(windowed_st[j].data) < config.winsize:
                    windowed_st[j].data = np.concatenate([windowed_st[j].data,np.zeros(config.winsize-len(windowed_st[j].data))])
                data_input[1].append(windowed_st[j].data[:config.winsize])
            for j in range(2*sta_num,3*sta_num):
                if len(windowed_st[j].data) < config.winsize:
                    windowed_st[j].data = np.concatenate([windowed_st[j].data,np.zeros(config.winsize-len(windowed_st[j].data))])
                data_input[2].append(windowed_st[j].data[:config.winsize])

            if model_name == 'cnn':

                if preprocess:
                    for i in range(3):
                        for j in range(sta_num):
                            data_input[i][j] = data_preprocess(data_input[i][j])

                data_input=np.array(data_input)

                if len(data_input[0][0])<config.winsize:
                    concat = np.zeros([3, sta_num, config.winsize - len(data_input[0][0])])
                    data_input=np.concatenate([data_input,concat],axis=2)

                if len(data_input[0][0])>config.winsize:
                    data_input=data_input[:, :, :config.winsize]

                data_input=data_input.transpose((1,2,0))
                data_input = np.array([data_input])
                #print(event_list)

            class_pred, confidence = model.classify(sess=sess, input_=data_input)
            confidence0.append(confidence)

            print(class_pred,confidence)
            if class_pred == 1:
                confidence_total.append(confidence)
                start_total.append(windowed_st[0].stats.starttime)
                end_total.append(windowed_st[0].stats.endtime)

                if start_flag == -1:
                    start_flag = windowed_st[0].stats.starttime
                    end_flag = windowed_st[0].stats.endtime
                else:
                    end_flag = windowed_st[0].stats.endtime
            print(class_pred,start_flag,end_flag,windowed_st[0].stats.starttime)

            if class_pred == 0 and start_flag != -1:  #end_flag < windowed_st[0].stats.starttime:

                confidence = np.max(confidence_total)
                for j in range(len(confidence_total)):
                    if confidence == confidence_total[j]:
                        break
                start_local = start_total[j]
                end_local = end_total[j]
                a=True


                # event = [file[0][0].split('/')[-2], start_flag, end_flag,
                #          confidence, start_local, end_local]
                event = [file[0][0].split('/')[-2], start_flag, end_flag, confidence]

                confidence_total=[]
                start_total = []
                end_total = []

                if plot:
                    plot_num = int(plot_num + 1)
                    name = config.root + '/event_detect/detect_result/png/' \
                           + str(plot_num) + '_' + str(confidence) + '.png'
                    plot_traces.plot(starttime=start_flag, endtime=end_flag, size=(800, 800),
                                    automerge=False, equal_scale=False, linewidth=0.8, outfile=name)

                # print(event)

                event_list.append(event)
                #print(event_list)

                start_flag = -1
                end_flag = -1

            if class_pred == 1 and end_flag+config.winlag / samples_trace>=endtime:
                confidence = np.max(confidence_total)
                for j in range(len(confidence_total)):
                    if confidence == confidence_total[j]:
                        break
                start_local = start_total[j]
                end_local = end_total[j]

                if plot:
                    plot_num = int(plot_num + 1)
                    name = config.root + '/event_detect/detect_result/png/' \
                           + str(plot_num) + '_' + str(confidence) + '.png'
                    plot_traces.plot(starttime=start_flag, endtime=endtime, size=(800, 800),
                                     automerge=False, equal_scale=False, linewidth=0.8, outfile=name)

                # event = [file[0][0].split('/')[-2], start_flag, endtime,
                #          confidence, start_total, end_total]
                event = [file[0][0].split('/')[-2], start_flag, endtime, confidence]

                event_list.append(event)
                start_flag = -1
                end_flag = -1

        if len(event_list) != 0:
            with open(config.root + '/event_detect/detect_result/' + model_name + '/events_list.csv', mode='a', newline='') as f:
                csvwriter = csv.writer(f)
                for event in event_list:
                    csvwriter.writerow(event)
                f.close()

        start_point += 1
        with open(config.root + '/event_detect/detect_result/' + model_name + '/checkpoint', mode='w') as f:
            f.write(str(start_point))
            end = datetime.datetime.now()
            print('{} scanned, num {}, time {}.'.format(file[0][0].split('/')[-2], start_point, end - begin))
            print('checkpoint saved.')
Exemple #11
0
    def train(self, passes, new_training=True):
        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True
        sess_config.allow_soft_placement = True
        with tf.Session(config=sess_config) as sess:
            if new_training:
                saver, global_step = Model.start_new_session(sess)
            else:
                saver, global_step = Model.continue_previous_session(sess,
                                                                     model_file='model/saver/{}'.format(self.saveFile),
                                                                     ckpt_file='model/saver/{}/checkpoint'.format(
                                                                         self.saveFile))

            self.train_writer.add_graph(sess.graph, global_step=global_step)

            walk_times = 1
            for step in range(1 + global_step, 1 + passes + global_step):
                with tf.variable_scope('Train'):
                    walk_nodes = self.reader.nodes_walk_reader()
                    neg_walk_nodes = [self.g.negative_sample(walk_nodes[i],
                                                             self.config.loss1_neg_sample_num,
                                                             self.g.nodes_degree_table)
                                      for i in range(len(walk_nodes))]
                    neg_walk_nodes = np.array(neg_walk_nodes)
                    walk_nodes_labels = list()
                    for node_list in walk_nodes:
                        nodes_label_tmp = self.g.get_train_node_label(node_list)
                        walk_nodes_labels.append(nodes_label_tmp)
                    walk_nodes_labels = np.array(walk_nodes_labels)

                    # if (step - 1) % int(self.g.nodes_num / self.config.nodes_seq_batch_num) == 0:
                    #     print(walk_times)
                    #     walk_times += 1

                    if step < 200 and self.init_emb_file is not None:
                        train_op = self.train_op[1]
                    else:
                        train_op = self.train_op[0]
                    _, train_summary, loss = sess.run(
                        [train_op,
                         self.loss_train_merged,
                         self.layer['loss']],
                        feed_dict={self.layer['walk_nodes']: walk_nodes,
                                   self.layer['walk_nodes_labels']: walk_nodes_labels,
                                   self.layer['neg_walk_nodes']: neg_walk_nodes})

                    self.train_writer.add_summary(train_summary, step)

                    if step % 500 == 0 or step == 1:
                        [node_emb, sup_emb] = sess.run([self.layer['emb'],
                                                        self.layer['sup_emb']])
                        node_emb = np.concatenate((node_emb, sup_emb), axis=1)
                        print("gobal_step {},loss {}".format(step, loss))

                    if step % 1000 == 0 or step == 1:
                        micro_f1, macro_f1 = self.multi_label_node_classification(node_emb)
                        [test_summary] = sess.run([self.test_merged],
                                                  feed_dict={self.test_metrics['micro_f1']: micro_f1,
                                                             self.test_metrics['macro_f1']: macro_f1})
                        print("micro_f1 {},macro_f1 {}".format(micro_f1, macro_f1))
                        self.train_writer.add_summary(test_summary, step)
                        saver.save(sess, 'model/saver/{}/MPRSNE'.format(self.saveFile), global_step=step)
                        print('checkpoint saved')