コード例 #1
0
def _build_pseudo_seg(iterator_seg,
                      iterator,
                      outputs_to_num_classes,
                      ignore_label,
                      batch_size=8):
    """Builds a clone of PseudoSeg.

  Args:
    iterator_seg: An iterator of type tf.data.Iterator for images and labels.
    (seg)
    iterator: An iterator of type tf.data. Iterator for images and labels.
    outputs_to_num_classes: A map from output type to the number of classes. For
      example, for the task of semantic segmentation with 21 semantic classes,
      we would have outputs_to_num_classes['semantic'] = 21.
    ignore_label: Ignore label.
    batch_size: Training batch size for each clone.
  """
    samples_cls = iterator.get_next()
    samples_cls[common.IMAGE] = tf.identity(samples_cls[common.IMAGE],
                                            name='weak')
    samples_cls['strong'] = tf.identity(samples_cls['strong'], name='strong')
    samples_cls[common.LABEL] = tf.identity(samples_cls[common.LABEL],
                                            name='unlabeled')

    samples_seg = iterator_seg.get_next()
    samples_seg[common.IMAGE] = tf.identity(samples_seg[common.IMAGE],
                                            name=common.IMAGE + '_seg')
    samples_seg[common.LABEL] = tf.identity(samples_seg[common.LABEL],
                                            name=common.LABEL + '_seg')

    model_options = common.ModelOptions(
        outputs_to_num_classes=outputs_to_num_classes,
        crop_size=[int(sz) for sz in FLAGS.train_crop_size],
        atrous_rates=FLAGS.atrous_rates,
        output_stride=FLAGS.output_stride)

    if FLAGS.att_v2:
        cam_func = train_utils_core.compute_cam_v2
    else:
        cam_func = train_utils_core.compute_cam

    ### Cls/unlabeled data
    ## 1) If we have image-level label, we train the classifier here
    if FLAGS.weakly:
        if FLAGS.cls_with_cls:
            curr_samples = samples_cls
        else:
            curr_samples = samples_seg

        _, end_points_cls = feature_extractor.extract_features(
            curr_samples[common.IMAGE],
            output_stride=model_options.output_stride,
            multi_grid=model_options.multi_grid,
            model_variant=model_options.model_variant,
            depth_multiplier=model_options.depth_multiplier,
            divisible_by=model_options.divisible_by,
            weight_decay=FLAGS.weight_decay,
            reuse=tf.AUTO_REUSE,
            is_training=True,
            preprocessed_images_dtype=model_options.preprocessed_images_dtype,
            fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
            global_pool=True,
            num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1)

        # ResNet beta version has an additional suffix in FLAGS.model_variant, but
        # it shares the same variable names with original version. Add a special
        # handling here for beta version ResNet.
        logits_cls = end_points_cls['{}/logits'.format(
            FLAGS.model_variant).replace('_beta', '')]
        logits_cls = tf.reshape(
            logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1])
        # Seems that people usually use multi-label soft margin loss in PyTorch
        loss_cls = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=curr_samples['cls_label'], logits=logits_cls)
        loss_cls = tf.reduce_mean(loss_cls)
        loss_cls = tf.identity(loss_cls, name='loss_cls')
        tf.compat.v1.losses.add_loss(loss_cls)

    ## 2) Consistency
    with tf.name_scope('cls_weak'):
        outputs_to_scales_to_logits, _ = model.multi_scale_logits(
            samples_cls[common.IMAGE],
            model_options=model_options,
            image_pyramid=FLAGS.image_pyramid,
            weight_decay=FLAGS.weight_decay,
            is_training=True,
            fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
            nas_training_hyper_parameters={
                'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
                'total_training_steps': FLAGS.training_number_of_steps,
            },
            output_end_points=True)
        logits_weak = outputs_to_scales_to_logits[common.OUTPUT_TYPE][
            model.MERGED_LOGITS_SCOPE]

    prob_weak = tf.nn.softmax(logits_weak, axis=-1)
    logits_weak = tf.identity(logits_weak, name='logits_weak')
    # Monitor max score
    max_prob_weak = tf.reduce_max(prob_weak, axis=-1)
    max_prob_weak = tf.identity(max_prob_weak, name='max_prob_weak')

    valid_mask_pad = samples_cls['valid']
    valid_mask_pad = tf.compat.v1.image.resize_nearest_neighbor(
        valid_mask_pad,
        preprocess_utils.resolve_shape(logits_weak, 4)[1:3])
    valid_mask_pad = tf.cast(valid_mask_pad, tf.float32)

    if FLAGS.use_attention:
        # Using inference mode to generate Grad-CAM
        with tf.name_scope('cls_data_cls_inference'):
            _, end_points_cls = feature_extractor.extract_features(
                samples_cls[common.IMAGE],
                output_stride=model_options.output_stride,
                multi_grid=model_options.multi_grid,
                model_variant=model_options.model_variant,
                depth_multiplier=model_options.depth_multiplier,
                divisible_by=model_options.divisible_by,
                weight_decay=FLAGS.weight_decay,
                reuse=tf.AUTO_REUSE,
                is_training=False,
                preprocessed_images_dtype=model_options.
                preprocessed_images_dtype,
                fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
                global_pool=True,
                num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1)

        logits_cls = end_points_cls['{}/logits'.format(
            FLAGS.model_variant).replace('_beta', '')]
        logits_cls = tf.reshape(
            logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1])

        # We can only get ground truth image-level label in weakly+semi setting
        if FLAGS.weakly:
            image_level_label = samples_cls['cls_label']
        else:
            prob_cls = tf.sigmoid(logits_cls)
            # TODO(ylzou): Might use a variable threshold for different classes
            pred_cls = tf.greater_equal(prob_cls, 0.5)
            image_level_label = tf.stop_gradient(tf.cast(pred_cls, tf.float32))

        cam_weak, att_cam_weak = cam_func(
            end_points_cls,
            logits_cls,
            image_level_label,
            num_class=outputs_to_num_classes[common.OUTPUT_TYPE],
            use_attention=True,
            attention_dim=FLAGS.attention_dim,
            strides=[int(st) for st in FLAGS.att_strides],
            is_training=True,
            valid_mask=valid_mask_pad,
            net=FLAGS.model_variant.replace('_beta', ''))
        att_logits_weak = att_cam_weak
        # Upsample att-cam
        att_logits_weak = tf.compat.v1.image.resize_bilinear(
            att_logits_weak,
            preprocess_utils.resolve_shape(logits_weak, 4)[1:3],
            align_corners=True)
        # Monitor vanilla cam
        cam_weak = tf.compat.v1.image.resize_bilinear(
            cam_weak,
            preprocess_utils.resolve_shape(logits_weak, 4)[1:3],
            align_corners=True)
        cam_weak = tf.identity(cam_weak, name='cam_weak')

        att_prob_weak = tf.nn.softmax(att_logits_weak, axis=-1)
        att_logits_weak = tf.identity(att_logits_weak, name='att_logits_weak')
        # Monitor max score
        max_att_prob_weak = tf.reduce_max(att_prob_weak, axis=-1)
        max_att_prob_weak = tf.identity(max_att_prob_weak,
                                        name='max_att_prob_weak')

        # Ensemble
        if FLAGS.pseudo_src == 'att':
            prob_weak = att_prob_weak
        else:
            if FLAGS.logit_norm:
                v = tf.concat([logits_weak, att_logits_weak], axis=0)
                all_logits_weak = v * tf.rsqrt(
                    tf.reduce_mean(tf.square(v)) + 1e-8)
                scaled_logits_weak = all_logits_weak[:batch_size]
                prob_weak = tf.nn.softmax(scaled_logits_weak, axis=-1)
                scaled_att_logits_weak = all_logits_weak[batch_size:]
                att_prob_weak = tf.nn.softmax(scaled_att_logits_weak, axis=-1)
            prob_weak = (prob_weak + att_prob_weak) / 2.

        # Monitor max score
        max_prob_avg = tf.reduce_max(prob_weak, axis=-1)
        max_prob_avg = tf.identity(max_prob_avg, name='max_prob_avg')

    # Temperature
    if FLAGS.soft_pseudo_label and FLAGS.temperature != 1.0:
        prob_weak = tf.pow(prob_weak, 1. / FLAGS.temperature)
        prob_weak /= tf.reduce_sum(prob_weak, axis=-1, keepdims=True)
        # Monitor max score
        max_prob_avg_t = tf.reduce_max(prob_weak, axis=-1)
        max_prob_avg_t = tf.identity(max_prob_avg_t, name='max_prob_avg_t')
        # Monitor merged logits
        prob_weak = tf.identity(prob_weak, name='merged_logits')

    with tf.name_scope('cls_strong'):
        outputs_to_scales_to_logits, _ = model.multi_scale_logits(
            samples_cls['strong'],
            model_options=model_options,
            image_pyramid=FLAGS.image_pyramid,
            weight_decay=FLAGS.weight_decay,
            is_training=True,
            fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
            nas_training_hyper_parameters={
                'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
                'total_training_steps': FLAGS.training_number_of_steps,
            },
            output_end_points=True)
        logits_strong = outputs_to_scales_to_logits[common.OUTPUT_TYPE][
            model.MERGED_LOGITS_SCOPE]
    logits_strong = tf.identity(logits_strong, name='logits_strong')

    if FLAGS.pseudo_label_threshold > 0:
        confidence_weak = tf.expand_dims(tf.reduce_max(prob_weak, axis=-1),
                                         axis=-1)
        valid_mask_score = tf.greater_equal(confidence_weak,
                                            FLAGS.pseudo_label_threshold)
        valid_mask_score = tf.cast(valid_mask_score, tf.float32)
        valid_mask = valid_mask_score * valid_mask_pad
    else:
        valid_mask_score = None
        valid_mask = valid_mask_pad
    # Save for visualization
    valid_mask = tf.identity(valid_mask, name='valid_mask')

    logits_strong = tf.reshape(
        logits_strong, [-1, outputs_to_num_classes[common.OUTPUT_TYPE]])

    if not FLAGS.soft_pseudo_label:
        pseudo_label = tf.argmax(prob_weak, axis=-1)
        pseudo_label = tf.reshape(pseudo_label, [-1])
        pseudo_label = tf.stop_gradient(pseudo_label)
        loss_consistency = tf.compat.v1.nn.sparse_softmax_cross_entropy_with_logits(
            labels=pseudo_label,
            logits=logits_strong,
            name='consistency_losses')
        loss_consistency = loss_consistency * tf.reshape(valid_mask, [-1])
        pred_pseudo = pseudo_label
    else:
        pseudo_label = prob_weak
        pseudo_label = tf.reshape(
            pseudo_label, [-1, outputs_to_num_classes[common.OUTPUT_TYPE]])
        pseudo_label = tf.stop_gradient(pseudo_label)
        loss_consistency = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
            labels=pseudo_label,
            logits=logits_strong,
            name='consistency_losses')
        loss_consistency = loss_consistency * tf.reshape(valid_mask, [-1])
        pred_pseudo = tf.argmax(pseudo_label, axis=-1)

    # NOTE: When average, we divide by the number of pixels excluding padding
    loss_consistency = tf.reduce_sum(loss_consistency)
    loss_consistency = train_utils._div_maybe_zero(
        loss_consistency, tf.reduce_sum(valid_mask_pad))
    loss_consistency *= FLAGS.unlabeled_weight
    loss_consistency = tf.identity(loss_consistency, 'loss_consistency')
    tf.compat.v1.losses.add_loss(loss_consistency)

    ## 3) Monitor prediction quality
    temp_label = tf.compat.v1.image.resize_nearest_neighbor(
        samples_cls[common.LABEL],
        preprocess_utils.resolve_shape(logits_weak, 4)[1:3])
    temp_label = tf.reshape(temp_label, [-1])

    # Get #pixel of each class, so that we can re-weight them for pixel acc.
    dump = tf.concat(
        [tf.range(outputs_to_num_classes[common.OUTPUT_TYPE]), temp_label],
        axis=-1)
    _, _, count = tf.unique_with_counts(dump)
    num_pixel_list = count - 1
    # Exclude the ignore region
    num_pixel_list = num_pixel_list[:outputs_to_num_classes[common.
                                                            OUTPUT_TYPE]]
    num_pixel_list = tf.cast(num_pixel_list, tf.float32)
    inverse_ratio = train_utils._div_maybe_zero(1, num_pixel_list)
    inverse_ratio = inverse_ratio / tf.reduce_sum(inverse_ratio)

    # Since tf.metrics.mean_per_class_accuracy does not support weighted average
    # for each class directly, we here convert it to pixel-wise weighted mask to
    # compute weighted average pixel accuracy.
    weight_mask = tf.einsum(
        '...y,y->...',
        tf.one_hot(temp_label,
                   outputs_to_num_classes[common.OUTPUT_TYPE],
                   dtype=tf.float32), inverse_ratio)
    temp_valid = tf.not_equal(temp_label, ignore_label)
    if valid_mask_score is not None:
        temp_valid_confident = tf.cast(temp_valid, tf.float32) * tf.reshape(
            valid_mask_score, [-1])
        temp_valid_confident = tf.cast(temp_valid_confident, tf.bool)
    else:
        temp_valid_confident = temp_valid

    temp_label_confident = tf.boolean_mask(temp_label, temp_valid_confident)
    temp_label_valid = tf.boolean_mask(temp_label, temp_valid)
    weight_mask_confident = tf.boolean_mask(weight_mask, temp_valid_confident)
    weight_mask_valid = tf.boolean_mask(weight_mask, temp_valid)

    if FLAGS.pseudo_label_threshold > 0:
        acc_pseudo, acc_pseudo_op = tf.metrics.mean_per_class_accuracy(
            temp_label_confident,
            tf.boolean_mask(pred_pseudo, temp_valid_confident),
            outputs_to_num_classes[common.OUTPUT_TYPE],
            weights=weight_mask_confident)
        with tf.control_dependencies([acc_pseudo_op]):
            acc_pseudo = tf.identity(acc_pseudo, name='acc_pseudo')

    pred_weak = tf.cast(tf.argmax(prob_weak, axis=-1), tf.int32)
    pred_weak = tf.reshape(pred_weak, [-1])
    acc_weak, acc_weak_op = tf.metrics.mean_per_class_accuracy(
        temp_label_valid,
        tf.boolean_mask(pred_weak, temp_valid),
        outputs_to_num_classes[common.OUTPUT_TYPE],
        weights=weight_mask_valid)
    with tf.control_dependencies([acc_weak_op]):
        acc_weak = tf.identity(acc_weak, name='acc_weak')

    pred_strong = tf.cast(tf.argmax(logits_strong, axis=-1), tf.int32)
    pred_strong = tf.reshape(pred_strong, [-1])
    # For all pixels
    acc_strong, acc_strong_op = tf.metrics.mean_per_class_accuracy(
        temp_label_valid,
        tf.boolean_mask(pred_strong, temp_valid),
        outputs_to_num_classes[common.OUTPUT_TYPE],
        weights=weight_mask_valid)
    with tf.control_dependencies([acc_strong_op]):
        acc_strong = tf.identity(acc_strong, name='acc_strong')

    # For confident pixels
    if FLAGS.pseudo_label_threshold > 0:
        acc_strong_confident, acc_strong_confident_op = tf.metrics.mean_per_class_accuracy(
            temp_label_confident,
            tf.boolean_mask(pred_strong, temp_valid_confident),
            outputs_to_num_classes[common.OUTPUT_TYPE],
            weights=weight_mask_confident)
        with tf.control_dependencies([acc_strong_confident_op]):
            acc_strong_confident = tf.identity(acc_strong_confident,
                                               name='acc_strong_confident')

        valid_ratio = tf.reduce_sum(valid_mask) / tf.reduce_sum(valid_mask_pad)
        valid_ratio = tf.identity(valid_ratio, name='valid_ratio')

    ### Pixel-level data
    ## 1) Segmentation
    outputs_to_scales_to_logits = model.multi_scale_logits(
        samples_seg[common.IMAGE],
        model_options=model_options,
        image_pyramid=FLAGS.image_pyramid,
        weight_decay=FLAGS.weight_decay,
        is_training=True,
        fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
        nas_training_hyper_parameters={
            'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
            'total_training_steps': FLAGS.training_number_of_steps,
        })

    # Add name to graph node so we can add to summary.
    output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE]
    output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity(
        output_type_dict[model.MERGED_LOGITS_SCOPE],
        name=common.OUTPUT_TYPE + '_seg')

    for output, num_classes in six.iteritems(outputs_to_num_classes):
        train_utils.add_softmax_cross_entropy_loss_for_each_scale(
            outputs_to_scales_to_logits[output],
            samples_seg[common.LABEL],
            num_classes,
            ignore_label,
            loss_weight=model_options.label_weights,
            upsample_logits=FLAGS.upsample_logits,
            hard_example_mining_step=FLAGS.hard_example_mining_step,
            top_k_percent_pixels=FLAGS.top_k_percent_pixels,
            scope=output)

    ## 2) Train self-attention module
    if FLAGS.use_attention:
        valid_mask_pad = samples_seg['valid']
        valid_mask_pad = tf.compat.v1.image.resize_nearest_neighbor(
            valid_mask_pad,
            preprocess_utils.resolve_shape(logits_weak, 4)[1:3])
        valid_mask_pad = tf.cast(valid_mask_pad, tf.float32)

        with tf.name_scope('seg_data_cls'):
            _, end_points_cls = feature_extractor.extract_features(
                samples_seg[common.IMAGE],
                output_stride=model_options.output_stride,
                multi_grid=model_options.multi_grid,
                model_variant=model_options.model_variant,
                depth_multiplier=model_options.depth_multiplier,
                divisible_by=model_options.divisible_by,
                weight_decay=FLAGS.weight_decay,
                reuse=tf.AUTO_REUSE,
                is_training=True,
                preprocessed_images_dtype=model_options.
                preprocessed_images_dtype,
                fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
                global_pool=True,
                num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1)

        logits_cls = end_points_cls['{}/logits'.format(
            FLAGS.model_variant).replace('_beta', '')]
        logits_cls = tf.reshape(
            logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1])

        _, att_cam_labeled = cam_func(
            end_points_cls,
            logits_cls,
            samples_seg['cls_label'],
            num_class=outputs_to_num_classes[common.OUTPUT_TYPE],
            use_attention=True,
            attention_dim=FLAGS.attention_dim,
            strides=[int(st) for st in FLAGS.att_strides],
            is_training=True,
            valid_mask=valid_mask_pad,
            net=FLAGS.model_variant.replace('_beta', ''))
        att_logits_labeled = att_cam_labeled

        # Loss
        train_utils.add_softmax_cross_entropy_loss_for_each_scale(
            {'self-attention_logits': att_logits_labeled},
            samples_seg[common.LABEL],
            outputs_to_num_classes[common.OUTPUT_TYPE],
            ignore_label,
            loss_weight=model_options.label_weights,
            upsample_logits=FLAGS.upsample_logits,
            hard_example_mining_step=FLAGS.hard_example_mining_step,
            top_k_percent_pixels=FLAGS.top_k_percent_pixels,
            scope='self-attention')

        att_logits_labeled = tf.identity(att_logits_labeled,
                                         name='att_logits_labeled')

        ## 3) If no image-level label, convert pixel-level label to train classifier
        if not FLAGS.weakly:
            # Seems that people usually use multi-label soft margin loss in PyTorch
            loss_cls = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=samples_seg['cls_label'], logits=logits_cls)
            loss_cls = tf.reduce_mean(loss_cls)
            loss_cls = tf.identity(loss_cls, name='loss_cls')
            tf.compat.v1.losses.add_loss(loss_cls)

    ## 4) Sanity check. Monitor pixel accuracy
    logits_seg = outputs_to_scales_to_logits[common.OUTPUT_TYPE][
        model.MERGED_LOGITS_SCOPE]
    temp_label = tf.compat.v1.image.resize_nearest_neighbor(
        samples_seg[common.LABEL],
        preprocess_utils.resolve_shape(logits_seg, 4)[1:3])
    temp_label = tf.reshape(temp_label, [-1])

    dump = tf.concat(
        [tf.range(outputs_to_num_classes[common.OUTPUT_TYPE]), temp_label],
        axis=-1)
    _, _, count = tf.unique_with_counts(dump)
    num_pixel_list = count - 1
    # Exclude the ignore region
    num_pixel_list = num_pixel_list[:outputs_to_num_classes[common.
                                                            OUTPUT_TYPE]]
    num_pixel_list = tf.cast(num_pixel_list, tf.float32)
    inverse_ratio = train_utils._div_maybe_zero(1, num_pixel_list)
    inverse_ratio = inverse_ratio / tf.reduce_sum(inverse_ratio)

    # Create weight mask to balance each class
    weight_mask = tf.einsum(
        '...y,y->...',
        tf.one_hot(temp_label,
                   outputs_to_num_classes[common.OUTPUT_TYPE],
                   dtype=tf.float32), inverse_ratio)
    temp_valid = tf.not_equal(temp_label, ignore_label)
    temp_label_valid = tf.boolean_mask(temp_label, temp_valid)
    weight_mask_valid = tf.boolean_mask(weight_mask, temp_valid)

    pred_seg = tf.argmax(logits_seg, axis=-1)
    pred_seg = tf.reshape(pred_seg, [-1])
    acc_seg, acc_seg_op = tf.metrics.mean_per_class_accuracy(
        temp_label_valid,
        tf.boolean_mask(pred_seg, temp_valid),
        outputs_to_num_classes[common.OUTPUT_TYPE],
        weights=weight_mask_valid)
    with tf.control_dependencies([acc_seg_op]):
        acc_seg = tf.identity(acc_seg, name='acc_seg')
