def _calculate_vmf_loss(embedding, semantic_labels, unique_instance_labels, prototype_labels, concentration, memory=None, memory_labels=None, split=1): """Calculates the von-Mises Fisher loss given semantic and cluster labels. Args: embedding: A 2-D float tensor with shape [num_pixels, embedding_dim]. semantic_labels: A 1-D integer tensor with length [num_pixels]. It contains the semantic label for each pixel. unique_instance_labels: A 1-D integer tensor with length [num_pixels]. It contains the unique instance label for each pixel. prototype_labels: A 1-D integer tensor with length [num_prototypes]. It contains the semantic label for each prototype. concentration: A float that controls the sharpness of cosine similarities. memory: A 2-D float tensor for memory prototypes with shape `[num_prototypes, embedding_dim]`. memory_labels: A 1-D integer tensor for labels of memory prototypes with length `[num_prototypes]`. split: An integer for number of splits of matrix multiplication. Returns: loss: A float for the von-Mises Fisher loss. new_memory: A 2-D float tensor for the memory prototypes to update with shape `[num_prototypes, embedding_dim]`. new_memory_labels: A 1-D integer tensor for labels of memory prototypes to update with length `[num_prototypes]`. """ prototypes = common_utils.calculate_prototypes_from_labels( embedding, unique_instance_labels, tf.size(prototype_labels)) if memory is not None: memory = common_utils.normalize_embedding(memory) rand_index = tf.random_shuffle(tf.range(tf.shape(prototype_labels)[0])) new_memory = tf.squeeze(tf.gather(prototypes, rand_index)) new_memory_labels = tf.squeeze(tf.gather(prototype_labels, rand_index)) prototypes = tf.concat([prototypes, memory], 0) prototype_labels = tf.concat([prototype_labels, memory_labels], 0) else: new_memory = new_memory_labels = None similarities = _calculate_similarities(embedding, prototypes, concentration, split) log_likelihood = _calculate_log_likelihood(similarities, unique_instance_labels, semantic_labels, prototype_labels) loss = tf.reduce_mean(log_likelihood) return loss, new_memory, new_memory_labels
def add_unsupervised_segsort_loss(embedding, concentration, cluster_labels, num_banks=0, loss_scope=None): with tf.name_scope(loss_scope, 'unsupervised_segsort_loss', (embedding, concentration, cluster_labels, num_banks)): # Normalize embedding. embedding = common_utils.normalize_embedding(embedding) shape = embedding.get_shape().as_list() batch_size = shape[0] embedding_dim = shape[3] # Add offset to cluster labels. max_clusters = 256 offset = tf.range(0, max_clusters * batch_size, max_clusters) cluster_labels += tf.reshape(offset, [-1, 1, 1, 1]) _, cluster_labels = tf.unique(tf.reshape(cluster_labels, [-1])) # Calculate prototypes. embedding = tf.reshape(embedding, [-1, embedding_dim]) prototypes = common_utils.calculate_prototypes_from_labels( embedding, cluster_labels) similarities = _calculate_similarities(embedding, prototypes, concentration, batch_size) # Calculate the unsupervised loss. self_indices = tf.concat([ tf.expand_dims(tf.range(tf.shape(similarities)[0]), 1), tf.expand_dims(cluster_labels, 1) ], axis=1) numerator = tf.reshape(tf.gather_nd(similarities, self_indices), [-1]) denominator = tf.reduce_sum(similarities, axis=1) probabilities = tf.divide(numerator, denominator) return tf.reduce_mean(-tf.log(probabilities))
def main(): """Creates the model and start the inference process.""" args = get_arguments() # Parse image processing arguments. input_size = parse_commastr(args.input_size) strides = parse_commastr(args.strides) assert (input_size is not None and strides is not None) h, w = input_size innet_size = (int(math.ceil(h / 8)), int(math.ceil(w / 8))) # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = SegSortImageReader( args.data_dir, args.data_list, None, False, # No random scale False, # No random mirror False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image = reader.image[:reader.height, :reader.width] image_list = reader.image_list image_batch = tf.expand_dims(image, dim=0) # Create multi-scale augmented datas. rescale_image_batches = [] is_flipped = [] scales = [0.5, 0.75, 1, 1.25, 1.5, 1.75] if args.scale_aug else [1] for scale in scales: h_new = tf.to_int32( tf.multiply(tf.to_float(tf.shape(image_batch)[1]), scale)) w_new = tf.to_int32( tf.multiply(tf.to_float(tf.shape(image_batch)[2]), scale)) new_shape = tf.stack([h_new, w_new]) new_image_batch = tf.image.resize_images(image_batch, new_shape) rescale_image_batches.append(new_image_batch) is_flipped.append(False) # Create horizontally flipped augmented datas. if args.flip_aug: for i in range(len(scales)): img = rescale_image_batches[i] is_flip = is_flipped[i] img = tf.squeeze(img, axis=0) flip_img = tf.image.flip_left_right(img) flip_img = tf.expand_dims(flip_img, axis=0) rescale_image_batches.append(flip_img) is_flipped.append(True) # Create input tensor to the Network. crop_image_batch = tf.placeholder( name='crop_image_batch', shape=[1, input_size[0], input_size[1], 3], dtype=tf.float32) # Create network. outputs = model(crop_image_batch, args.embedding_dim, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'crop_image_batch' not in v.name ] # Output predictions. output = outputs[0] output = tf.image.resize_bilinear(output, [input_size[0], input_size[1]]) embedding = common_utils.normalize_embedding(output) # Prototype placeholders. prototype_features = tf.placeholder(tf.float32, shape=[None, args.embedding_dim]) prototype_labels = tf.placeholder(tf.int32) # Combine embedding with location features and kmeans shape = embedding.get_shape().as_list() loc_feature = tf.expand_dims( common_utils.generate_location_features([shape[1], shape[2]], 'float'), 0) embedding_with_location = tf.concat([embedding, loc_feature], 3) embedding_with_location = common_utils.normalize_embedding( embedding_with_location) # Perform Kmeans clustering and extract prototypes. cluster_labels = common_utils.kmeans( embedding_with_location, [args.num_clusters, args.num_clusters], args.kmeans_iterations) _, cluster_labels = tf.unique(tf.reshape(cluster_labels, [-1])) test_prototypes = common_utils.calculate_prototypes_from_labels( embedding, cluster_labels) # Predict semantic labels. similarities = tf.matmul(test_prototypes, prototype_features, transpose_b=True) _, k_predictions = tf.nn.top_k(similarities, k=args.k_in_nearest_neighbors, sorted=True) k_predictions = tf.gather(prototype_labels, k_predictions) k_predictions = tf.gather(k_predictions, cluster_labels) k_predictions = tf.reshape( k_predictions, [shape[0], shape[1], shape[2], args.k_in_nearest_neighbors]) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Get colormap. map_data = scipy.io.loadmat(args.colormap) key = os.path.basename(args.colormap).replace('.mat', '') colormap = map_data[key] colormap *= 255 colormap = colormap.astype(np.uint8) # Create directory for saving predictions. pred_dir = os.path.join(args.save_dir, 'gray') color_dir = os.path.join(args.save_dir, 'color') if not os.path.isdir(pred_dir): os.makedirs(pred_dir) if not os.path.isdir(color_dir): os.makedirs(color_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n')) - 1 # Load prototype features and labels. prototype_features_np = np.load( os.path.join(args.prototype_dir, 'prototype_features.npy')) prototype_labels_np = np.load( os.path.join(args.prototype_dir, 'prototype_labels.npy')) feed_dict = { prototype_features: prototype_features_np, prototype_labels: prototype_labels_np } pbar = tqdm(range(num_steps)) for step in pbar: rescale_img_batches = sess.run(rescale_image_batches) # Final segmentation results (average across multiple scales). scale_ind = 2 if args.scale_aug else 0 final_lab_size = list(rescale_img_batches[scale_ind].shape[1:]) final_lab_size[-1] = args.num_classes final_lab_batch = np.zeros(final_lab_size) # Iterate over multiple scales. for img_batch, is_flip in zip(rescale_img_batches, is_flipped): img_size = img_batch.shape padimg_size = list(img_size) # deep copy of img_size padimg_h, padimg_w = padimg_size[1:3] input_h, input_w = input_size if input_h > padimg_h: padimg_h = input_h if input_w > padimg_w: padimg_w = input_w # Update padded image size. padimg_size[1] = padimg_h padimg_size[2] = padimg_w padimg_batch = np.zeros(padimg_size, dtype=np.float32) img_h, img_w = img_size[1:3] padimg_batch[:, :img_h, :img_w, :] = img_batch stride_h, stride_w = strides npatches_h = math.ceil(1.0 * (padimg_h - input_h) / stride_h) + 1 npatches_w = math.ceil(1.0 * (padimg_w - input_w) / stride_w) + 1 # Create padded prediction array. pred_size = list(padimg_size) pred_size[-1] = args.num_classes predictions_np = np.zeros(pred_size, dtype=np.int32) # Create the ending index of each patch. patch_indh = np.linspace(input_h, padimg_h, npatches_h, dtype=np.int32) patch_indw = np.linspace(input_w, padimg_w, npatches_w, dtype=np.int32) pred_size[-1] = args.embedding_dim for indh in patch_indh: for indw in patch_indw: sh, eh = indh - input_h, indh # start & end ind of H sw, ew = indw - input_w, indw # start & end ind of W cropimg_batch = padimg_batch[:, sh:eh, sw:ew, :] feed_dict[crop_image_batch] = cropimg_batch k_predictions_np = sess.run(k_predictions, feed_dict=feed_dict) # Sum up KNN votes. # This is the speed bottleneck for multiscale inference. # Use singlescale inference for fast results. # TODO: Either compute on GPU or change a way of implementation. for c in range(args.num_classes): predictions_np[:, sh:eh, sw:ew, c] += np.sum( (k_predictions_np == c).astype(np.int), axis=3) predictions_np = predictions_np[0, :img_h, :img_w, :] lab_batch = predictions_np.astype(np.float32) # Rescale prediction back to original resolution. lab_batch = cv2.resize(lab_batch, (final_lab_size[1], final_lab_size[0]), interpolation=cv2.INTER_LINEAR) if is_flip: # Flipped prediction back to original orientation. lab_batch = lab_batch[:, ::-1, :] final_lab_batch += lab_batch final_lab_ind = np.argmax(final_lab_batch, axis=-1) final_lab_ind = final_lab_ind.astype(np.uint8) basename = os.path.basename(image_list[step]) basename = basename.replace('jpg', 'png') predname = os.path.join(pred_dir, basename) Image.fromarray(final_lab_ind, mode='L').save(predname) colorname = os.path.join(color_dir, basename) color = colormap[final_lab_ind] Image.fromarray(color, mode='RGB').save(colorname) coord.request_stop() coord.join(threads)
def extract_trained_prototypes(embedding, location_features, cluster_labels, num_clusters, kmeans_iterations, panoptic_labels, panoptic_label_divisor, ignore_label, evaluate_semantic_or_panoptic): """Extracts the trained prototypes in an image. Args: embedding: A 2-D float tensor with shape `[pixels, embedding_dim]`. location_features: A 2-D float tensor for location features with shape `[pixels, 2]`. cluster_labels: A 1-D integer tensor for cluster labels for all pixels. num_clusters: An integer scalar for total number of clusters. kmeans_iterations: Number of iterations for the k-means clustering. panoptic_labels: A 1-D integer tensor for panoptic labels for all pixels. panoptic_label_divisor: An integer constant to separate semantic and instance labels from panoptic labels. ignore_label: The semantic label to ignore. evaluate_semantic_or_panoptic: A boolean that specifies whether to evaluate semantic or panoptic segmentation. Returns: prototype_features: A 2-D float tensor for prototype features with shape `[num_prototypes, embedding_dim]`. prototype_labels: A 1-D integer tensor for prototype labels. """ # Collect pixels of valid semantic classes. valid_pixels = tf.where( tf.not_equal(panoptic_labels // panoptic_label_divisor, ignore_label)) panoptic_labels = tf.squeeze(tf.gather(panoptic_labels, valid_pixels), axis=1) cluster_labels = tf.squeeze(tf.gather(cluster_labels, valid_pixels), axis=1) embedding = tf.squeeze(tf.gather(embedding, valid_pixels), axis=1) location_features = tf.squeeze(tf.gather(location_features, valid_pixels), axis=1) # Generate cluster labels via kmeans clustering. embedding_with_location = tf.concat([embedding, location_features], 1) embedding_with_location = common_utils.normalize_embedding( embedding_with_location) cluster_labels = common_utils.kmeans_with_initial_labels( embedding_with_location, cluster_labels, num_clusters, kmeans_iterations) _, cluster_labels = tf.unique(cluster_labels) if evaluate_semantic_or_panoptic == 'panoptic': # Calculate semantic and unique instance labels for all pixels. label_mapping, unique_panoptic_labels = tf.unique(panoptic_labels) # Find pixels of majority classes. select_pixels, majority_labels = find_majority_label_index( unique_panoptic_labels, cluster_labels) else: # Find pixels of majority semantic classes. semantic_labels = panoptic_labels // panoptic_label_divisor select_pixels, majority_labels = find_majority_label_index( semantic_labels, cluster_labels) cluster_labels = tf.squeeze(tf.gather(cluster_labels, select_pixels), axis=1) embedding = tf.squeeze(tf.gather(embedding, select_pixels), axis=1) # Calculate the majority semantic and instance label for each prototype. if evaluate_semantic_or_panoptic == 'panoptic': prototype_panoptic_labels = tf.gather(label_mapping, majority_labels) prototype_semantic_labels = (prototype_panoptic_labels // panoptic_label_divisor) prototype_instance_labels = majority_labels else: prototype_semantic_labels = majority_labels prototype_instance_labels = tf.zeros_like(majority_labels) # Calculate the prototype features. prototype_features = common_utils.calculate_prototypes_from_labels( embedding, cluster_labels) return (prototype_features, prototype_semantic_labels, prototype_instance_labels)
def predict_all_labels(embedding, num_clusters, kmeans_iterations, prototype_features, prototype_semantic_labels, prototype_instance_labels, k_in_nearest_neighbors, panoptic_label_divisor, class_has_instances_list): """Predicts panoptic, semantic, and instance labels using the vMF embedding. Args: embedding: A 4-D float tensor with shape `[batch, height, width, embedding_dim]`. num_clusters: A list of 2 integers for number of clusters in y and x axes. kmeans_iterations: Number of iterations for the k-means clustering. prototype_features: A 2-D float tensor for trained prototype features with shape `[num_prototypes, embedding_dim]`. prototype_semantic_labels: A 1-D integer tensor for trained prototype semantic labels with length `[num_prototypes]`. prototype_instance_labels: A 1-D integer tensor for trained prototype instance labels with length `[num_prototypes]`. k_in_nearest_neighbors: The number of nearest neighbors to search, or k in k-nearest neighbors. panoptic_label_divisor: An integer constant to separate semantic and instance labels from panoptic labels. class_has_instances_list: A list of thing classes, which have instances. Returns: panoptic_predictions: A 1-D integer tensor for pixel panoptic predictions. semantic_predictions: A 1-D integer tensor for pixel semantic predictions. instance_predictions: A 1-D integer tensor for pixel instance predictions. """ # Generate location features and combine them with embedding features. shape = embedding.get_shape().as_list() location_features = common_utils.generate_location_features( [shape[1], shape[2]], 'float') location_features = tf.expand_dims(location_features, 0) embedding_with_location = tf.concat([embedding, location_features], 3) embedding_with_location = common_utils.normalize_embedding( embedding_with_location) # Kmeans clustering. cluster_labels = common_utils.kmeans(embedding_with_location, num_clusters, kmeans_iterations) test_prototypes = common_utils.calculate_prototypes_from_labels( embedding, cluster_labels) # Predict semantic and instance labels. semantic_predictions, instance_predictions = predict_semantic_instance_labels( cluster_labels, test_prototypes, prototype_features, prototype_semantic_labels, prototype_instance_labels, k_in_nearest_neighbors) # Refine instance labels. class_has_instances_list = tf.reshape(class_has_instances_list, [1, 1, 1, -1]) instance_predictions = tf.where( tf.reduce_all(tf.not_equal(tf.expand_dims(semantic_predictions, 3), class_has_instances_list), axis=3), tf.zeros_like(instance_predictions), instance_predictions) # Combine semantic and panoptic predictions as panoptic predictions. panoptic_predictions = (semantic_predictions * panoptic_label_divisor + instance_predictions) return (panoptic_predictions, semantic_predictions, instance_predictions, cluster_labels)
def main(): """Create the model and start the Inference process.""" args = get_arguments() # Create queue coordinator. coord = tf.train.Coordinator() # Load the data reader. with tf.name_scope('create_inputs'): reader = SegSortImageReader( args.data_dir, args.data_list, parse_commastr(args.input_size), False, # No random scale False, # No random mirror False, # No random crop, center crop instead args.ignore_label, IMG_MEAN) image_list = reader.image_list image_batch = tf.expand_dims(reader.image, dim=0) label_batch = tf.expand_dims(reader.label, dim=0) cluster_label_batch = tf.expand_dims(reader.cluster_label, dim=0) loc_feature_batch = tf.expand_dims(reader.loc_feature, dim=0) height = reader.height width = reader.width # Create network and output prediction. outputs = model(image_batch, args.embedding_dim, False, True) # Grab variable names which should be restored from checkpoints. restore_var = [ v for v in tf.global_variables() if 'crop_image_batch' not in v.name ] # Output predictions. output = outputs[0] output = tf.image.resize_bilinear(output, tf.shape(image_batch)[1:3, ]) embedding = common_utils.normalize_embedding(output) # Prototype placeholders. prototype_features = tf.placeholder(tf.float32, shape=[None, args.embedding_dim]) prototype_labels = tf.placeholder(tf.int32) # Combine embedding with location features. embedding_with_location = tf.concat([embedding, loc_feature_batch], 3) embedding_with_location = common_utils.normalize_embedding( embedding_with_location) # Kmeans clustering. cluster_labels = common_utils.kmeans( embedding_with_location, [args.num_clusters, args.num_clusters], args.kmeans_iterations) test_prototypes = common_utils.calculate_prototypes_from_labels( embedding, cluster_labels) # Predict semantic labels. semantic_predictions, _ = eval_utils.predict_semantic_instance_labels( cluster_labels, test_prototypes, prototype_features, prototype_labels, None, args.k_in_nearest_neighbors) semantic_predictions = tf.cast(semantic_predictions, tf.uint8) semantic_predictions = tf.squeeze(semantic_predictions) # Set up tf session and initialize variables. config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) sess.run(tf.local_variables_initializer()) # Load weights. loader = tf.train.Saver(var_list=restore_var) if args.restore_from is not None: load(loader, sess, args.restore_from) # Start queue threads. threads = tf.train.start_queue_runners(coord=coord, sess=sess) # Get colormap. map_data = scipy.io.loadmat(args.colormap) key = os.path.basename(args.colormap).replace('.mat', '') colormap = map_data[key] colormap *= 255 colormap = colormap.astype(np.uint8) # Create directory for saving predictions. pred_dir = os.path.join(args.save_dir, 'gray') color_dir = os.path.join(args.save_dir, 'color') if not os.path.isdir(pred_dir): os.makedirs(pred_dir) if not os.path.isdir(color_dir): os.makedirs(color_dir) # Iterate over testing steps. with open(args.data_list, 'r') as listf: num_steps = len(listf.read().split('\n')) - 1 # Load prototype features and labels. prototype_features_np = np.load( os.path.join(args.prototype_dir, 'prototype_features.npy')) prototype_labels_np = np.load( os.path.join(args.prototype_dir, 'prototype_labels.npy')) feed_dict = { prototype_features: prototype_features_np, prototype_labels: prototype_labels_np } for step in tqdm(range(num_steps)): semantic_predictions_np, height_np, width_np = sess.run( [semantic_predictions, height, width], feed_dict=feed_dict) semantic_predictions_np = semantic_predictions_np[:height_np, : width_np] basename = os.path.basename(image_list[step]) basename = basename.replace('jpg', 'png') predname = os.path.join(pred_dir, basename) Image.fromarray(semantic_predictions_np, mode='L').save(predname) colorname = os.path.join(color_dir, basename) color = colormap[semantic_predictions_np] Image.fromarray(color, mode='RGB').save(colorname) coord.request_stop() coord.join(threads)