Example #1
0
def main(argv=None):
    # 1. input placeholders
    keep_probability = tf.placeholder(tf.float32, name="keep_probabilty")
    image = tf.placeholder(
        tf.float32,
        shape=(
            None,
            IMAGE_SIZE,
            IMAGE_SIZE,
            3),
        name="input_image")
    annotation = tf.placeholder(
        tf.int32,
        shape=(
            None,
            IMAGE_SIZE,
            IMAGE_SIZE,
            1),
        name="annotation")
    # global_step = tf.Variable(0, trainable=False, name='global_step')

    # 2. construct inference network
    pred_annotation, logits, net = unetinference(image, keep_probability)
    tf.summary.image("input_image", image, max_outputs=3)
    tf.summary.image(
        "ground_truth",
        tf.cast(
            annotation,
            tf.uint8),
        max_outputs=3)

    tf.summary.image(
        "pred_annotation",
        tf.cast(
            pred_annotation,
            tf.uint8),
        max_outputs=3)

    # 3. loss measure
    loss = tf.reduce_mean(
        (tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits,
            labels=tf.squeeze(
                annotation,
                squeeze_dims=[3]),
            name="entropy")))
    tf.summary.scalar("entropy", loss)

    # 4. optimizing
    trainable_var = tf.trainable_variables()
    if FLAGS.debug:
        for var in trainable_var:
            Utils.add_to_regularization_and_summary(var)

    train_op = train(loss, trainable_var, net['global_step'])

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    print("Setting up image reader from ", FLAGS.data_dir, "...")
    print("data dir:", FLAGS.data_dir)

    train_records, valid_records = fashion_parsing.read_dataset(FLAGS.data_dir)
    test_records = None
    if DATA_SET == "CFPD":
        train_records, valid_records, test_records = ClothingParsing.read_dataset(
            FLAGS.data_dir)
        print("test_records length :", len(test_records))
    if DATA_SET == "LIP":
        train_records, valid_records = HumanParsing.read_dataset(
            FLAGS.data_dir)

    print("train_records length :", len(train_records))
    print("valid_records length :", len(valid_records))

    print("Setting up dataset reader")
    train_dataset_reader = None
    validation_dataset_reader = None
    test_dataset_reader = None
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}

    if FLAGS.mode == 'train':
        train_dataset_reader = DataSetReader.BatchDatset(
            train_records, image_options)
        validation_dataset_reader = DataSetReader.BatchDatset(
            valid_records, image_options)
        if DATA_SET == "CFPD":
            test_dataset_reader = DataSetReader.BatchDatset(
                test_records, image_options)
    if FLAGS.mode == 'visualize':
        validation_dataset_reader = DataSetReader.BatchDatset(
            valid_records, image_options)
    if FLAGS.mode == 'test' or FLAGS.mode == 'crftest' or FLAGS.mode == 'predonly' or FLAGS.mode == "fulltest":
        if DATA_SET == "CFPD":
            test_dataset_reader = DataSetReader.BatchDatset(
                test_records, image_options)
        else:
            test_dataset_reader = DataSetReader.BatchDatset(
                valid_records, image_options)
            test_records = valid_records

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    # 5. parameter setup
    # 5.1 init params
    sess.run(tf.global_variables_initializer())
    # 5.2 restore params if possible
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    # 6. train-mode
    if FLAGS.mode == "train":

        fd.mode_train(sess, FLAGS, net, train_dataset_reader, validation_dataset_reader, train_records,
                      pred_annotation,
                      image, annotation, keep_probability, logits, train_op, loss, summary_op, summary_writer,
                      saver, DISPLAY_STEP)

    # test-random-validation-data mode
    elif FLAGS.mode == "visualize":

        fd.mode_visualize(sess, FLAGS, VIS_DIR, validation_dataset_reader,
                          pred_annotation, image, annotation, keep_probability, NUM_OF_CLASSES)

    # test-full-validation-dataset mode
    elif FLAGS.mode == "test":

        fd.mode_new_test(sess, FLAGS, TEST_DIR, test_dataset_reader, test_records,
                         pred_annotation, image, annotation, keep_probability, logits, NUM_OF_CLASSES)

        # fd.mode_test(sess, FLAGS, TEST_DIR, test_dataset_reader, test_records,
        # pred_annotation, image, annotation, keep_probability, logits, NUM_OF_CLASSES)

    elif FLAGS.mode == "crftest":

        fd.mode_predonly(sess, FLAGS, TEST_DIR, test_dataset_reader, test_records,
                         pred_annotation, image, annotation, keep_probability, logits, NUM_OF_CLASSES)

    elif FLAGS.mode == "predonly":

        fd.mode_predonly(sess, FLAGS, TEST_DIR, test_dataset_reader, test_records,
                         pred_annotation, image, annotation, keep_probability, logits, NUM_OF_CLASSES)

    elif FLAGS.mode == "fulltest":

        fd.mode_full_test(sess, FLAGS, TEST_DIR, test_dataset_reader, test_records,
                          pred_annotation, image, annotation, keep_probability, logits, NUM_OF_CLASSES)

    sess.close()
