Beispiel #1
0
def main(argv=None):
    # 1. input placeholders
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=(None, IMAGE_SIZE, IMAGE_SIZE, 3),
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=(None, IMAGE_SIZE, IMAGE_SIZE, 1),
                                name="annotation")
    # global_step = tf.Variable(0, trainable=False, name='global_step')

    image075 = tf.image.resize_images(
        image, [int(IMAGE_SIZE * 0.75),
                int(IMAGE_SIZE * 0.75)])
    image050 = tf.image.resize_images(
        image,
        [int(IMAGE_SIZE * 0.5), int(IMAGE_SIZE * 0.5)])
    image125 = tf.image.resize_images(
        image, [int(IMAGE_SIZE * 1.25),
                int(IMAGE_SIZE * 1.25)])

    annotation075 = tf.cast(
        tf.image.resize_images(
            annotation, [int(IMAGE_SIZE * 0.75),
                         int(IMAGE_SIZE * 0.75)]), tf.int32)
    annotation050 = tf.cast(
        tf.image.resize_images(annotation,
                               [int(IMAGE_SIZE * 0.5),
                                int(IMAGE_SIZE * 0.5)]), tf.int32)
    annotation125 = tf.cast(
        tf.image.resize_images(
            annotation, [int(IMAGE_SIZE * 1.25),
                         int(IMAGE_SIZE * 1.25)]), tf.int32)

    # 2. construct inference network
    reuse1 = False
    reuse2 = True

    with tf.variable_scope('', reuse=reuse1):
        pred_annotation100, logits100, net100, att100 = unetinference(
            image, keep_probability)
    with tf.variable_scope('', reuse=reuse2):
        pred_annotation075, logits075, net075, att075 = unetinference(
            image075, keep_probability)
    with tf.variable_scope('', reuse=reuse2):
        pred_annotation050, logits050, net050, att050 = unetinference(
            image050, keep_probability)
    with tf.variable_scope('', reuse=reuse2):
        pred_annotation125, logits125, net125, att125 = unetinference(
            image125, keep_probability)

    # Attention model
    msc_trainable_var = None

    if FLAGS.mode == "train":
        # apply attention model - train
        attn_input = []
        attn_input.append(att100)
        attn_input.append(
            tf.image.resize_images(att075,
                                   tf.shape(att100)[1:3, ]))
        attn_input.append(
            tf.image.resize_images(att050,
                                   tf.shape(att100)[1:3, ]))
        attn_input_train = tf.concat(attn_input, axis=3)

        with tf.variable_scope('attention'):
            attn_output_train = attention(attn_input_train, is_training=True)
            attention_scales_weights = tf.nn.softmax(
                attn_output_train)  # Add axis?

        score_att_x_100 = tf.multiply(
            logits100,
            tf.image.resize_images(
                tf.expand_dims(attention_scales_weights[:, :, :, 0], axis=3),
                tf.shape(logits100)[1:3, ]))
        score_att_x_075 = tf.multiply(
            tf.image.resize_images(logits075,
                                   tf.shape(logits100)[1:3, ]),
            tf.image.resize_images(
                tf.expand_dims(attention_scales_weights[:, :, :, 1], axis=3),
                tf.shape(logits100)[1:3, ]))
        score_att_x_050 = tf.multiply(
            tf.image.resize_images(logits050,
                                   tf.shape(logits100)[1:3, ]),
            tf.image.resize_images(
                tf.expand_dims(attention_scales_weights[:, :, :, 2], axis=3),
                tf.shape(logits100)[1:3, ]))
        logits_train_attention = score_att_x_100 + score_att_x_075 + score_att_x_050

        pred_annotation_train_attention = tf.expand_dims(tf.argmax(
            logits_train_attention, dimension=3, name="final_prediction"),
                                                         dim=3)

        # 3. loss measure
        loss100 = tf.reduce_mean(
            (tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits100,
                labels=tf.squeeze(annotation, squeeze_dims=[3]),
                name="entropy")))
        tf.summary.scalar("entropy", loss100)

        loss075 = tf.reduce_mean(
            (tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits075,
                labels=tf.squeeze(annotation075, squeeze_dims=[3]),
                name="entropy")))
        tf.summary.scalar("entropy", loss075)

        loss050 = tf.reduce_mean(
            (tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits050,
                labels=tf.squeeze(annotation050, squeeze_dims=[3]),
                name="entropy")))
        tf.summary.scalar("entropy", loss050)

        with_attention_loss = tf.reduce_mean(
            (tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits_train_attention,
                labels=tf.squeeze(annotation, squeeze_dims=[3]),
                name="entropy")))
        tf.summary.scalar("entropy", with_attention_loss)

        attention_combined_loss = with_attention_loss + loss100 + loss075 + loss050

        # 4. optimizing
        msc_trainable_var = tf.trainable_variables('inference')
        # attention_trainable_var = tf.trainable_variables('attention')
        attention_trainable_var = tf.trainable_variables()
        if FLAGS.debug:
            for var in msc_trainable_var:
                Utils.add_to_regularization_and_summary(var)

        attention_train_op = train(attention_combined_loss,
                                   attention_trainable_var,
                                   net100['global_step'])

    else:
        # apply attention model - test
        attn_input = []
        attn_input.append(att100)
        attn_input.append(
            tf.image.resize_images(att075,
                                   tf.shape(att100)[1:3, ]))
        attn_input.append(
            tf.image.resize_images(att125,
                                   tf.shape(att100)[1:3, ]))
        attn_input_test = tf.concat(attn_input, axis=3)

        with tf.variable_scope('attention'):
            attn_output_test = attention(attn_input_test, is_training=False)
            attention_scales_weights = tf.nn.softmax(
                attn_output_test)  # Add axis?

        score_att_x_100 = tf.multiply(
            logits100,
            tf.image.resize_images(
                tf.expand_dims(attention_scales_weights[:, :, :, 0], axis=3),
                tf.shape(logits100)[1:3, ]))
        score_att_x_075 = tf.multiply(
            tf.image.resize_images(logits075,
                                   tf.shape(logits100)[1:3, ]),
            tf.image.resize_images(
                tf.expand_dims(attention_scales_weights[:, :, :, 1], axis=3),
                tf.shape(logits100)[1:3, ]))
        score_att_x_125 = tf.multiply(
            tf.image.resize_images(logits125,
                                   tf.shape(logits100)[1:3, ]),
            tf.image.resize_images(
                tf.expand_dims(attention_scales_weights[:, :, :, 2], axis=3),
                tf.shape(logits100)[1:3, ]))

        logits_test_attention = score_att_x_100 + score_att_x_075 + score_att_x_125

        pred_annotation_test_attention = tf.expand_dims(tf.argmax(
            logits_test_attention, dimension=3, name="final_prediction"),
                                                        dim=3)

    tf.summary.image("input_image", image, max_outputs=3)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=3)

    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation100, tf.uint8),
                     max_outputs=3)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    print("Setting up image reader from ", FLAGS.data_dir, "...")
    print("data dir:", FLAGS.data_dir)

    train_records, valid_records = fashion_parsing.read_dataset(FLAGS.data_dir)
    test_records = None
    if DATA_SET == "CFPD":
        train_records, valid_records, test_records = ClothingParsing.read_dataset(
            FLAGS.data_dir)
        print("test_records length :", len(test_records))
    if DATA_SET == "LIP":
        train_records, valid_records = HumanParsing.read_dataset(
            FLAGS.data_dir)

    print("train_records length :", len(train_records))
    print("valid_records length :", len(valid_records))

    print("Setting up dataset reader")
    train_dataset_reader = None
    validation_dataset_reader = None
    test_dataset_reader = None
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}

    if FLAGS.mode == 'train':
        train_dataset_reader = DataSetReader.BatchDatset(
            train_records, image_options)
        validation_dataset_reader = DataSetReader.BatchDatset(
            valid_records, image_options)
        if DATA_SET == "CFPD":
            test_dataset_reader = DataSetReader.BatchDatset(
                test_records, image_options)
    if FLAGS.mode == 'visualize':
        validation_dataset_reader = DataSetReader.BatchDatset(
            valid_records, image_options)
    if FLAGS.mode == 'test' or FLAGS.mode == 'crftest' or FLAGS.mode == 'predonly' or FLAGS.mode == "fulltest":
        if DATA_SET == "CFPD":
            test_dataset_reader = DataSetReader.BatchDatset(
                test_records, image_options)
        else:
            test_dataset_reader = DataSetReader.BatchDatset(
                valid_records, image_options)
            test_records = valid_records

    sess = tf.Session()

    print("Setting up Saver...")
    saver_attn = tf.train.Saver()
    saver_msc = tf.train.Saver(msc_trainable_var)
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    # 5. parameter setup
    # 5.1 init params
    sess.run(tf.global_variables_initializer())

    # 5.2 restore params if possible
    ckpt_attn = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    ckpt_msc = tf.train.get_checkpoint_state(FLAGS.model_dir)

    if ckpt_attn and ckpt_attn.model_checkpoint_path:
        saver_attn.restore(sess, ckpt_attn.model_checkpoint_path)
        print("Attention Model restored...")
    else:
        saver_msc.restore(sess, ckpt_msc.model_checkpoint_path)
        print("Multi-scales Model restored...")

    # 6. train-mode
    if FLAGS.mode == "train":

        fd.mode_train(sess, FLAGS, net100, train_dataset_reader,
                      validation_dataset_reader, train_records,
                      pred_annotation_train_attention, image, annotation,
                      keep_probability, logits_train_attention,
                      attention_train_op, attention_combined_loss, summary_op,
                      summary_writer, saver_attn, DISPLAY_STEP)

    # test-random-validation-data mode
    elif FLAGS.mode == "visualize":

        # fd.mode_visualize(sess, FLAGS, VIS_DIR, validation_dataset_reader, pred_annotation_test_attention, image, annotation, keep_probability, NUM_OF_CLASSES)
        fd.mode_visualize_attention(
            sess, FLAGS, VIS_DIR, validation_dataset_reader,
            pred_annotation100, pred_annotation075, pred_annotation125,
            score_att_x_100, score_att_x_075, score_att_x_125,
            attention_scales_weights, pred_annotation_test_attention, image,
            annotation, keep_probability, NUM_OF_CLASSES)

    # test-full-validation-dataset mode
    elif FLAGS.mode == "test":

        fd.mode_new_test(sess, FLAGS, TEST_DIR, test_dataset_reader,
                         test_records, pred_annotation_test_attention, image,
                         annotation, keep_probability, logits_test_attention,
                         NUM_OF_CLASSES)

    sess.close()
