コード例 #1
0
ファイル: factory.py プロジェクト: EricL132/Object-Rec
def build_segmentation_model(
        input_specs: tf.keras.layers.InputSpec,
        model_config: segmentation_cfg.SemanticSegmentationModel,
        l2_regularizer: tf.keras.regularizers.Regularizer = None):
    """Builds Segmentation model."""
    backbone = backbones.factory.build_backbone(input_specs=input_specs,
                                                model_config=model_config,
                                                l2_regularizer=l2_regularizer)

    decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs,
                                            model_config=model_config,
                                            l2_regularizer=l2_regularizer)

    head_config = model_config.head
    norm_activation_config = model_config.norm_activation

    head = segmentation_heads.SegmentationHead(
        num_classes=model_config.num_classes,
        level=head_config.level,
        num_convs=head_config.num_convs,
        num_filters=head_config.num_filters,
        upsample_factor=head_config.upsample_factor,
        feature_fusion=head_config.feature_fusion,
        low_level=head_config.low_level,
        low_level_num_filters=head_config.low_level_num_filters,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    model = segmentation_model.SegmentationModel(backbone, decoder, head)
    return model
コード例 #2
0
ファイル: factory.py プロジェクト: vishalbelsare/models
def build_yolo(input_specs, model_config, l2_regularization):
    """Builds yolo model."""
    backbone = model_config.backbone.get()
    anchor_dict, _ = model_config.anchor_boxes.get(backbone.min_level,
                                                   backbone.max_level)
    backbone = backbone_factory.build_backbone(input_specs,
                                               model_config.backbone,
                                               model_config.norm_activation,
                                               l2_regularization)
    decoder = decoder_factory.build_decoder(backbone.output_specs,
                                            model_config, l2_regularization)

    head = build_yolo_head(decoder.output_specs, model_config,
                           l2_regularization)
    detection_generator_obj = build_yolo_detection_generator(
        model_config, anchor_dict)

    model = yolo_model.Yolo(backbone=backbone,
                            decoder=decoder,
                            head=head,
                            detection_generator=detection_generator_obj)
    model.build(input_specs.shape)

    model.summary(print_fn=logging.info)

    losses = detection_generator_obj.get_losses()
    return model, losses
コード例 #3
0
def build_basnet_model(
    input_specs: tf.keras.layers.InputSpec,
    model_config: basnet_cfg.BASNetModel,
    l2_regularizer: tf.keras.regularizers.Regularizer = None):
  """Builds BASNet model."""
  backbone = backbones.factory.build_backbone(
      input_specs=input_specs,
      model_config=model_config,
      l2_regularizer=l2_regularizer)

  decoder = decoder_factory.build_decoder(
      input_specs=backbone.output_specs,
      model_config=model_config,
      l2_regularizer=l2_regularizer)

  refinement = refunet.RefUnet()

  #head_config = model_config.head
  norm_activation_config = model_config.norm_activation
  """
  head = segmentation_heads.SegmentationHead(
      num_classes=model_config.num_classes,
      level=head_config.level,
      num_convs=head_config.num_convs,
      num_filters=head_config.num_filters,
      upsample_factor=head_config.upsample_factor,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)
  """
  model = basnet_model.BASNetModel(backbone, decoder, refinement)
  return model
コード例 #4
0
ファイル: factory_test.py プロジェクト: milovanarsul/models
    def test_nasfpn_decoder_creation(self, num_filters, num_repeats,
                                     use_separable_conv):
        """Test creation of NASFPN decoder."""
        min_level = 3
        max_level = 7
        input_specs = {}
        for level in range(min_level, max_level):
            input_specs[str(level)] = tf.TensorShape(
                [1, 128 // (2**level), 128 // (2**level), 3])

        network = decoders.NASFPN(input_specs=input_specs,
                                  num_filters=num_filters,
                                  num_repeats=num_repeats,
                                  use_separable_conv=use_separable_conv,
                                  use_sync_bn=True)

        model_config = configs.retinanet.RetinaNet()
        model_config.min_level = min_level
        model_config.max_level = max_level
        model_config.num_classes = 10
        model_config.input_size = [None, None, 3]
        model_config.decoder = decoders_cfg.Decoder(
            type='nasfpn',
            nasfpn=decoders_cfg.NASFPN(num_filters=num_filters,
                                       num_repeats=num_repeats,
                                       use_separable_conv=use_separable_conv))

        factory_network = factory.build_decoder(input_specs=input_specs,
                                                model_config=model_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()

        self.assertEqual(network_config, factory_network_config)
コード例 #5
0
ファイル: factory_test.py プロジェクト: milovanarsul/models
    def test_aspp_decoder_creation(self, level, dilation_rates, num_filters):
        """Test creation of ASPP decoder."""
        input_specs = {'1': tf.TensorShape([1, 128, 128, 3])}

        network = decoders.ASPP(level=level,
                                dilation_rates=dilation_rates,
                                num_filters=num_filters,
                                use_sync_bn=True)

        model_config = configs.semantic_segmentation.SemanticSegmentationModel(
        )
        model_config.num_classes = 10
        model_config.input_size = [None, None, 3]
        model_config.decoder = decoders_cfg.Decoder(
            type='aspp',
            aspp=decoders_cfg.ASPP(level=level,
                                   dilation_rates=dilation_rates,
                                   num_filters=num_filters))

        factory_network = factory.build_decoder(input_specs=input_specs,
                                                model_config=model_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()
        # Due to calling `super().get_config()` in aspp layer, everything but the
        # the name of two layer instances are the same, so we force equal name so it
        # will not give false alarm.
        factory_network_config['name'] = network_config['name']

        self.assertEqual(network_config, factory_network_config)
コード例 #6
0
def build_yolo_model(
    input_specs: tf.keras.layers.InputSpec,
    model_config: yolo_cfg.YoloModel,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:
    """Builds YOLO model."""
    norm_activation_config = model_config.norm_activation
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)

    decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs,
                                            model_config=model_config,
                                            l2_regularizer=l2_regularizer)

    head_config = model_config.head

    head = instance_heads.YOLOv3Head(
        levels=len(decoder.output_specs),
        num_classes=model_config.num_classes,
        strides=head_config.strides,
        anchor_per_scale=head_config.anchor_per_scale,
        anchors=head_config.anchors,
        xy_scale=head_config.xy_scale,
        kernel_regularizer=l2_regularizer)

    model = segmentation_model.SegmentationModel(backbone, decoder, head)
    return model
コード例 #7
0
ファイル: factory_test.py プロジェクト: ykate1998/models
  def test_aspp_decoder_creation(self, level, dilation_rates, num_filters):
    """Test creation of ASPP decoder."""
    input_specs = {'1': tf.TensorShape([1, 128, 128, 3])}

    network = decoders.ASPP(
        level=level,
        dilation_rates=dilation_rates,
        num_filters=num_filters,
        use_sync_bn=True)

    model_config = configs.semantic_segmentation.SemanticSegmentationModel()
    model_config.num_classes = 10
    model_config.input_size = [None, None, 3]
    model_config.decoder = decoders_cfg.Decoder(
        type='aspp',
        aspp=decoders_cfg.ASPP(
            level=level, dilation_rates=dilation_rates,
            num_filters=num_filters))

    factory_network = factory.build_decoder(
        input_specs=input_specs, model_config=model_config)

    network_config = network.get_config()
    factory_network_config = factory_network.get_config()

    self.assertEqual(network_config, factory_network_config)
コード例 #8
0
ファイル: factory.py プロジェクト: HabanaAI/Model-References
def build_retinanet(
    input_specs: tf.keras.layers.InputSpec,
    model_config: retinanet_cfg.RetinaNet,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:
    """Builds RetinaNet model."""
    norm_activation_config = model_config.norm_activation
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)
    backbone(tf.keras.Input(input_specs.shape[1:]))

    decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs,
                                            model_config=model_config,
                                            l2_regularizer=l2_regularizer)

    head_config = model_config.head
    generator_config = model_config.detection_generator
    num_anchors_per_location = (len(model_config.anchor.aspect_ratios) *
                                model_config.anchor.num_scales)

    head = dense_prediction_heads.RetinaNetHead(
        min_level=model_config.min_level,
        max_level=model_config.max_level,
        num_classes=model_config.num_classes,
        num_anchors_per_location=num_anchors_per_location,
        num_convs=head_config.num_convs,
        num_filters=head_config.num_filters,
        attribute_heads=[
            cfg.as_dict() for cfg in (head_config.attribute_heads or [])
        ],
        use_separable_conv=head_config.use_separable_conv,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    detection_generator_obj = detection_generator.MultilevelDetectionGenerator(
        apply_nms=generator_config.apply_nms,
        pre_nms_top_k=generator_config.pre_nms_top_k,
        pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
        nms_iou_threshold=generator_config.nms_iou_threshold,
        max_num_detections=generator_config.max_num_detections,
        use_batched_nms=generator_config.use_batched_nms)

    model = retinanet_model.RetinaNetModel(
        backbone,
        decoder,
        head,
        detection_generator_obj,
        min_level=model_config.min_level,
        max_level=model_config.max_level,
        num_scales=model_config.anchor.num_scales,
        aspect_ratios=model_config.anchor.aspect_ratios,
        anchor_size=model_config.anchor.anchor_size)
    return model
コード例 #9
0
ファイル: factory_test.py プロジェクト: milovanarsul/models
    def test_identity_decoder_creation(self):
        """Test creation of identity decoder."""
        model_config = configs.retinanet.RetinaNet()
        model_config.num_classes = 2
        model_config.input_size = [None, None, 3]

        model_config.decoder = decoders_cfg.Decoder(
            type='identity', identity=decoders_cfg.Identity())

        factory_network = factory.build_decoder(input_specs=None,
                                                model_config=model_config)

        self.assertIsNone(factory_network)
コード例 #10
0
def build_maskrcnn(
    input_specs: tf.keras.layers.InputSpec,
    model_config: maskrcnn_cfg.MaskRCNN,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:
    """Builds Mask R-CNN model."""
    norm_activation_config = model_config.norm_activation
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)

    decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs,
                                            model_config=model_config,
                                            l2_regularizer=l2_regularizer)

    rpn_head_config = model_config.rpn_head
    roi_generator_config = model_config.roi_generator
    roi_sampler_config = model_config.roi_sampler
    roi_aligner_config = model_config.roi_aligner
    detection_head_config = model_config.detection_head
    generator_config = model_config.detection_generator
    num_anchors_per_location = (len(model_config.anchor.aspect_ratios) *
                                model_config.anchor.num_scales)

    rpn_head = dense_prediction_heads.RPNHead(
        min_level=model_config.min_level,
        max_level=model_config.max_level,
        num_anchors_per_location=num_anchors_per_location,
        num_convs=rpn_head_config.num_convs,
        num_filters=rpn_head_config.num_filters,
        use_separable_conv=rpn_head_config.use_separable_conv,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    detection_head = instance_heads.DetectionHead(
        num_classes=model_config.num_classes,
        num_convs=detection_head_config.num_convs,
        num_filters=detection_head_config.num_filters,
        use_separable_conv=detection_head_config.use_separable_conv,
        num_fcs=detection_head_config.num_fcs,
        fc_dims=detection_head_config.fc_dims,
        class_agnostic_bbox_pred=detection_head_config.
        class_agnostic_bbox_pred,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer,
        name='detection_head')
    if roi_sampler_config.cascade_iou_thresholds:
        detection_head_cascade = [detection_head]
        for cascade_num in range(len(
                roi_sampler_config.cascade_iou_thresholds)):
            detection_head = instance_heads.DetectionHead(
                num_classes=model_config.num_classes,
                num_convs=detection_head_config.num_convs,
                num_filters=detection_head_config.num_filters,
                use_separable_conv=detection_head_config.use_separable_conv,
                num_fcs=detection_head_config.num_fcs,
                fc_dims=detection_head_config.fc_dims,
                class_agnostic_bbox_pred=detection_head_config.
                class_agnostic_bbox_pred,
                activation=norm_activation_config.activation,
                use_sync_bn=norm_activation_config.use_sync_bn,
                norm_momentum=norm_activation_config.norm_momentum,
                norm_epsilon=norm_activation_config.norm_epsilon,
                kernel_regularizer=l2_regularizer,
                name='detection_head_{}'.format(cascade_num + 1))
            detection_head_cascade.append(detection_head)
        detection_head = detection_head_cascade

    roi_generator_obj = roi_generator.MultilevelROIGenerator(
        pre_nms_top_k=roi_generator_config.pre_nms_top_k,
        pre_nms_score_threshold=roi_generator_config.pre_nms_score_threshold,
        pre_nms_min_size_threshold=(
            roi_generator_config.pre_nms_min_size_threshold),
        nms_iou_threshold=roi_generator_config.nms_iou_threshold,
        num_proposals=roi_generator_config.num_proposals,
        test_pre_nms_top_k=roi_generator_config.test_pre_nms_top_k,
        test_pre_nms_score_threshold=(
            roi_generator_config.test_pre_nms_score_threshold),
        test_pre_nms_min_size_threshold=(
            roi_generator_config.test_pre_nms_min_size_threshold),
        test_nms_iou_threshold=roi_generator_config.test_nms_iou_threshold,
        test_num_proposals=roi_generator_config.test_num_proposals,
        use_batched_nms=roi_generator_config.use_batched_nms)

    roi_sampler_cascade = []
    roi_sampler_obj = roi_sampler.ROISampler(
        mix_gt_boxes=roi_sampler_config.mix_gt_boxes,
        num_sampled_rois=roi_sampler_config.num_sampled_rois,
        foreground_fraction=roi_sampler_config.foreground_fraction,
        foreground_iou_threshold=roi_sampler_config.foreground_iou_threshold,
        background_iou_high_threshold=(
            roi_sampler_config.background_iou_high_threshold),
        background_iou_low_threshold=(
            roi_sampler_config.background_iou_low_threshold))
    roi_sampler_cascade.append(roi_sampler_obj)
    # Initialize addtional roi simplers for cascade heads.
    if roi_sampler_config.cascade_iou_thresholds:
        for iou in roi_sampler_config.cascade_iou_thresholds:
            roi_sampler_obj = roi_sampler.ROISampler(
                mix_gt_boxes=False,
                num_sampled_rois=roi_sampler_config.num_sampled_rois,
                foreground_iou_threshold=iou,
                background_iou_high_threshold=iou,
                background_iou_low_threshold=0.0,
                skip_subsampling=True)
            roi_sampler_cascade.append(roi_sampler_obj)

    roi_aligner_obj = roi_aligner.MultilevelROIAligner(
        crop_size=roi_aligner_config.crop_size,
        sample_offset=roi_aligner_config.sample_offset)

    detection_generator_obj = detection_generator.DetectionGenerator(
        apply_nms=generator_config.apply_nms,
        pre_nms_top_k=generator_config.pre_nms_top_k,
        pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
        nms_iou_threshold=generator_config.nms_iou_threshold,
        max_num_detections=generator_config.max_num_detections,
        use_batched_nms=generator_config.use_batched_nms)

    if model_config.include_mask:
        mask_head = instance_heads.MaskHead(
            num_classes=model_config.num_classes,
            upsample_factor=model_config.mask_head.upsample_factor,
            num_convs=model_config.mask_head.num_convs,
            num_filters=model_config.mask_head.num_filters,
            use_separable_conv=model_config.mask_head.use_separable_conv,
            activation=model_config.norm_activation.activation,
            norm_momentum=model_config.norm_activation.norm_momentum,
            norm_epsilon=model_config.norm_activation.norm_epsilon,
            kernel_regularizer=l2_regularizer,
            class_agnostic=model_config.mask_head.class_agnostic)

        mask_sampler_obj = mask_sampler.MaskSampler(
            mask_target_size=(model_config.mask_roi_aligner.crop_size *
                              model_config.mask_head.upsample_factor),
            num_sampled_masks=model_config.mask_sampler.num_sampled_masks)

        mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(
            crop_size=model_config.mask_roi_aligner.crop_size,
            sample_offset=model_config.mask_roi_aligner.sample_offset)
    else:
        mask_head = None
        mask_sampler_obj = None
        mask_roi_aligner_obj = None

    model = maskrcnn_model.MaskRCNNModel(
        backbone=backbone,
        decoder=decoder,
        rpn_head=rpn_head,
        detection_head=detection_head,
        roi_generator=roi_generator_obj,
        roi_sampler=roi_sampler_cascade,
        roi_aligner=roi_aligner_obj,
        detection_generator=detection_generator_obj,
        mask_head=mask_head,
        mask_sampler=mask_sampler_obj,
        mask_roi_aligner=mask_roi_aligner_obj,
        class_agnostic_bbox_pred=detection_head_config.
        class_agnostic_bbox_pred,
        cascade_class_ensemble=detection_head_config.cascade_class_ensemble,
        min_level=model_config.min_level,
        max_level=model_config.max_level,
        num_scales=model_config.anchor.num_scales,
        aspect_ratios=model_config.anchor.aspect_ratios,
        anchor_size=model_config.anchor.anchor_size)
    return model
コード例 #11
0
def build_submodel(
    norm_activation_config: hyperparams.Config,
    backbone: tf.keras.Model,
    input_specs: tf.keras.layers.InputSpec,
    submodel_config: multitask_config.Submodel,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:
    """Builds submodel for a subtask. Leverages on SegmentationModel's structure that 
  takes any arbitrary backbone, decoder and head."""
    decoder = decoder_factory.build_decoder(
        input_specs=backbone.output_specs,
        model_config=submodel_config,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)

    if submodel_config.decoder.freeze:
        decoder.trainable = False

    head_config = submodel_config.head

    if isinstance(head_config, multitask_config.ImageClassificationHead):
        head = classification_heads.ClassificationHead(
            num_classes=submodel_config.num_classes,
            level=head_config.level,
            num_convs=head_config.num_convs,
            num_filters=head_config.num_filters,
            add_head_batch_norm=head_config.add_head_batch_norm,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            dropout_rate=head_config.dropout_rate,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif isinstance(head_config, multitask_config.SegmentationHead):
        head = segmentation_heads.SegmentationHead(
            num_classes=submodel_config.num_classes,
            level=head_config.level,
            num_convs=head_config.num_convs,
            prediction_kernel_size=head_config.prediction_kernel_size,
            num_filters=head_config.num_filters,
            upsample_factor=head_config.upsample_factor,
            feature_fusion=head_config.feature_fusion,
            low_level=head_config.low_level,
            low_level_num_filters=head_config.low_level_num_filters,
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)
    elif isinstance(head_config, multitask_config.YoloHead):
        head = instance_heads.YOLOv3Head(
            levels=len(decoder.output_specs),
            num_classes=submodel_config.num_classes,
            strides=head_config.strides,
            anchor_per_scale=head_config.anchor_per_scale,
            anchors=head_config.anchors,
            xy_scale=head_config.xy_scale,
            kernel_regularizer=l2_regularizer)
    else:
        raise NotImplementedError('%s head is not implemented yet.' %
                                  (type(head_config)))

    if submodel_config.head.freeze:
        head.trainable = False

    return SegmentationModel(backbone, decoder, head)
コード例 #12
0
def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
                   model_config: deep_mask_head_rcnn_config.DeepMaskHeadRCNN,
                   l2_regularizer: tf.keras.regularizers.Regularizer = None):
    """Builds Mask R-CNN model."""
    norm_activation_config = model_config.norm_activation
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)

    decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs,
                                            model_config=model_config,
                                            l2_regularizer=l2_regularizer)

    rpn_head_config = model_config.rpn_head
    roi_generator_config = model_config.roi_generator
    roi_sampler_config = model_config.roi_sampler
    roi_aligner_config = model_config.roi_aligner
    detection_head_config = model_config.detection_head
    generator_config = model_config.detection_generator
    num_anchors_per_location = (len(model_config.anchor.aspect_ratios) *
                                model_config.anchor.num_scales)

    rpn_head = dense_prediction_heads.RPNHead(
        min_level=model_config.min_level,
        max_level=model_config.max_level,
        num_anchors_per_location=num_anchors_per_location,
        num_convs=rpn_head_config.num_convs,
        num_filters=rpn_head_config.num_filters,
        use_separable_conv=rpn_head_config.use_separable_conv,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    detection_head = instance_heads.DetectionHead(
        num_classes=model_config.num_classes,
        num_convs=detection_head_config.num_convs,
        num_filters=detection_head_config.num_filters,
        use_separable_conv=detection_head_config.use_separable_conv,
        num_fcs=detection_head_config.num_fcs,
        fc_dims=detection_head_config.fc_dims,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    roi_generator_obj = roi_generator.MultilevelROIGenerator(
        pre_nms_top_k=roi_generator_config.pre_nms_top_k,
        pre_nms_score_threshold=roi_generator_config.pre_nms_score_threshold,
        pre_nms_min_size_threshold=(
            roi_generator_config.pre_nms_min_size_threshold),
        nms_iou_threshold=roi_generator_config.nms_iou_threshold,
        num_proposals=roi_generator_config.num_proposals,
        test_pre_nms_top_k=roi_generator_config.test_pre_nms_top_k,
        test_pre_nms_score_threshold=(
            roi_generator_config.test_pre_nms_score_threshold),
        test_pre_nms_min_size_threshold=(
            roi_generator_config.test_pre_nms_min_size_threshold),
        test_nms_iou_threshold=roi_generator_config.test_nms_iou_threshold,
        test_num_proposals=roi_generator_config.test_num_proposals,
        use_batched_nms=roi_generator_config.use_batched_nms)

    roi_sampler_obj = roi_sampler.ROISampler(
        mix_gt_boxes=roi_sampler_config.mix_gt_boxes,
        num_sampled_rois=roi_sampler_config.num_sampled_rois,
        foreground_fraction=roi_sampler_config.foreground_fraction,
        foreground_iou_threshold=roi_sampler_config.foreground_iou_threshold,
        background_iou_high_threshold=(
            roi_sampler_config.background_iou_high_threshold),
        background_iou_low_threshold=(
            roi_sampler_config.background_iou_low_threshold))

    roi_aligner_obj = roi_aligner.MultilevelROIAligner(
        crop_size=roi_aligner_config.crop_size,
        sample_offset=roi_aligner_config.sample_offset)

    detection_generator_obj = detection_generator.DetectionGenerator(
        apply_nms=True,
        pre_nms_top_k=generator_config.pre_nms_top_k,
        pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
        nms_iou_threshold=generator_config.nms_iou_threshold,
        max_num_detections=generator_config.max_num_detections,
        use_batched_nms=generator_config.use_batched_nms)

    if model_config.include_mask:
        mask_head = deep_instance_heads.DeepMaskHead(
            num_classes=model_config.num_classes,
            upsample_factor=model_config.mask_head.upsample_factor,
            num_convs=model_config.mask_head.num_convs,
            num_filters=model_config.mask_head.num_filters,
            use_separable_conv=model_config.mask_head.use_separable_conv,
            activation=model_config.norm_activation.activation,
            norm_momentum=model_config.norm_activation.norm_momentum,
            norm_epsilon=model_config.norm_activation.norm_epsilon,
            kernel_regularizer=l2_regularizer,
            class_agnostic=model_config.mask_head.class_agnostic,
            convnet_variant=model_config.mask_head.convnet_variant)

        mask_sampler_obj = mask_sampler.MaskSampler(
            mask_target_size=(model_config.mask_roi_aligner.crop_size *
                              model_config.mask_head.upsample_factor),
            num_sampled_masks=model_config.mask_sampler.num_sampled_masks)

        mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(
            crop_size=model_config.mask_roi_aligner.crop_size,
            sample_offset=model_config.mask_roi_aligner.sample_offset)
    else:
        mask_head = None
        mask_sampler_obj = None
        mask_roi_aligner_obj = None

    model = deep_maskrcnn_model.DeepMaskRCNNModel(
        backbone=backbone,
        decoder=decoder,
        rpn_head=rpn_head,
        detection_head=detection_head,
        roi_generator=roi_generator_obj,
        roi_sampler=roi_sampler_obj,
        roi_aligner=roi_aligner_obj,
        detection_generator=detection_generator_obj,
        mask_head=mask_head,
        mask_sampler=mask_sampler_obj,
        mask_roi_aligner=mask_roi_aligner_obj,
        use_gt_boxes_for_masks=model_config.use_gt_boxes_for_masks)
    return model
コード例 #13
0
def build_panoptic_maskrcnn(
    input_specs: tf.keras.layers.InputSpec,
    model_config: panoptic_maskrcnn_cfg.PanopticMaskRCNN,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
    """Builds Panoptic Mask R-CNN model.

  This factory function builds the mask rcnn first, builds the non-shared
  semantic segmentation layers, and finally combines the two models to form
  the panoptic segmentation model.

  Args:
    input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
    model_config: Config instance for the panoptic maskrcnn model.
    l2_regularizer: Optional `tf.keras.regularizers.Regularizer`, if specified,
      the model is built with the provided regularization layer.
  Returns:
    tf.keras.Model for the panoptic segmentation model.
  """
    norm_activation_config = model_config.norm_activation
    segmentation_config = model_config.segmentation_model

    # Builds the maskrcnn model.
    maskrcnn_model = deep_mask_head_rcnn.build_maskrcnn(
        input_specs=input_specs,
        model_config=model_config,
        l2_regularizer=l2_regularizer)

    # Builds the semantic segmentation branch.
    if not model_config.shared_backbone:
        segmentation_backbone = backbones.factory.build_backbone(
            input_specs=input_specs,
            backbone_config=segmentation_config.backbone,
            norm_activation_config=norm_activation_config,
            l2_regularizer=l2_regularizer)
        segmentation_decoder_input_specs = segmentation_backbone.output_specs
    else:
        segmentation_backbone = None
        segmentation_decoder_input_specs = maskrcnn_model.backbone.output_specs

    if not model_config.shared_decoder:
        segmentation_decoder = decoder_factory.build_decoder(
            input_specs=segmentation_decoder_input_specs,
            model_config=segmentation_config,
            l2_regularizer=l2_regularizer)
        decoder_config = segmentation_decoder.get_config()
    else:
        segmentation_decoder = None
        decoder_config = maskrcnn_model.decoder.get_config()

    segmentation_head_config = segmentation_config.head
    detection_head_config = model_config.detection_head
    postprocessing_config = model_config.panoptic_segmentation_generator

    segmentation_head = segmentation_heads.SegmentationHead(
        num_classes=segmentation_config.num_classes,
        level=segmentation_head_config.level,
        num_convs=segmentation_head_config.num_convs,
        prediction_kernel_size=segmentation_head_config.prediction_kernel_size,
        num_filters=segmentation_head_config.num_filters,
        upsample_factor=segmentation_head_config.upsample_factor,
        feature_fusion=segmentation_head_config.feature_fusion,
        decoder_min_level=segmentation_head_config.decoder_min_level,
        decoder_max_level=segmentation_head_config.decoder_max_level,
        low_level=segmentation_head_config.low_level,
        low_level_num_filters=segmentation_head_config.low_level_num_filters,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        num_decoder_filters=decoder_config['num_filters'],
        kernel_regularizer=l2_regularizer)

    if model_config.generate_panoptic_masks:
        max_num_detections = model_config.detection_generator.max_num_detections
        mask_binarize_threshold = postprocessing_config.mask_binarize_threshold
        panoptic_segmentation_generator_obj = panoptic_segmentation_generator.PanopticSegmentationGenerator(
            output_size=postprocessing_config.output_size,
            max_num_detections=max_num_detections,
            stuff_classes_offset=model_config.stuff_classes_offset,
            mask_binarize_threshold=mask_binarize_threshold,
            score_threshold=postprocessing_config.score_threshold,
            things_overlap_threshold=postprocessing_config.
            things_overlap_threshold,
            things_class_label=postprocessing_config.things_class_label,
            stuff_area_threshold=postprocessing_config.stuff_area_threshold,
            void_class_label=postprocessing_config.void_class_label,
            void_instance_id=postprocessing_config.void_instance_id,
            rescale_predictions=postprocessing_config.rescale_predictions)
    else:
        panoptic_segmentation_generator_obj = None

    # Combines maskrcnn, and segmentation models to build panoptic segmentation
    # model.

    model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
        backbone=maskrcnn_model.backbone,
        decoder=maskrcnn_model.decoder,
        rpn_head=maskrcnn_model.rpn_head,
        detection_head=maskrcnn_model.detection_head,
        roi_generator=maskrcnn_model.roi_generator,
        roi_sampler=maskrcnn_model.roi_sampler,
        roi_aligner=maskrcnn_model.roi_aligner,
        detection_generator=maskrcnn_model.detection_generator,
        panoptic_segmentation_generator=panoptic_segmentation_generator_obj,
        mask_head=maskrcnn_model.mask_head,
        mask_sampler=maskrcnn_model.mask_sampler,
        mask_roi_aligner=maskrcnn_model.mask_roi_aligner,
        segmentation_backbone=segmentation_backbone,
        segmentation_decoder=segmentation_decoder,
        segmentation_head=segmentation_head,
        class_agnostic_bbox_pred=detection_head_config.
        class_agnostic_bbox_pred,
        cascade_class_ensemble=detection_head_config.cascade_class_ensemble,
        min_level=model_config.min_level,
        max_level=model_config.max_level,
        num_scales=model_config.anchor.num_scales,
        aspect_ratios=model_config.anchor.aspect_ratios,
        anchor_size=model_config.anchor.anchor_size)
    return model
コード例 #14
0
ファイル: factory.py プロジェクト: ykate1998/models
def build_panoptic_maskrcnn(
    input_specs: tf.keras.layers.InputSpec,
    model_config: panoptic_maskrcnn_cfg.PanopticMaskRCNN,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:
    """Builds Panoptic Mask R-CNN model.

  This factory function builds the mask rcnn first, builds the non-shared
  semantic segmentation layers, and finally combines the two models to form
  the panoptic segmentation model.

  Args:
    input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
    model_config: Config instance for the panoptic maskrcnn model.
    l2_regularizer: Optional `tf.keras.regularizers.Regularizer`, if specified,
      the model is built with the provided regularization layer.
  Returns:
    tf.keras.Model for the panoptic segmentation model.
  """
    norm_activation_config = model_config.norm_activation
    segmentation_config = model_config.segmentation_model

    # Builds the maskrcnn model.
    maskrcnn_model = models_factory.build_maskrcnn(
        input_specs=input_specs,
        model_config=model_config,
        l2_regularizer=l2_regularizer)

    # Builds the semantic segmentation branch.
    if not model_config.shared_backbone:
        segmentation_backbone = backbones.factory.build_backbone(
            input_specs=input_specs,
            backbone_config=segmentation_config.backbone,
            norm_activation_config=norm_activation_config,
            l2_regularizer=l2_regularizer)
        segmentation_decoder_input_specs = segmentation_backbone.output_specs
    else:
        segmentation_backbone = None
        segmentation_decoder_input_specs = maskrcnn_model.backbone.output_specs

    if not model_config.shared_decoder:
        segmentation_decoder = decoder_factory.build_decoder(
            input_specs=segmentation_decoder_input_specs,
            model_config=segmentation_config,
            l2_regularizer=l2_regularizer)
    else:
        segmentation_decoder = None

    segmentation_head_config = segmentation_config.head
    detection_head_config = model_config.detection_head

    segmentation_head = segmentation_heads.SegmentationHead(
        num_classes=segmentation_config.num_classes,
        level=segmentation_head_config.level,
        num_convs=segmentation_head_config.num_convs,
        prediction_kernel_size=segmentation_head_config.prediction_kernel_size,
        num_filters=segmentation_head_config.num_filters,
        upsample_factor=segmentation_head_config.upsample_factor,
        feature_fusion=segmentation_head_config.feature_fusion,
        low_level=segmentation_head_config.low_level,
        low_level_num_filters=segmentation_head_config.low_level_num_filters,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    # Combines maskrcnn, and segmentation models to build panoptic segmentation
    # model.
    model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
        backbone=maskrcnn_model.backbone,
        decoder=maskrcnn_model.decoder,
        rpn_head=maskrcnn_model.rpn_head,
        detection_head=maskrcnn_model.detection_head,
        roi_generator=maskrcnn_model.roi_generator,
        roi_sampler=maskrcnn_model.roi_sampler,
        roi_aligner=maskrcnn_model.roi_aligner,
        detection_generator=maskrcnn_model.detection_generator,
        mask_head=maskrcnn_model.mask_head,
        mask_sampler=maskrcnn_model.mask_sampler,
        mask_roi_aligner=maskrcnn_model.mask_roi_aligner,
        segmentation_backbone=segmentation_backbone,
        segmentation_decoder=segmentation_decoder,
        segmentation_head=segmentation_head,
        class_agnostic_bbox_pred=detection_head_config.
        class_agnostic_bbox_pred,
        cascade_class_ensemble=detection_head_config.cascade_class_ensemble,
        min_level=model_config.min_level,
        max_level=model_config.max_level,
        num_scales=model_config.anchor.num_scales,
        aspect_ratios=model_config.anchor.aspect_ratios,
        anchor_size=model_config.anchor.anchor_size)
    return model