コード例 #1
0
ファイル: train.py プロジェクト: zhulf0804/PSPNet-Tensorflow
with tf.name_scope('input'):
    x = tf.placeholder(dtype=tf.float32,
                       shape=[
                           FLAGS.batch_size, FLAGS.crop_height,
                           FLAGS.crop_width, FLAGS.channels
                       ],
                       name='x_input')
    y = tf.placeholder(
        dtype=tf.int32,
        shape=[FLAGS.batch_size, FLAGS.crop_height, FLAGS.crop_width],
        name='ground_truth')

auxi_logits, logits = pspnet.PSPNet(
    x,
    is_training=True,
    output_stride=FLAGS.output_stride,
    pre_trained_model=FLAGS.pretrained_model_path,
    classes=FLAGS.classes)

with tf.name_scope('regularization'):
    train_var_list = [
        v for v in tf.trainable_variables()
        if 'beta' not in v.name and 'gamma' not in v.name
    ]
    # Add weight decay to the loss.
    with tf.variable_scope("total_loss"):
        l2_loss = FLAGS.weight_decay * tf.add_n(
            [tf.nn.l2_loss(v) for v in train_var_list])

with tf.name_scope('loss'):
    #reshaped_logits = tf.reshape(logits, [BATCH_SIZE, -1])
コード例 #2
0
                                      crop_width=FLAGS.crop_width,
                                      classes=FLAGS.classes,
                                      ignore_label=FLAGS.ignore_label,
                                      scales=FLAGS.scales)

with tf.name_scope("input"):

    x = tf.placeholder(tf.float32,
                       [FLAGS.batch_size, FLAGS.height, FLAGS.width, 3],
                       name='x_input')
    y = tf.placeholder(tf.int32, [FLAGS.batch_size, FLAGS.height, FLAGS.width],
                       name='ground_truth')

_, logits = PSPNet.PSPNet(x,
                          is_training=False,
                          output_stride=FLAGS.output_stride,
                          pre_trained_model=FLAGS.pretrained_model_path,
                          classes=FLAGS.classes)

with tf.name_scope('prediction_and_miou'):

    prediction = tf.argmax(logits, axis=-1, name='predictions')


def get_val_predictions():

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
コード例 #3
0
    for i in range(height):
        for j in range(width):
            return_img[i, j, :] = cmap[image[i, j]]

    return return_img


image_batch_0, image_batch, anno_batch, filename = input_data.read_batch(BATCH_SIZE, type=prediction_on)


with tf.name_scope("input"):

    x = tf.placeholder(tf.float32, [BATCH_SIZE, HEIGHT, WIDTH, 3], name='x_input')
    y = tf.placeholder(tf.int32, [BATCH_SIZE, HEIGHT, WIDTH], name='ground_truth')

_, logits = PSPNet.PSPNet(x, is_training=False, output_stride=8, pre_trained_model=PRETRAINED_MODEL_PATH)


with tf.name_scope('prediction_and_miou'):

    prediction = tf.argmax(logits, axis=-1, name='predictions')



with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    #saver.restore(sess, './checkpoint/pspnet.model-2000')