Exemple #1
0
BATCH_SIZE = 8
NUM_EPOCHS = 100
class_weights = [0.6, 0.4]

print("Start loading...")
x_train = read.load_data_npy("/cmach-data/segthor/Train/")
y_train = read.load_label_npy("/cmach-data/segthor/Train/")
print("Finish loading.")

X = tf.placeholder(tf.float32, [BATCH_SIZE, 512, 512, 1], name='images')
Y = tf.placeholder(tf.int32, [BATCH_SIZE, 512, 512, 1], name='labels')

logits = unet.inference(X)

loss = unet.loss(logits, Y)  #, class_weights=class_weights)
tf.summary.scalar('total_loss', loss)

global_step = tf.train.get_or_create_global_step()
# lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
#                                 global_step,
#                                 NUM_EPOCHS_PER_DECAY,
#                                 LEARNING_RATE_DECAY_FACTOR,
#                                 staircase=True)
# lr = tf.train.piecewise_constant(global_step,[500,700,1200,1600,2000,2500,3200,4000],[0.1,0.05,0.02,0.01,0.004,0.001,0.0005,0.0001,0.00005])
# tf.summary.scalar('learning_rate', lr)

# optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr).minimize(loss, global_step=global_step)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(
    loss, global_step=global_step)
def train():
    """
    Train unet using specified args:
    """

    data_files, data_size = load_datafiles(FLAGS.tfrecords_prefix)
    images, labels, filenames = dataset_loader.inputs(
                                    data_files = data_files,
                                    image_size = FLAGS.image_size,
                                    batch_size = FLAGS.batch_size,
                                    num_epochs = FLAGS.num_epochs,
                                    train = True)

    logits = unet.build(images, FLAGS.num_classes, True)

    accuarcy = unet.accuracy(logits, labels)

    #load class weights if available
    if FLAGS.class_weights is not None:
        weights = np.load(FLAGS.class_weights)
        class_weight_tensor = tf.constant(weights, dtype=tf.float32, shape=[FLAGS.num_classes, 1])
    else:
        class_weight_tensor = None

    loss = unet.loss(logits, labels, FLAGS.weight_decay_rate, class_weight_tensor)

    global_step = tf.Variable(0, name = 'global_step', trainable = False)
    train_op = unet.train(loss, FLAGS.learning_rate, FLAGS.learning_rate_decay_steps, FLAGS.learning_rate_decay_rate, global_step)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver()
    session_manager = tf.train.SessionManager(local_init_op = tf.local_variables_initializer())
    sess = session_manager.prepare_session("", init_op = init_op, saver = saver, checkpoint_dir = FLAGS.checkpoint_dir)

    writer = tf.summary.FileWriter(FLAGS.checkpoint_dir + "/train_logs", sess.graph)

    merged = tf.summary.merge_all()

    coord = tf.train.Coordinator()

    threads = tf.train.start_queue_runners(sess = sess, coord = coord)

    start_time = time.time()

    try:
        while not coord.should_stop():
            step = tf.train.global_step(sess, global_step)
            
            _, loss_value, summary = sess.run([train_op, loss, merged])
            writer.add_summary(summary, step)

            if step % 1000 == 0:
                acc_seg_value = sess.run([accuarcy])

                epoch = step * FLAGS.batch_size / data_size
                duration = time.time() - start_time
                start_time = time.time()

                print('[PROGRESS]\tEpoch %d, Step %d: loss = %.2f, accuarcy = %.2f (%.3f sec)' % (epoch, step, loss_value, acc_seg_value, duration))

            if step % 5000 == 0:
                print('[PROGRESS]\tSaving checkpoint')
                checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'unet.ckpt')
                saver.save(sess, checkpoint_path, global_step = step)

    except tf.errors.OutOfRangeError:
        print('[INFO    ]\tDone training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))

    finally:
        # When done, ask the threads to stop.
        coord.request_stop()

    # Wait for threads to finish.
    coord.join(threads)
    writer.close()
    sess.close()
Exemple #3
0
def train():
    """
    Train unet using specified args:
    """

    data_files, data_size = load_datafiles(FLAGS.tfrecords_prefix)
    print data_files, data_size
    #  images, labels, filenames = dataset_loader.inputs(
    #                                 data_files = data_files,
    #                                 image_size = FLAGS.image_size,
    #                             batch_size = FLAGS.batch_size,
    #                                 num_epochs = FLAGS.num_epochs,
    #                                 train = True)
    setproctitle.setproctitle('quakenet')

    tf.set_random_seed(1234)

    cfg = config.Config()
    cfg.batch_size = FLAGS.batch_size
    cfg.add = 1
    cfg.n_clusters = FLAGS.num_classes
    cfg.n_clusters += 1

    # data pipeline for positive and negative examples
    pos_pipeline = dp.DataPipeline(FLAGS.tfrecords_dir, cfg, True)
    #  images:[batch_size, n_channels, n_points]
    images = pos_pipeline.samples
    labels = pos_pipeline.labels
    logits = unet.build_30s(images, FLAGS.num_classes, True)
    accuarcy = unet.accuracy(logits, labels)
    print "accuarcy,recall,f1", accuarcy
    #load class weights if available
    if FLAGS.class_weights is not None:
        weights = np.load(FLAGS.class_weights)
        class_weight_tensor = tf.constant(weights,
                                          dtype=tf.float32,
                                          shape=[FLAGS.num_classes, 1])
    else:
        class_weight_tensor = None
    loss = unet.loss(logits, labels, FLAGS.weight_decay_rate)
    global_step = tf.Variable(0, name='global_step', trainable=False)
    train_op = unet.train(loss, FLAGS.learning_rate,
                          FLAGS.learning_rate_decay_steps,
                          FLAGS.learning_rate_decay_rate, global_step)
    #print "train_op",train_op

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    saver = tf.train.Saver()
    session_manager = tf.train.SessionManager(
        local_init_op=tf.local_variables_initializer())
    sess = session_manager.prepare_session("",
                                           init_op=init_op,
                                           saver=saver,
                                           checkpoint_dir=FLAGS.checkpoint_dir)

    writer = tf.summary.FileWriter(FLAGS.checkpoint_dir + "/train_logs",
                                   sess.graph)

    merged = tf.summary.merge_all()

    coord = tf.train.Coordinator()

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    start_time = time.time()

    try:
        while not coord.should_stop():

            step = tf.train.global_step(sess, global_step)
            _, loss_value, summary = sess.run([train_op, loss, merged])
            #print loss_value
            writer.add_summary(summary, step)
            if step % 1000 == 0:
                acc_seg_value = sess.run([accuarcy])
                #print "acc_seg_value:",acc_seg_value,acc_seg_value[0],acc_seg_value[0][1],acc_seg_value[0][1][0]
                epoch = step * FLAGS.batch_size / data_size
                #print epoch
                duration = time.time() - start_time
                #print step,duration
                start_time = time.time()
                #print('[PROGRESS]\tEpoch %d | Step %d | loss = %.2f | total. acc. = %.2f | P. acc. =  %.3f \
                #      | S. acc. =  %.3f | N. acc. =  %.3f | dur. = (%.3f sec)'\
                #      % (epoch, step, loss_value, acc_seg_value[0][1][0],acc_seg_value[0][1][1], acc_seg_value[0][1][2],\
                #         acc_seg_value[0][3],duration))

                print('[PROGRESS]\tEpoch %d | Step %d | loss = %.2f | P. acc. =  %.3f \
                      | S. acc. =  %.3f | N. acc. =  %.3f | dur. = (%.3f sec)'\
                      % (epoch, step, loss_value, acc_seg_value[0][1][1],acc_seg_value[0][1][2], acc_seg_value[0][1][0],\
                         duration))
            if step % 5000 == 0:
                print('[PROGRESS]\tSaving checkpoint')
                checkpoint_path = os.path.join(FLAGS.checkpoint_dir,
                                               'unet.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

    except tf.errors.OutOfRangeError:
        print('[INFO    ]\tDone training for %d epochs, %d steps.' %
              (FLAGS.num_epochs, step))

    finally:
        # When done, ask the threads to stop.
        coord.request_stop()

    # Wait for threads to finish.
    coord.join(threads)
    writer.close()
    sess.close()
Exemple #4
0
def evaluate():
    """
    Eval unet using specified args:
    """
    if FLAGS.events:
        summary_dir = os.path.join(FLAGS.checkpoint_path, "events")
    while True:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
        if FLAGS.eval_interval < 0 or ckpt:
            print('Evaluating model')
            break
        print('Waiting for training job to save a checkpoint')
        time.sleep(FLAGS.eval_interval)

    #data_files, data_size = load_datafiles(FLAGS.tfrecords_prefix)

    setproctitle.setproctitle('quakenet')

    tf.set_random_seed(1234)

    cfg = config.Config()
    cfg.batch_size = FLAGS.batch_size
    cfg.add = 1
    cfg.n_clusters = FLAGS.num_classes
    cfg.n_clusters += 1
    cfg.n_epochs = 1
    model_files = [
        file for file in os.listdir(FLAGS.checkpoint_path)
        if fnmatch.fnmatch(file, '*.meta')
    ]
    for model_file in sorted(model_files):
        step = model_file.split(".meta")[0].split("-")[1]
        print(step)
        try:
            model_file = os.path.join(FLAGS.checkpoint_path, model_file)
            # data pipeline for positive and negative examples
            pos_pipeline = dp.DataPipeline(FLAGS.tfrecords_dir, cfg, True)
            #  images:[batch_size, n_channels, n_points]
            images = pos_pipeline.samples
            labels = pos_pipeline.labels
            logits = unet.build_30s(images, FLAGS.num_classes, False)

            predicted_images = unet.predict(logits, FLAGS.batch_size,
                                            FLAGS.image_size)

            accuracy = unet.accuracy(logits, labels)
            loss = unet.loss(logits, labels, FLAGS.weight_decay_rate)
            summary_writer = tf.summary.FileWriter(summary_dir, None)

            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())

            sess = tf.Session()

            sess.run(init_op)

            saver = tf.train.Saver()

            #if not tf.gfile.Exists(FLAGS.checkpoint_path + '.meta'):
            if not tf.gfile.Exists(model_file):
                raise ValueError("Can't find checkpoint file")
            else:
                print('[INFO    ]\tFound checkpoint file, restoring model.')
                saver.restore(sess, model_file.split(".meta")[0])

            coord = tf.train.Coordinator()

            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            #metrics = validation_metrics()
            global_accuracy = 0.0
            global_p_accuracy = 0.0
            global_s_accuracy = 0.0
            global_n_accuracy = 0.0
            global_loss = 0.0

            n = 0
            #mean_metrics = {}
            #for key in metrics:
            #    mean_metrics[key] = 0
            #pred_labels = np.empty(1)
            #true_labels = np.empty(1)

            try:
                while not coord.should_stop():
                    acc_seg_value, loss_value, predicted_images_value, images_value = sess.run(
                        [accuracy, loss, predicted_images, images])
                    accuracy_p_value = acc_seg_value[1][1]
                    accuracy_s_value = acc_seg_value[1][2]
                    accuracy_n_value = acc_seg_value[1][0]
                    #pred_labels = np.append(pred_labels, predicted_images_value)
                    #true_labels = np.append(true_labels, images_value)
                    global_accuracy += acc_seg_value
                    global_p_accuracy += accuracy_p_value
                    global_s_accuracy += accuracy_s_value
                    global_n_accuracy += accuracy_n_value
                    global_loss += loss_value
                    # print  true_labels
                    #for key in metrics:
                    #    mean_metrics[key] += cfg.batch_size * metrics_[key]
                    filenames_value = []
                    # for i in range(FLAGS.batch_size):
                    #     filenames_value.append(str(step)+"_"+str(i)+".png")
                    #print (predicted_images_value[:,100:200])
                    if (FLAGS.plot):
                        maybe_save_images(predicted_images_value, images_value,
                                          filenames_value)
                    #s='loss = {:.5f} | det. acc. = {:.1f}% | loc. acc. = {:.1f}%'.format(metrics['loss']
                    print(
                        '[PROGRESS]\tAccuracy for current batch: |  P. acc. =%.5f| S. acc. =%.5f| '
                        'noise. acc. =%.5f.' %
                        (accuracy_p_value, accuracy_s_value, accuracy_n_value))
                    n += cfg.batch_size
                    #  step += 1
                    print(n)
            except KeyboardInterrupt:
                print('stopping evaluation')
            except tf.errors.OutOfRangeError:
                print('Evaluation completed ({} epochs).'.format(cfg.n_epochs))
                print("{} windows seen".format(n))
                #print('[INFO    ]\tDone evaluating in %d steps.' % step)
                if n > 0:
                    loss_value /= n
                    summary = tf.Summary(value=[
                        tf.Summary.Value(tag='loss/val',
                                         simple_value=loss_value)
                    ])
                    if FLAGS.save_summary:
                        summary_writer.add_summary(summary, global_step=step)
                    global_accuracy /= n
                    global_p_accuracy /= n
                    global_s_accuracy /= n
                    global_n_accuracy /= n
                    summary = tf.Summary(value=[
                        tf.Summary.Value(tag='accuracy/val',
                                         simple_value=global_accuracy)
                    ])
                    if FLAGS.save_summary:
                        summary_writer.add_summary(summary, global_step=step)
                    summary = tf.Summary(value=[
                        tf.Summary.Value(tag='accuracy/val_p',
                                         simple_value=global_p_accuracy)
                    ])
                    if FLAGS.save_summary:
                        summary_writer.add_summary(summary, global_step=step)
                    summary = tf.Summary(value=[
                        tf.Summary.Value(tag='accuracy/val_s',
                                         simple_value=global_s_accuracy)
                    ])
                    if FLAGS.save_summary:
                        summary_writer.add_summary(summary, global_step=step)
                    summary = tf.Summary(value=[
                        tf.Summary.Value(tag='accuracy/val_noise',
                                         simple_value=global_n_accuracy)
                    ])
                    if FLAGS.save_summary:
                        summary_writer.add_summary(summary, global_step=step)
                    print(
                        '[End of evaluation for current epoch]\n\nAccuracy for current epoch:%s | total. acc. =%.5f| P. acc. =%.5f| S. acc. =%.5f| '
                        'noise. acc. =%.5f.' %
                        (step, global_accuracy, global_p_accuracy,
                         global_s_accuracy, global_n_accuracy))
                    print('Sleeping for {}s'.format(FLAGS.eval_interval))
                    time.sleep(FLAGS.eval_interval)
                summary_writer.flush()
            finally:
                # When done, ask the threads to stop.
                coord.request_stop()
            tf.reset_default_graph()
            #print('Sleeping for {}s'.format(FLAGS.eval_interval))
            #time.sleep(FLAGS.eval_interval)
        finally:
            print('joining data threads')

            coord = tf.train.Coordinator()
            coord.request_stop()

    #pred_labels = pred_labels[1::]
    #true_labels = true_labels[1::]
    #print  ("---Confusion Matrix----")
    #print (confusion_matrix(true_labels, pred_labels))
    # Wait for threads to finish.
    coord.join(threads)
    sess.close()