Example #2
0
def main(argv=None):
    # 1. input placeholders
    keep_probability = tf.placeholder(tf.float32, name="keep_probability")
    image = tf.placeholder(tf.float32,
                           shape=(None, IMAGE_SIZE, IMAGE_SIZE, 3),
                           name="input_image")
    annotation = tf.placeholder(tf.int32,
                                shape=(None, IMAGE_SIZE, IMAGE_SIZE, 1),
                                name="annotation")
    # global_step = tf.Variable(0, trainable=False, name='global_step')
    is_training = False
    if FLAGS.mode == "train":
        is_training = True

    image075 = tf.image.resize_images(
        image, [int(IMAGE_SIZE * 0.75),
                int(IMAGE_SIZE * 0.75)])
    image050 = tf.image.resize_images(
        image,
        [int(IMAGE_SIZE * 0.5), int(IMAGE_SIZE * 0.5)])
    image125 = tf.image.resize_images(
        image, [int(IMAGE_SIZE * 1.25),
                int(IMAGE_SIZE * 1.25)])

    annotation075 = tf.cast(
        tf.image.resize_images(
            annotation, [int(IMAGE_SIZE * 0.75),
                         int(IMAGE_SIZE * 0.75)]), tf.int32)
    annotation050 = tf.cast(
        tf.image.resize_images(annotation,
                               [int(IMAGE_SIZE * 0.5),
                                int(IMAGE_SIZE * 0.5)]), tf.int32)
    annotation125 = tf.cast(
        tf.image.resize_images(
            annotation, [int(IMAGE_SIZE * 1.25),
                         int(IMAGE_SIZE * 1.25)]), tf.int32)

    # 2. construct inference network
    reuse1 = False
    reuse2 = True

    with tf.variable_scope('', reuse=reuse1):
        pred_annotation100, logits100, net100, att100 = unetinference(
            image, keep_probability, is_training=is_training)
    with tf.variable_scope('', reuse=reuse2):
        pred_annotation075, logits075, net075, att075 = unetinference(
            image075, keep_probability, is_training=is_training)
    with tf.variable_scope('', reuse=reuse2):
        pred_annotation050, logits050, net050, att050 = unetinference(
            image050, keep_probability, is_training=is_training)
    with tf.variable_scope('', reuse=reuse2):
        pred_annotation125, logits125, net125, att125 = unetinference(
            image125, keep_probability, is_training=is_training)

    # apply attention model - train
    score_final_train = None
    score_final_test = None
    final_annotation_pred_train = None
    final_annotation_pred_test = None

    train_op = None
    reduced_loss = None

    if FLAGS.mode == "train":
        attn_input = []
        attn_input.append(att100)
        attn_input.append(
            tf.image.resize_images(att075,
                                   tf.shape(att100)[1:3, ]))
        attn_input.append(
            tf.image.resize_images(att050,
                                   tf.shape(att100)[1:3, ]))
        attn_input_train = tf.concat(attn_input, axis=3)
        attn_output_train = attention(attn_input_train, is_training)
        scale_att_mask = tf.nn.softmax(attn_output_train)

        score_att_x = tf.multiply(
            logits100,
            tf.image.resize_images(
                tf.expand_dims(scale_att_mask[:, :, :, 0], axis=3),
                tf.shape(logits100)[1:3, ]))
        score_att_x_075 = tf.multiply(
            tf.image.resize_images(logits075,
                                   tf.shape(logits100)[1:3, ]),
            tf.image.resize_images(
                tf.expand_dims(scale_att_mask[:, :, :, 1], axis=3),
                tf.shape(logits100)[1:3, ]))
        score_att_x_050 = tf.multiply(
            tf.image.resize_images(logits050,
                                   tf.shape(logits100)[1:3, ]),
            tf.image.resize_images(
                tf.expand_dims(scale_att_mask[:, :, :, 2], axis=3),
                tf.shape(logits100)[1:3, ]))
        score_final_train = score_att_x + score_att_x_075 + score_att_x_050

        final_annotation_pred = tf.expand_dims(tf.argmax(
            score_final_train, dimension=3, name="final_prediction"),
                                               dim=3)
        final_annotation_pred_train = tf.reduce_mean(tf.stack([
            tf.cast(final_annotation_pred, tf.float32),
            tf.cast(pred_annotation100, tf.float32),
            tf.image.resize_images(pred_annotation075,
                                   tf.shape(pred_annotation100)[1:3, ]),
            tf.image.resize_images(pred_annotation050,
                                   tf.shape(pred_annotation100)[1:3, ])
        ]),
                                                     axis=0)

        # 3. loss measure
        loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=score_final_train,
            labels=tf.squeeze(annotation, squeeze_dims=[3]),
            name="entropy")))
        tf.summary.scalar("entropy", loss)

        loss100 = tf.reduce_mean(
            (tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits100,
                labels=tf.squeeze(annotation, squeeze_dims=[3]),
                name="entropy")))
        tf.summary.scalar("entropy", loss100)

        loss075 = tf.reduce_mean(
            (tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits075,
                labels=tf.squeeze(annotation075, squeeze_dims=[3]),
                name="entropy")))
        tf.summary.scalar("entropy", loss075)

        loss050 = tf.reduce_mean(
            (tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logits050,
                labels=tf.squeeze(annotation050, squeeze_dims=[3]),
                name="entropy")))
        tf.summary.scalar("entropy", loss050)

        reduced_loss = loss + loss100 + loss075 + loss050

        # 4. optimizing
        trainable_var = tf.trainable_variables()
        if FLAGS.debug:
            for var in trainable_var:
                Utils.add_to_regularization_and_summary(var)

        train_op = train(reduced_loss, trainable_var, net100['global_step'])

    else:
        # apply attention model - test
        attn_input = []
        attn_input.append(att100)
        attn_input.append(
            tf.image.resize_images(att075,
                                   tf.shape(att100)[1:3, ]))
        attn_input.append(
            tf.image.resize_images(att125,
                                   tf.shape(att100)[1:3, ]))
        attn_input_test = tf.concat(attn_input, axis=3)
        attn_output_test = attention(attn_input_test, is_training)
        scale_att_mask = tf.nn.softmax(attn_output_test)

        score_att_x = tf.multiply(
            logits100,
            tf.image.resize_images(
                tf.expand_dims(scale_att_mask[:, :, :, 0], axis=3),
                tf.shape(logits100)[1:3, ]))
        score_att_x_075 = tf.multiply(
            tf.image.resize_images(logits075,
                                   tf.shape(logits100)[1:3, ]),
            tf.image.resize_images(
                tf.expand_dims(scale_att_mask[:, :, :, 1], axis=3),
                tf.shape(logits100)[1:3, ]))
        score_att_x_125 = tf.multiply(
            tf.image.resize_images(logits125,
                                   tf.shape(logits100)[1:3, ]),
            tf.image.resize_images(
                tf.expand_dims(scale_att_mask[:, :, :, 2], axis=3),
                tf.shape(logits100)[1:3, ]))

        score_final_test = score_att_x + score_att_x_075 + score_att_x_125

        final_annotation_pred = tf.expand_dims(tf.argmax(
            score_final_test, dimension=3, name="final_prediction"),
                                               dim=3)
        final_annotation_pred_test = tf.reduce_mean(tf.stack([
            tf.cast(final_annotation_pred, tf.float32),
            tf.cast(pred_annotation100, tf.float32),
            tf.image.resize_images(pred_annotation075,
                                   tf.shape(pred_annotation100)[1:3, ]),
            tf.image.resize_images(pred_annotation125,
                                   tf.shape(pred_annotation100)[1:3, ])
        ]),
                                                    axis=0)

    tf.summary.image("input_image", image, max_outputs=3)
    tf.summary.image("ground_truth",
                     tf.cast(annotation, tf.uint8),
                     max_outputs=3)

    tf.summary.image("pred_annotation",
                     tf.cast(pred_annotation100, tf.uint8),
                     max_outputs=3)

    print("Setting up summary op...")
    summary_op = tf.summary.merge_all()

    print("Setting up image reader from ", FLAGS.data_dir, "...")
    print("data dir:", FLAGS.data_dir)

    train_records, valid_records = fashion_parsing.read_dataset(FLAGS.data_dir)
    test_records = None
    if DATA_SET == "CFPD":
        train_records, valid_records, test_records = ClothingParsing.read_dataset(
            FLAGS.data_dir)
        print("test_records length :", len(test_records))
    if DATA_SET == "LIP":
        train_records, valid_records = HumanParsing.read_dataset(
            FLAGS.data_dir)

    print("train_records length :", len(train_records))
    print("valid_records length :", len(valid_records))

    print("Setting up dataset reader")
    train_dataset_reader = None
    validation_dataset_reader = None
    test_dataset_reader = None
    image_options = {'resize': True, 'resize_size': IMAGE_SIZE}

    if FLAGS.mode == 'train':
        train_dataset_reader = DataSetReader.BatchDatset(
            train_records, image_options)
        validation_dataset_reader = DataSetReader.BatchDatset(
            valid_records, image_options)
        if DATA_SET == "CFPD":
            test_dataset_reader = DataSetReader.BatchDatset(
                test_records, image_options)
    if FLAGS.mode == 'visualize':
        validation_dataset_reader = DataSetReader.BatchDatset(
            valid_records, image_options)
    if FLAGS.mode == 'test' or FLAGS.mode == 'crftest' or FLAGS.mode == 'predonly' or FLAGS.mode == "fulltest":
        if DATA_SET == "CFPD":
            test_dataset_reader = DataSetReader.BatchDatset(
                test_records, image_options)
        else:
            test_dataset_reader = DataSetReader.BatchDatset(
                valid_records, image_options)
            test_records = valid_records

    sess = tf.Session()

    print("Setting up Saver...")
    saver = tf.train.Saver()
    summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

    # 5. parameter setup
    # 5.1 init params
    sess.run(tf.global_variables_initializer())
    # 5.2 restore params if possible
    ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print("Model restored...")

    # 6. train-mode
    if FLAGS.mode == "train":

        fd.mode_train(sess, FLAGS, net100, train_dataset_reader,
                      validation_dataset_reader, train_records,
                      final_annotation_pred_train, image, annotation,
                      keep_probability, score_final_train, train_op,
                      reduced_loss, summary_op, summary_writer, saver,
                      DISPLAY_STEP)

    # test-random-validation-data mode
    elif FLAGS.mode == "visualize":

        fd.mode_visualize(sess, FLAGS, VIS_DIR, validation_dataset_reader,
                          final_annotation_pred_test, image, annotation,
                          keep_probability, NUM_OF_CLASSES)

    # test-full-validation-dataset mode
    elif FLAGS.mode == "test":

        fd.mode_new_test(sess, FLAGS, TEST_DIR, test_dataset_reader,
                         test_records, final_annotation_pred_test, image,
                         annotation, keep_probability, score_final_test,
                         NUM_OF_CLASSES)

        # fd.mode_test(sess, FLAGS, TEST_DIR, test_dataset_reader, test_records,
        # pred_annotation, image, annotation, keep_probability, logits, NUM_OF_CLASSES)

    elif FLAGS.mode == "crftest":

        fd.mode_predonly(sess, FLAGS, TEST_DIR, test_dataset_reader,
                         test_records, final_annotation_pred_test, image,
                         annotation, keep_probability, score_final_test,
                         NUM_OF_CLASSES)

    elif FLAGS.mode == "predonly":

        fd.mode_predonly(sess, FLAGS, TEST_DIR, test_dataset_reader,
                         test_records, final_annotation_pred_test, image,
                         annotation, keep_probability, score_final_test,
                         NUM_OF_CLASSES)

    elif FLAGS.mode == "fulltest":

        fd.mode_full_test(sess, FLAGS, TEST_DIR, test_dataset_reader,
                          test_records, final_annotation_pred_test, image,
                          annotation, keep_probability, score_final_test,
                          NUM_OF_CLASSES)

    sess.close()