def _build_pseudo_seg(iterator_seg, iterator, outputs_to_num_classes, ignore_label, batch_size=8): """Builds a clone of PseudoSeg. Args: iterator_seg: An iterator of type tf.data.Iterator for images and labels. (seg) iterator: An iterator of type tf.data. Iterator for images and labels. outputs_to_num_classes: A map from output type to the number of classes. For example, for the task of semantic segmentation with 21 semantic classes, we would have outputs_to_num_classes['semantic'] = 21. ignore_label: Ignore label. batch_size: Training batch size for each clone. """ samples_cls = iterator.get_next() samples_cls[common.IMAGE] = tf.identity(samples_cls[common.IMAGE], name='weak') samples_cls['strong'] = tf.identity(samples_cls['strong'], name='strong') samples_cls[common.LABEL] = tf.identity(samples_cls[common.LABEL], name='unlabeled') samples_seg = iterator_seg.get_next() samples_seg[common.IMAGE] = tf.identity(samples_seg[common.IMAGE], name=common.IMAGE + '_seg') samples_seg[common.LABEL] = tf.identity(samples_seg[common.LABEL], name=common.LABEL + '_seg') model_options = common.ModelOptions( outputs_to_num_classes=outputs_to_num_classes, crop_size=[int(sz) for sz in FLAGS.train_crop_size], atrous_rates=FLAGS.atrous_rates, output_stride=FLAGS.output_stride) if FLAGS.att_v2: cam_func = train_utils_core.compute_cam_v2 else: cam_func = train_utils_core.compute_cam ### Cls/unlabeled data ## 1) If we have image-level label, we train the classifier here if FLAGS.weakly: if FLAGS.cls_with_cls: curr_samples = samples_cls else: curr_samples = samples_seg _, end_points_cls = feature_extractor.extract_features( curr_samples[common.IMAGE], output_stride=model_options.output_stride, multi_grid=model_options.multi_grid, model_variant=model_options.model_variant, depth_multiplier=model_options.depth_multiplier, divisible_by=model_options.divisible_by, weight_decay=FLAGS.weight_decay, reuse=tf.AUTO_REUSE, is_training=True, preprocessed_images_dtype=model_options.preprocessed_images_dtype, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, global_pool=True, num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1) # ResNet beta version has an additional suffix in FLAGS.model_variant, but # it shares the same variable names with original version. Add a special # handling here for beta version ResNet. logits_cls = end_points_cls['{}/logits'.format( FLAGS.model_variant).replace('_beta', '')] logits_cls = tf.reshape( logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1]) # Seems that people usually use multi-label soft margin loss in PyTorch loss_cls = tf.nn.sigmoid_cross_entropy_with_logits( labels=curr_samples['cls_label'], logits=logits_cls) loss_cls = tf.reduce_mean(loss_cls) loss_cls = tf.identity(loss_cls, name='loss_cls') tf.compat.v1.losses.add_loss(loss_cls) ## 2) Consistency with tf.name_scope('cls_weak'): outputs_to_scales_to_logits, _ = model.multi_scale_logits( samples_cls[common.IMAGE], model_options=model_options, image_pyramid=FLAGS.image_pyramid, weight_decay=FLAGS.weight_decay, is_training=True, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, nas_training_hyper_parameters={ 'drop_path_keep_prob': FLAGS.drop_path_keep_prob, 'total_training_steps': FLAGS.training_number_of_steps, }, output_end_points=True) logits_weak = outputs_to_scales_to_logits[common.OUTPUT_TYPE][ model.MERGED_LOGITS_SCOPE] prob_weak = tf.nn.softmax(logits_weak, axis=-1) logits_weak = tf.identity(logits_weak, name='logits_weak') # Monitor max score max_prob_weak = tf.reduce_max(prob_weak, axis=-1) max_prob_weak = tf.identity(max_prob_weak, name='max_prob_weak') valid_mask_pad = samples_cls['valid'] valid_mask_pad = tf.compat.v1.image.resize_nearest_neighbor( valid_mask_pad, preprocess_utils.resolve_shape(logits_weak, 4)[1:3]) valid_mask_pad = tf.cast(valid_mask_pad, tf.float32) if FLAGS.use_attention: # Using inference mode to generate Grad-CAM with tf.name_scope('cls_data_cls_inference'): _, end_points_cls = feature_extractor.extract_features( samples_cls[common.IMAGE], output_stride=model_options.output_stride, multi_grid=model_options.multi_grid, model_variant=model_options.model_variant, depth_multiplier=model_options.depth_multiplier, divisible_by=model_options.divisible_by, weight_decay=FLAGS.weight_decay, reuse=tf.AUTO_REUSE, is_training=False, preprocessed_images_dtype=model_options. preprocessed_images_dtype, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, global_pool=True, num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1) logits_cls = end_points_cls['{}/logits'.format( FLAGS.model_variant).replace('_beta', '')] logits_cls = tf.reshape( logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1]) # We can only get ground truth image-level label in weakly+semi setting if FLAGS.weakly: image_level_label = samples_cls['cls_label'] else: prob_cls = tf.sigmoid(logits_cls) # TODO(ylzou): Might use a variable threshold for different classes pred_cls = tf.greater_equal(prob_cls, 0.5) image_level_label = tf.stop_gradient(tf.cast(pred_cls, tf.float32)) cam_weak, att_cam_weak = cam_func( end_points_cls, logits_cls, image_level_label, num_class=outputs_to_num_classes[common.OUTPUT_TYPE], use_attention=True, attention_dim=FLAGS.attention_dim, strides=[int(st) for st in FLAGS.att_strides], is_training=True, valid_mask=valid_mask_pad, net=FLAGS.model_variant.replace('_beta', '')) att_logits_weak = att_cam_weak # Upsample att-cam att_logits_weak = tf.compat.v1.image.resize_bilinear( att_logits_weak, preprocess_utils.resolve_shape(logits_weak, 4)[1:3], align_corners=True) # Monitor vanilla cam cam_weak = tf.compat.v1.image.resize_bilinear( cam_weak, preprocess_utils.resolve_shape(logits_weak, 4)[1:3], align_corners=True) cam_weak = tf.identity(cam_weak, name='cam_weak') att_prob_weak = tf.nn.softmax(att_logits_weak, axis=-1) att_logits_weak = tf.identity(att_logits_weak, name='att_logits_weak') # Monitor max score max_att_prob_weak = tf.reduce_max(att_prob_weak, axis=-1) max_att_prob_weak = tf.identity(max_att_prob_weak, name='max_att_prob_weak') # Ensemble if FLAGS.pseudo_src == 'att': prob_weak = att_prob_weak else: if FLAGS.logit_norm: v = tf.concat([logits_weak, att_logits_weak], axis=0) all_logits_weak = v * tf.rsqrt( tf.reduce_mean(tf.square(v)) + 1e-8) scaled_logits_weak = all_logits_weak[:batch_size] prob_weak = tf.nn.softmax(scaled_logits_weak, axis=-1) scaled_att_logits_weak = all_logits_weak[batch_size:] att_prob_weak = tf.nn.softmax(scaled_att_logits_weak, axis=-1) prob_weak = (prob_weak + att_prob_weak) / 2. # Monitor max score max_prob_avg = tf.reduce_max(prob_weak, axis=-1) max_prob_avg = tf.identity(max_prob_avg, name='max_prob_avg') # Temperature if FLAGS.soft_pseudo_label and FLAGS.temperature != 1.0: prob_weak = tf.pow(prob_weak, 1. / FLAGS.temperature) prob_weak /= tf.reduce_sum(prob_weak, axis=-1, keepdims=True) # Monitor max score max_prob_avg_t = tf.reduce_max(prob_weak, axis=-1) max_prob_avg_t = tf.identity(max_prob_avg_t, name='max_prob_avg_t') # Monitor merged logits prob_weak = tf.identity(prob_weak, name='merged_logits') with tf.name_scope('cls_strong'): outputs_to_scales_to_logits, _ = model.multi_scale_logits( samples_cls['strong'], model_options=model_options, image_pyramid=FLAGS.image_pyramid, weight_decay=FLAGS.weight_decay, is_training=True, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, nas_training_hyper_parameters={ 'drop_path_keep_prob': FLAGS.drop_path_keep_prob, 'total_training_steps': FLAGS.training_number_of_steps, }, output_end_points=True) logits_strong = outputs_to_scales_to_logits[common.OUTPUT_TYPE][ model.MERGED_LOGITS_SCOPE] logits_strong = tf.identity(logits_strong, name='logits_strong') if FLAGS.pseudo_label_threshold > 0: confidence_weak = tf.expand_dims(tf.reduce_max(prob_weak, axis=-1), axis=-1) valid_mask_score = tf.greater_equal(confidence_weak, FLAGS.pseudo_label_threshold) valid_mask_score = tf.cast(valid_mask_score, tf.float32) valid_mask = valid_mask_score * valid_mask_pad else: valid_mask_score = None valid_mask = valid_mask_pad # Save for visualization valid_mask = tf.identity(valid_mask, name='valid_mask') logits_strong = tf.reshape( logits_strong, [-1, outputs_to_num_classes[common.OUTPUT_TYPE]]) if not FLAGS.soft_pseudo_label: pseudo_label = tf.argmax(prob_weak, axis=-1) pseudo_label = tf.reshape(pseudo_label, [-1]) pseudo_label = tf.stop_gradient(pseudo_label) loss_consistency = tf.compat.v1.nn.sparse_softmax_cross_entropy_with_logits( labels=pseudo_label, logits=logits_strong, name='consistency_losses') loss_consistency = loss_consistency * tf.reshape(valid_mask, [-1]) pred_pseudo = pseudo_label else: pseudo_label = prob_weak pseudo_label = tf.reshape( pseudo_label, [-1, outputs_to_num_classes[common.OUTPUT_TYPE]]) pseudo_label = tf.stop_gradient(pseudo_label) loss_consistency = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( labels=pseudo_label, logits=logits_strong, name='consistency_losses') loss_consistency = loss_consistency * tf.reshape(valid_mask, [-1]) pred_pseudo = tf.argmax(pseudo_label, axis=-1) # NOTE: When average, we divide by the number of pixels excluding padding loss_consistency = tf.reduce_sum(loss_consistency) loss_consistency = train_utils._div_maybe_zero( loss_consistency, tf.reduce_sum(valid_mask_pad)) loss_consistency *= FLAGS.unlabeled_weight loss_consistency = tf.identity(loss_consistency, 'loss_consistency') tf.compat.v1.losses.add_loss(loss_consistency) ## 3) Monitor prediction quality temp_label = tf.compat.v1.image.resize_nearest_neighbor( samples_cls[common.LABEL], preprocess_utils.resolve_shape(logits_weak, 4)[1:3]) temp_label = tf.reshape(temp_label, [-1]) # Get #pixel of each class, so that we can re-weight them for pixel acc. dump = tf.concat( [tf.range(outputs_to_num_classes[common.OUTPUT_TYPE]), temp_label], axis=-1) _, _, count = tf.unique_with_counts(dump) num_pixel_list = count - 1 # Exclude the ignore region num_pixel_list = num_pixel_list[:outputs_to_num_classes[common. OUTPUT_TYPE]] num_pixel_list = tf.cast(num_pixel_list, tf.float32) inverse_ratio = train_utils._div_maybe_zero(1, num_pixel_list) inverse_ratio = inverse_ratio / tf.reduce_sum(inverse_ratio) # Since tf.metrics.mean_per_class_accuracy does not support weighted average # for each class directly, we here convert it to pixel-wise weighted mask to # compute weighted average pixel accuracy. weight_mask = tf.einsum( '...y,y->...', tf.one_hot(temp_label, outputs_to_num_classes[common.OUTPUT_TYPE], dtype=tf.float32), inverse_ratio) temp_valid = tf.not_equal(temp_label, ignore_label) if valid_mask_score is not None: temp_valid_confident = tf.cast(temp_valid, tf.float32) * tf.reshape( valid_mask_score, [-1]) temp_valid_confident = tf.cast(temp_valid_confident, tf.bool) else: temp_valid_confident = temp_valid temp_label_confident = tf.boolean_mask(temp_label, temp_valid_confident) temp_label_valid = tf.boolean_mask(temp_label, temp_valid) weight_mask_confident = tf.boolean_mask(weight_mask, temp_valid_confident) weight_mask_valid = tf.boolean_mask(weight_mask, temp_valid) if FLAGS.pseudo_label_threshold > 0: acc_pseudo, acc_pseudo_op = tf.metrics.mean_per_class_accuracy( temp_label_confident, tf.boolean_mask(pred_pseudo, temp_valid_confident), outputs_to_num_classes[common.OUTPUT_TYPE], weights=weight_mask_confident) with tf.control_dependencies([acc_pseudo_op]): acc_pseudo = tf.identity(acc_pseudo, name='acc_pseudo') pred_weak = tf.cast(tf.argmax(prob_weak, axis=-1), tf.int32) pred_weak = tf.reshape(pred_weak, [-1]) acc_weak, acc_weak_op = tf.metrics.mean_per_class_accuracy( temp_label_valid, tf.boolean_mask(pred_weak, temp_valid), outputs_to_num_classes[common.OUTPUT_TYPE], weights=weight_mask_valid) with tf.control_dependencies([acc_weak_op]): acc_weak = tf.identity(acc_weak, name='acc_weak') pred_strong = tf.cast(tf.argmax(logits_strong, axis=-1), tf.int32) pred_strong = tf.reshape(pred_strong, [-1]) # For all pixels acc_strong, acc_strong_op = tf.metrics.mean_per_class_accuracy( temp_label_valid, tf.boolean_mask(pred_strong, temp_valid), outputs_to_num_classes[common.OUTPUT_TYPE], weights=weight_mask_valid) with tf.control_dependencies([acc_strong_op]): acc_strong = tf.identity(acc_strong, name='acc_strong') # For confident pixels if FLAGS.pseudo_label_threshold > 0: acc_strong_confident, acc_strong_confident_op = tf.metrics.mean_per_class_accuracy( temp_label_confident, tf.boolean_mask(pred_strong, temp_valid_confident), outputs_to_num_classes[common.OUTPUT_TYPE], weights=weight_mask_confident) with tf.control_dependencies([acc_strong_confident_op]): acc_strong_confident = tf.identity(acc_strong_confident, name='acc_strong_confident') valid_ratio = tf.reduce_sum(valid_mask) / tf.reduce_sum(valid_mask_pad) valid_ratio = tf.identity(valid_ratio, name='valid_ratio') ### Pixel-level data ## 1) Segmentation outputs_to_scales_to_logits = model.multi_scale_logits( samples_seg[common.IMAGE], model_options=model_options, image_pyramid=FLAGS.image_pyramid, weight_decay=FLAGS.weight_decay, is_training=True, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, nas_training_hyper_parameters={ 'drop_path_keep_prob': FLAGS.drop_path_keep_prob, 'total_training_steps': FLAGS.training_number_of_steps, }) # Add name to graph node so we can add to summary. output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE] output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity( output_type_dict[model.MERGED_LOGITS_SCOPE], name=common.OUTPUT_TYPE + '_seg') for output, num_classes in six.iteritems(outputs_to_num_classes): train_utils.add_softmax_cross_entropy_loss_for_each_scale( outputs_to_scales_to_logits[output], samples_seg[common.LABEL], num_classes, ignore_label, loss_weight=model_options.label_weights, upsample_logits=FLAGS.upsample_logits, hard_example_mining_step=FLAGS.hard_example_mining_step, top_k_percent_pixels=FLAGS.top_k_percent_pixels, scope=output) ## 2) Train self-attention module if FLAGS.use_attention: valid_mask_pad = samples_seg['valid'] valid_mask_pad = tf.compat.v1.image.resize_nearest_neighbor( valid_mask_pad, preprocess_utils.resolve_shape(logits_weak, 4)[1:3]) valid_mask_pad = tf.cast(valid_mask_pad, tf.float32) with tf.name_scope('seg_data_cls'): _, end_points_cls = feature_extractor.extract_features( samples_seg[common.IMAGE], output_stride=model_options.output_stride, multi_grid=model_options.multi_grid, model_variant=model_options.model_variant, depth_multiplier=model_options.depth_multiplier, divisible_by=model_options.divisible_by, weight_decay=FLAGS.weight_decay, reuse=tf.AUTO_REUSE, is_training=True, preprocessed_images_dtype=model_options. preprocessed_images_dtype, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, global_pool=True, num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1) logits_cls = end_points_cls['{}/logits'.format( FLAGS.model_variant).replace('_beta', '')] logits_cls = tf.reshape( logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1]) _, att_cam_labeled = cam_func( end_points_cls, logits_cls, samples_seg['cls_label'], num_class=outputs_to_num_classes[common.OUTPUT_TYPE], use_attention=True, attention_dim=FLAGS.attention_dim, strides=[int(st) for st in FLAGS.att_strides], is_training=True, valid_mask=valid_mask_pad, net=FLAGS.model_variant.replace('_beta', '')) att_logits_labeled = att_cam_labeled # Loss train_utils.add_softmax_cross_entropy_loss_for_each_scale( {'self-attention_logits': att_logits_labeled}, samples_seg[common.LABEL], outputs_to_num_classes[common.OUTPUT_TYPE], ignore_label, loss_weight=model_options.label_weights, upsample_logits=FLAGS.upsample_logits, hard_example_mining_step=FLAGS.hard_example_mining_step, top_k_percent_pixels=FLAGS.top_k_percent_pixels, scope='self-attention') att_logits_labeled = tf.identity(att_logits_labeled, name='att_logits_labeled') ## 3) If no image-level label, convert pixel-level label to train classifier if not FLAGS.weakly: # Seems that people usually use multi-label soft margin loss in PyTorch loss_cls = tf.nn.sigmoid_cross_entropy_with_logits( labels=samples_seg['cls_label'], logits=logits_cls) loss_cls = tf.reduce_mean(loss_cls) loss_cls = tf.identity(loss_cls, name='loss_cls') tf.compat.v1.losses.add_loss(loss_cls) ## 4) Sanity check. Monitor pixel accuracy logits_seg = outputs_to_scales_to_logits[common.OUTPUT_TYPE][ model.MERGED_LOGITS_SCOPE] temp_label = tf.compat.v1.image.resize_nearest_neighbor( samples_seg[common.LABEL], preprocess_utils.resolve_shape(logits_seg, 4)[1:3]) temp_label = tf.reshape(temp_label, [-1]) dump = tf.concat( [tf.range(outputs_to_num_classes[common.OUTPUT_TYPE]), temp_label], axis=-1) _, _, count = tf.unique_with_counts(dump) num_pixel_list = count - 1 # Exclude the ignore region num_pixel_list = num_pixel_list[:outputs_to_num_classes[common. OUTPUT_TYPE]] num_pixel_list = tf.cast(num_pixel_list, tf.float32) inverse_ratio = train_utils._div_maybe_zero(1, num_pixel_list) inverse_ratio = inverse_ratio / tf.reduce_sum(inverse_ratio) # Create weight mask to balance each class weight_mask = tf.einsum( '...y,y->...', tf.one_hot(temp_label, outputs_to_num_classes[common.OUTPUT_TYPE], dtype=tf.float32), inverse_ratio) temp_valid = tf.not_equal(temp_label, ignore_label) temp_label_valid = tf.boolean_mask(temp_label, temp_valid) weight_mask_valid = tf.boolean_mask(weight_mask, temp_valid) pred_seg = tf.argmax(logits_seg, axis=-1) pred_seg = tf.reshape(pred_seg, [-1]) acc_seg, acc_seg_op = tf.metrics.mean_per_class_accuracy( temp_label_valid, tf.boolean_mask(pred_seg, temp_valid), outputs_to_num_classes[common.OUTPUT_TYPE], weights=weight_mask_valid) with tf.control_dependencies([acc_seg_op]): acc_seg = tf.identity(acc_seg, name='acc_seg')
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) dataset = data_generator.Dataset( dataset_name=FLAGS.dataset, split_name=FLAGS.eval_split, dataset_dir=FLAGS.dataset_dir, batch_size=FLAGS.eval_batch_size, crop_size=[int(sz) for sz in FLAGS.eval_crop_size], min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, model_variant=FLAGS.model_variant, num_readers=2, is_training=False, should_shuffle=False, should_repeat=False, with_cls=True, cls_only=False, output_valid=True) tf.gfile.MakeDirs(FLAGS.eval_logdir) tf.logging.info('Evaluating on %s set', FLAGS.eval_split) with tf.Graph().as_default(): samples = dataset.get_one_shot_iterator().get_next() model_options = common.ModelOptions( outputs_to_num_classes={ common.OUTPUT_TYPE: dataset.num_of_classes }, crop_size=[int(sz) for sz in FLAGS.eval_crop_size], atrous_rates=FLAGS.atrous_rates, output_stride=FLAGS.output_stride) # Set shape in order for tf.contrib.tfprof.model_analyzer to work properly. samples[common.IMAGE].set_shape([ FLAGS.eval_batch_size, int(FLAGS.eval_crop_size[0]), int(FLAGS.eval_crop_size[1]), 3 ]) if tuple(FLAGS.eval_scales) == (1.0, ): tf.logging.info('Performing single-scale test.') predictions = model.predict_labels( samples[common.IMAGE], model_options, image_pyramid=FLAGS.image_pyramid) else: tf.logging.info('Performing multi-scale test.') raise NotImplementedError('Multi-scale is not supported yet!') metric_map = {} ## Extract cls logits if FLAGS.weakly: _, end_points = feature_extractor.extract_features( samples[common.IMAGE], output_stride=model_options.output_stride, multi_grid=model_options.multi_grid, model_variant=model_options.model_variant, depth_multiplier=model_options.depth_multiplier, divisible_by=model_options.divisible_by, reuse=tf.AUTO_REUSE, is_training=False, preprocessed_images_dtype=model_options. preprocessed_images_dtype, global_pool=True, num_classes=dataset.num_of_classes - 1) # ResNet beta version has an additional suffix in FLAGS.model_variant, but # it shares the same variable names with original version. Add a special # handling here for beta version ResNet. logits = end_points['{}/logits'.format( FLAGS.model_variant).replace('_beta', '')] logits = tf.reshape(logits, [-1, dataset.num_of_classes - 1]) cls_pred = tf.sigmoid(logits) # Multi-label classification evaluation cls_label = samples['cls_label'] cls_pred = tf.cast(tf.greater_equal(cls_pred, 0.5), tf.int32) ## For classification metric_map['eval/cls_overall'] = tf.metrics.accuracy( labels=cls_label, predictions=cls_pred) metric_map['eval/cls_precision'] = tf.metrics.precision( labels=cls_label, predictions=cls_pred) metric_map['eval/cls_recall'] = tf.metrics.recall( labels=cls_label, predictions=cls_pred) ## For segmentation branch eval predictions = predictions[common.OUTPUT_TYPE] predictions = tf.reshape(predictions, shape=[-1]) labels = tf.reshape(samples[common.LABEL], shape=[-1]) weights = tf.to_float(tf.not_equal(labels, dataset.ignore_label)) # Set ignore_label regions to label 0, because metrics.mean_iou requires # range of labels = [0, dataset.num_classes). Note the ignore_label regions # are not evaluated since the corresponding regions contain weights = 0. labels = tf.where(tf.equal(labels, dataset.ignore_label), tf.zeros_like(labels), labels) predictions_tag = 'miou' # Define the evaluation metric. num_classes = dataset.num_of_classes ## For segmentation metric_map['eval/%s_overall' % predictions_tag] = tf.metrics.mean_iou( labels=labels, predictions=predictions, num_classes=num_classes, weights=weights) # IoU for each class. one_hot_predictions = tf.one_hot(predictions, num_classes) one_hot_predictions = tf.reshape(one_hot_predictions, [-1, num_classes]) one_hot_labels = tf.one_hot(labels, num_classes) one_hot_labels = tf.reshape(one_hot_labels, [-1, num_classes]) for c in range(num_classes): predictions_tag_c = '%s_class_%d' % (predictions_tag, c) tp, tp_op = tf.metrics.true_positives( labels=one_hot_labels[:, c], predictions=one_hot_predictions[:, c], weights=weights) fp, fp_op = tf.metrics.false_positives( labels=one_hot_labels[:, c], predictions=one_hot_predictions[:, c], weights=weights) fn, fn_op = tf.metrics.false_negatives( labels=one_hot_labels[:, c], predictions=one_hot_predictions[:, c], weights=weights) tp_fp_fn_op = tf.group(tp_op, fp_op, fn_op) iou = tf.where(tf.greater(tp + fn, 0.0), tp / (tp + fn + fp), tf.constant(np.NaN)) metric_map['eval/%s' % predictions_tag_c] = (iou, tp_fp_fn_op) (metrics_to_values, metrics_to_updates) = contrib_metrics.aggregate_metric_map(metric_map) summary_ops = [] for metric_name, metric_value in six.iteritems(metrics_to_values): op = tf.summary.scalar(metric_name, metric_value) op = tf.Print(op, [metric_value], metric_name) summary_ops.append(op) summary_op = tf.summary.merge(summary_ops) summary_hook = contrib_training.SummaryAtEndHook( log_dir=FLAGS.eval_logdir, summary_op=summary_op) hooks = [summary_hook] num_eval_iters = None if FLAGS.max_number_of_evaluations > 0: num_eval_iters = FLAGS.max_number_of_evaluations if FLAGS.quantize_delay_step >= 0: contrib_quantize.create_eval_graph() contrib_tfprof.model_analyzer.print_model_analysis( tf.get_default_graph(), tfprof_options=contrib_tfprof.model_analyzer. TRAINABLE_VARS_PARAMS_STAT_OPTIONS) contrib_tfprof.model_analyzer.print_model_analysis( tf.get_default_graph(), tfprof_options=contrib_tfprof.model_analyzer.FLOAT_OPS_OPTIONS) contrib_training.evaluate_repeatedly( checkpoint_dir=FLAGS.checkpoint_dir, master=FLAGS.master, eval_ops=list(metrics_to_updates.values()), max_number_of_evaluations=num_eval_iters, hooks=hooks, eval_interval_secs=FLAGS.eval_interval_secs)
def _build_deeplab(iterator_seg, iterator, outputs_to_num_classes, ignore_label): """Builds a clone of Supervised DeepLab. Args: iterator_seg: An iterator of type tf.data.Iterator for images and labels. (seg) iterator: An iterator of type tf.data. Iterator for images and labels. outputs_to_num_classes: A map from output type to the number of classes. For example, for the task of semantic segmentation with 21 semantic classes, we would have outputs_to_num_classes['semantic'] = 21. ignore_label: Ignore label. """ if FLAGS.weakly: samples = iterator.get_next() samples[common.IMAGE] = tf.identity(samples[common.IMAGE], name=common.IMAGE) samples[common.LABEL] = tf.identity(samples[common.LABEL], name=common.LABEL) samples_seg = iterator_seg.get_next() samples_seg[common.IMAGE] = tf.identity(samples_seg[common.IMAGE], name=common.IMAGE + '_seg') samples_seg[common.LABEL] = tf.identity(samples_seg[common.LABEL], name=common.LABEL + '_seg') model_options = common.ModelOptions( outputs_to_num_classes=outputs_to_num_classes, crop_size=[int(sz) for sz in FLAGS.train_crop_size], atrous_rates=FLAGS.atrous_rates, output_stride=FLAGS.output_stride) ### Cls data if FLAGS.weakly: _, end_points_cls = feature_extractor.extract_features( samples[common.IMAGE], output_stride=model_options.output_stride, multi_grid=model_options.multi_grid, model_variant=model_options.model_variant, depth_multiplier=model_options.depth_multiplier, divisible_by=model_options.divisible_by, weight_decay=FLAGS.weight_decay, reuse=tf.AUTO_REUSE, is_training=True, preprocessed_images_dtype=model_options.preprocessed_images_dtype, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, global_pool=True, num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1) # ResNet beta version has an additional suffix in FLAGS.model_variant, but # it shares the same variable names with original version. Add a special # handling here for beta version ResNet. logits_cls = end_points_cls['{}/logits'.format( FLAGS.model_variant).replace('_beta', '')] logits_cls = tf.reshape( logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1]) # Seems that people usually use multi-label soft margin loss loss_cls = tf.nn.sigmoid_cross_entropy_with_logits( labels=samples['cls_label'], logits=logits_cls) loss_cls = tf.reduce_mean(loss_cls) loss_cls = tf.identity(loss_cls, name='loss_cls') tf.compat.v1.losses.add_loss(loss_cls) ### Seg data outputs_to_scales_to_logits = model.multi_scale_logits( samples_seg[common.IMAGE], model_options=model_options, image_pyramid=FLAGS.image_pyramid, weight_decay=FLAGS.weight_decay, is_training=True, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, nas_training_hyper_parameters={ 'drop_path_keep_prob': FLAGS.drop_path_keep_prob, 'total_training_steps': FLAGS.training_number_of_steps, }) # Add name to graph node so we can add to summary. output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE] output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity( output_type_dict[model.MERGED_LOGITS_SCOPE], name=common.OUTPUT_TYPE + '_seg') for output, num_classes in six.iteritems(outputs_to_num_classes): train_utils.add_softmax_cross_entropy_loss_for_each_scale( outputs_to_scales_to_logits[output], samples_seg[common.LABEL], num_classes, ignore_label, loss_weight=model_options.label_weights, upsample_logits=FLAGS.upsample_logits, hard_example_mining_step=FLAGS.hard_example_mining_step, top_k_percent_pixels=FLAGS.top_k_percent_pixels, scope=output) ## Sanity check. Monitor pixel accuracy logits_seg = outputs_to_scales_to_logits[common.OUTPUT_TYPE][ model.MERGED_LOGITS_SCOPE] temp_label = tf.compat.v1.image.resize_nearest_neighbor( samples_seg[common.LABEL], preprocess_utils.resolve_shape(logits_seg, 4)[1:3]) temp_label = tf.reshape(temp_label, [-1]) dump = tf.concat( [tf.range(outputs_to_num_classes[common.OUTPUT_TYPE]), temp_label], axis=-1) _, _, count = tf.unique_with_counts(dump) num_pixel_list = count - 1 # Exclude the ignore region num_pixel_list = num_pixel_list[:outputs_to_num_classes[common. OUTPUT_TYPE]] num_pixel_list = tf.cast(num_pixel_list, tf.float32) inverse_ratio = train_utils._div_maybe_zero(1, num_pixel_list) inverse_ratio = inverse_ratio / tf.reduce_sum(inverse_ratio) # Create weight mask to balance each class weight_mask = tf.einsum( '...y,y->...', tf.one_hot(temp_label, outputs_to_num_classes[common.OUTPUT_TYPE], dtype=tf.float32), inverse_ratio) temp_valid = tf.not_equal(temp_label, ignore_label) temp_label_valid = tf.boolean_mask(temp_label, temp_valid) weight_mask_valid = tf.boolean_mask(weight_mask, temp_valid) pred_seg = tf.argmax(logits_seg, axis=-1) pred_seg = tf.reshape(pred_seg, [-1]) acc_seg, acc_seg_op = tf.metrics.mean_per_class_accuracy( temp_label_valid, tf.boolean_mask(pred_seg, temp_valid), outputs_to_num_classes[common.OUTPUT_TYPE], weights=weight_mask_valid) with tf.control_dependencies([acc_seg_op]): acc_seg = tf.identity(acc_seg, name='acc_seg')
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Get dataset-dependent information. dataset = data_generator.Dataset( dataset_name=FLAGS.dataset, split_name=FLAGS.vis_split, dataset_dir=FLAGS.dataset_dir, batch_size=FLAGS.vis_batch_size, crop_size=[int(sz) for sz in FLAGS.vis_crop_size], min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, model_variant=FLAGS.model_variant, is_training=False, should_shuffle=False, should_repeat=False) train_id_to_eval_id = None if dataset.dataset_name == data_generator.get_cityscapes_dataset_name(): tf.logging.info('Cityscapes requires converting train_id to eval_id.') train_id_to_eval_id = _CITYSCAPES_TRAIN_ID_TO_EVAL_ID # Prepare for visualization. tf.gfile.MakeDirs(FLAGS.vis_logdir) save_dir = os.path.join(FLAGS.vis_logdir, _SEMANTIC_PREDICTION_SAVE_FOLDER) tf.gfile.MakeDirs(save_dir) raw_save_dir = os.path.join(FLAGS.vis_logdir, _RAW_SEMANTIC_PREDICTION_SAVE_FOLDER) tf.gfile.MakeDirs(raw_save_dir) tf.logging.info('Visualizing on %s set', FLAGS.vis_split) with tf.Graph().as_default(): samples = dataset.get_one_shot_iterator().get_next() model_options = common.ModelOptions( outputs_to_num_classes={ common.OUTPUT_TYPE: dataset.num_of_classes }, crop_size=[int(sz) for sz in FLAGS.vis_crop_size], atrous_rates=FLAGS.atrous_rates, output_stride=FLAGS.output_stride) if tuple(FLAGS.eval_scales) == (1.0, ): tf.logging.info('Performing single-scale test.') predictions = model.predict_labels( samples[common.IMAGE], model_options=model_options, image_pyramid=FLAGS.image_pyramid) else: tf.logging.info('Performing multi-scale test.') if FLAGS.quantize_delay_step >= 0: raise ValueError( 'Quantize mode is not supported with multi-scale test.') predictions = model.predict_labels_multi_scale( samples[common.IMAGE], model_options=model_options, eval_scales=FLAGS.eval_scales, add_flipped_images=FLAGS.add_flipped_images) predictions = predictions[common.OUTPUT_TYPE] if FLAGS.min_resize_value and FLAGS.max_resize_value: # Only support batch_size = 1, since we assume the dimensions of original # image after tf.squeeze is [height, width, 3]. assert FLAGS.vis_batch_size == 1 # Reverse the resizing and padding operations performed in preprocessing. # First, we slice the valid regions (i.e., remove padded region) and then # we resize the predictions back. original_image = tf.squeeze(samples[common.ORIGINAL_IMAGE]) original_image_shape = tf.shape(original_image) predictions = tf.slice( predictions, [0, 0, 0], [1, original_image_shape[0], original_image_shape[1]]) resized_shape = tf.to_int32([ tf.squeeze(samples[common.HEIGHT]), tf.squeeze(samples[common.WIDTH]) ]) predictions = tf.squeeze( tf.image.resize_images( tf.expand_dims(predictions, 3), resized_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=True), 3) tf.train.get_or_create_global_step() if FLAGS.quantize_delay_step >= 0: contrib_quantize.create_eval_graph() num_iteration = 0 max_num_iteration = FLAGS.max_number_of_iterations if True: checkpoint_path = FLAGS.checkpoint_dir num_iteration += 1 tf.logging.info('Starting visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime())) tf.logging.info('Visualizing with model %s', checkpoint_path) scaffold = tf.train.Scaffold( init_op=tf.global_variables_initializer()) session_creator = tf.train.ChiefSessionCreator( scaffold=scaffold, master=FLAGS.master, checkpoint_filename_with_path=checkpoint_path) with tf.train.MonitoredSession(session_creator=session_creator, hooks=None) as sess: batch = 0 image_id_offset = 0 while not sess.should_stop(): tf.logging.info('Visualizing batch %d', batch + 1) _process_batch( sess=sess, original_images=samples[common.ORIGINAL_IMAGE], semantic_predictions=predictions, image_names=samples[common.IMAGE_NAME], image_heights=samples[common.HEIGHT], image_widths=samples[common.WIDTH], image_id_offset=image_id_offset, save_dir=save_dir, raw_save_dir=raw_save_dir, train_id_to_eval_id=train_id_to_eval_id) image_id_offset += FLAGS.vis_batch_size batch += 1 tf.logging.info('Finished visualization at ' + time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))