def test_forward(self, level, feature_fusion, decoder_min_level,
                     decoder_max_level):
        backbone_features = {
            '3': np.random.rand(2, 128, 128, 16),
            '4': np.random.rand(2, 64, 64, 16),
            '5': np.random.rand(2, 32, 32, 16),
        }
        decoder_features = {
            '3': np.random.rand(2, 128, 128, 64),
            '4': np.random.rand(2, 64, 64, 64),
            '5': np.random.rand(2, 32, 32, 64),
            '6': np.random.rand(2, 16, 16, 64),
        }

        if feature_fusion == 'panoptic_fpn_fusion':
            backbone_features['2'] = np.random.rand(2, 256, 256, 16)
            decoder_features['2'] = np.random.rand(2, 256, 256, 64)

        head = segmentation_heads.SegmentationHead(
            num_classes=10,
            level=level,
            feature_fusion=feature_fusion,
            decoder_min_level=decoder_min_level,
            decoder_max_level=decoder_max_level,
            num_decoder_filters=64)

        logits = head((backbone_features, decoder_features))

        if level in decoder_features:
            self.assertAllEqual(logits.numpy().shape, [
                2, decoder_features[str(level)].shape[1],
                decoder_features[str(level)].shape[2], 10
            ])
Exemple #2
0
def build_segmentation_model(
        input_specs: tf.keras.layers.InputSpec,
        model_config: segmentation_cfg.SemanticSegmentationModel,
        l2_regularizer: tf.keras.regularizers.Regularizer = None):
    """Builds Segmentation model."""
    backbone = backbones.factory.build_backbone(input_specs=input_specs,
                                                model_config=model_config,
                                                l2_regularizer=l2_regularizer)

    decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs,
                                            model_config=model_config,
                                            l2_regularizer=l2_regularizer)

    head_config = model_config.head
    norm_activation_config = model_config.norm_activation

    head = segmentation_heads.SegmentationHead(
        num_classes=model_config.num_classes,
        level=head_config.level,
        num_convs=head_config.num_convs,
        num_filters=head_config.num_filters,
        upsample_factor=head_config.upsample_factor,
        feature_fusion=head_config.feature_fusion,
        low_level=head_config.low_level,
        low_level_num_filters=head_config.low_level_num_filters,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    model = segmentation_model.SegmentationModel(backbone, decoder, head)
    return model
 def test_forward(self, level):
     head = segmentation_heads.SegmentationHead(num_classes=10, level=level)
     features = {
         '3': np.random.rand(2, 128, 128, 16),
         '4': np.random.rand(2, 64, 64, 16),
     }
     logits = head(features)
     self.assertAllEqual(logits.numpy().shape, [
         2, features[str(level)].shape[1], features[str(level)].shape[2], 10
     ])
Exemple #4
0
def build_segmentation_model(
    input_specs: tf.keras.layers.InputSpec,
    model_config: segmentation_cfg.SemanticSegmentationModel,
    l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
    backbone: Optional[tf.keras.regularizers.Regularizer] = None,
    decoder: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
    """Builds Segmentation model."""
    norm_activation_config = model_config.norm_activation
    if not backbone:
        backbone = backbones.factory.build_backbone(
            input_specs=input_specs,
            backbone_config=model_config.backbone,
            norm_activation_config=norm_activation_config,
            l2_regularizer=l2_regularizer)

    if not decoder:
        decoder = decoders.factory.build_decoder(
            input_specs=backbone.output_specs,
            model_config=model_config,
            l2_regularizer=l2_regularizer)

    head_config = model_config.head

    head = segmentation_heads.SegmentationHead(
        num_classes=model_config.num_classes,
        level=head_config.level,
        num_convs=head_config.num_convs,
        prediction_kernel_size=head_config.prediction_kernel_size,
        num_filters=head_config.num_filters,
        use_depthwise_convolution=head_config.use_depthwise_convolution,
        upsample_factor=head_config.upsample_factor,
        feature_fusion=head_config.feature_fusion,
        low_level=head_config.low_level,
        low_level_num_filters=head_config.low_level_num_filters,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    mask_scoring_head = None
    if model_config.mask_scoring_head:
        mask_scoring_head = segmentation_heads.MaskScoring(
            num_classes=model_config.num_classes,
            **model_config.mask_scoring_head.as_dict(),
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)

    model = segmentation_model.SegmentationModel(
        backbone, decoder, head, mask_scoring_head=mask_scoring_head)
    return model
Exemple #5
0
    def test_forward(self, level, feature_fusion):
        head = segmentation_heads.SegmentationHead(
            num_classes=10, level=level, feature_fusion=feature_fusion)
        backbone_features = {
            '3': np.random.rand(2, 128, 128, 16),
            '4': np.random.rand(2, 64, 64, 16),
        }
        decoder_features = {
            '3': np.random.rand(2, 128, 128, 16),
            '4': np.random.rand(2, 64, 64, 16),
        }
        logits = head(backbone_features, decoder_features)

        if level in decoder_features:
            self.assertAllEqual(logits.numpy().shape, [
                2, decoder_features[str(level)].shape[1],
                decoder_features[str(level)].shape[2], 10
            ])
Exemple #6
0
    def test_serialize_deserialize(self):
        """Validate the network can be serialized and deserialized."""
        num_classes = 3
        backbone = backbones.ResNet(model_id=50)
        decoder = fpn.FPN(input_specs=backbone.output_specs,
                          min_level=3,
                          max_level=7)
        head = segmentation_heads.SegmentationHead(num_classes, level=3)
        model = segmentation_model.SegmentationModel(backbone=backbone,
                                                     decoder=decoder,
                                                     head=head)

        config = model.get_config()
        new_model = segmentation_model.SegmentationModel.from_config(config)

        # Validate that the config can be forced to JSON.
        _ = new_model.to_json()

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(model.get_config(), new_model.get_config())
Exemple #7
0
    def test_segmentation_network_creation(self, input_size, level):
        """Test for creation of a segmentation network."""
        num_classes = 10
        inputs = np.random.rand(2, input_size, input_size, 3)
        tf.keras.backend.set_image_data_format('channels_last')
        backbone = backbones.ResNet(model_id=50)

        decoder = fpn.FPN(input_specs=backbone.output_specs,
                          min_level=2,
                          max_level=7)
        head = segmentation_heads.SegmentationHead(num_classes, level=level)

        model = segmentation_model.SegmentationModel(backbone=backbone,
                                                     decoder=decoder,
                                                     head=head)

        logits = model(inputs)
        self.assertAllEqual([
            2, input_size // (2**level), input_size // (2**level), num_classes
        ],
                            logits.numpy().shape)
Exemple #8
0
def build_segmentation_model(
    input_specs: tf.keras.layers.InputSpec,
    model_config: segmentation_cfg.SemanticSegmentationModel,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
    """Builds Segmentation model."""
    norm_activation_config = model_config.norm_activation
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)

    decoder = decoders.factory.build_decoder(input_specs=backbone.output_specs,
                                             model_config=model_config,
                                             l2_regularizer=l2_regularizer)

    head_config = model_config.head

    head = segmentation_heads.SegmentationHead(
        num_classes=model_config.num_classes,
        level=head_config.level,
        num_convs=head_config.num_convs,
        prediction_kernel_size=head_config.prediction_kernel_size,
        num_filters=head_config.num_filters,
        use_depthwise_convolution=head_config.use_depthwise_convolution,
        upsample_factor=head_config.upsample_factor,
        feature_fusion=head_config.feature_fusion,
        low_level=head_config.low_level,
        low_level_num_filters=head_config.low_level_num_filters,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    model = segmentation_model.SegmentationModel(backbone, decoder, head)
    return model
Exemple #9
0
    def test_forward(self, strategy, training, shared_backbone,
                     shared_decoder):
        num_classes = 3
        min_level = 3
        max_level = 4
        num_scales = 3
        aspect_ratios = [1.0]
        anchor_size = 3
        segmentation_resnet_model_id = 101
        segmentation_output_stride = 16
        aspp_dilation_rates = [6, 12, 18]
        aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
        fpn_decoder_level = 3

        class_agnostic_bbox_pred = False
        cascade_class_ensemble = False

        image_size = (256, 256)
        images = np.random.rand(2, image_size[0], image_size[1], 3)
        image_shape = np.array([[224, 100], [100, 224]])
        shared_decoder = shared_decoder and shared_backbone
        with strategy.scope():

            anchor_boxes = anchor.Anchor(
                min_level=min_level,
                max_level=max_level,
                num_scales=num_scales,
                aspect_ratios=aspect_ratios,
                anchor_size=anchor_size,
                image_size=image_size).multilevel_boxes

            num_anchors_per_location = len(aspect_ratios) * num_scales

            input_specs = tf.keras.layers.InputSpec(
                shape=[None, None, None, 3])
            backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
            decoder = fpn.FPN(min_level=min_level,
                              max_level=max_level,
                              input_specs=backbone.output_specs)
            rpn_head = dense_prediction_heads.RPNHead(
                min_level=min_level,
                max_level=max_level,
                num_anchors_per_location=num_anchors_per_location)
            detection_head = instance_heads.DetectionHead(
                num_classes=num_classes,
                class_agnostic_bbox_pred=class_agnostic_bbox_pred)
            roi_generator_obj = roi_generator.MultilevelROIGenerator()

            roi_sampler_cascade = []
            roi_sampler_obj = roi_sampler.ROISampler()
            roi_sampler_cascade.append(roi_sampler_obj)
            roi_aligner_obj = roi_aligner.MultilevelROIAligner()
            detection_generator_obj = detection_generator.DetectionGenerator()
            mask_head = instance_heads.MaskHead(num_classes=num_classes,
                                                upsample_factor=2)
            mask_sampler_obj = mask_sampler.MaskSampler(mask_target_size=28,
                                                        num_sampled_masks=1)
            mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(
                crop_size=14)

            if shared_backbone:
                segmentation_backbone = None
            else:
                segmentation_backbone = resnet.ResNet(
                    model_id=segmentation_resnet_model_id)
            if not shared_decoder:
                level = aspp_decoder_level
                segmentation_decoder = aspp.ASPP(
                    level=level, dilation_rates=aspp_dilation_rates)
            else:
                level = fpn_decoder_level
                segmentation_decoder = None
            segmentation_head = segmentation_heads.SegmentationHead(
                num_classes=2,  # stuff and common class for things,
                level=level,
                num_convs=2)

            model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
                backbone,
                decoder,
                rpn_head,
                detection_head,
                roi_generator_obj,
                roi_sampler_obj,
                roi_aligner_obj,
                detection_generator_obj,
                mask_head,
                mask_sampler_obj,
                mask_roi_aligner_obj,
                segmentation_backbone=segmentation_backbone,
                segmentation_decoder=segmentation_decoder,
                segmentation_head=segmentation_head,
                class_agnostic_bbox_pred=class_agnostic_bbox_pred,
                cascade_class_ensemble=cascade_class_ensemble,
                min_level=min_level,
                max_level=max_level,
                num_scales=num_scales,
                aspect_ratios=aspect_ratios,
                anchor_size=anchor_size)

            gt_boxes = np.array(
                [[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
                 [[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
                dtype=np.float32)
            gt_classes = np.array([[2, 1, -1], [1, -1, -1]], dtype=np.int32)
            gt_masks = np.ones((2, 3, 100, 100))

            results = model(images,
                            image_shape,
                            anchor_boxes,
                            gt_boxes,
                            gt_classes,
                            gt_masks,
                            training=training)

        self.assertIn('rpn_boxes', results)
        self.assertIn('rpn_scores', results)
        if training:
            self.assertIn('class_targets', results)
            self.assertIn('box_targets', results)
            self.assertIn('class_outputs', results)
            self.assertIn('box_outputs', results)
            self.assertIn('mask_outputs', results)
        else:
            self.assertIn('detection_boxes', results)
            self.assertIn('detection_scores', results)
            self.assertIn('detection_classes', results)
            self.assertIn('num_detections', results)
            self.assertIn('detection_masks', results)
            self.assertIn('segmentation_outputs', results)
            self.assertAllEqual([
                2, image_size[0] // (2**level), image_size[1] // (2**level), 2
            ], results['segmentation_outputs'].numpy().shape)
Exemple #10
0
def build_panoptic_maskrcnn(
    input_specs: tf.keras.layers.InputSpec,
    model_config: panoptic_maskrcnn_cfg.PanopticMaskRCNN,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
    """Builds Panoptic Mask R-CNN model.

  This factory function builds the mask rcnn first, builds the non-shared
  semantic segmentation layers, and finally combines the two models to form
  the panoptic segmentation model.

  Args:
    input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
    model_config: Config instance for the panoptic maskrcnn model.
    l2_regularizer: Optional `tf.keras.regularizers.Regularizer`, if specified,
      the model is built with the provided regularization layer.
  Returns:
    tf.keras.Model for the panoptic segmentation model.
  """
    norm_activation_config = model_config.norm_activation
    segmentation_config = model_config.segmentation_model

    # Builds the maskrcnn model.
    maskrcnn_model = deep_mask_head_rcnn.build_maskrcnn(
        input_specs=input_specs,
        model_config=model_config,
        l2_regularizer=l2_regularizer)

    # Builds the semantic segmentation branch.
    if not model_config.shared_backbone:
        segmentation_backbone = backbones.factory.build_backbone(
            input_specs=input_specs,
            backbone_config=segmentation_config.backbone,
            norm_activation_config=norm_activation_config,
            l2_regularizer=l2_regularizer)
        segmentation_decoder_input_specs = segmentation_backbone.output_specs
    else:
        segmentation_backbone = None
        segmentation_decoder_input_specs = maskrcnn_model.backbone.output_specs

    if not model_config.shared_decoder:
        segmentation_decoder = decoder_factory.build_decoder(
            input_specs=segmentation_decoder_input_specs,
            model_config=segmentation_config,
            l2_regularizer=l2_regularizer)
        decoder_config = segmentation_decoder.get_config()
    else:
        segmentation_decoder = None
        decoder_config = maskrcnn_model.decoder.get_config()

    segmentation_head_config = segmentation_config.head
    detection_head_config = model_config.detection_head
    postprocessing_config = model_config.panoptic_segmentation_generator

    segmentation_head = segmentation_heads.SegmentationHead(
        num_classes=segmentation_config.num_classes,
        level=segmentation_head_config.level,
        num_convs=segmentation_head_config.num_convs,
        prediction_kernel_size=segmentation_head_config.prediction_kernel_size,
        num_filters=segmentation_head_config.num_filters,
        upsample_factor=segmentation_head_config.upsample_factor,
        feature_fusion=segmentation_head_config.feature_fusion,
        decoder_min_level=segmentation_head_config.decoder_min_level,
        decoder_max_level=segmentation_head_config.decoder_max_level,
        low_level=segmentation_head_config.low_level,
        low_level_num_filters=segmentation_head_config.low_level_num_filters,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        num_decoder_filters=decoder_config['num_filters'],
        kernel_regularizer=l2_regularizer)

    if model_config.generate_panoptic_masks:
        max_num_detections = model_config.detection_generator.max_num_detections
        mask_binarize_threshold = postprocessing_config.mask_binarize_threshold
        panoptic_segmentation_generator_obj = panoptic_segmentation_generator.PanopticSegmentationGenerator(
            output_size=postprocessing_config.output_size,
            max_num_detections=max_num_detections,
            stuff_classes_offset=model_config.stuff_classes_offset,
            mask_binarize_threshold=mask_binarize_threshold,
            score_threshold=postprocessing_config.score_threshold,
            things_overlap_threshold=postprocessing_config.
            things_overlap_threshold,
            things_class_label=postprocessing_config.things_class_label,
            stuff_area_threshold=postprocessing_config.stuff_area_threshold,
            void_class_label=postprocessing_config.void_class_label,
            void_instance_id=postprocessing_config.void_instance_id,
            rescale_predictions=postprocessing_config.rescale_predictions)
    else:
        panoptic_segmentation_generator_obj = None

    # Combines maskrcnn, and segmentation models to build panoptic segmentation
    # model.

    model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
        backbone=maskrcnn_model.backbone,
        decoder=maskrcnn_model.decoder,
        rpn_head=maskrcnn_model.rpn_head,
        detection_head=maskrcnn_model.detection_head,
        roi_generator=maskrcnn_model.roi_generator,
        roi_sampler=maskrcnn_model.roi_sampler,
        roi_aligner=maskrcnn_model.roi_aligner,
        detection_generator=maskrcnn_model.detection_generator,
        panoptic_segmentation_generator=panoptic_segmentation_generator_obj,
        mask_head=maskrcnn_model.mask_head,
        mask_sampler=maskrcnn_model.mask_sampler,
        mask_roi_aligner=maskrcnn_model.mask_roi_aligner,
        segmentation_backbone=segmentation_backbone,
        segmentation_decoder=segmentation_decoder,
        segmentation_head=segmentation_head,
        class_agnostic_bbox_pred=detection_head_config.
        class_agnostic_bbox_pred,
        cascade_class_ensemble=detection_head_config.cascade_class_ensemble,
        min_level=model_config.min_level,
        max_level=model_config.max_level,
        num_scales=model_config.anchor.num_scales,
        aspect_ratios=model_config.anchor.aspect_ratios,
        anchor_size=model_config.anchor.anchor_size)
    return model
    def test_build_model(self,
                         use_separable_conv,
                         build_anchor_boxes,
                         shared_backbone,
                         shared_decoder,
                         is_training=True):
        num_classes = 3
        min_level = 2
        max_level = 6
        num_scales = 3
        aspect_ratios = [1.0]
        anchor_size = 3
        resnet_model_id = 50
        segmentation_resnet_model_id = 50
        aspp_dilation_rates = [6, 12, 18]
        aspp_decoder_level = 2
        fpn_decoder_level = 2
        num_anchors_per_location = num_scales * len(aspect_ratios)
        image_size = 128
        images = tf.random.normal([2, image_size, image_size, 3])
        image_info = tf.convert_to_tensor([[[image_size, image_size],
                                            [image_size, image_size], [1, 1],
                                            [0, 0]],
                                           [[image_size, image_size],
                                            [image_size, image_size], [1, 1],
                                            [0, 0]]])
        shared_decoder = shared_decoder and shared_backbone
        if build_anchor_boxes or not is_training:
            anchor_boxes = anchor.Anchor(
                min_level=min_level,
                max_level=max_level,
                num_scales=num_scales,
                aspect_ratios=aspect_ratios,
                anchor_size=3,
                image_size=(image_size, image_size)).multilevel_boxes
            for l in anchor_boxes:
                anchor_boxes[l] = tf.tile(
                    tf.expand_dims(anchor_boxes[l], axis=0), [2, 1, 1, 1])
        else:
            anchor_boxes = None

        backbone = resnet.ResNet(model_id=resnet_model_id)
        decoder = fpn.FPN(input_specs=backbone.output_specs,
                          min_level=min_level,
                          max_level=max_level,
                          use_separable_conv=use_separable_conv)
        rpn_head = dense_prediction_heads.RPNHead(
            min_level=min_level,
            max_level=max_level,
            num_anchors_per_location=num_anchors_per_location,
            num_convs=1)
        detection_head = instance_heads.DetectionHead(num_classes=num_classes)
        roi_generator_obj = roi_generator.MultilevelROIGenerator()
        roi_sampler_obj = roi_sampler.ROISampler()
        roi_aligner_obj = roi_aligner.MultilevelROIAligner()
        detection_generator_obj = detection_generator.DetectionGenerator()
        panoptic_segmentation_generator_obj = panoptic_segmentation_generator.PanopticSegmentationGenerator(
            output_size=[image_size, image_size],
            max_num_detections=100,
            stuff_classes_offset=90)
        mask_head = instance_heads.MaskHead(num_classes=num_classes,
                                            upsample_factor=2)
        mask_sampler_obj = mask_sampler.MaskSampler(mask_target_size=28,
                                                    num_sampled_masks=1)
        mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)

        if shared_backbone:
            segmentation_backbone = None
        else:
            segmentation_backbone = resnet.ResNet(
                model_id=segmentation_resnet_model_id)
        if not shared_decoder:
            feature_fusion = 'deeplabv3plus'
            level = aspp_decoder_level
            segmentation_decoder = aspp.ASPP(
                level=level, dilation_rates=aspp_dilation_rates)
        else:
            feature_fusion = 'panoptic_fpn_fusion'
            level = fpn_decoder_level
            segmentation_decoder = None
        segmentation_head = segmentation_heads.SegmentationHead(
            num_classes=2,  # stuff and common class for things,
            level=level,
            feature_fusion=feature_fusion,
            decoder_min_level=min_level,
            decoder_max_level=max_level,
            num_convs=2)

        model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
            backbone,
            decoder,
            rpn_head,
            detection_head,
            roi_generator_obj,
            roi_sampler_obj,
            roi_aligner_obj,
            detection_generator_obj,
            panoptic_segmentation_generator_obj,
            mask_head,
            mask_sampler_obj,
            mask_roi_aligner_obj,
            segmentation_backbone=segmentation_backbone,
            segmentation_decoder=segmentation_decoder,
            segmentation_head=segmentation_head,
            min_level=min_level,
            max_level=max_level,
            num_scales=num_scales,
            aspect_ratios=aspect_ratios,
            anchor_size=anchor_size)

        gt_boxes = tf.convert_to_tensor(
            [[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
             [[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
            dtype=tf.float32)
        gt_classes = tf.convert_to_tensor([[2, 1, -1], [1, -1, -1]],
                                          dtype=tf.int32)
        gt_masks = tf.ones((2, 3, 100, 100))

        # Results will be checked in test_forward.
        _ = model(images,
                  image_info,
                  anchor_boxes,
                  gt_boxes,
                  gt_classes,
                  gt_masks,
                  training=is_training)
    def test_checkpoint(self, shared_backbone, shared_decoder):
        input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
        backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
        decoder = fpn.FPN(min_level=3,
                          max_level=7,
                          input_specs=backbone.output_specs)
        rpn_head = dense_prediction_heads.RPNHead(min_level=3,
                                                  max_level=7,
                                                  num_anchors_per_location=3)
        detection_head = instance_heads.DetectionHead(num_classes=2)
        roi_generator_obj = roi_generator.MultilevelROIGenerator()
        roi_sampler_obj = roi_sampler.ROISampler()
        roi_aligner_obj = roi_aligner.MultilevelROIAligner()
        detection_generator_obj = detection_generator.DetectionGenerator()
        panoptic_segmentation_generator_obj = panoptic_segmentation_generator.PanopticSegmentationGenerator(
            output_size=[None, None],
            max_num_detections=100,
            stuff_classes_offset=90)
        segmentation_resnet_model_id = 101
        aspp_dilation_rates = [6, 12, 18]
        min_level = 2
        max_level = 6
        aspp_decoder_level = 2
        fpn_decoder_level = 2
        shared_decoder = shared_decoder and shared_backbone
        mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
        mask_sampler_obj = mask_sampler.MaskSampler(mask_target_size=28,
                                                    num_sampled_masks=1)
        mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)

        if shared_backbone:
            segmentation_backbone = None
        else:
            segmentation_backbone = resnet.ResNet(
                model_id=segmentation_resnet_model_id)
        if not shared_decoder:
            feature_fusion = 'deeplabv3plus'
            level = aspp_decoder_level
            segmentation_decoder = aspp.ASPP(
                level=level, dilation_rates=aspp_dilation_rates)
        else:
            feature_fusion = 'panoptic_fpn_fusion'
            level = fpn_decoder_level
            segmentation_decoder = None
        segmentation_head = segmentation_heads.SegmentationHead(
            num_classes=2,  # stuff and common class for things,
            level=level,
            feature_fusion=feature_fusion,
            decoder_min_level=min_level,
            decoder_max_level=max_level,
            num_convs=2)

        model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
            backbone,
            decoder,
            rpn_head,
            detection_head,
            roi_generator_obj,
            roi_sampler_obj,
            roi_aligner_obj,
            detection_generator_obj,
            panoptic_segmentation_generator_obj,
            mask_head,
            mask_sampler_obj,
            mask_roi_aligner_obj,
            segmentation_backbone=segmentation_backbone,
            segmentation_decoder=segmentation_decoder,
            segmentation_head=segmentation_head,
            min_level=max_level,
            max_level=max_level,
            num_scales=3,
            aspect_ratios=[1.0],
            anchor_size=3)
        expect_checkpoint_items = dict(backbone=backbone,
                                       decoder=decoder,
                                       rpn_head=rpn_head,
                                       detection_head=[detection_head])
        expect_checkpoint_items['mask_head'] = mask_head
        if not shared_backbone:
            expect_checkpoint_items[
                'segmentation_backbone'] = segmentation_backbone
        if not shared_decoder:
            expect_checkpoint_items[
                'segmentation_decoder'] = segmentation_decoder
        expect_checkpoint_items['segmentation_head'] = segmentation_head
        self.assertAllEqual(expect_checkpoint_items, model.checkpoint_items)

        # Test save and load checkpoints.
        ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
        save_dir = self.create_tempdir().full_path
        ckpt.save(os.path.join(save_dir, 'ckpt'))

        partial_ckpt = tf.train.Checkpoint(backbone=backbone)
        partial_ckpt.read(tf.train.latest_checkpoint(
            save_dir)).expect_partial().assert_existing_objects_matched()

        partial_ckpt_mask = tf.train.Checkpoint(backbone=backbone,
                                                mask_head=mask_head)
        partial_ckpt_mask.restore(tf.train.latest_checkpoint(
            save_dir)).expect_partial().assert_existing_objects_matched()

        if not shared_backbone:
            partial_ckpt_segmentation = tf.train.Checkpoint(
                segmentation_backbone=segmentation_backbone,
                segmentation_decoder=segmentation_decoder,
                segmentation_head=segmentation_head)
        elif not shared_decoder:
            partial_ckpt_segmentation = tf.train.Checkpoint(
                segmentation_decoder=segmentation_decoder,
                segmentation_head=segmentation_head)
        else:
            partial_ckpt_segmentation = tf.train.Checkpoint(
                segmentation_head=segmentation_head)

        partial_ckpt_segmentation.restore(tf.train.latest_checkpoint(
            save_dir)).expect_partial().assert_existing_objects_matched()
    def test_serialize_deserialize(self, shared_backbone, shared_decoder):
        input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, 3])
        backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
        decoder = fpn.FPN(min_level=3,
                          max_level=7,
                          input_specs=backbone.output_specs)
        rpn_head = dense_prediction_heads.RPNHead(min_level=3,
                                                  max_level=7,
                                                  num_anchors_per_location=3)
        detection_head = instance_heads.DetectionHead(num_classes=2)
        roi_generator_obj = roi_generator.MultilevelROIGenerator()
        roi_sampler_obj = roi_sampler.ROISampler()
        roi_aligner_obj = roi_aligner.MultilevelROIAligner()
        detection_generator_obj = detection_generator.DetectionGenerator()
        panoptic_segmentation_generator_obj = panoptic_segmentation_generator.PanopticSegmentationGenerator(
            output_size=[None, None],
            max_num_detections=100,
            stuff_classes_offset=90)
        segmentation_resnet_model_id = 101
        aspp_dilation_rates = [6, 12, 18]
        min_level = 2
        max_level = 6
        aspp_decoder_level = 2
        fpn_decoder_level = 2
        shared_decoder = shared_decoder and shared_backbone
        mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
        mask_sampler_obj = mask_sampler.MaskSampler(mask_target_size=28,
                                                    num_sampled_masks=1)
        mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)

        if shared_backbone:
            segmentation_backbone = None
        else:
            segmentation_backbone = resnet.ResNet(
                model_id=segmentation_resnet_model_id)
        if not shared_decoder:
            feature_fusion = 'deeplabv3plus'
            level = aspp_decoder_level
            segmentation_decoder = aspp.ASPP(
                level=level, dilation_rates=aspp_dilation_rates)
        else:
            feature_fusion = 'panoptic_fpn_fusion'
            level = fpn_decoder_level
            segmentation_decoder = None
        segmentation_head = segmentation_heads.SegmentationHead(
            num_classes=2,  # stuff and common class for things,
            level=level,
            feature_fusion=feature_fusion,
            decoder_min_level=min_level,
            decoder_max_level=max_level,
            num_convs=2)

        model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
            backbone,
            decoder,
            rpn_head,
            detection_head,
            roi_generator_obj,
            roi_sampler_obj,
            roi_aligner_obj,
            detection_generator_obj,
            panoptic_segmentation_generator_obj,
            mask_head,
            mask_sampler_obj,
            mask_roi_aligner_obj,
            segmentation_backbone=segmentation_backbone,
            segmentation_decoder=segmentation_decoder,
            segmentation_head=segmentation_head,
            min_level=min_level,
            max_level=max_level,
            num_scales=3,
            aspect_ratios=[1.0],
            anchor_size=3)

        config = model.get_config()
        new_model = panoptic_maskrcnn_model.PanopticMaskRCNNModel.from_config(
            config)

        # Validate that the config can be forced to JSON.
        _ = new_model.to_json()

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(model.get_config(), new_model.get_config())
    def test_forward(self, strategy, training, shared_backbone, shared_decoder,
                     generate_panoptic_masks):
        num_classes = 3
        min_level = 2
        max_level = 6
        num_scales = 3
        aspect_ratios = [1.0]
        anchor_size = 3
        segmentation_resnet_model_id = 101
        aspp_dilation_rates = [6, 12, 18]
        aspp_decoder_level = 2
        fpn_decoder_level = 2

        class_agnostic_bbox_pred = False
        cascade_class_ensemble = False

        image_size = (256, 256)
        images = tf.random.normal([2, image_size[0], image_size[1], 3])
        image_info = tf.convert_to_tensor([[[224, 100], [224, 100], [1, 1],
                                            [0, 0]],
                                           [[224, 100], [224, 100], [1, 1],
                                            [0, 0]]])
        shared_decoder = shared_decoder and shared_backbone
        with strategy.scope():

            anchor_boxes = anchor.Anchor(
                min_level=min_level,
                max_level=max_level,
                num_scales=num_scales,
                aspect_ratios=aspect_ratios,
                anchor_size=anchor_size,
                image_size=image_size).multilevel_boxes

            num_anchors_per_location = len(aspect_ratios) * num_scales

            input_specs = tf.keras.layers.InputSpec(
                shape=[None, None, None, 3])
            backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
            decoder = fpn.FPN(min_level=min_level,
                              max_level=max_level,
                              input_specs=backbone.output_specs)
            rpn_head = dense_prediction_heads.RPNHead(
                min_level=min_level,
                max_level=max_level,
                num_anchors_per_location=num_anchors_per_location)
            detection_head = instance_heads.DetectionHead(
                num_classes=num_classes,
                class_agnostic_bbox_pred=class_agnostic_bbox_pred)
            roi_generator_obj = roi_generator.MultilevelROIGenerator()

            roi_sampler_cascade = []
            roi_sampler_obj = roi_sampler.ROISampler()
            roi_sampler_cascade.append(roi_sampler_obj)
            roi_aligner_obj = roi_aligner.MultilevelROIAligner()
            detection_generator_obj = detection_generator.DetectionGenerator()

            if generate_panoptic_masks:
                panoptic_segmentation_generator_obj = panoptic_segmentation_generator.PanopticSegmentationGenerator(
                    output_size=list(image_size),
                    max_num_detections=100,
                    stuff_classes_offset=90)
            else:
                panoptic_segmentation_generator_obj = None

            mask_head = instance_heads.MaskHead(num_classes=num_classes,
                                                upsample_factor=2)
            mask_sampler_obj = mask_sampler.MaskSampler(mask_target_size=28,
                                                        num_sampled_masks=1)
            mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(
                crop_size=14)

            if shared_backbone:
                segmentation_backbone = None
            else:
                segmentation_backbone = resnet.ResNet(
                    model_id=segmentation_resnet_model_id)
            if not shared_decoder:
                feature_fusion = 'deeplabv3plus'
                level = aspp_decoder_level
                segmentation_decoder = aspp.ASPP(
                    level=level, dilation_rates=aspp_dilation_rates)
            else:
                feature_fusion = 'panoptic_fpn_fusion'
                level = fpn_decoder_level
                segmentation_decoder = None
            segmentation_head = segmentation_heads.SegmentationHead(
                num_classes=2,  # stuff and common class for things,
                level=level,
                feature_fusion=feature_fusion,
                decoder_min_level=min_level,
                decoder_max_level=max_level,
                num_convs=2)

            model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
                backbone,
                decoder,
                rpn_head,
                detection_head,
                roi_generator_obj,
                roi_sampler_obj,
                roi_aligner_obj,
                detection_generator_obj,
                panoptic_segmentation_generator_obj,
                mask_head,
                mask_sampler_obj,
                mask_roi_aligner_obj,
                segmentation_backbone=segmentation_backbone,
                segmentation_decoder=segmentation_decoder,
                segmentation_head=segmentation_head,
                class_agnostic_bbox_pred=class_agnostic_bbox_pred,
                cascade_class_ensemble=cascade_class_ensemble,
                min_level=min_level,
                max_level=max_level,
                num_scales=num_scales,
                aspect_ratios=aspect_ratios,
                anchor_size=anchor_size)

            gt_boxes = tf.convert_to_tensor(
                [[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
                 [[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
                dtype=tf.float32)
            gt_classes = tf.convert_to_tensor([[2, 1, -1], [1, -1, -1]],
                                              dtype=tf.int32)
            gt_masks = tf.ones((2, 3, 100, 100))

            results = model(images,
                            image_info,
                            anchor_boxes,
                            gt_boxes,
                            gt_classes,
                            gt_masks,
                            training=training)

        self.assertIn('rpn_boxes', results)
        self.assertIn('rpn_scores', results)
        if training:
            self.assertIn('class_targets', results)
            self.assertIn('box_targets', results)
            self.assertIn('class_outputs', results)
            self.assertIn('box_outputs', results)
            self.assertIn('mask_outputs', results)
        else:
            self.assertIn('detection_boxes', results)
            self.assertIn('detection_scores', results)
            self.assertIn('detection_classes', results)
            self.assertIn('num_detections', results)
            self.assertIn('detection_masks', results)
            self.assertIn('segmentation_outputs', results)

            self.assertAllEqual([
                2, image_size[0] // (2**level), image_size[1] // (2**level), 2
            ], results['segmentation_outputs'].numpy().shape)

            if generate_panoptic_masks:
                self.assertIn('panoptic_outputs', results)
                self.assertIn('category_mask', results['panoptic_outputs'])
                self.assertIn('instance_mask', results['panoptic_outputs'])
                self.assertAllEqual(
                    [2, image_size[0], image_size[1]],
                    results['panoptic_outputs']['category_mask'].numpy().shape)
                self.assertAllEqual(
                    [2, image_size[0], image_size[1]],
                    results['panoptic_outputs']['instance_mask'].numpy().shape)
            else:
                self.assertNotIn('panoptic_outputs', results)
Exemple #15
0
    def test_build_model(self,
                         use_separable_conv,
                         build_anchor_boxes,
                         shared_backbone,
                         shared_decoder,
                         is_training=True):
        num_classes = 3
        min_level = 3
        max_level = 7
        num_scales = 3
        aspect_ratios = [1.0]
        anchor_size = 3
        resnet_model_id = 50
        segmentation_resnet_model_id = 50
        segmentation_output_stride = 16
        aspp_dilation_rates = [6, 12, 18]
        aspp_decoder_level = int(np.math.log2(segmentation_output_stride))
        fpn_decoder_level = 3
        num_anchors_per_location = num_scales * len(aspect_ratios)
        image_size = 128
        images = np.random.rand(2, image_size, image_size, 3)
        image_shape = np.array([[image_size, image_size],
                                [image_size, image_size]])
        shared_decoder = shared_decoder and shared_backbone
        if build_anchor_boxes:
            anchor_boxes = anchor.Anchor(
                min_level=min_level,
                max_level=max_level,
                num_scales=num_scales,
                aspect_ratios=aspect_ratios,
                anchor_size=3,
                image_size=(image_size, image_size)).multilevel_boxes
            for l in anchor_boxes:
                anchor_boxes[l] = tf.tile(
                    tf.expand_dims(anchor_boxes[l], axis=0), [2, 1, 1, 1])
        else:
            anchor_boxes = None

        backbone = resnet.ResNet(model_id=resnet_model_id)
        decoder = fpn.FPN(input_specs=backbone.output_specs,
                          min_level=min_level,
                          max_level=max_level,
                          use_separable_conv=use_separable_conv)
        rpn_head = dense_prediction_heads.RPNHead(
            min_level=min_level,
            max_level=max_level,
            num_anchors_per_location=num_anchors_per_location,
            num_convs=1)
        detection_head = instance_heads.DetectionHead(num_classes=num_classes)
        roi_generator_obj = roi_generator.MultilevelROIGenerator()
        roi_sampler_obj = roi_sampler.ROISampler()
        roi_aligner_obj = roi_aligner.MultilevelROIAligner()
        detection_generator_obj = detection_generator.DetectionGenerator()
        mask_head = instance_heads.MaskHead(num_classes=num_classes,
                                            upsample_factor=2)
        mask_sampler_obj = mask_sampler.MaskSampler(mask_target_size=28,
                                                    num_sampled_masks=1)
        mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)

        if shared_backbone:
            segmentation_backbone = None
        else:
            segmentation_backbone = resnet.ResNet(
                model_id=segmentation_resnet_model_id)
        if not shared_decoder:
            level = aspp_decoder_level
            segmentation_decoder = aspp.ASPP(
                level=level, dilation_rates=aspp_dilation_rates)
        else:
            level = fpn_decoder_level
            segmentation_decoder = None
        segmentation_head = segmentation_heads.SegmentationHead(
            num_classes=2,  # stuff and common class for things,
            level=level,
            num_convs=2)

        model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
            backbone,
            decoder,
            rpn_head,
            detection_head,
            roi_generator_obj,
            roi_sampler_obj,
            roi_aligner_obj,
            detection_generator_obj,
            mask_head,
            mask_sampler_obj,
            mask_roi_aligner_obj,
            segmentation_backbone=segmentation_backbone,
            segmentation_decoder=segmentation_decoder,
            segmentation_head=segmentation_head,
            min_level=min_level,
            max_level=max_level,
            num_scales=num_scales,
            aspect_ratios=aspect_ratios,
            anchor_size=anchor_size)

        gt_boxes = np.array(
            [[[10, 10, 15, 15], [2.5, 2.5, 7.5, 7.5], [-1, -1, -1, -1]],
             [[100, 100, 150, 150], [-1, -1, -1, -1], [-1, -1, -1, -1]]],
            dtype=np.float32)
        gt_classes = np.array([[2, 1, -1], [1, -1, -1]], dtype=np.int32)
        gt_masks = np.ones((2, 3, 100, 100))

        # Results will be checked in test_forward.
        _ = model(images,
                  image_shape,
                  anchor_boxes,
                  gt_boxes,
                  gt_classes,
                  gt_masks,
                  training=is_training)
 def test_serialize_deserialize(self):
     head = segmentation_heads.SegmentationHead(num_classes=10, level=3)
     config = head.get_config()
     new_head = segmentation_heads.SegmentationHead.from_config(config)
     self.assertAllEqual(head.get_config(), new_head.get_config())
Exemple #17
0
def build_submodel(
    norm_activation_config: hyperparams.Config,
    backbone: tf.keras.Model,
    input_specs: tf.keras.layers.InputSpec,
    submodel_config: multitask_config.Submodel,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:
    """Builds submodel for a subtask. Leverages on SegmentationModel's structure that 
  takes any arbitrary backbone, decoder and head."""
    decoder = decoder_factory.build_decoder(
        input_specs=backbone.output_specs,
        model_config=submodel_config,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)

    if submodel_config.decoder.freeze:
        decoder.trainable = False

    head_config = submodel_config.head

    if isinstance(head_config, multitask_config.ImageClassificationHead):
        head = classification_heads.ClassificationHead(
            num_classes=submodel_config.num_classes,
            level=head_config.level,
            num_convs=head_config.num_convs,
            num_filters=head_config.num_filters,
            add_head_batch_norm=head_config.add_head_batch_norm,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            dropout_rate=head_config.dropout_rate,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif isinstance(head_config, multitask_config.SegmentationHead):
        head = segmentation_heads.SegmentationHead(
            num_classes=submodel_config.num_classes,
            level=head_config.level,
            num_convs=head_config.num_convs,
            prediction_kernel_size=head_config.prediction_kernel_size,
            num_filters=head_config.num_filters,
            upsample_factor=head_config.upsample_factor,
            feature_fusion=head_config.feature_fusion,
            low_level=head_config.low_level,
            low_level_num_filters=head_config.low_level_num_filters,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif isinstance(head_config, multitask_config.YoloHead):
        head = instance_heads.YOLOv3Head(
            levels=len(decoder.output_specs),
            num_classes=submodel_config.num_classes,
            strides=head_config.strides,
            anchor_per_scale=head_config.anchor_per_scale,
            anchors=head_config.anchors,
            xy_scale=head_config.xy_scale,
            kernel_regularizer=l2_regularizer)
    else:
        raise NotImplementedError('%s head is not implemented yet.' %
                                  (type(head_config)))

    if submodel_config.head.freeze:
        head.trainable = False

    return SegmentationModel(backbone, decoder, head)
Exemple #18
0
def build_panoptic_maskrcnn(
    input_specs: tf.keras.layers.InputSpec,
    model_config: panoptic_maskrcnn_cfg.PanopticMaskRCNN,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:
    """Builds Panoptic Mask R-CNN model.

  This factory function builds the mask rcnn first, builds the non-shared
  semantic segmentation layers, and finally combines the two models to form
  the panoptic segmentation model.

  Args:
    input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
    model_config: Config instance for the panoptic maskrcnn model.
    l2_regularizer: Optional `tf.keras.regularizers.Regularizer`, if specified,
      the model is built with the provided regularization layer.
  Returns:
    tf.keras.Model for the panoptic segmentation model.
  """
    norm_activation_config = model_config.norm_activation
    segmentation_config = model_config.segmentation_model

    # Builds the maskrcnn model.
    maskrcnn_model = models_factory.build_maskrcnn(
        input_specs=input_specs,
        model_config=model_config,
        l2_regularizer=l2_regularizer)

    # Builds the semantic segmentation branch.
    if not model_config.shared_backbone:
        segmentation_backbone = backbones.factory.build_backbone(
            input_specs=input_specs,
            backbone_config=segmentation_config.backbone,
            norm_activation_config=norm_activation_config,
            l2_regularizer=l2_regularizer)
        segmentation_decoder_input_specs = segmentation_backbone.output_specs
    else:
        segmentation_backbone = None
        segmentation_decoder_input_specs = maskrcnn_model.backbone.output_specs

    if not model_config.shared_decoder:
        segmentation_decoder = decoder_factory.build_decoder(
            input_specs=segmentation_decoder_input_specs,
            model_config=segmentation_config,
            l2_regularizer=l2_regularizer)
    else:
        segmentation_decoder = None

    segmentation_head_config = segmentation_config.head
    detection_head_config = model_config.detection_head

    segmentation_head = segmentation_heads.SegmentationHead(
        num_classes=segmentation_config.num_classes,
        level=segmentation_head_config.level,
        num_convs=segmentation_head_config.num_convs,
        prediction_kernel_size=segmentation_head_config.prediction_kernel_size,
        num_filters=segmentation_head_config.num_filters,
        upsample_factor=segmentation_head_config.upsample_factor,
        feature_fusion=segmentation_head_config.feature_fusion,
        low_level=segmentation_head_config.low_level,
        low_level_num_filters=segmentation_head_config.low_level_num_filters,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    # Combines maskrcnn, and segmentation models to build panoptic segmentation
    # model.
    model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
        backbone=maskrcnn_model.backbone,
        decoder=maskrcnn_model.decoder,
        rpn_head=maskrcnn_model.rpn_head,
        detection_head=maskrcnn_model.detection_head,
        roi_generator=maskrcnn_model.roi_generator,
        roi_sampler=maskrcnn_model.roi_sampler,
        roi_aligner=maskrcnn_model.roi_aligner,
        detection_generator=maskrcnn_model.detection_generator,
        mask_head=maskrcnn_model.mask_head,
        mask_sampler=maskrcnn_model.mask_sampler,
        mask_roi_aligner=maskrcnn_model.mask_roi_aligner,
        segmentation_backbone=segmentation_backbone,
        segmentation_decoder=segmentation_decoder,
        segmentation_head=segmentation_head,
        class_agnostic_bbox_pred=detection_head_config.
        class_agnostic_bbox_pred,
        cascade_class_ensemble=detection_head_config.cascade_class_ensemble,
        min_level=model_config.min_level,
        max_level=model_config.max_level,
        num_scales=model_config.anchor.num_scales,
        aspect_ratios=model_config.anchor.aspect_ratios,
        anchor_size=model_config.anchor.anchor_size)
    return model