Example #1
0
MODEL_DIR = os.path.join(os.path.dirname(__file__), 'data', 'model')
if not os.path.exists(MODEL_DIR):
    os.mkdir(MODEL_DIR)

image_batch = tf.placeholder(tf.float32,
                             shape=[None, NUM_REF + NUM_TARGET] + IMAGE_SIZE +
                             [3],
                             name='images')

##### color clustering
kmeans = Clustering(tf.reshape(image_batch[:, :, :, :, 1:], [-1, 2]),
                    NUM_CLUSTERS,
                    mini_batch_steps_per_iteration=KMEANS_STEPS_PER_ITERATION)
image_batch_flat = tf.reshape(image_batch, [-1] + IMAGE_SIZE + [3])
labels = tf.image.resize_images(image_batch_flat, FEATURE_MAP_SIZE)
labels = kmeans.lab_to_labels(labels)
labels = tf.reshape(labels, [-1, NUM_REF + NUM_TARGET] + FEATURE_MAP_SIZE)
labels = tf.placeholder_with_default(labels, [None, NUM_REF + NUM_TARGET] +
                                     FEATURE_MAP_SIZE,
                                     name='labels')

##### extract features from gray scale image (only L channel) using CNN
if USE_CONV3D:
    inputs = image_batch[:, :, :, :, 0:1]
else:
    inputs = image_batch_flat[:, :, :, 0:1]
is_training = tf.placeholder_with_default(False, [], name='is_training')
feature_map = feature_extractor(inputs,
                                dim=FEATURE_DIM,
                                weight_decay=WEIGHT_DECAY,
                                batch_norm_decay=BATCH_NORM_DECAY,
def _build_graph(image_batch):
    global_step = tf.Variable(0, trainable=False, name='global_step')
    t = tf.cast(global_step, tf.float32)
    
    ##### color clustering
    kmeans = Clustering(tf.reshape(image_batch[:,:,:,:,1:], [-1,2]), NUM_CLUSTERS,
                        mini_batch_steps_per_iteration=KMEANS_STEPS_PER_ITERATION)
    image_batch_flat = tf.reshape(image_batch, [-1]+IMAGE_SIZE+[3])
    labels = tf.image.resize_images(image_batch_flat, FEATURE_MAP_SIZE)
    labels = kmeans.lab_to_labels(labels)
    labels = tf.reshape(labels, [-1,NUM_REF+NUM_TARGET]+FEATURE_MAP_SIZE, name='labels')

    ##### extract features from gray scale image (only L channel) using CNN
    if USE_CONV3D:
        inputs = image_batch[:,:,:,:,0:1]
    else:
        inputs = image_batch_flat[:,:,:,0:1]
    is_training = tf.placeholder_with_default(False, [], name='is_training')
    feature_map = feature_extractor(inputs,
                                    dim = FEATURE_DIM,
                                    weight_decay = WEIGHT_DECAY,
                                    batch_norm_decay = BATCH_NORM_DECAY,
                                    batch_renorm_decay = BATCH_RENORM_DECAY,
                                    batch_renorm_rmax = BATCH_RENORM_RMAX(t),
                                    batch_renorm_dmax = BATCH_RENORM_DMAX(t),
                                    is_training = is_training,
                                    use_conv3d = USE_CONV3D)
    if not USE_CONV3D:
        feature_map = tf.reshape(
            feature_map, [-1,NUM_REF+NUM_TARGET]+FEATURE_MAP_SIZE+[FEATURE_DIM])
    # rename with tf.identity so that it can be easily fetched/fed at sess.run
    feature_map = tf.identity(feature_map, name='features')

    ##### predict the color (or other category) on the basis of the features
    def loop_body(i, losses, predictions, predictions_lab):
        f = feature_map[i]
        l = labels[i]
        end_points = colorizer(f[:NUM_REF], tf.one_hot(l[:NUM_REF], NUM_CLUSTERS),
                               f[NUM_REF:], l[NUM_REF:])
        mean_losses = tf.reduce_mean(tf.reduce_mean(end_points['losses'], 2), 1)
        losses = tf.concat([losses, tf.expand_dims(mean_losses, 0)], 0)
        pred = end_points['predictions']
        predictions = tf.concat([predictions, tf.expand_dims(pred, 0)], 0)
        predictions_lab = tf.concat([predictions_lab, tf.expand_dims(kmeans.labels_to_lab(pred), 0)], 0)
        return i+1, losses, predictions, predictions_lab
    loop_cond = lambda i, _1, _2, _3: tf.less(i, BATCH_SIZE)
    loop_vars = [tf.constant(0),
                 tf.zeros([0,NUM_TARGET], dtype=tf.float32),
                 tf.zeros([0,NUM_TARGET]+FEATURE_MAP_SIZE+[NUM_CLUSTERS]),
                 tf.zeros([0,NUM_TARGET]+FEATURE_MAP_SIZE+[3])]
    shape_invariants = [tf.TensorShape([]),
                        tf.TensorShape([None,NUM_TARGET]),
                        tf.TensorShape([None,NUM_TARGET]+FEATURE_MAP_SIZE+[NUM_CLUSTERS]),
                        tf.TensorShape([None,NUM_TARGET]+FEATURE_MAP_SIZE+[3])]
    _, losses, predictions, predictions_lab = tf.while_loop(loop_cond, loop_body, loop_vars,
                                                            shape_invariants=shape_invariants)
    predictions = tf.identity(predictions, name='predictions')
    predictions_lab = tf.identity(predictions_lab, name='predictions_lab')
    losses = tf.identity(losses, name='losses')

    ##### calculate differences between reference and target images
    #[BATCH_SIZE,NUM_REF+NUM_TARGET,NUM_CLUSTERS]
    pq = tf.reduce_mean(tf.reduce_mean(tf.one_hot(labels, NUM_CLUSTERS), 2), 2)