def train(): """ Train unet using specified args: """ data_files, data_size = load_datafiles(FLAGS.tfrecords_prefix) print data_files, data_size # images, labels, filenames = dataset_loader.inputs( # data_files = data_files, # image_size = FLAGS.image_size, # batch_size = FLAGS.batch_size, # num_epochs = FLAGS.num_epochs, # train = True) setproctitle.setproctitle('quakenet') tf.set_random_seed(1234) cfg = config.Config() cfg.batch_size = FLAGS.batch_size cfg.add = 1 cfg.n_clusters = FLAGS.num_classes cfg.n_clusters += 1 # data pipeline for positive and negative examples pos_pipeline = dp.DataPipeline(FLAGS.tfrecords_dir, cfg, True) # images:[batch_size, n_channels, n_points] images = pos_pipeline.samples labels = pos_pipeline.labels logits = unet.build_30s(images, FLAGS.num_classes, True) accuarcy = unet.accuracy(logits, labels) print "accuarcy,recall,f1", accuarcy #load class weights if available if FLAGS.class_weights is not None: weights = np.load(FLAGS.class_weights) class_weight_tensor = tf.constant(weights, dtype=tf.float32, shape=[FLAGS.num_classes, 1]) else: class_weight_tensor = None loss = unet.loss(logits, labels, FLAGS.weight_decay_rate) global_step = tf.Variable(0, name='global_step', trainable=False) train_op = unet.train(loss, FLAGS.learning_rate, FLAGS.learning_rate_decay_steps, FLAGS.learning_rate_decay_rate, global_step) #print "train_op",train_op init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver() session_manager = tf.train.SessionManager( local_init_op=tf.local_variables_initializer()) sess = session_manager.prepare_session("", init_op=init_op, saver=saver, checkpoint_dir=FLAGS.checkpoint_dir) writer = tf.summary.FileWriter(FLAGS.checkpoint_dir + "/train_logs", sess.graph) merged = tf.summary.merge_all() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) start_time = time.time() try: while not coord.should_stop(): step = tf.train.global_step(sess, global_step) _, loss_value, summary = sess.run([train_op, loss, merged]) #print loss_value writer.add_summary(summary, step) if step % 1000 == 0: acc_seg_value = sess.run([accuarcy]) #print "acc_seg_value:",acc_seg_value,acc_seg_value[0],acc_seg_value[0][1],acc_seg_value[0][1][0] epoch = step * FLAGS.batch_size / data_size #print epoch duration = time.time() - start_time #print step,duration start_time = time.time() #print('[PROGRESS]\tEpoch %d | Step %d | loss = %.2f | total. acc. = %.2f | P. acc. = %.3f \ # | S. acc. = %.3f | N. acc. = %.3f | dur. = (%.3f sec)'\ # % (epoch, step, loss_value, acc_seg_value[0][1][0],acc_seg_value[0][1][1], acc_seg_value[0][1][2],\ # acc_seg_value[0][3],duration)) print('[PROGRESS]\tEpoch %d | Step %d | loss = %.2f | P. acc. = %.3f \ | S. acc. = %.3f | N. acc. = %.3f | dur. = (%.3f sec)'\ % (epoch, step, loss_value, acc_seg_value[0][1][1],acc_seg_value[0][1][2], acc_seg_value[0][1][0],\ duration)) if step % 5000 == 0: print('[PROGRESS]\tSaving checkpoint') checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'unet.ckpt') saver.save(sess, checkpoint_path, global_step=step) except tf.errors.OutOfRangeError: print('[INFO ]\tDone training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) finally: # When done, ask the threads to stop. coord.request_stop() # Wait for threads to finish. coord.join(threads) writer.close() sess.close()
def evaluate(): """ Eval unet using specified args: Note: restore the pretrained model from checkpoint!! """ data_files, data_size = load_datafiles(FLAGS.tfrecords_prefix) images, labels, filenames = dataset_loader.inputs( data_files=data_files, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size, num_epochs=1, train=False) logits = unet.build(images, FLAGS.num_classes, False) predicted_images = unet.predict(logits, FLAGS.batch_size, FLAGS.image_size) accuracy = unet.accuracy(logits, labels) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess = tf.Session() sess.run(init_op) saver = tf.train.Saver() if not tf.gfile.Exists(FLAGS.checkpoint_path + '.meta'): raise ValueError("Can't find checkpoint file") else: print('[INFO ]\tFound checkpoint file, restoring model.') saver.restore(sess, FLAGS.checkpoint_path) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) global_accuracy = 0.0 step = 0 try: while not coord.should_stop(): acc_seg_value, predicted_images_value, filenames_value = sess.run( [accuracy, predicted_images, filenames]) global_accuracy += acc_seg_value maybe_save_images(predicted_images_value, filenames_value) print('[PROGRESS]\tAccuracy for current batch: %.5f' % (acc_seg_value)) step += 1 except tf.errors.OutOfRangeError: print('[INFO ]\tDone evaluating in %d steps.' % step) finally: # When done, ask the threads to stop. coord.request_stop() global_accuracy = global_accuracy / step print('[RESULT ]\tGlobal accuracy = %.5f' % (global_accuracy)) # Wait for threads to finish. coord.join(threads) sess.close()
def train(): """ Train unet using specified args: """ data_files, data_size = load_datafiles(FLAGS.tfrecords_prefix) images, labels, filenames = dataset_loader.inputs( data_files = data_files, image_size = FLAGS.image_size, batch_size = FLAGS.batch_size, num_epochs = FLAGS.num_epochs, train = True) logits = unet.build(images, FLAGS.num_classes, True) accuarcy = unet.accuracy(logits, labels) #load class weights if available if FLAGS.class_weights is not None: weights = np.load(FLAGS.class_weights) class_weight_tensor = tf.constant(weights, dtype=tf.float32, shape=[FLAGS.num_classes, 1]) else: class_weight_tensor = None loss = unet.loss(logits, labels, FLAGS.weight_decay_rate, class_weight_tensor) global_step = tf.Variable(0, name = 'global_step', trainable = False) train_op = unet.train(loss, FLAGS.learning_rate, FLAGS.learning_rate_decay_steps, FLAGS.learning_rate_decay_rate, global_step) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver() session_manager = tf.train.SessionManager(local_init_op = tf.local_variables_initializer()) sess = session_manager.prepare_session("", init_op = init_op, saver = saver, checkpoint_dir = FLAGS.checkpoint_dir) writer = tf.summary.FileWriter(FLAGS.checkpoint_dir + "/train_logs", sess.graph) merged = tf.summary.merge_all() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess = sess, coord = coord) start_time = time.time() try: while not coord.should_stop(): step = tf.train.global_step(sess, global_step) _, loss_value, summary = sess.run([train_op, loss, merged]) writer.add_summary(summary, step) if step % 1000 == 0: acc_seg_value = sess.run([accuarcy]) epoch = step * FLAGS.batch_size / data_size duration = time.time() - start_time start_time = time.time() print('[PROGRESS]\tEpoch %d, Step %d: loss = %.2f, accuarcy = %.2f (%.3f sec)' % (epoch, step, loss_value, acc_seg_value, duration)) if step % 5000 == 0: print('[PROGRESS]\tSaving checkpoint') checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'unet.ckpt') saver.save(sess, checkpoint_path, global_step = step) except tf.errors.OutOfRangeError: print('[INFO ]\tDone training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) finally: # When done, ask the threads to stop. coord.request_stop() # Wait for threads to finish. coord.join(threads) writer.close() sess.close()
def evaluate(): """ Eval unet using specified args: """ if FLAGS.events: summary_dir = os.path.join(FLAGS.checkpoint_path, "events") while True: ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_path) if FLAGS.eval_interval < 0 or ckpt: print('Evaluating model') break print('Waiting for training job to save a checkpoint') time.sleep(FLAGS.eval_interval) #data_files, data_size = load_datafiles(FLAGS.tfrecords_prefix) setproctitle.setproctitle('quakenet') tf.set_random_seed(1234) cfg = config.Config() cfg.batch_size = FLAGS.batch_size cfg.add = 1 cfg.n_clusters = FLAGS.num_classes cfg.n_clusters += 1 cfg.n_epochs = 1 model_files = [ file for file in os.listdir(FLAGS.checkpoint_path) if fnmatch.fnmatch(file, '*.meta') ] for model_file in sorted(model_files): step = model_file.split(".meta")[0].split("-")[1] print(step) try: model_file = os.path.join(FLAGS.checkpoint_path, model_file) # data pipeline for positive and negative examples pos_pipeline = dp.DataPipeline(FLAGS.tfrecords_dir, cfg, True) # images:[batch_size, n_channels, n_points] images = pos_pipeline.samples labels = pos_pipeline.labels logits = unet.build_30s(images, FLAGS.num_classes, False) predicted_images = unet.predict(logits, FLAGS.batch_size, FLAGS.image_size) accuracy = unet.accuracy(logits, labels) loss = unet.loss(logits, labels, FLAGS.weight_decay_rate) summary_writer = tf.summary.FileWriter(summary_dir, None) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess = tf.Session() sess.run(init_op) saver = tf.train.Saver() #if not tf.gfile.Exists(FLAGS.checkpoint_path + '.meta'): if not tf.gfile.Exists(model_file): raise ValueError("Can't find checkpoint file") else: print('[INFO ]\tFound checkpoint file, restoring model.') saver.restore(sess, model_file.split(".meta")[0]) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) #metrics = validation_metrics() global_accuracy = 0.0 global_p_accuracy = 0.0 global_s_accuracy = 0.0 global_n_accuracy = 0.0 global_loss = 0.0 n = 0 #mean_metrics = {} #for key in metrics: # mean_metrics[key] = 0 #pred_labels = np.empty(1) #true_labels = np.empty(1) try: while not coord.should_stop(): acc_seg_value, loss_value, predicted_images_value, images_value = sess.run( [accuracy, loss, predicted_images, images]) accuracy_p_value = acc_seg_value[1][1] accuracy_s_value = acc_seg_value[1][2] accuracy_n_value = acc_seg_value[1][0] #pred_labels = np.append(pred_labels, predicted_images_value) #true_labels = np.append(true_labels, images_value) global_accuracy += acc_seg_value global_p_accuracy += accuracy_p_value global_s_accuracy += accuracy_s_value global_n_accuracy += accuracy_n_value global_loss += loss_value # print true_labels #for key in metrics: # mean_metrics[key] += cfg.batch_size * metrics_[key] filenames_value = [] # for i in range(FLAGS.batch_size): # filenames_value.append(str(step)+"_"+str(i)+".png") #print (predicted_images_value[:,100:200]) if (FLAGS.plot): maybe_save_images(predicted_images_value, images_value, filenames_value) #s='loss = {:.5f} | det. acc. = {:.1f}% | loc. acc. = {:.1f}%'.format(metrics['loss'] print( '[PROGRESS]\tAccuracy for current batch: | P. acc. =%.5f| S. acc. =%.5f| ' 'noise. acc. =%.5f.' % (accuracy_p_value, accuracy_s_value, accuracy_n_value)) n += cfg.batch_size # step += 1 print(n) except KeyboardInterrupt: print('stopping evaluation') except tf.errors.OutOfRangeError: print('Evaluation completed ({} epochs).'.format(cfg.n_epochs)) print("{} windows seen".format(n)) #print('[INFO ]\tDone evaluating in %d steps.' % step) if n > 0: loss_value /= n summary = tf.Summary(value=[ tf.Summary.Value(tag='loss/val', simple_value=loss_value) ]) if FLAGS.save_summary: summary_writer.add_summary(summary, global_step=step) global_accuracy /= n global_p_accuracy /= n global_s_accuracy /= n global_n_accuracy /= n summary = tf.Summary(value=[ tf.Summary.Value(tag='accuracy/val', simple_value=global_accuracy) ]) if FLAGS.save_summary: summary_writer.add_summary(summary, global_step=step) summary = tf.Summary(value=[ tf.Summary.Value(tag='accuracy/val_p', simple_value=global_p_accuracy) ]) if FLAGS.save_summary: summary_writer.add_summary(summary, global_step=step) summary = tf.Summary(value=[ tf.Summary.Value(tag='accuracy/val_s', simple_value=global_s_accuracy) ]) if FLAGS.save_summary: summary_writer.add_summary(summary, global_step=step) summary = tf.Summary(value=[ tf.Summary.Value(tag='accuracy/val_noise', simple_value=global_n_accuracy) ]) if FLAGS.save_summary: summary_writer.add_summary(summary, global_step=step) print( '[End of evaluation for current epoch]\n\nAccuracy for current epoch:%s | total. acc. =%.5f| P. acc. =%.5f| S. acc. =%.5f| ' 'noise. acc. =%.5f.' % (step, global_accuracy, global_p_accuracy, global_s_accuracy, global_n_accuracy)) print('Sleeping for {}s'.format(FLAGS.eval_interval)) time.sleep(FLAGS.eval_interval) summary_writer.flush() finally: # When done, ask the threads to stop. coord.request_stop() tf.reset_default_graph() #print('Sleeping for {}s'.format(FLAGS.eval_interval)) #time.sleep(FLAGS.eval_interval) finally: print('joining data threads') coord = tf.train.Coordinator() coord.request_stop() #pred_labels = pred_labels[1::] #true_labels = true_labels[1::] #print ("---Confusion Matrix----") #print (confusion_matrix(true_labels, pred_labels)) # Wait for threads to finish. coord.join(threads) sess.close()