示例#1
0
 def test_forward_class_agnostic(self):
     mask_head = deep_instance_heads.DeepMaskHead(num_classes=3,
                                                  class_agnostic=True)
     roi_features = np.random.rand(2, 10, 14, 14, 16)
     roi_classes = np.zeros((2, 10))
     masks = mask_head([roi_features, roi_classes])
     self.assertAllEqual(masks.numpy().shape, [2, 10, 28, 28])
示例#2
0
 def test_instance_head_hourglass(self):
     mask_head = deep_instance_heads.DeepMaskHead(
         num_classes=3,
         class_agnostic=True,
         convnet_variant='hourglass20',
         num_filters=32,
         upsample_factor=2)
     roi_features = np.random.rand(2, 10, 16, 16, 16)
     roi_classes = np.zeros((2, 10))
     masks = mask_head([roi_features, roi_classes])
     self.assertAllEqual(masks.numpy().shape, [2, 10, 32, 32])
def construct_model_and_anchors(image_size, use_gt_boxes_for_masks):
  num_classes = 3
  min_level = 3
  max_level = 4
  num_scales = 3
  aspect_ratios = [1.0]

  anchor_boxes = anchor.Anchor(
      min_level=min_level,
      max_level=max_level,
      num_scales=num_scales,
      aspect_ratios=aspect_ratios,
      anchor_size=3,
      image_size=image_size).multilevel_boxes
  num_anchors_per_location = len(aspect_ratios) * num_scales

  input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
  backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
  decoder = fpn.FPN(
      min_level=min_level,
      max_level=max_level,
      input_specs=backbone.output_specs)
  rpn_head = dense_prediction_heads.RPNHead(
      min_level=min_level,
      max_level=max_level,
      num_anchors_per_location=num_anchors_per_location)
  detection_head = instance_heads.DetectionHead(
      num_classes=num_classes)
  roi_generator_obj = roi_generator.MultilevelROIGenerator()
  roi_sampler_obj = roi_sampler.ROISampler()
  roi_aligner_obj = roi_aligner.MultilevelROIAligner()
  detection_generator_obj = detection_generator.DetectionGenerator()
  mask_head = deep_instance_heads.DeepMaskHead(
      num_classes=num_classes, upsample_factor=2)
  mask_sampler_obj = mask_sampler.MaskSampler(
      mask_target_size=28, num_sampled_masks=1)
  mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)

  model = maskrcnn_model.DeepMaskRCNNModel(
      backbone,
      decoder,
      rpn_head,
      detection_head,
      roi_generator_obj,
      roi_sampler_obj,
      roi_aligner_obj,
      detection_generator_obj,
      mask_head,
      mask_sampler_obj,
      mask_roi_aligner_obj,
      use_gt_boxes_for_masks=use_gt_boxes_for_masks)

  return model, anchor_boxes
示例#4
0
 def test_serialize_deserialize(self):
     mask_head = deep_instance_heads.DeepMaskHead(
         num_classes=3,
         upsample_factor=2,
         num_convs=1,
         num_filters=256,
         use_separable_conv=False,
         activation='relu',
         use_sync_bn=False,
         norm_momentum=0.99,
         norm_epsilon=0.001,
         kernel_regularizer=None,
         bias_regularizer=None,
     )
     config = mask_head.get_config()
     new_mask_head = deep_instance_heads.DeepMaskHead.from_config(config)
     self.assertAllEqual(mask_head.get_config(), new_mask_head.get_config())
示例#5
0
 def test_forward(self, upsample_factor, num_convs, use_sync_bn):
     mask_head = deep_instance_heads.DeepMaskHead(
         num_classes=3,
         upsample_factor=upsample_factor,
         num_convs=num_convs,
         num_filters=16,
         use_separable_conv=False,
         activation='relu',
         use_sync_bn=use_sync_bn,
         norm_momentum=0.99,
         norm_epsilon=0.001,
         kernel_regularizer=None,
         bias_regularizer=None,
     )
     roi_features = np.random.rand(2, 10, 14, 14, 16)
     roi_classes = np.zeros((2, 10))
     masks = mask_head([roi_features, roi_classes])
     self.assertAllEqual(
         masks.numpy().shape,
         [2, 10, 14 * upsample_factor, 14 * upsample_factor])