Beispiel #2
0
def main(argv=None):
    # 1. input placeholders
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=(None, IMAGE_SIZE, IMAGE_SIZE, 3),
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=(None, IMAGE_SIZE, IMAGE_SIZE, 1),
                                name="annotation")
    # global_step = tf.Variable(0, trainable=False, name='global_step')

    image075 = tf.image.resize_images(
        image, [int(IMAGE_SIZE * 0.75),
                int(IMAGE_SIZE * 0.75)])
    image050 = tf.image.resize_images(
        image,
        [int(IMAGE_SIZE * 0.5), int(IMAGE_SIZE * 0.5)])
    image125 = tf.image.resize_images(
        image, [int(IMAGE_SIZE * 1.25),
                int(IMAGE_SIZE * 1.25)])

    annotation075 = tf.cast(
        tf.image.resize_images(
            annotation, [int(IMAGE_SIZE * 0.75),
                         int(IMAGE_SIZE * 0.75)]), tf.int32)
    annotation050 = tf.cast(
        tf.image.resize_images(annotation,
                               [int(IMAGE_SIZE * 0.5),
                                int(IMAGE_SIZE * 0.5)]), tf.int32)
    annotation125 = tf.cast(
        tf.image.resize_images(
            annotation, [int(IMAGE_SIZE * 1.25),
                         int(IMAGE_SIZE * 1.25)]), tf.int32)

    # 2. construct inference network
    reuse1 = False
    reuse2 = True

    with tf.variable_scope('', reuse=reuse1):
        pred_annotation100, logits100, net100 = unetinference(
            image, keep_probability)
    with tf.variable_scope('', reuse=reuse2):
        pred_annotation075, logits075, net075 = unetinference(
            image075, keep_probability)
    with tf.variable_scope('', reuse=reuse2):
        pred_annotation050, logits050, net050 = unetinference(
            image050, keep_probability)
    with tf.variable_scope('', reuse=reuse2):
        pred_annotation125, logits125, net125 = unetinference(
            image125, keep_probability)

    logits = tf.reduce_mean(tf.stack([
        logits100,
        tf.image.resize_images(logits075,
                               tf.shape(logits100)[1:3, ]),
        tf.image.resize_images(logits050,
                               tf.shape(logits100)[1:3, ])
    ]),
                            axis=0)

    pred_annotation = tf.reduce_mean(tf.stack([
        tf.cast(pred_annotation100, tf.float32),
        tf.image.resize_images(pred_annotation075,
                               tf.shape(pred_annotation100)[1:3, ]),
        tf.image.resize_images(pred_annotation050,
                               tf.shape(pred_annotation100)[1:3, ])
    ]),
                                     axis=0)

    pred_annotation_pred = tf.reduce_mean(tf.stack([
        tf.cast(pred_annotation100, tf.float32),
        tf.image.resize_images(pred_annotation075,
                               tf.shape(pred_annotation100)[1:3, ]),
        tf.image.resize_images(pred_annotation125,
                               tf.shape(pred_annotation100)[1:3, ])
    ]),
                                          axis=0)

    logits_pred = tf.reduce_mean(tf.stack([
        logits100,
        tf.image.resize_images(logits075,
                               tf.shape(logits100)[1:3, ]),
        tf.image.resize_images(logits125,
                               tf.shape(logits100)[1:3, ])
    ]),
                                 axis=0)

    tf.summary.image("input_image", image, max_outputs=3)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=3)

    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation100, tf.uint8),
                     max_outputs=3)

    # 3. loss measure
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss)

    loss100 = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits100,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss100)

    loss075 = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits075,
        labels=tf.squeeze(annotation075, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss075)

    loss050 = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits050,
        labels=tf.squeeze(annotation050, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss050)

    reduced_loss = loss + loss100 + loss075 + loss050

    # 4. optimizing
    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            Utils.add_to_regularization_and_summary(var)

    train_op = train(reduced_loss, trainable_var, net100['global_step'])

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    print("Setting up image reader from ", FLAGS.data_dir, "...")
    print("data dir:", FLAGS.data_dir)

    train_records, valid_records = fashion_parsing.read_dataset(FLAGS.data_dir)
    test_records = None
    if DATA_SET == "CFPD":
        train_records, valid_records, test_records = ClothingParsing.read_dataset(
            FLAGS.data_dir)
        print("test_records length :", len(test_records))
    if DATA_SET == "LIP":
        train_records, valid_records = HumanParsing.read_dataset(
            FLAGS.data_dir)

    print("train_records length :", len(train_records))
    print("valid_records length :", len(valid_records))

    print("Setting up dataset reader")
    train_dataset_reader = None
    validation_dataset_reader = None
    test_dataset_reader = None
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}

    if FLAGS.mode == 'train':
        train_dataset_reader = DataSetReader.BatchDatset(
            train_records, image_options)
        validation_dataset_reader = DataSetReader.BatchDatset(
            valid_records, image_options)
        if DATA_SET == "CFPD":
            test_dataset_reader = DataSetReader.BatchDatset(
                test_records, image_options)
    if FLAGS.mode == 'visualize':
        validation_dataset_reader = DataSetReader.BatchDatset(
            valid_records, image_options)
    if FLAGS.mode == 'test' or FLAGS.mode == 'crftest' or FLAGS.mode == 'predonly' or FLAGS.mode == "fulltest":
        if DATA_SET == "CFPD":
            test_dataset_reader = DataSetReader.BatchDatset(
                test_records, image_options)
        else:
            test_dataset_reader = DataSetReader.BatchDatset(
                valid_records, image_options)
            test_records = valid_records

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    # 5. parameter setup
    # 5.1 init params
    sess.run(tf.global_variables_initializer())
    # 5.2 restore params if possible
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    # 6. train-mode
    if FLAGS.mode == "train":

        fd.mode_train(sess, FLAGS, net100, train_dataset_reader,
                      validation_dataset_reader, train_records,
                      pred_annotation, image, annotation, keep_probability,
                      logits, train_op, reduced_loss, summary_op,
                      summary_writer, saver, DISPLAY_STEP)

    # test-random-validation-data mode
    elif FLAGS.mode == "visualize":

        fd.mode_visualize(sess, FLAGS, VIS_DIR, validation_dataset_reader,
                          pred_annotation_pred, image, annotation,
                          keep_probability, NUM_OF_CLASSES)

    # test-full-validation-dataset mode
    elif FLAGS.mode == "test":

        fd.mode_new_test(sess, FLAGS, TEST_DIR, test_dataset_reader,
                         test_records, pred_annotation_pred, image, annotation,
                         keep_probability, logits_pred, NUM_OF_CLASSES)

        # fd.mode_test(sess, FLAGS, TEST_DIR, test_dataset_reader, test_records,
        # pred_annotation, image, annotation, keep_probability, logits, NUM_OF_CLASSES)

    elif FLAGS.mode == "crftest":

        fd.mode_predonly(sess, FLAGS, TEST_DIR, test_dataset_reader,
                         test_records, pred_annotation_pred, image, annotation,
                         keep_probability, logits_pred, NUM_OF_CLASSES)

    elif FLAGS.mode == "predonly":

        fd.mode_predonly(sess, FLAGS, TEST_DIR, test_dataset_reader,
                         test_records, pred_annotation_pred, image, annotation,
                         keep_probability, logits_pred, NUM_OF_CLASSES)

    elif FLAGS.mode == "fulltest":

        fd.mode_full_test(sess, FLAGS, TEST_DIR, test_dataset_reader,
                          test_records, pred_annotation_pred, image,
                          annotation, keep_probability, logits_pred,
                          NUM_OF_CLASSES)

    sess.close()
