コード例 #1
0
def main(flags):
    IMG_MEAN = np.zeros(3)
    image_std = [1.0, 1.0, 1.0]
    # parameters of building data set
    citylist = [
        'Norfolk', 'Arlington', 'Atlanta', 'Austin', 'Seekonk', 'NewHaven'
    ]
    image_mean_list = {
        'Norfolk': [127.07435926, 129.40160709, 128.28713284],
        'Arlington': [88.30304996, 94.97338776, 93.21268212],
        'Atlanta': [101.997014375, 108.42171833, 110.044871],
        'Austin': [97.0896012682, 102.94697026, 100.7540157],
        'Seekonk': [86.67800904, 93.31221168, 92.1328146],
        'NewHaven': [106.7092798, 111.4314, 110.74903832]
    }  # BGR mean for the training data for each city

    # set training data
    if flags.training_data == 'SP':
        IMG_MEAN = np.array(
            (121.68045527, 132.14961763, 129.30317439),
            dtype=np.float32)  # mean of solar panel data in BGR order

    elif flags.training_data in citylist:
        print("Training on {} data".format(flags.training_data))
        IMG_MEAN = image_mean_list[flags.training_data]
        # if flags.unit_std:
        #     image_std = image_std_list[flags.training_data]
    elif 'all_but' in flags.training_data:
        print("Training on all(excludes Seekonk) but {} data".format(
            flags.training_data))
        except_city_name = flags.training_data.split('_')[2]
        for cityname in citylist:
            if cityname != except_city_name and cityname != 'Seekonk':
                IMG_MEAN = IMG_MEAN + np.array(image_mean_list[cityname])
        IMG_MEAN = IMG_MEAN / 4

    elif flags.training_data == 'all':
        print("Training on data of all cities (excludes Seekonk)")
        for cityname in citylist:
            if cityname != 'Seekonk':
                IMG_MEAN = IMG_MEAN + np.array(image_mean_list[cityname])
        IMG_MEAN = IMG_MEAN / 5
    else:
        print("Wrong data option: {}".format(flags.data_option))

    # setup used GPU
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = flags.GPU

    # presetting
    input_size = (128, 128)
    tf.set_random_seed(1234)
    coord = tf.train.Coordinator()
    # img_mean = [127.07435926, 129.40160709, 128.28713284]
    with tf.name_scope("training_inputs"):
        training_reader = ImageReader(flags.training_data_list,
                                      input_size,
                                      random_scale=True,
                                      random_mirror=True,
                                      random_rotate=True,
                                      ignore_label=255,
                                      img_mean=IMG_MEAN,
                                      coord=coord)
    with tf.name_scope("validation_inputs"):
        validation_reader = ImageReader(
            flags.validation_data_list,
            input_size,
            random_scale=False,
            random_mirror=False,
            random_rotate=False,
            ignore_label=255,
            img_mean=IMG_MEAN,
            coord=coord,
        )
    X_batch_op, y_batch_op = training_reader.shuffle_dequeue(flags.batch_size)
    X_test_op, y_test_op = validation_reader.shuffle_dequeue(flags.batch_size *
                                                             2)

    train = pd.read_csv(flags.training_data_list, header=0)
    n_train = train.shape[0] + 1

    test = pd.read_csv(flags.validation_data_list, header=0)
    n_test = test.shape[0] + 1

    current_time = time.strftime("%m_%d/%H_%M")

    # tf.reset_default_graph()
    X = tf.placeholder(tf.float32, shape=[None, 128, 128, 3], name="X")
    y = tf.placeholder(tf.float32, shape=[None, 128, 128, 1], name="y")
    mode = tf.placeholder(tf.bool, name="mode")

    pred_raw = make_unet(X, mode)
    pred = tf.nn.sigmoid(pred_raw)
    tf.add_to_collection("inputs", X)
    tf.add_to_collection("inputs", mode)
    tf.add_to_collection("outputs", pred)

    tf.summary.histogram("Predicted Mask", pred)
    # tf.summary.image("Predicted Mask", pred)

    global_step = tf.Variable(0,
                              dtype=tf.int64,
                              trainable=False,
                              name='global_step')

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    learning_rate = tf.train.exponential_decay(
        flags.learning_rate,
        global_step,
        tf.cast(n_train / flags.batch_size * flags.decay_step, tf.int32),
        flags.decay_rate,
        staircase=True)

    IOU_op = IOU_(pred, y)
    cross_entropy = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=pred))
    tf.summary.scalar("loss/IOU_training", IOU_op)
    tf.summary.scalar("loss/cross_entropy_training", cross_entropy)

    learning_rate_summary = tf.summary.scalar(
        "learning_rate", learning_rate)  # summary recording learning rate

    #loss = cross_entropy
    if flags.is_loss_entropy:
        loss = cross_entropy
    else:
        loss = -IOU_op

    with tf.control_dependencies(update_ops):
        train_op = make_train_op(loss, global_step, learning_rate)
        # train_op = make_train_op(cross_entropy, global_step, learning_rate)

    summary_op = tf.summary.merge_all()

    valid_IoU = tf.placeholder(tf.float32, [])
    valid_IoU_summary_op = tf.summary.scalar("loss/IoU_validation", valid_IoU)
    valid_cross_entropy = tf.placeholder(tf.float32, [])
    valid_cross_entropy_summary_op = tf.summary.scalar(
        "loss/cross_entropy_validation", valid_cross_entropy)

    # original images for summary
    train_images = tf.placeholder(tf.uint8,
                                  shape=[None, 128, 128 * 3, 3],
                                  name="training_images")
    train_image_summary_op = tf.summary.image("Training_images_summary",
                                              train_images,
                                              max_outputs=10)
    valid_images = tf.placeholder(tf.uint8,
                                  shape=[None, 128, 128 * 3, 3],
                                  name="validation_images")
    valid_image_summary_op = tf.summary.image("Validation_images_summary",
                                              valid_images,
                                              max_outputs=10)

    # Set up TF session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:

        init = tf.global_variables_initializer()
        sess.run(init)

        saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=1)

        if os.path.exists(flags.ckdir) and tf.train.get_checkpoint_state(
                flags.ckdir):
            latest_check_point = tf.train.latest_checkpoint(flags.ckdir)
            saver.restore(sess, latest_check_point)

        # elif not os.path.exists(flags.ckdir):
        #     # try:
        #     #     os.rmdir(flags.ckdir)
        #     # except FileNotFoundError:
        #     #     pass
        #     os.mkdir(flags.ckdir)

        try:
            train_summary_writer = tf.summary.FileWriter(
                flags.ckdir, sess.graph)

            threads = tf.train.start_queue_runners(coord=coord, sess=sess)

            for epoch in range(flags.epochs):

                for step in range(0, n_train, flags.batch_size):

                    start_time = time.time()
                    X_batch, y_batch = sess.run([X_batch_op, y_batch_op])

                    _, global_step_value = sess.run([train_op, global_step],
                                                    feed_dict={
                                                        X: X_batch,
                                                        y: y_batch,
                                                        mode: True
                                                    })
                    if global_step_value % 100 == 0:
                        duration = time.time() - start_time
                        pred_train, step_iou, step_cross_entropy, step_summary, = sess.run(
                            [pred, IOU_op, cross_entropy, summary_op],
                            feed_dict={
                                X: X_batch,
                                y: y_batch,
                                mode: False
                            })
                        train_summary_writer.add_summary(
                            step_summary, global_step_value)

                        print(
                            'Epoch {:d} step {:d} \t cross entropy = {:.3f}, IOU = {:.3f} ({:.3f} sec/step)'
                            .format(epoch, global_step_value,
                                    step_cross_entropy, step_iou, duration))

                # validation every epoch
                    if global_step_value % 1000 == 0:
                        segmetric = SegMetric(1)
                        # for step in range(0, n_test, flags.batch_size):
                        X_test, y_test = sess.run([X_test_op, y_test_op])
                        pred_valid, valid_cross_entropy_value = sess.run(
                            [pred, cross_entropy],
                            feed_dict={
                                X: X_test,
                                y: y_test,
                                mode: False
                            })
                        iou_temp = myIOU(y_pred=pred_valid > 0.5,
                                         y_true=y_test,
                                         segmetric=segmetric)
                        print("Test IoU: {}  Cross_Entropy: {}".format(
                            segmetric.mean_IU(), valid_cross_entropy_value))

                        valid_IoU_summary = sess.run(
                            valid_IoU_summary_op,
                            feed_dict={valid_IoU: iou_temp})
                        train_summary_writer.add_summary(
                            valid_IoU_summary, global_step_value)
                        valid_cross_entropy_summary = sess.run(
                            valid_cross_entropy_summary_op,
                            feed_dict={
                                valid_cross_entropy: valid_cross_entropy_value
                            })
                        train_summary_writer.add_summary(
                            valid_cross_entropy_summary, global_step_value)

                        train_image_summary = sess.run(
                            train_image_summary_op,
                            feed_dict={
                                train_images:
                                image_summary(X_batch,
                                              y_batch,
                                              pred_train > 0.5,
                                              IMG_MEAN,
                                              num_classes=flags.num_classes)
                            })
                        train_summary_writer.add_summary(
                            train_image_summary, global_step_value)
                        valid_image_summary = sess.run(
                            valid_image_summary_op,
                            feed_dict={
                                valid_images:
                                image_summary(X_test,
                                              y_test,
                                              pred_valid > 0.5,
                                              IMG_MEAN,
                                              num_classes=flags.num_classes)
                            })
                        train_summary_writer.add_summary(
                            valid_image_summary, global_step_value)
                    # total_iou += step_iou * X_test.shape[0]
                    #
                    # test_summary_writer.add_summary(step_summary, (epoch + 1) * (step + 1))

                saver.save(sess,
                           "{}/model.ckpt".format(flags.ckdir),
                           global_step=global_step)

        finally:
            coord.request_stop()
            coord.join(threads)
            saver.save(sess,
                       "{}/model.ckpt".format(flags.ckdir),
                       global_step=global_step)