MODEL_DIR = os.path.join(os.path.dirname(__file__), 'data', 'model') if not os.path.exists(MODEL_DIR): os.mkdir(MODEL_DIR) image_batch = tf.placeholder(tf.float32, shape=[None, NUM_REF + NUM_TARGET] + IMAGE_SIZE + [3], name='images') ##### color clustering kmeans = Clustering(tf.reshape(image_batch[:, :, :, :, 1:], [-1, 2]), NUM_CLUSTERS, mini_batch_steps_per_iteration=KMEANS_STEPS_PER_ITERATION) image_batch_flat = tf.reshape(image_batch, [-1] + IMAGE_SIZE + [3]) labels = tf.image.resize_images(image_batch_flat, FEATURE_MAP_SIZE) labels = kmeans.lab_to_labels(labels) labels = tf.reshape(labels, [-1, NUM_REF + NUM_TARGET] + FEATURE_MAP_SIZE) labels = tf.placeholder_with_default(labels, [None, NUM_REF + NUM_TARGET] + FEATURE_MAP_SIZE, name='labels') ##### extract features from gray scale image (only L channel) using CNN if USE_CONV3D: inputs = image_batch[:, :, :, :, 0:1] else: inputs = image_batch_flat[:, :, :, 0:1] is_training = tf.placeholder_with_default(False, [], name='is_training') feature_map = feature_extractor(inputs, dim=FEATURE_DIM, weight_decay=WEIGHT_DECAY, batch_norm_decay=BATCH_NORM_DECAY,
def _build_graph(image_batch): global_step = tf.Variable(0, trainable=False, name='global_step') t = tf.cast(global_step, tf.float32) ##### color clustering kmeans = Clustering(tf.reshape(image_batch[:,:,:,:,1:], [-1,2]), NUM_CLUSTERS, mini_batch_steps_per_iteration=KMEANS_STEPS_PER_ITERATION) image_batch_flat = tf.reshape(image_batch, [-1]+IMAGE_SIZE+[3]) labels = tf.image.resize_images(image_batch_flat, FEATURE_MAP_SIZE) labels = kmeans.lab_to_labels(labels) labels = tf.reshape(labels, [-1,NUM_REF+NUM_TARGET]+FEATURE_MAP_SIZE, name='labels') ##### extract features from gray scale image (only L channel) using CNN if USE_CONV3D: inputs = image_batch[:,:,:,:,0:1] else: inputs = image_batch_flat[:,:,:,0:1] is_training = tf.placeholder_with_default(False, [], name='is_training') feature_map = feature_extractor(inputs, dim = FEATURE_DIM, weight_decay = WEIGHT_DECAY, batch_norm_decay = BATCH_NORM_DECAY, batch_renorm_decay = BATCH_RENORM_DECAY, batch_renorm_rmax = BATCH_RENORM_RMAX(t), batch_renorm_dmax = BATCH_RENORM_DMAX(t), is_training = is_training, use_conv3d = USE_CONV3D) if not USE_CONV3D: feature_map = tf.reshape( feature_map, [-1,NUM_REF+NUM_TARGET]+FEATURE_MAP_SIZE+[FEATURE_DIM]) # rename with tf.identity so that it can be easily fetched/fed at sess.run feature_map = tf.identity(feature_map, name='features') ##### predict the color (or other category) on the basis of the features def loop_body(i, losses, predictions, predictions_lab): f = feature_map[i] l = labels[i] end_points = colorizer(f[:NUM_REF], tf.one_hot(l[:NUM_REF], NUM_CLUSTERS), f[NUM_REF:], l[NUM_REF:]) mean_losses = tf.reduce_mean(tf.reduce_mean(end_points['losses'], 2), 1) losses = tf.concat([losses, tf.expand_dims(mean_losses, 0)], 0) pred = end_points['predictions'] predictions = tf.concat([predictions, tf.expand_dims(pred, 0)], 0) predictions_lab = tf.concat([predictions_lab, tf.expand_dims(kmeans.labels_to_lab(pred), 0)], 0) return i+1, losses, predictions, predictions_lab loop_cond = lambda i, _1, _2, _3: tf.less(i, BATCH_SIZE) loop_vars = [tf.constant(0), tf.zeros([0,NUM_TARGET], dtype=tf.float32), tf.zeros([0,NUM_TARGET]+FEATURE_MAP_SIZE+[NUM_CLUSTERS]), tf.zeros([0,NUM_TARGET]+FEATURE_MAP_SIZE+[3])] shape_invariants = [tf.TensorShape([]), tf.TensorShape([None,NUM_TARGET]), tf.TensorShape([None,NUM_TARGET]+FEATURE_MAP_SIZE+[NUM_CLUSTERS]), tf.TensorShape([None,NUM_TARGET]+FEATURE_MAP_SIZE+[3])] _, losses, predictions, predictions_lab = tf.while_loop(loop_cond, loop_body, loop_vars, shape_invariants=shape_invariants) predictions = tf.identity(predictions, name='predictions') predictions_lab = tf.identity(predictions_lab, name='predictions_lab') losses = tf.identity(losses, name='losses') ##### calculate differences between reference and target images #[BATCH_SIZE,NUM_REF+NUM_TARGET,NUM_CLUSTERS] pq = tf.reduce_mean(tf.reduce_mean(tf.one_hot(labels, NUM_CLUSTERS), 2), 2)