コード例 #2
0
ファイル: eval.py プロジェクト: templeblock/wss
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    dataset = data_generator.Dataset(
        dataset_name=FLAGS.dataset,
        split_name=FLAGS.eval_split,
        dataset_dir=FLAGS.dataset_dir,
        batch_size=FLAGS.eval_batch_size,
        crop_size=[int(sz) for sz in FLAGS.eval_crop_size],
        min_resize_value=FLAGS.min_resize_value,
        max_resize_value=FLAGS.max_resize_value,
        resize_factor=FLAGS.resize_factor,
        model_variant=FLAGS.model_variant,
        num_readers=2,
        is_training=False,
        should_shuffle=False,
        should_repeat=False,
        with_cls=True,
        cls_only=False,
        output_valid=True)

    tf.gfile.MakeDirs(FLAGS.eval_logdir)
    tf.logging.info('Evaluating on %s set', FLAGS.eval_split)

    with tf.Graph().as_default():
        samples = dataset.get_one_shot_iterator().get_next()

        model_options = common.ModelOptions(
            outputs_to_num_classes={
                common.OUTPUT_TYPE: dataset.num_of_classes
            },
            crop_size=[int(sz) for sz in FLAGS.eval_crop_size],
            atrous_rates=FLAGS.atrous_rates,
            output_stride=FLAGS.output_stride)

        # Set shape in order for tf.contrib.tfprof.model_analyzer to work properly.
        samples[common.IMAGE].set_shape([
            FLAGS.eval_batch_size,
            int(FLAGS.eval_crop_size[0]),
            int(FLAGS.eval_crop_size[1]), 3
        ])
        if tuple(FLAGS.eval_scales) == (1.0, ):
            tf.logging.info('Performing single-scale test.')
            predictions = model.predict_labels(
                samples[common.IMAGE],
                model_options,
                image_pyramid=FLAGS.image_pyramid)
        else:
            tf.logging.info('Performing multi-scale test.')
            raise NotImplementedError('Multi-scale is not supported yet!')

        metric_map = {}
        ## Extract cls logits
        if FLAGS.weakly:
            _, end_points = feature_extractor.extract_features(
                samples[common.IMAGE],
                output_stride=model_options.output_stride,
                multi_grid=model_options.multi_grid,
                model_variant=model_options.model_variant,
                depth_multiplier=model_options.depth_multiplier,
                divisible_by=model_options.divisible_by,
                reuse=tf.AUTO_REUSE,
                is_training=False,
                preprocessed_images_dtype=model_options.
                preprocessed_images_dtype,
                global_pool=True,
                num_classes=dataset.num_of_classes - 1)
            # ResNet beta version has an additional suffix in FLAGS.model_variant, but
            # it shares the same variable names with original version. Add a special
            # handling here for beta version ResNet.
            logits = end_points['{}/logits'.format(
                FLAGS.model_variant).replace('_beta', '')]
            logits = tf.reshape(logits, [-1, dataset.num_of_classes - 1])
            cls_pred = tf.sigmoid(logits)

            # Multi-label classification evaluation
            cls_label = samples['cls_label']
            cls_pred = tf.cast(tf.greater_equal(cls_pred, 0.5), tf.int32)

            ## For classification
            metric_map['eval/cls_overall'] = tf.metrics.accuracy(
                labels=cls_label, predictions=cls_pred)
            metric_map['eval/cls_precision'] = tf.metrics.precision(
                labels=cls_label, predictions=cls_pred)
            metric_map['eval/cls_recall'] = tf.metrics.recall(
                labels=cls_label, predictions=cls_pred)

        ## For segmentation branch eval
        predictions = predictions[common.OUTPUT_TYPE]
        predictions = tf.reshape(predictions, shape=[-1])
        labels = tf.reshape(samples[common.LABEL], shape=[-1])
        weights = tf.to_float(tf.not_equal(labels, dataset.ignore_label))

        # Set ignore_label regions to label 0, because metrics.mean_iou requires
        # range of labels = [0, dataset.num_classes). Note the ignore_label regions
        # are not evaluated since the corresponding regions contain weights = 0.
        labels = tf.where(tf.equal(labels, dataset.ignore_label),
                          tf.zeros_like(labels), labels)

        predictions_tag = 'miou'
        # Define the evaluation metric.
        num_classes = dataset.num_of_classes

        ## For segmentation
        metric_map['eval/%s_overall' % predictions_tag] = tf.metrics.mean_iou(
            labels=labels,
            predictions=predictions,
            num_classes=num_classes,
            weights=weights)
        # IoU for each class.
        one_hot_predictions = tf.one_hot(predictions, num_classes)
        one_hot_predictions = tf.reshape(one_hot_predictions,
                                         [-1, num_classes])
        one_hot_labels = tf.one_hot(labels, num_classes)
        one_hot_labels = tf.reshape(one_hot_labels, [-1, num_classes])
        for c in range(num_classes):
            predictions_tag_c = '%s_class_%d' % (predictions_tag, c)
            tp, tp_op = tf.metrics.true_positives(
                labels=one_hot_labels[:, c],
                predictions=one_hot_predictions[:, c],
                weights=weights)
            fp, fp_op = tf.metrics.false_positives(
                labels=one_hot_labels[:, c],
                predictions=one_hot_predictions[:, c],
                weights=weights)
            fn, fn_op = tf.metrics.false_negatives(
                labels=one_hot_labels[:, c],
                predictions=one_hot_predictions[:, c],
                weights=weights)
            tp_fp_fn_op = tf.group(tp_op, fp_op, fn_op)
            iou = tf.where(tf.greater(tp + fn, 0.0), tp / (tp + fn + fp),
                           tf.constant(np.NaN))
            metric_map['eval/%s' % predictions_tag_c] = (iou, tp_fp_fn_op)

        (metrics_to_values,
         metrics_to_updates) = contrib_metrics.aggregate_metric_map(metric_map)

        summary_ops = []
        for metric_name, metric_value in six.iteritems(metrics_to_values):
            op = tf.summary.scalar(metric_name, metric_value)
            op = tf.Print(op, [metric_value], metric_name)
            summary_ops.append(op)

        summary_op = tf.summary.merge(summary_ops)
        summary_hook = contrib_training.SummaryAtEndHook(
            log_dir=FLAGS.eval_logdir, summary_op=summary_op)
        hooks = [summary_hook]

        num_eval_iters = None
        if FLAGS.max_number_of_evaluations > 0:
            num_eval_iters = FLAGS.max_number_of_evaluations

        if FLAGS.quantize_delay_step >= 0:
            contrib_quantize.create_eval_graph()

        contrib_tfprof.model_analyzer.print_model_analysis(
            tf.get_default_graph(),
            tfprof_options=contrib_tfprof.model_analyzer.
            TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
        contrib_tfprof.model_analyzer.print_model_analysis(
            tf.get_default_graph(),
            tfprof_options=contrib_tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
        contrib_training.evaluate_repeatedly(
            checkpoint_dir=FLAGS.checkpoint_dir,
            master=FLAGS.master,
            eval_ops=list(metrics_to_updates.values()),
            max_number_of_evaluations=num_eval_iters,
            hooks=hooks,
            eval_interval_secs=FLAGS.eval_interval_secs)
コード例 #3
0
def _build_deeplab(iterator_seg, iterator, outputs_to_num_classes,
                   ignore_label):
    """Builds a clone of Supervised DeepLab.

  Args:
    iterator_seg: An iterator of type tf.data.Iterator for images and labels.
    (seg)
    iterator: An iterator of type tf.data. Iterator for images and labels.
    outputs_to_num_classes: A map from output type to the number of classes. For
      example, for the task of semantic segmentation with 21 semantic classes,
      we would have outputs_to_num_classes['semantic'] = 21.
    ignore_label: Ignore label.
  """
    if FLAGS.weakly:
        samples = iterator.get_next()
        samples[common.IMAGE] = tf.identity(samples[common.IMAGE],
                                            name=common.IMAGE)
        samples[common.LABEL] = tf.identity(samples[common.LABEL],
                                            name=common.LABEL)

    samples_seg = iterator_seg.get_next()
    samples_seg[common.IMAGE] = tf.identity(samples_seg[common.IMAGE],
                                            name=common.IMAGE + '_seg')
    samples_seg[common.LABEL] = tf.identity(samples_seg[common.LABEL],
                                            name=common.LABEL + '_seg')

    model_options = common.ModelOptions(
        outputs_to_num_classes=outputs_to_num_classes,
        crop_size=[int(sz) for sz in FLAGS.train_crop_size],
        atrous_rates=FLAGS.atrous_rates,
        output_stride=FLAGS.output_stride)

    ### Cls data
    if FLAGS.weakly:
        _, end_points_cls = feature_extractor.extract_features(
            samples[common.IMAGE],
            output_stride=model_options.output_stride,
            multi_grid=model_options.multi_grid,
            model_variant=model_options.model_variant,
            depth_multiplier=model_options.depth_multiplier,
            divisible_by=model_options.divisible_by,
            weight_decay=FLAGS.weight_decay,
            reuse=tf.AUTO_REUSE,
            is_training=True,
            preprocessed_images_dtype=model_options.preprocessed_images_dtype,
            fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
            global_pool=True,
            num_classes=outputs_to_num_classes[common.OUTPUT_TYPE] - 1)

        # ResNet beta version has an additional suffix in FLAGS.model_variant, but
        # it shares the same variable names with original version. Add a special
        # handling here for beta version ResNet.
        logits_cls = end_points_cls['{}/logits'.format(
            FLAGS.model_variant).replace('_beta', '')]
        logits_cls = tf.reshape(
            logits_cls, [-1, outputs_to_num_classes[common.OUTPUT_TYPE] - 1])
        # Seems that people usually use multi-label soft margin loss
        loss_cls = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=samples['cls_label'], logits=logits_cls)
        loss_cls = tf.reduce_mean(loss_cls)
        loss_cls = tf.identity(loss_cls, name='loss_cls')
        tf.compat.v1.losses.add_loss(loss_cls)

    ### Seg data
    outputs_to_scales_to_logits = model.multi_scale_logits(
        samples_seg[common.IMAGE],
        model_options=model_options,
        image_pyramid=FLAGS.image_pyramid,
        weight_decay=FLAGS.weight_decay,
        is_training=True,
        fine_tune_batch_norm=FLAGS.fine_tune_batch_norm,
        nas_training_hyper_parameters={
            'drop_path_keep_prob': FLAGS.drop_path_keep_prob,
            'total_training_steps': FLAGS.training_number_of_steps,
        })

    # Add name to graph node so we can add to summary.
    output_type_dict = outputs_to_scales_to_logits[common.OUTPUT_TYPE]
    output_type_dict[model.MERGED_LOGITS_SCOPE] = tf.identity(
        output_type_dict[model.MERGED_LOGITS_SCOPE],
        name=common.OUTPUT_TYPE + '_seg')

    for output, num_classes in six.iteritems(outputs_to_num_classes):
        train_utils.add_softmax_cross_entropy_loss_for_each_scale(
            outputs_to_scales_to_logits[output],
            samples_seg[common.LABEL],
            num_classes,
            ignore_label,
            loss_weight=model_options.label_weights,
            upsample_logits=FLAGS.upsample_logits,
            hard_example_mining_step=FLAGS.hard_example_mining_step,
            top_k_percent_pixels=FLAGS.top_k_percent_pixels,
            scope=output)

    ## Sanity check. Monitor pixel accuracy
    logits_seg = outputs_to_scales_to_logits[common.OUTPUT_TYPE][
        model.MERGED_LOGITS_SCOPE]
    temp_label = tf.compat.v1.image.resize_nearest_neighbor(
        samples_seg[common.LABEL],
        preprocess_utils.resolve_shape(logits_seg, 4)[1:3])
    temp_label = tf.reshape(temp_label, [-1])

    dump = tf.concat(
        [tf.range(outputs_to_num_classes[common.OUTPUT_TYPE]), temp_label],
        axis=-1)
    _, _, count = tf.unique_with_counts(dump)
    num_pixel_list = count - 1
    # Exclude the ignore region
    num_pixel_list = num_pixel_list[:outputs_to_num_classes[common.
                                                            OUTPUT_TYPE]]
    num_pixel_list = tf.cast(num_pixel_list, tf.float32)
    inverse_ratio = train_utils._div_maybe_zero(1, num_pixel_list)
    inverse_ratio = inverse_ratio / tf.reduce_sum(inverse_ratio)

    # Create weight mask to balance each class
    weight_mask = tf.einsum(
        '...y,y->...',
        tf.one_hot(temp_label,
                   outputs_to_num_classes[common.OUTPUT_TYPE],
                   dtype=tf.float32), inverse_ratio)
    temp_valid = tf.not_equal(temp_label, ignore_label)
    temp_label_valid = tf.boolean_mask(temp_label, temp_valid)
    weight_mask_valid = tf.boolean_mask(weight_mask, temp_valid)

    pred_seg = tf.argmax(logits_seg, axis=-1)
    pred_seg = tf.reshape(pred_seg, [-1])
    acc_seg, acc_seg_op = tf.metrics.mean_per_class_accuracy(
        temp_label_valid,
        tf.boolean_mask(pred_seg, temp_valid),
        outputs_to_num_classes[common.OUTPUT_TYPE],
        weights=weight_mask_valid)
    with tf.control_dependencies([acc_seg_op]):
        acc_seg = tf.identity(acc_seg, name='acc_seg')
コード例 #4
0
ファイル: model.py プロジェクト: templeblock/wss
def extract_features(images,
                     model_options,
                     weight_decay=0.0001,
                     reuse=None,
                     is_training=False,
                     fine_tune_batch_norm=False,
                     nas_training_hyper_parameters=None):
    """Extracts features by the particular model_variant.

  Args:
    images: A tensor of size [batch, height, width, channels].
    model_options: A ModelOptions instance to configure models.
    weight_decay: The weight decay for model variables.
    reuse: Reuse the model variables or not.
    is_training: Is training or not.
    fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
    nas_training_hyper_parameters: A dictionary storing hyper-parameters for
      training nas models. Its keys are:
      - `drop_path_keep_prob`: Probability to keep each path in the cell when
        training.
      - `total_training_steps`: Total training steps to help drop path
        probability calculation.

  Returns:
    concat_logits: A tensor of size [batch, feature_height, feature_width,
      feature_channels], where feature_height/feature_width are determined by
      the images height/width and output_stride.
    end_points: A dictionary from components of the network to the corresponding
      activation.
  """
    features, end_points = feature_extractor.extract_features(
        images,
        output_stride=model_options.output_stride,
        multi_grid=model_options.multi_grid,
        model_variant=model_options.model_variant,
        depth_multiplier=model_options.depth_multiplier,
        divisible_by=model_options.divisible_by,
        weight_decay=weight_decay,
        reuse=reuse,
        is_training=is_training,
        preprocessed_images_dtype=model_options.preprocessed_images_dtype,
        fine_tune_batch_norm=fine_tune_batch_norm,
        nas_architecture_options=model_options.nas_architecture_options,
        nas_training_hyper_parameters=nas_training_hyper_parameters,
        use_bounded_activation=model_options.use_bounded_activation)

    if not model_options.aspp_with_batch_norm:
        return features, end_points
    else:
        if model_options.dense_prediction_cell_config is not None:
            tf.logging.info('Using dense prediction cell config.')
            dense_prediction_layer = dense_prediction_cell.DensePredictionCell(
                config=model_options.dense_prediction_cell_config,
                hparams={
                    'conv_rate_multiplier': 16 // model_options.output_stride,
                })
            concat_logits = dense_prediction_layer.build_cell(
                features,
                output_stride=model_options.output_stride,
                crop_size=model_options.crop_size,
                image_pooling_crop_size=model_options.image_pooling_crop_size,
                weight_decay=weight_decay,
                reuse=reuse,
                is_training=is_training,
                fine_tune_batch_norm=fine_tune_batch_norm)
            return concat_logits, end_points
        else:
            # The following codes employ the DeepLabv3 ASPP module. Note that we
            # could express the ASPP module as one particular dense prediction
            # cell architecture. We do not do so but leave the following codes
            # for backward compatibility.
            batch_norm_params = utils.get_batch_norm_params(
                decay=0.9997,
                epsilon=1e-5,
                scale=True,
                is_training=(is_training and fine_tune_batch_norm),
                sync_batch_norm_method=model_options.sync_batch_norm_method)
            batch_norm = utils.get_batch_norm_fn(
                model_options.sync_batch_norm_method)
            activation_fn = (tf.nn.relu6
                             if model_options.use_bounded_activation else
                             tf.nn.relu)
            with slim.arg_scope(
                [slim.conv2d, slim.separable_conv2d],
                    weights_regularizer=slim.l2_regularizer(weight_decay),
                    activation_fn=activation_fn,
                    normalizer_fn=batch_norm,
                    padding='SAME',
                    stride=1,
                    reuse=reuse):
                with slim.arg_scope([batch_norm], **batch_norm_params):
                    depth = model_options.aspp_convs_filters
                    branch_logits = []

                    if model_options.add_image_level_feature:
                        if model_options.crop_size is not None:
                            image_pooling_crop_size = model_options.image_pooling_crop_size
                            # If image_pooling_crop_size is not specified, use crop_size.
                            if image_pooling_crop_size is None:
                                image_pooling_crop_size = model_options.crop_size
                            pool_height = scale_dimension(
                                image_pooling_crop_size[0],
                                1. / model_options.output_stride)
                            pool_width = scale_dimension(
                                image_pooling_crop_size[1],
                                1. / model_options.output_stride)
                            image_feature = slim.avg_pool2d(
                                features, [pool_height, pool_width],
                                model_options.image_pooling_stride,
                                padding='VALID')
                            resize_height = scale_dimension(
                                model_options.crop_size[0],
                                1. / model_options.output_stride)
                            resize_width = scale_dimension(
                                model_options.crop_size[1],
                                1. / model_options.output_stride)
                        else:
                            # If crop_size is None, we simply do global pooling.
                            pool_height = tf.shape(features)[1]
                            pool_width = tf.shape(features)[2]
                            image_feature = tf.reduce_mean(features,
                                                           axis=[1, 2],
                                                           keepdims=True)
                            resize_height = pool_height
                            resize_width = pool_width
                        image_feature_activation_fn = tf.nn.relu
                        image_feature_normalizer_fn = batch_norm
                        if model_options.aspp_with_squeeze_and_excitation:
                            image_feature_activation_fn = tf.nn.sigmoid
                            if model_options.image_se_uses_qsigmoid:
                                image_feature_activation_fn = utils.q_sigmoid
                            image_feature_normalizer_fn = None
                        image_feature = slim.conv2d(
                            image_feature,
                            depth,
                            1,
                            activation_fn=image_feature_activation_fn,
                            normalizer_fn=image_feature_normalizer_fn,
                            scope=IMAGE_POOLING_SCOPE)
                        image_feature = _resize_bilinear(
                            image_feature, [resize_height, resize_width],
                            image_feature.dtype)
                        # Set shape for resize_height/resize_width if they are not Tensor.
                        if isinstance(resize_height, tf.Tensor):
                            resize_height = None
                        if isinstance(resize_width, tf.Tensor):
                            resize_width = None
                        image_feature.set_shape(
                            [None, resize_height, resize_width, depth])
                        if not model_options.aspp_with_squeeze_and_excitation:
                            branch_logits.append(image_feature)

                    # Employ a 1x1 convolution.
                    branch_logits.append(
                        slim.conv2d(features,
                                    depth,
                                    1,
                                    scope=ASPP_SCOPE + str(0)))

                    if model_options.atrous_rates:
                        # Employ 3x3 convolutions with different atrous rates.
                        for i, rate in enumerate(model_options.atrous_rates,
                                                 1):
                            scope = ASPP_SCOPE + str(i)
                            if model_options.aspp_with_separable_conv:
                                aspp_features = split_separable_conv2d(
                                    features,
                                    filters=depth,
                                    rate=rate,
                                    weight_decay=weight_decay,
                                    scope=scope)
                            else:
                                aspp_features = slim.conv2d(features,
                                                            depth,
                                                            3,
                                                            rate=rate,
                                                            scope=scope)
                            branch_logits.append(aspp_features)

                    # Merge branch logits.
                    concat_logits = tf.concat(branch_logits, 3)
                    if model_options.aspp_with_concat_projection:
                        concat_logits = slim.conv2d(
                            concat_logits,
                            depth,
                            1,
                            scope=CONCAT_PROJECTION_SCOPE)
                        concat_logits = slim.dropout(
                            concat_logits,
                            keep_prob=0.9,
                            is_training=is_training,
                            scope=CONCAT_PROJECTION_SCOPE + '_dropout')
                    if (model_options.add_image_level_feature and
                            model_options.aspp_with_squeeze_and_excitation):
                        concat_logits = tf.math.multiply(
                            concat_logits,
                            image_feature,
                            name='aspp_multiply_image_feature')

                    return concat_logits, end_points