Beispiel #3
0
def main(argv=None):
    # 1. input placeholders
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(tf.float32,
                           shape=(None, IMAGE_SIZE, IMAGE_SIZE, 3),
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=(None, IMAGE_SIZE, IMAGE_SIZE, 1),
                                name="annotation")
    # global_step = tf.Variable(0, trainable=False, name='global_step')

    # 2. construct inference network
    pred_annotation, logits, net = unetinference(image, keep_probability)
    tf.summary.image("input_image", image, max_outputs=3)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=3)

    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation, tf.uint8),
                     max_outputs=3)

    # 3. loss measure
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss)

    # 4. optimizing
    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            utils.add_to_regularization_and_summary(var)

    train_op = train(loss, trainable_var, net['global_step'])

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    print("Setting up image reader from ", FLAGS.data_dir, "...")
    #train_records, valid_records = scene_parsing.read_dataset(FLAGS.data_dir)
    train_records, valid_records, test_records = fashion_parsing.read_dataset(
        FLAGS.data_dir)
    print("data dir:", FLAGS.data_dir)
    print("train_records length :", len(train_records))
    print("valid_records length :", len(valid_records))
    print("test_records length :", len(test_records))

    print("Setting up dataset reader")
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
    if FLAGS.mode == 'train':
        train_dataset_reader = dataset.BatchDatset(train_records,
                                                   image_options)
        validation_dataset_reader = dataset.BatchDatset(
            valid_records, image_options)
        test_dataset_reader = dataset.BatchDatset(test_records, image_options)
    if FLAGS.mode == 'visualize':
        validation_dataset_reader = dataset.BatchDatset(
            valid_records, image_options)
    if FLAGS.mode == 'test':
        test_dataset_reader = dataset.BatchDatset(test_records, image_options)

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    # 5. paramter setup
    # 5.1 init params
    sess.run(tf.global_variables_initializer())
    # 5.2 restore params if possible
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    # 6. train-mode
    if FLAGS.mode == "train":

        fd.mode_train(sess, FLAGS, net, train_dataset_reader,
                      validation_dataset_reader, train_records,
                      pred_annotation, image, annotation, keep_probability,
                      logits, train_op, loss, summary_op, summary_writer,
                      saver, DISPLAY_STEP)

        fd.mode_test(sess, FLAGS, TEST_DIR, test_dataset_reader, test_records,
                     pred_annotation, image, annotation, keep_probability,
                     logits, NUM_OF_CLASSES)

    # test-random-validation-data mode
    elif FLAGS.mode == "visualize":

        fd.mode_visualize(sess, FLAGS, VIS_DIR, validation_dataset_reader,
                          pred_annotation, image, annotation, keep_probability,
                          NUM_OF_CLASSES)

    # test-full-validation-dataset mode
    elif FLAGS.mode == "test":  # heejune added

        fd.mode_test(sess, FLAGS, TEST_DIR, test_dataset_reader, test_records,
                     pred_annotation, image, annotation, keep_probability,
                     logits, NUM_OF_CLASSES)

    sess.close()