def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
                   model_config: deep_mask_head_rcnn_config.DeepMaskHeadRCNN,
                   l2_regularizer: tf.keras.regularizers.Regularizer = None):  # pytype: disable=annotation-type-mismatch  # typed-keras
  """Builds Mask R-CNN model."""
  norm_activation_config = model_config.norm_activation
  backbone = backbones.factory.build_backbone(
      input_specs=input_specs,
      backbone_config=model_config.backbone,
      norm_activation_config=norm_activation_config,
      l2_regularizer=l2_regularizer)

  decoder = decoder_factory.build_decoder(
      input_specs=backbone.output_specs,
      model_config=model_config,
      l2_regularizer=l2_regularizer)

  rpn_head_config = model_config.rpn_head
  roi_generator_config = model_config.roi_generator
  roi_sampler_config = model_config.roi_sampler
  roi_aligner_config = model_config.roi_aligner
  detection_head_config = model_config.detection_head
  generator_config = model_config.detection_generator
  num_anchors_per_location = (
      len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)

  rpn_head = dense_prediction_heads.RPNHead(
      min_level=model_config.min_level,
      max_level=model_config.max_level,
      num_anchors_per_location=num_anchors_per_location,
      num_convs=rpn_head_config.num_convs,
      num_filters=rpn_head_config.num_filters,
      use_separable_conv=rpn_head_config.use_separable_conv,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)

  detection_head = instance_heads.DetectionHead(
      num_classes=model_config.num_classes,
      num_convs=detection_head_config.num_convs,
      num_filters=detection_head_config.num_filters,
      use_separable_conv=detection_head_config.use_separable_conv,
      num_fcs=detection_head_config.num_fcs,
      fc_dims=detection_head_config.fc_dims,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)

  roi_generator_obj = roi_generator.MultilevelROIGenerator(
      pre_nms_top_k=roi_generator_config.pre_nms_top_k,
      pre_nms_score_threshold=roi_generator_config.pre_nms_score_threshold,
      pre_nms_min_size_threshold=(
          roi_generator_config.pre_nms_min_size_threshold),
      nms_iou_threshold=roi_generator_config.nms_iou_threshold,
      num_proposals=roi_generator_config.num_proposals,
      test_pre_nms_top_k=roi_generator_config.test_pre_nms_top_k,
      test_pre_nms_score_threshold=(
          roi_generator_config.test_pre_nms_score_threshold),
      test_pre_nms_min_size_threshold=(
          roi_generator_config.test_pre_nms_min_size_threshold),
      test_nms_iou_threshold=roi_generator_config.test_nms_iou_threshold,
      test_num_proposals=roi_generator_config.test_num_proposals,
      use_batched_nms=roi_generator_config.use_batched_nms)

  roi_sampler_obj = roi_sampler.ROISampler(
      mix_gt_boxes=roi_sampler_config.mix_gt_boxes,
      num_sampled_rois=roi_sampler_config.num_sampled_rois,
      foreground_fraction=roi_sampler_config.foreground_fraction,
      foreground_iou_threshold=roi_sampler_config.foreground_iou_threshold,
      background_iou_high_threshold=(
          roi_sampler_config.background_iou_high_threshold),
      background_iou_low_threshold=(
          roi_sampler_config.background_iou_low_threshold))

  roi_aligner_obj = roi_aligner.MultilevelROIAligner(
      crop_size=roi_aligner_config.crop_size,
      sample_offset=roi_aligner_config.sample_offset)

  detection_generator_obj = detection_generator.DetectionGenerator(
      apply_nms=True,
      pre_nms_top_k=generator_config.pre_nms_top_k,
      pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
      nms_iou_threshold=generator_config.nms_iou_threshold,
      max_num_detections=generator_config.max_num_detections,
      nms_version=generator_config.nms_version)

  if model_config.include_mask:
    mask_head = deep_instance_heads.DeepMaskHead(
        num_classes=model_config.num_classes,
        upsample_factor=model_config.mask_head.upsample_factor,
        num_convs=model_config.mask_head.num_convs,
        num_filters=model_config.mask_head.num_filters,
        use_separable_conv=model_config.mask_head.use_separable_conv,
        activation=model_config.norm_activation.activation,
        norm_momentum=model_config.norm_activation.norm_momentum,
        norm_epsilon=model_config.norm_activation.norm_epsilon,
        kernel_regularizer=l2_regularizer,
        class_agnostic=model_config.mask_head.class_agnostic,
        convnet_variant=model_config.mask_head.convnet_variant)

    mask_sampler_obj = mask_sampler.MaskSampler(
        mask_target_size=(
            model_config.mask_roi_aligner.crop_size *
            model_config.mask_head.upsample_factor),
        num_sampled_masks=model_config.mask_sampler.num_sampled_masks)

    mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(
        crop_size=model_config.mask_roi_aligner.crop_size,
        sample_offset=model_config.mask_roi_aligner.sample_offset)
  else:
    mask_head = None
    mask_sampler_obj = None
    mask_roi_aligner_obj = None

  model = deep_maskrcnn_model.DeepMaskRCNNModel(
      backbone=backbone,
      decoder=decoder,
      rpn_head=rpn_head,
      detection_head=detection_head,
      roi_generator=roi_generator_obj,
      roi_sampler=roi_sampler_obj,
      roi_aligner=roi_aligner_obj,
      detection_generator=detection_generator_obj,
      mask_head=mask_head,
      mask_sampler=mask_sampler_obj,
      mask_roi_aligner=mask_roi_aligner_obj,
      use_gt_boxes_for_masks=model_config.use_gt_boxes_for_masks)
  return model