def _build_pseudo_seg(iterator_seg, iterator, outputs_to_num_classes, ignore_label, batch_size=8): """Builds a clone of PseudoSeg. Args: iterator_seg: An iterator of type tf.data.Iterator for images and labels. (seg) iterator: An iterator of type tf.data. Iterator for images and labels. outputs_to_num_classes: A map from output type to the number of classes. For example, for the task of semantic segmentation with 21 semantic classes, we would have outputs_to_num_classes['semantic'] = 21. ignore_label: Ignore label. batch_size: Training batch size for each clone. """ samples_cls = iterator.get_next() samples_cls[common.IMAGE] = tf.identity(samples_cls[common.IMAGE], name='weak') samples_cls['strong'] = tf.identity(samples_cls['strong'], name='strong') samples_cls[common.LABEL] = tf.identity(samples_cls[common.LABEL], name='unlabeled') samples_seg = iterator_seg.get_next() samples_seg[common.IMAGE] = tf.identity(samples_seg[common.IMAGE], name=common.IMAGE + '_seg') samples_seg[common.LABEL] = tf.identity(samples_seg[common.LABEL], name=common.LABEL + '_seg') model_options = common.ModelOptions( outputs_to_num_classes=outputs_to_num_classes, crop_size=[int(sz) for sz in FLAGS.train_crop_size], atrous_rates=FLAGS.atrous_rates, output_stride=FLAGS.output_stride) if FLAGS.att_v2: cam_func = train_utils_core.compute_cam_v2 else: cam_func = train_utils_core.compute_cam ### Cls/unlabeled data ## 1) If we have image-level label, we train the classifier here if FLAGS.weakly: if FLAGS.cls_with_cls: curr_samples = samples_cls else: curr_samples = samples_seg _, end_points_cls = feature_extractor.extract_features( curr_samples[common.IMAGE], output_stride=model_options.output_stride, multi_grid=model_options.multi_grid, model_variant=model_options.model_variant, depth_multiplier=model_options.depth_multiplier, divisible_by=model_options.divisible_by, weight_decay=FLAGS.weight_decay, reuse=tf.AUTO_REUSE, is_training=True, preprocessed_images_dtype=model_options.preprocessed_images_dtype, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, global_pool=True, num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1) # ResNet beta version has an additional suffix in FLAGS.model_variant, but # it shares the same variable names with original version. Add a special # handling here for beta version ResNet. logits_cls = end_points_cls['{}/logits'.format( FLAGS.model_variant).replace('_beta', '')] logits_cls = tf.reshape( logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1]) # Seems that people usually use multi-label soft margin loss in PyTorch loss_cls = tf.nn.sigmoid_cross_entropy_with_logits( labels=curr_samples['cls_label'], logits=logits_cls) loss_cls = tf.reduce_mean(loss_cls) loss_cls = tf.identity(loss_cls, name='loss_cls') tf.compat.v1.losses.add_loss(loss_cls) ## 2) Consistency with tf.name_scope('cls_weak'): outputs_to_scales_to_logits, _ = model.multi_scale_logits( samples_cls[common.IMAGE], model_options=model_options, image_pyramid=FLAGS.image_pyramid, weight_decay=FLAGS.weight_decay, is_training=True, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, nas_training_hyper_parameters={ 'drop_path_keep_prob': FLAGS.drop_path_keep_prob, 'total_training_steps': FLAGS.training_number_of_steps, }, output_end_points=True) logits_weak = outputs_to_scales_to_logits[common.OUTPUT_TYPE][ model.MERGED_LOGITS_SCOPE] prob_weak = tf.nn.softmax(logits_weak, axis=-1) logits_weak = tf.identity(logits_weak, name='logits_weak') # Monitor max score max_prob_weak = tf.reduce_max(prob_weak, axis=-1) max_prob_weak = tf.identity(max_prob_weak, name='max_prob_weak') valid_mask_pad = samples_cls['valid'] valid_mask_pad = tf.compat.v1.image.resize_nearest_neighbor( valid_mask_pad, preprocess_utils.resolve_shape(logits_weak, 4)[1:3]) valid_mask_pad = tf.cast(valid_mask_pad, tf.float32) if FLAGS.use_attention: # Using inference mode to generate Grad-CAM with tf.name_scope('cls_data_cls_inference'): _, end_points_cls = feature_extractor.extract_features( samples_cls[common.IMAGE], output_stride=model_options.output_stride, multi_grid=model_options.multi_grid, model_variant=model_options.model_variant, depth_multiplier=model_options.depth_multiplier, divisible_by=model_options.divisible_by, weight_decay=FLAGS.weight_decay, reuse=tf.AUTO_REUSE, is_training=False, preprocessed_images_dtype=model_options. preprocessed_images_dtype, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, global_pool=True, num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1) logits_cls = end_points_cls['{}/logits'.format( FLAGS.model_variant).replace('_beta', '')] logits_cls = tf.reshape( logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1]) # We can only get ground truth image-level label in weakly+semi setting if FLAGS.weakly: image_level_label = samples_cls['cls_label'] else: prob_cls = tf.sigmoid(logits_cls) # TODO(ylzou): Might use a variable threshold for different classes pred_cls = tf.greater_equal(prob_cls, 0.5) image_level_label = tf.stop_gradient(tf.cast(pred_cls, tf.float32)) cam_weak, att_cam_weak = cam_func( end_points_cls, logits_cls, image_level_label, num_class=outputs_to_num_classes[common.OUTPUT_TYPE], use_attention=True, attention_dim=FLAGS.attention_dim, strides=[int(st) for st in FLAGS.att_strides], is_training=True, valid_mask=valid_mask_pad, net=FLAGS.model_variant.replace('_beta', '')) att_logits_weak = att_cam_weak # Upsample att-cam att_logits_weak = tf.compat.v1.image.resize_bilinear( att_logits_weak, preprocess_utils.resolve_shape(logits_weak, 4)[1:3], align_corners=True) # Monitor vanilla cam cam_weak = tf.compat.v1.image.resize_bilinear( cam_weak, preprocess_utils.resolve_shape(logits_weak, 4)[1:3], align_corners=True) cam_weak = tf.identity(cam_weak, name='cam_weak') att_prob_weak = tf.nn.softmax(att_logits_weak, axis=-1) att_logits_weak = tf.identity(att_logits_weak, name='att_logits_weak') # Monitor max score max_att_prob_weak = tf.reduce_max(att_prob_weak, axis=-1) max_att_prob_weak = tf.identity(max_att_prob_weak, name='max_att_prob_weak') # Ensemble if FLAGS.pseudo_src == 'att': prob_weak = att_prob_weak else: if FLAGS.logit_norm: v = tf.concat([logits_weak, att_logits_weak], axis=0) all_logits_weak = v * tf.rsqrt( tf.reduce_mean(tf.square(v)) + 1e-8) scaled_logits_weak = all_logits_weak[:batch_size] prob_weak = tf.nn.softmax(scaled_logits_weak, axis=-1) scaled_att_logits_weak = all_logits_weak[batch_size:] att_prob_weak = tf.nn.softmax(scaled_att_logits_weak, axis=-1) prob_weak = (prob_weak + att_prob_weak) / 2. # Monitor max score max_prob_avg = tf.reduce_max(prob_weak, axis=-1) max_prob_avg = tf.identity(max_prob_avg, name='max_prob_avg') # Temperature if FLAGS.soft_pseudo_label and FLAGS.temperature != 1.0: prob_weak = tf.pow(prob_weak, 1. / FLAGS.temperature) prob_weak /= tf.reduce_sum(prob_weak, axis=-1, keepdims=True) # Monitor max score max_prob_avg_t = tf.reduce_max(prob_weak, axis=-1) max_prob_avg_t = tf.identity(max_prob_avg_t, name='max_prob_avg_t') # Monitor merged logits prob_weak = tf.identity(prob_weak, name='merged_logits') with tf.name_scope('cls_strong'): outputs_to_scales_to_logits, _ = model.multi_scale_logits( samples_cls['strong'], model_options=model_options, image_pyramid=FLAGS.image_pyramid, weight_decay=FLAGS.weight_decay, is_training=True, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, nas_training_hyper_parameters={ 'drop_path_keep_prob': FLAGS.drop_path_keep_prob, 'total_training_steps': FLAGS.training_number_of_steps, }, output_end_points=True) logits_strong = outputs_to_scales_to_logits[common.OUTPUT_TYPE][ model.MERGED_LOGITS_SCOPE] logits_strong = tf.identity(logits_strong, name='logits_strong') if FLAGS.pseudo_label_threshold > 0: confidence_weak = tf.expand_dims(tf.reduce_max(prob_weak, axis=-1), axis=-1) valid_mask_score = tf.greater_equal(confidence_weak, FLAGS.pseudo_label_threshold) valid_mask_score = tf.cast(valid_mask_score, tf.float32) valid_mask = valid_mask_score * valid_mask_pad else: valid_mask_score = None valid_mask = valid_mask_pad # Save for visualization valid_mask = tf.identity(valid_mask, name='valid_mask') logits_strong = tf.reshape( logits_strong, [-1, outputs_to_num_classes[common.OUTPUT_TYPE]]) if not FLAGS.soft_pseudo_label: pseudo_label = tf.argmax(prob_weak, axis=-1) pseudo_label = tf.reshape(pseudo_label, [-1]) pseudo_label = tf.stop_gradient(pseudo_label) loss_consistency = tf.compat.v1.nn.sparse_softmax_cross_entropy_with_logits( labels=pseudo_label, logits=logits_strong, name='consistency_losses') loss_consistency = loss_consistency * tf.reshape(valid_mask, [-1]) pred_pseudo = pseudo_label else: pseudo_label = prob_weak pseudo_label = tf.reshape( pseudo_label, [-1, outputs_to_num_classes[common.OUTPUT_TYPE]]) pseudo_label = tf.stop_gradient(pseudo_label) loss_consistency = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( labels=pseudo_label, logits=logits_strong, name='consistency_losses') loss_consistency = loss_consistency * tf.reshape(valid_mask, [-1]) pred_pseudo = tf.argmax(pseudo_label, axis=-1) # NOTE: When average, we divide by the number of pixels excluding padding loss_consistency = tf.reduce_sum(loss_consistency) loss_consistency = train_utils._div_maybe_zero( loss_consistency, tf.reduce_sum(valid_mask_pad)) loss_consistency *= FLAGS.unlabeled_weight loss_consistency = tf.identity(loss_consistency, 'loss_consistency') tf.compat.v1.losses.add_loss(loss_consistency) ## 3) Monitor prediction quality temp_label = tf.compat.v1.image.resize_nearest_neighbor( samples_cls[common.LABEL], preprocess_utils.resolve_shape(logits_weak, 4)[1:3]) temp_label = tf.reshape(temp_label, [-1]) # Get #pixel of each class, so that we can re-weight them for pixel acc. dump = tf.concat( [tf.range(outputs_to_num_classes[common.OUTPUT_TYPE]), temp_label], axis=-1) _, _, count = tf.unique_with_counts(dump) num_pixel_list = count - 1 # Exclude the ignore region num_pixel_list = num_pixel_list[:outputs_to_num_classes[common. OUTPUT_TYPE]] num_pixel_list = tf.cast(num_pixel_list, tf.float32) inverse_ratio = train_utils._div_maybe_zero(1, num_pixel_list) inverse_ratio = inverse_ratio / tf.reduce_sum(inverse_ratio) # Since tf.metrics.mean_per_class_accuracy does not support weighted average # for each class directly, we here convert it to pixel-wise weighted mask to # compute weighted average pixel accuracy. weight_mask = tf.einsum( '...y,y->...', tf.one_hot(temp_label, outputs_to_num_classes[common.OUTPUT_TYPE], dtype=tf.float32), inverse_ratio) temp_valid = tf.not_equal(temp_label, ignore_label) if valid_mask_score is not None: temp_valid_confident = tf.cast(temp_valid, tf.float32) * tf.reshape( valid_mask_score, [-1]) temp_valid_confident = tf.cast(temp_valid_confident, tf.bool) else: temp_valid_confident = temp_valid temp_label_confident = tf.boolean_mask(temp_label, temp_valid_confident) temp_label_valid = tf.boolean_mask(temp_label, temp_valid) weight_mask_confident = tf.boolean_mask(weight_mask, temp_valid_confident) weight_mask_valid = tf.boolean_mask(weight_mask, temp_valid) if FLAGS.pseudo_label_threshold > 0: acc_pseudo, acc_pseudo_op = tf.metrics.mean_per_class_accuracy( temp_label_confident, tf.boolean_mask(pred_pseudo, temp_valid_confident), outputs_to_num_classes[common.OUTPUT_TYPE], weights=weight_mask_confident) with tf.control_dependencies([acc_pseudo_op]): acc_pseudo = tf.identity(acc_pseudo, name='acc_pseudo') pred_weak = tf.cast(tf.argmax(prob_weak, axis=-1), tf.int32) pred_weak = tf.reshape(pred_weak, [-1]) acc_weak, acc_weak_op = tf.metrics.mean_per_class_accuracy( temp_label_valid, tf.boolean_mask(pred_weak, temp_valid), outputs_to_num_classes[common.OUTPUT_TYPE], weights=weight_mask_valid) with tf.control_dependencies([acc_weak_op]): acc_weak = tf.identity(acc_weak, name='acc_weak') pred_strong = tf.cast(tf.argmax(logits_strong, axis=-1), tf.int32) pred_strong = tf.reshape(pred_strong, [-1]) # For all pixels acc_strong, acc_strong_op = tf.metrics.mean_per_class_accuracy( temp_label_valid, tf.boolean_mask(pred_strong, temp_valid), outputs_to_num_classes[common.OUTPUT_TYPE], weights=weight_mask_valid) with tf.control_dependencies([acc_strong_op]): acc_strong = tf.identity(acc_strong, name='acc_strong') # For confident pixels if FLAGS.pseudo_label_threshold > 0: acc_strong_confident, acc_strong_confident_op = tf.metrics.mean_per_class_accuracy( temp_label_confident, tf.boolean_mask(pred_strong, temp_valid_confident), outputs_to_num_classes[common.OUTPUT_TYPE], weights=weight_mask_confident) with tf.control_dependencies([acc_strong_confident_op]): acc_strong_confident = tf.identity(acc_strong_confident, name='acc_strong_confident') valid_ratio = tf.reduce_sum(valid_mask) / tf.reduce_sum(valid_mask_pad) valid_ratio = tf.identity(valid_ratio, name='valid_ratio') ### Pixel-level data ## 1) Segmentation outputs_to_scales_to_logits = model.multi_scale_logits( samples_seg[common.IMAGE], model_options=model_options, image_pyramid=FLAGS.image_pyramid, weight_decay=FLAGS.weight_decay, is_training=True, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, nas_training_hyper_parameters={ 'drop_path_keep_prob': FLAGS.drop_path_keep_prob, 'total_training_steps': FLAGS.training_number_of_steps, }) # Add name to graph node so we can add to summary. output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE] output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity( output_type_dict[model.MERGED_LOGITS_SCOPE], name=common.OUTPUT_TYPE + '_seg') for output, num_classes in six.iteritems(outputs_to_num_classes): train_utils.add_softmax_cross_entropy_loss_for_each_scale( outputs_to_scales_to_logits[output], samples_seg[common.LABEL], num_classes, ignore_label, loss_weight=model_options.label_weights, upsample_logits=FLAGS.upsample_logits, hard_example_mining_step=FLAGS.hard_example_mining_step, top_k_percent_pixels=FLAGS.top_k_percent_pixels, scope=output) ## 2) Train self-attention module if FLAGS.use_attention: valid_mask_pad = samples_seg['valid'] valid_mask_pad = tf.compat.v1.image.resize_nearest_neighbor( valid_mask_pad, preprocess_utils.resolve_shape(logits_weak, 4)[1:3]) valid_mask_pad = tf.cast(valid_mask_pad, tf.float32) with tf.name_scope('seg_data_cls'): _, end_points_cls = feature_extractor.extract_features( samples_seg[common.IMAGE], output_stride=model_options.output_stride, multi_grid=model_options.multi_grid, model_variant=model_options.model_variant, depth_multiplier=model_options.depth_multiplier, divisible_by=model_options.divisible_by, weight_decay=FLAGS.weight_decay, reuse=tf.AUTO_REUSE, is_training=True, preprocessed_images_dtype=model_options. preprocessed_images_dtype, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, global_pool=True, num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1) logits_cls = end_points_cls['{}/logits'.format( FLAGS.model_variant).replace('_beta', '')] logits_cls = tf.reshape( logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1]) _, att_cam_labeled = cam_func( end_points_cls, logits_cls, samples_seg['cls_label'], num_class=outputs_to_num_classes[common.OUTPUT_TYPE], use_attention=True, attention_dim=FLAGS.attention_dim, strides=[int(st) for st in FLAGS.att_strides], is_training=True, valid_mask=valid_mask_pad, net=FLAGS.model_variant.replace('_beta', '')) att_logits_labeled = att_cam_labeled # Loss train_utils.add_softmax_cross_entropy_loss_for_each_scale( {'self-attention_logits': att_logits_labeled}, samples_seg[common.LABEL], outputs_to_num_classes[common.OUTPUT_TYPE], ignore_label, loss_weight=model_options.label_weights, upsample_logits=FLAGS.upsample_logits, hard_example_mining_step=FLAGS.hard_example_mining_step, top_k_percent_pixels=FLAGS.top_k_percent_pixels, scope='self-attention') att_logits_labeled = tf.identity(att_logits_labeled, name='att_logits_labeled') ## 3) If no image-level label, convert pixel-level label to train classifier if not FLAGS.weakly: # Seems that people usually use multi-label soft margin loss in PyTorch loss_cls = tf.nn.sigmoid_cross_entropy_with_logits( labels=samples_seg['cls_label'], logits=logits_cls) loss_cls = tf.reduce_mean(loss_cls) loss_cls = tf.identity(loss_cls, name='loss_cls') tf.compat.v1.losses.add_loss(loss_cls) ## 4) Sanity check. Monitor pixel accuracy logits_seg = outputs_to_scales_to_logits[common.OUTPUT_TYPE][ model.MERGED_LOGITS_SCOPE] temp_label = tf.compat.v1.image.resize_nearest_neighbor( samples_seg[common.LABEL], preprocess_utils.resolve_shape(logits_seg, 4)[1:3]) temp_label = tf.reshape(temp_label, [-1]) dump = tf.concat( [tf.range(outputs_to_num_classes[common.OUTPUT_TYPE]), temp_label], axis=-1) _, _, count = tf.unique_with_counts(dump) num_pixel_list = count - 1 # Exclude the ignore region num_pixel_list = num_pixel_list[:outputs_to_num_classes[common. OUTPUT_TYPE]] num_pixel_list = tf.cast(num_pixel_list, tf.float32) inverse_ratio = train_utils._div_maybe_zero(1, num_pixel_list) inverse_ratio = inverse_ratio / tf.reduce_sum(inverse_ratio) # Create weight mask to balance each class weight_mask = tf.einsum( '...y,y->...', tf.one_hot(temp_label, outputs_to_num_classes[common.OUTPUT_TYPE], dtype=tf.float32), inverse_ratio) temp_valid = tf.not_equal(temp_label, ignore_label) temp_label_valid = tf.boolean_mask(temp_label, temp_valid) weight_mask_valid = tf.boolean_mask(weight_mask, temp_valid) pred_seg = tf.argmax(logits_seg, axis=-1) pred_seg = tf.reshape(pred_seg, [-1]) acc_seg, acc_seg_op = tf.metrics.mean_per_class_accuracy( temp_label_valid, tf.boolean_mask(pred_seg, temp_valid), outputs_to_num_classes[common.OUTPUT_TYPE], weights=weight_mask_valid) with tf.control_dependencies([acc_seg_op]): acc_seg = tf.identity(acc_seg, name='acc_seg')
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) dataset = data_generator.Dataset( dataset_name=FLAGS.dataset, split_name=FLAGS.eval_split, dataset_dir=FLAGS.dataset_dir, batch_size=FLAGS.eval_batch_size, crop_size=[int(sz) for sz in FLAGS.eval_crop_size], min_resize_value=FLAGS.min_resize_value, max_resize_value=FLAGS.max_resize_value, resize_factor=FLAGS.resize_factor, model_variant=FLAGS.model_variant, num_readers=2, is_training=False, should_shuffle=False, should_repeat=False, with_cls=True, cls_only=False, output_valid=True) tf.gfile.MakeDirs(FLAGS.eval_logdir) tf.logging.info('Evaluating on %s set', FLAGS.eval_split) with tf.Graph().as_default(): samples = dataset.get_one_shot_iterator().get_next() model_options = common.ModelOptions( outputs_to_num_classes={ common.OUTPUT_TYPE: dataset.num_of_classes }, crop_size=[int(sz) for sz in FLAGS.eval_crop_size], atrous_rates=FLAGS.atrous_rates, output_stride=FLAGS.output_stride) # Set shape in order for tf.contrib.tfprof.model_analyzer to work properly. samples[common.IMAGE].set_shape([ FLAGS.eval_batch_size, int(FLAGS.eval_crop_size[0]), int(FLAGS.eval_crop_size[1]), 3 ]) if tuple(FLAGS.eval_scales) == (1.0, ): tf.logging.info('Performing single-scale test.') predictions = model.predict_labels( samples[common.IMAGE], model_options, image_pyramid=FLAGS.image_pyramid) else: tf.logging.info('Performing multi-scale test.') raise NotImplementedError('Multi-scale is not supported yet!') metric_map = {} ## Extract cls logits if FLAGS.weakly: _, end_points = feature_extractor.extract_features( samples[common.IMAGE], output_stride=model_options.output_stride, multi_grid=model_options.multi_grid, model_variant=model_options.model_variant, depth_multiplier=model_options.depth_multiplier, divisible_by=model_options.divisible_by, reuse=tf.AUTO_REUSE, is_training=False, preprocessed_images_dtype=model_options. preprocessed_images_dtype, global_pool=True, num_classes=dataset.num_of_classes - 1) # ResNet beta version has an additional suffix in FLAGS.model_variant, but # it shares the same variable names with original version. Add a special # handling here for beta version ResNet. logits = end_points['{}/logits'.format( FLAGS.model_variant).replace('_beta', '')] logits = tf.reshape(logits, [-1, dataset.num_of_classes - 1]) cls_pred = tf.sigmoid(logits) # Multi-label classification evaluation cls_label = samples['cls_label'] cls_pred = tf.cast(tf.greater_equal(cls_pred, 0.5), tf.int32) ## For classification metric_map['eval/cls_overall'] = tf.metrics.accuracy( labels=cls_label, predictions=cls_pred) metric_map['eval/cls_precision'] = tf.metrics.precision( labels=cls_label, predictions=cls_pred) metric_map['eval/cls_recall'] = tf.metrics.recall( labels=cls_label, predictions=cls_pred) ## For segmentation branch eval predictions = predictions[common.OUTPUT_TYPE] predictions = tf.reshape(predictions, shape=[-1]) labels = tf.reshape(samples[common.LABEL], shape=[-1]) weights = tf.to_float(tf.not_equal(labels, dataset.ignore_label)) # Set ignore_label regions to label 0, because metrics.mean_iou requires # range of labels = [0, dataset.num_classes). Note the ignore_label regions # are not evaluated since the corresponding regions contain weights = 0. labels = tf.where(tf.equal(labels, dataset.ignore_label), tf.zeros_like(labels), labels) predictions_tag = 'miou' # Define the evaluation metric. num_classes = dataset.num_of_classes ## For segmentation metric_map['eval/%s_overall' % predictions_tag] = tf.metrics.mean_iou( labels=labels, predictions=predictions, num_classes=num_classes, weights=weights) # IoU for each class. one_hot_predictions = tf.one_hot(predictions, num_classes) one_hot_predictions = tf.reshape(one_hot_predictions, [-1, num_classes]) one_hot_labels = tf.one_hot(labels, num_classes) one_hot_labels = tf.reshape(one_hot_labels, [-1, num_classes]) for c in range(num_classes): predictions_tag_c = '%s_class_%d' % (predictions_tag, c) tp, tp_op = tf.metrics.true_positives( labels=one_hot_labels[:, c], predictions=one_hot_predictions[:, c], weights=weights) fp, fp_op = tf.metrics.false_positives( labels=one_hot_labels[:, c], predictions=one_hot_predictions[:, c], weights=weights) fn, fn_op = tf.metrics.false_negatives( labels=one_hot_labels[:, c], predictions=one_hot_predictions[:, c], weights=weights) tp_fp_fn_op = tf.group(tp_op, fp_op, fn_op) iou = tf.where(tf.greater(tp + fn, 0.0), tp / (tp + fn + fp), tf.constant(np.NaN)) metric_map['eval/%s' % predictions_tag_c] = (iou, tp_fp_fn_op) (metrics_to_values, metrics_to_updates) = contrib_metrics.aggregate_metric_map(metric_map) summary_ops = [] for metric_name, metric_value in six.iteritems(metrics_to_values): op = tf.summary.scalar(metric_name, metric_value) op = tf.Print(op, [metric_value], metric_name) summary_ops.append(op) summary_op = tf.summary.merge(summary_ops) summary_hook = contrib_training.SummaryAtEndHook( log_dir=FLAGS.eval_logdir, summary_op=summary_op) hooks = [summary_hook] num_eval_iters = None if FLAGS.max_number_of_evaluations > 0: num_eval_iters = FLAGS.max_number_of_evaluations if FLAGS.quantize_delay_step >= 0: contrib_quantize.create_eval_graph() contrib_tfprof.model_analyzer.print_model_analysis( tf.get_default_graph(), tfprof_options=contrib_tfprof.model_analyzer. TRAINABLE_VARS_PARAMS_STAT_OPTIONS) contrib_tfprof.model_analyzer.print_model_analysis( tf.get_default_graph(), tfprof_options=contrib_tfprof.model_analyzer.FLOAT_OPS_OPTIONS) contrib_training.evaluate_repeatedly( checkpoint_dir=FLAGS.checkpoint_dir, master=FLAGS.master, eval_ops=list(metrics_to_updates.values()), max_number_of_evaluations=num_eval_iters, hooks=hooks, eval_interval_secs=FLAGS.eval_interval_secs)
def _build_deeplab(iterator_seg, iterator, outputs_to_num_classes, ignore_label): """Builds a clone of Supervised DeepLab. Args: iterator_seg: An iterator of type tf.data.Iterator for images and labels. (seg) iterator: An iterator of type tf.data. Iterator for images and labels. outputs_to_num_classes: A map from output type to the number of classes. For example, for the task of semantic segmentation with 21 semantic classes, we would have outputs_to_num_classes['semantic'] = 21. ignore_label: Ignore label. """ if FLAGS.weakly: samples = iterator.get_next() samples[common.IMAGE] = tf.identity(samples[common.IMAGE], name=common.IMAGE) samples[common.LABEL] = tf.identity(samples[common.LABEL], name=common.LABEL) samples_seg = iterator_seg.get_next() samples_seg[common.IMAGE] = tf.identity(samples_seg[common.IMAGE], name=common.IMAGE + '_seg') samples_seg[common.LABEL] = tf.identity(samples_seg[common.LABEL], name=common.LABEL + '_seg') model_options = common.ModelOptions( outputs_to_num_classes=outputs_to_num_classes, crop_size=[int(sz) for sz in FLAGS.train_crop_size], atrous_rates=FLAGS.atrous_rates, output_stride=FLAGS.output_stride) ### Cls data if FLAGS.weakly: _, end_points_cls = feature_extractor.extract_features( samples[common.IMAGE], output_stride=model_options.output_stride, multi_grid=model_options.multi_grid, model_variant=model_options.model_variant, depth_multiplier=model_options.depth_multiplier, divisible_by=model_options.divisible_by, weight_decay=FLAGS.weight_decay, reuse=tf.AUTO_REUSE, is_training=True, preprocessed_images_dtype=model_options.preprocessed_images_dtype, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, global_pool=True, num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1) # ResNet beta version has an additional suffix in FLAGS.model_variant, but # it shares the same variable names with original version. Add a special # handling here for beta version ResNet. logits_cls = end_points_cls['{}/logits'.format( FLAGS.model_variant).replace('_beta', '')] logits_cls = tf.reshape( logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1]) # Seems that people usually use multi-label soft margin loss loss_cls = tf.nn.sigmoid_cross_entropy_with_logits( labels=samples['cls_label'], logits=logits_cls) loss_cls = tf.reduce_mean(loss_cls) loss_cls = tf.identity(loss_cls, name='loss_cls') tf.compat.v1.losses.add_loss(loss_cls) ### Seg data outputs_to_scales_to_logits = model.multi_scale_logits( samples_seg[common.IMAGE], model_options=model_options, image_pyramid=FLAGS.image_pyramid, weight_decay=FLAGS.weight_decay, is_training=True, fine_tune_batch_norm=FLAGS.fine_tune_batch_norm, nas_training_hyper_parameters={ 'drop_path_keep_prob': FLAGS.drop_path_keep_prob, 'total_training_steps': FLAGS.training_number_of_steps, }) # Add name to graph node so we can add to summary. output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE] output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity( output_type_dict[model.MERGED_LOGITS_SCOPE], name=common.OUTPUT_TYPE + '_seg') for output, num_classes in six.iteritems(outputs_to_num_classes): train_utils.add_softmax_cross_entropy_loss_for_each_scale( outputs_to_scales_to_logits[output], samples_seg[common.LABEL], num_classes, ignore_label, loss_weight=model_options.label_weights, upsample_logits=FLAGS.upsample_logits, hard_example_mining_step=FLAGS.hard_example_mining_step, top_k_percent_pixels=FLAGS.top_k_percent_pixels, scope=output) ## Sanity check. Monitor pixel accuracy logits_seg = outputs_to_scales_to_logits[common.OUTPUT_TYPE][ model.MERGED_LOGITS_SCOPE] temp_label = tf.compat.v1.image.resize_nearest_neighbor( samples_seg[common.LABEL], preprocess_utils.resolve_shape(logits_seg, 4)[1:3]) temp_label = tf.reshape(temp_label, [-1]) dump = tf.concat( [tf.range(outputs_to_num_classes[common.OUTPUT_TYPE]), temp_label], axis=-1) _, _, count = tf.unique_with_counts(dump) num_pixel_list = count - 1 # Exclude the ignore region num_pixel_list = num_pixel_list[:outputs_to_num_classes[common. OUTPUT_TYPE]] num_pixel_list = tf.cast(num_pixel_list, tf.float32) inverse_ratio = train_utils._div_maybe_zero(1, num_pixel_list) inverse_ratio = inverse_ratio / tf.reduce_sum(inverse_ratio) # Create weight mask to balance each class weight_mask = tf.einsum( '...y,y->...', tf.one_hot(temp_label, outputs_to_num_classes[common.OUTPUT_TYPE], dtype=tf.float32), inverse_ratio) temp_valid = tf.not_equal(temp_label, ignore_label) temp_label_valid = tf.boolean_mask(temp_label, temp_valid) weight_mask_valid = tf.boolean_mask(weight_mask, temp_valid) pred_seg = tf.argmax(logits_seg, axis=-1) pred_seg = tf.reshape(pred_seg, [-1]) acc_seg, acc_seg_op = tf.metrics.mean_per_class_accuracy( temp_label_valid, tf.boolean_mask(pred_seg, temp_valid), outputs_to_num_classes[common.OUTPUT_TYPE], weights=weight_mask_valid) with tf.control_dependencies([acc_seg_op]): acc_seg = tf.identity(acc_seg, name='acc_seg')
def extract_features(images, model_options, weight_decay=0.0001, reuse=None, is_training=False, fine_tune_batch_norm=False, nas_training_hyper_parameters=None): """Extracts features by the particular model_variant. Args: images: A tensor of size [batch, height, width, channels]. model_options: A ModelOptions instance to configure models. weight_decay: The weight decay for model variables. reuse: Reuse the model variables or not. is_training: Is training or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not. nas_training_hyper_parameters: A dictionary storing hyper-parameters for training nas models. Its keys are: - `drop_path_keep_prob`: Probability to keep each path in the cell when training. - `total_training_steps`: Total training steps to help drop path probability calculation. Returns: concat_logits: A tensor of size [batch, feature_height, feature_width, feature_channels], where feature_height/feature_width are determined by the images height/width and output_stride. end_points: A dictionary from components of the network to the corresponding activation. """ features, end_points = feature_extractor.extract_features( images, output_stride=model_options.output_stride, multi_grid=model_options.multi_grid, model_variant=model_options.model_variant, depth_multiplier=model_options.depth_multiplier, divisible_by=model_options.divisible_by, weight_decay=weight_decay, reuse=reuse, is_training=is_training, preprocessed_images_dtype=model_options.preprocessed_images_dtype, fine_tune_batch_norm=fine_tune_batch_norm, nas_architecture_options=model_options.nas_architecture_options, nas_training_hyper_parameters=nas_training_hyper_parameters, use_bounded_activation=model_options.use_bounded_activation) if not model_options.aspp_with_batch_norm: return features, end_points else: if model_options.dense_prediction_cell_config is not None: tf.logging.info('Using dense prediction cell config.') dense_prediction_layer = dense_prediction_cell.DensePredictionCell( config=model_options.dense_prediction_cell_config, hparams={ 'conv_rate_multiplier': 16 // model_options.output_stride, }) concat_logits = dense_prediction_layer.build_cell( features, output_stride=model_options.output_stride, crop_size=model_options.crop_size, image_pooling_crop_size=model_options.image_pooling_crop_size, weight_decay=weight_decay, reuse=reuse, is_training=is_training, fine_tune_batch_norm=fine_tune_batch_norm) return concat_logits, end_points else: # The following codes employ the DeepLabv3 ASPP module. Note that we # could express the ASPP module as one particular dense prediction # cell architecture. We do not do so but leave the following codes # for backward compatibility. batch_norm_params = utils.get_batch_norm_params( decay=0.9997, epsilon=1e-5, scale=True, is_training=(is_training and fine_tune_batch_norm), sync_batch_norm_method=model_options.sync_batch_norm_method) batch_norm = utils.get_batch_norm_fn( model_options.sync_batch_norm_method) activation_fn = (tf.nn.relu6 if model_options.use_bounded_activation else tf.nn.relu) with slim.arg_scope( [slim.conv2d, slim.separable_conv2d], weights_regularizer=slim.l2_regularizer(weight_decay), activation_fn=activation_fn, normalizer_fn=batch_norm, padding='SAME', stride=1, reuse=reuse): with slim.arg_scope([batch_norm], **batch_norm_params): depth = model_options.aspp_convs_filters branch_logits = [] if model_options.add_image_level_feature: if model_options.crop_size is not None: image_pooling_crop_size = model_options.image_pooling_crop_size # If image_pooling_crop_size is not specified, use crop_size. if image_pooling_crop_size is None: image_pooling_crop_size = model_options.crop_size pool_height = scale_dimension( image_pooling_crop_size[0], 1. / model_options.output_stride) pool_width = scale_dimension( image_pooling_crop_size[1], 1. / model_options.output_stride) image_feature = slim.avg_pool2d( features, [pool_height, pool_width], model_options.image_pooling_stride, padding='VALID') resize_height = scale_dimension( model_options.crop_size[0], 1. / model_options.output_stride) resize_width = scale_dimension( model_options.crop_size[1], 1. / model_options.output_stride) else: # If crop_size is None, we simply do global pooling. pool_height = tf.shape(features)[1] pool_width = tf.shape(features)[2] image_feature = tf.reduce_mean(features, axis=[1, 2], keepdims=True) resize_height = pool_height resize_width = pool_width image_feature_activation_fn = tf.nn.relu image_feature_normalizer_fn = batch_norm if model_options.aspp_with_squeeze_and_excitation: image_feature_activation_fn = tf.nn.sigmoid if model_options.image_se_uses_qsigmoid: image_feature_activation_fn = utils.q_sigmoid image_feature_normalizer_fn = None image_feature = slim.conv2d( image_feature, depth, 1, activation_fn=image_feature_activation_fn, normalizer_fn=image_feature_normalizer_fn, scope=IMAGE_POOLING_SCOPE) image_feature = _resize_bilinear( image_feature, [resize_height, resize_width], image_feature.dtype) # Set shape for resize_height/resize_width if they are not Tensor. if isinstance(resize_height, tf.Tensor): resize_height = None if isinstance(resize_width, tf.Tensor): resize_width = None image_feature.set_shape( [None, resize_height, resize_width, depth]) if not model_options.aspp_with_squeeze_and_excitation: branch_logits.append(image_feature) # Employ a 1x1 convolution. branch_logits.append( slim.conv2d(features, depth, 1, scope=ASPP_SCOPE + str(0))) if model_options.atrous_rates: # Employ 3x3 convolutions with different atrous rates. for i, rate in enumerate(model_options.atrous_rates, 1): scope = ASPP_SCOPE + str(i) if model_options.aspp_with_separable_conv: aspp_features = split_separable_conv2d( features, filters=depth, rate=rate, weight_decay=weight_decay, scope=scope) else: aspp_features = slim.conv2d(features, depth, 3, rate=rate, scope=scope) branch_logits.append(aspp_features) # Merge branch logits. concat_logits = tf.concat(branch_logits, 3) if model_options.aspp_with_concat_projection: concat_logits = slim.conv2d( concat_logits, depth, 1, scope=CONCAT_PROJECTION_SCOPE) concat_logits = slim.dropout( concat_logits, keep_prob=0.9, is_training=is_training, scope=CONCAT_PROJECTION_SCOPE + '_dropout') if (model_options.add_image_level_feature and model_options.aspp_with_squeeze_and_excitation): concat_logits = tf.math.multiply( concat_logits, image_feature, name='aspp_multiply_image_feature') return concat_logits, end_points