def main(argv=None):
    # 1. input placeholders
    keep_probability = tf.placeholder(tf.float32, name="keep_probability")
    image = tf.placeholder(tf.float32,
                           shape=(None, IMAGE_SIZE, IMAGE_SIZE, 3),
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=(None, IMAGE_SIZE, IMAGE_SIZE, 1),
                                name="annotation")
    # global_step = tf.Variable(0, trainable=False, name='global_step')

    # 2. construct inference network
    # 预测一个batch图像  获得预测图[b,h,w,c=1]  结果特征图[b,h,w,c=151]
    pred_annotation, logits, net = inference(image, keep_probability)
    tf.summary.image("input_image", image, max_outputs=3)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=3)
    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation, tf.uint8),
                     max_outputs=3)

    # 3. loss measure       # 空间交叉熵损失函数[b,h,w,c=151]  和labels[b,h,w]    每一张图分别对比
    loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits,
        labels=tf.squeeze(annotation, squeeze_dims=[3]),
        name="entropy")))
    tf.summary.scalar("entropy", loss)

    # 4. optimizing
    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            Utils.add_to_regularization_and_summary(var)
    train_op = train(loss, trainable_var, net['global_step'])

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    print("Setting up image reader from ", FLAGS.data_dir, "...")
    print("data dir:", FLAGS.data_dir)

    train_records, valid_records, test_records, logs_records = fashion_parsing.read_dataset(
        FLAGS.data_dir)

    if DATA_SET == "CFPD":
        train_records, valid_records, test_records, logs_records = ClothingParsing.read_dataset(
            FLAGS.data_dir)
    if DATA_SET == "LIP":
        train_records, valid_records = HumanParsing.read_dataset(
            FLAGS.data_dir)
    print("test_records length :", len(test_records))
    print("train_records length :", len(train_records))
    print("valid_records length :", len(valid_records))
    print("logs_records length :", len(logs_records))
    print("Setting up dataset reader")
    train_dataset_reader = None
    validation_dataset_reader = None
    test_dataset_reader = None
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}

    if FLAGS.mode == 'train':
        train_dataset_reader = DataSetReader.BatchDatset(
            "train", train_records, image_options)
        validation_dataset_reader = DataSetReader.BatchDatset(
            "val", valid_records, image_options)
        logs_dataset_reader = DataSetReader.BatchDatset(
            "logs", logs_records, image_options)

    if FLAGS.mode == 'visualize':
        validation_dataset_reader = DataSetReader.BatchDatset(
            valid_records, image_options)
    if FLAGS.mode == 'test' or FLAGS.mode == 'crftest' or FLAGS.mode == 'predonly' or FLAGS.mode == "fulltest":
        if DATA_SET == "CFPD":
            test_dataset_reader = DataSetReader.BatchDatset(
                "test", test_records, image_options)
        else:
            test_dataset_reader = DataSetReader.BatchDatset(
                "test", valid_records, image_options)
            #test_records = valid_records

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver(max_to_keep=4)
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    # 5. parameter setup

    # 5.1 init params
    sess.run(tf.global_variables_initializer())

    # 5.2 restore params if possible
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    # 6. train-mode
    if FLAGS.mode == "train":

        fd.mode_train(sess, FLAGS, net, train_dataset_reader,
                      validation_dataset_reader, logs_dataset_reader,
                      train_records, logs_records, pred_annotation, image,
                      annotation, keep_probability, logits, train_op, loss,
                      summary_op, summary_writer, saver, DISPLAY_STEP)

    elif FLAGS.mode == "test":  # heejune added

        fd.mode_new_test(sess, FLAGS, TEST_DIR, test_dataset_reader,
                         test_records, pred_annotation, image, annotation,
                         keep_probability, logits, NUM_OF_CLASSES)

        # fd.mode_test(sess, FLAGS, TEST_DIR, test_dataset_reader, test_records,
        # pred_annotation, image, annotation, keep_probability, logits, NUM_OF_CLASSES)

    elif FLAGS.mode == "view":
        train_dataset_reader = DataSetReader.BatchDatset(
            "train", train_records, image_options)
        logs_dataset_reader = DataSetReader.BatchDatset(
            "logs", logs_records, image_options)
        fd.mode_view(sess, FLAGS, "./VisImage/", train_dataset_reader,
                     logs_dataset_reader, logs_records, pred_annotation, image,
                     annotation, keep_probability, logits, 23)
    '''
    elif FLAGS.mode == "crftest":

        fd.mode_predonly(sess, FLAGS, TEST_DIR, test_dataset_reader, test_records,
                         pred_annotation, image, annotation, keep_probability, logits, NUM_OF_CLASSES)

    elif FLAGS.mode == "predonly":

        fd.mode_predonly(sess, FLAGS, TEST_DIR, test_dataset_reader, test_records,
                         pred_annotation, image, annotation, keep_probability, logits, NUM_OF_CLASSES)

    elif FLAGS.mode == "fulltest":

        fd.mode_full_test(sess, FLAGS, TEST_DIR, test_dataset_reader, test_records,
                          pred_annotation, image, annotation, keep_probability, logits, NUM_OF_CLASSES)
    '''
    sess.close()