def build_segmentation_model( input_specs: tf.keras.layers.InputSpec, model_config: segmentation_cfg.SemanticSegmentationModel, l2_regularizer: tf.keras.regularizers.Regularizer = None): """Builds Segmentation model.""" backbone = backbones.factory.build_backbone(input_specs=input_specs, model_config=model_config, l2_regularizer=l2_regularizer) decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) head_config = model_config.head norm_activation_config = model_config.norm_activation head = segmentation_heads.SegmentationHead( num_classes=model_config.num_classes, level=head_config.level, num_convs=head_config.num_convs, num_filters=head_config.num_filters, upsample_factor=head_config.upsample_factor, feature_fusion=head_config.feature_fusion, low_level=head_config.low_level, low_level_num_filters=head_config.low_level_num_filters, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) model = segmentation_model.SegmentationModel(backbone, decoder, head) return model
def build_yolo(input_specs, model_config, l2_regularization): """Builds yolo model.""" backbone = model_config.backbone.get() anchor_dict, _ = model_config.anchor_boxes.get(backbone.min_level, backbone.max_level) backbone = backbone_factory.build_backbone(input_specs, model_config.backbone, model_config.norm_activation, l2_regularization) decoder = decoder_factory.build_decoder(backbone.output_specs, model_config, l2_regularization) head = build_yolo_head(decoder.output_specs, model_config, l2_regularization) detection_generator_obj = build_yolo_detection_generator( model_config, anchor_dict) model = yolo_model.Yolo(backbone=backbone, decoder=decoder, head=head, detection_generator=detection_generator_obj) model.build(input_specs.shape) model.summary(print_fn=logging.info) losses = detection_generator_obj.get_losses() return model, losses
def build_basnet_model( input_specs: tf.keras.layers.InputSpec, model_config: basnet_cfg.BASNetModel, l2_regularizer: tf.keras.regularizers.Regularizer = None): """Builds BASNet model.""" backbone = backbones.factory.build_backbone( input_specs=input_specs, model_config=model_config, l2_regularizer=l2_regularizer) decoder = decoder_factory.build_decoder( input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) refinement = refunet.RefUnet() #head_config = model_config.head norm_activation_config = model_config.norm_activation """ head = segmentation_heads.SegmentationHead( num_classes=model_config.num_classes, level=head_config.level, num_convs=head_config.num_convs, num_filters=head_config.num_filters, upsample_factor=head_config.upsample_factor, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) """ model = basnet_model.BASNetModel(backbone, decoder, refinement) return model
def test_nasfpn_decoder_creation(self, num_filters, num_repeats, use_separable_conv): """Test creation of NASFPN decoder.""" min_level = 3 max_level = 7 input_specs = {} for level in range(min_level, max_level): input_specs[str(level)] = tf.TensorShape( [1, 128 // (2**level), 128 // (2**level), 3]) network = decoders.NASFPN(input_specs=input_specs, num_filters=num_filters, num_repeats=num_repeats, use_separable_conv=use_separable_conv, use_sync_bn=True) model_config = configs.retinanet.RetinaNet() model_config.min_level = min_level model_config.max_level = max_level model_config.num_classes = 10 model_config.input_size = [None, None, 3] model_config.decoder = decoders_cfg.Decoder( type='nasfpn', nasfpn=decoders_cfg.NASFPN(num_filters=num_filters, num_repeats=num_repeats, use_separable_conv=use_separable_conv)) factory_network = factory.build_decoder(input_specs=input_specs, model_config=model_config) network_config = network.get_config() factory_network_config = factory_network.get_config() self.assertEqual(network_config, factory_network_config)
def test_aspp_decoder_creation(self, level, dilation_rates, num_filters): """Test creation of ASPP decoder.""" input_specs = {'1': tf.TensorShape([1, 128, 128, 3])} network = decoders.ASPP(level=level, dilation_rates=dilation_rates, num_filters=num_filters, use_sync_bn=True) model_config = configs.semantic_segmentation.SemanticSegmentationModel( ) model_config.num_classes = 10 model_config.input_size = [None, None, 3] model_config.decoder = decoders_cfg.Decoder( type='aspp', aspp=decoders_cfg.ASPP(level=level, dilation_rates=dilation_rates, num_filters=num_filters)) factory_network = factory.build_decoder(input_specs=input_specs, model_config=model_config) network_config = network.get_config() factory_network_config = factory_network.get_config() # Due to calling `super().get_config()` in aspp layer, everything but the # the name of two layer instances are the same, so we force equal name so it # will not give false alarm. factory_network_config['name'] = network_config['name'] self.assertEqual(network_config, factory_network_config)
def build_yolo_model( input_specs: tf.keras.layers.InputSpec, model_config: yolo_cfg.YoloModel, l2_regularizer: tf.keras.regularizers.Regularizer = None ) -> tf.keras.Model: """Builds YOLO model.""" norm_activation_config = model_config.norm_activation backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) head_config = model_config.head head = instance_heads.YOLOv3Head( levels=len(decoder.output_specs), num_classes=model_config.num_classes, strides=head_config.strides, anchor_per_scale=head_config.anchor_per_scale, anchors=head_config.anchors, xy_scale=head_config.xy_scale, kernel_regularizer=l2_regularizer) model = segmentation_model.SegmentationModel(backbone, decoder, head) return model
def test_aspp_decoder_creation(self, level, dilation_rates, num_filters): """Test creation of ASPP decoder.""" input_specs = {'1': tf.TensorShape([1, 128, 128, 3])} network = decoders.ASPP( level=level, dilation_rates=dilation_rates, num_filters=num_filters, use_sync_bn=True) model_config = configs.semantic_segmentation.SemanticSegmentationModel() model_config.num_classes = 10 model_config.input_size = [None, None, 3] model_config.decoder = decoders_cfg.Decoder( type='aspp', aspp=decoders_cfg.ASPP( level=level, dilation_rates=dilation_rates, num_filters=num_filters)) factory_network = factory.build_decoder( input_specs=input_specs, model_config=model_config) network_config = network.get_config() factory_network_config = factory_network.get_config() self.assertEqual(network_config, factory_network_config)
def build_retinanet( input_specs: tf.keras.layers.InputSpec, model_config: retinanet_cfg.RetinaNet, l2_regularizer: tf.keras.regularizers.Regularizer = None ) -> tf.keras.Model: """Builds RetinaNet model.""" norm_activation_config = model_config.norm_activation backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) backbone(tf.keras.Input(input_specs.shape[1:])) decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) head_config = model_config.head generator_config = model_config.detection_generator num_anchors_per_location = (len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales) head = dense_prediction_heads.RetinaNetHead( min_level=model_config.min_level, max_level=model_config.max_level, num_classes=model_config.num_classes, num_anchors_per_location=num_anchors_per_location, num_convs=head_config.num_convs, num_filters=head_config.num_filters, attribute_heads=[ cfg.as_dict() for cfg in (head_config.attribute_heads or []) ], use_separable_conv=head_config.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) detection_generator_obj = detection_generator.MultilevelDetectionGenerator( apply_nms=generator_config.apply_nms, pre_nms_top_k=generator_config.pre_nms_top_k, pre_nms_score_threshold=generator_config.pre_nms_score_threshold, nms_iou_threshold=generator_config.nms_iou_threshold, max_num_detections=generator_config.max_num_detections, use_batched_nms=generator_config.use_batched_nms) model = retinanet_model.RetinaNetModel( backbone, decoder, head, detection_generator_obj, min_level=model_config.min_level, max_level=model_config.max_level, num_scales=model_config.anchor.num_scales, aspect_ratios=model_config.anchor.aspect_ratios, anchor_size=model_config.anchor.anchor_size) return model
def test_identity_decoder_creation(self): """Test creation of identity decoder.""" model_config = configs.retinanet.RetinaNet() model_config.num_classes = 2 model_config.input_size = [None, None, 3] model_config.decoder = decoders_cfg.Decoder( type='identity', identity=decoders_cfg.Identity()) factory_network = factory.build_decoder(input_specs=None, model_config=model_config) self.assertIsNone(factory_network)
def build_maskrcnn( input_specs: tf.keras.layers.InputSpec, model_config: maskrcnn_cfg.MaskRCNN, l2_regularizer: tf.keras.regularizers.Regularizer = None ) -> tf.keras.Model: """Builds Mask R-CNN model.""" norm_activation_config = model_config.norm_activation backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) rpn_head_config = model_config.rpn_head roi_generator_config = model_config.roi_generator roi_sampler_config = model_config.roi_sampler roi_aligner_config = model_config.roi_aligner detection_head_config = model_config.detection_head generator_config = model_config.detection_generator num_anchors_per_location = (len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales) rpn_head = dense_prediction_heads.RPNHead( min_level=model_config.min_level, max_level=model_config.max_level, num_anchors_per_location=num_anchors_per_location, num_convs=rpn_head_config.num_convs, num_filters=rpn_head_config.num_filters, use_separable_conv=rpn_head_config.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) detection_head = instance_heads.DetectionHead( num_classes=model_config.num_classes, num_convs=detection_head_config.num_convs, num_filters=detection_head_config.num_filters, use_separable_conv=detection_head_config.use_separable_conv, num_fcs=detection_head_config.num_fcs, fc_dims=detection_head_config.fc_dims, class_agnostic_bbox_pred=detection_head_config. class_agnostic_bbox_pred, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer, name='detection_head') if roi_sampler_config.cascade_iou_thresholds: detection_head_cascade = [detection_head] for cascade_num in range(len( roi_sampler_config.cascade_iou_thresholds)): detection_head = instance_heads.DetectionHead( num_classes=model_config.num_classes, num_convs=detection_head_config.num_convs, num_filters=detection_head_config.num_filters, use_separable_conv=detection_head_config.use_separable_conv, num_fcs=detection_head_config.num_fcs, fc_dims=detection_head_config.fc_dims, class_agnostic_bbox_pred=detection_head_config. class_agnostic_bbox_pred, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer, name='detection_head_{}'.format(cascade_num + 1)) detection_head_cascade.append(detection_head) detection_head = detection_head_cascade roi_generator_obj = roi_generator.MultilevelROIGenerator( pre_nms_top_k=roi_generator_config.pre_nms_top_k, pre_nms_score_threshold=roi_generator_config.pre_nms_score_threshold, pre_nms_min_size_threshold=( roi_generator_config.pre_nms_min_size_threshold), nms_iou_threshold=roi_generator_config.nms_iou_threshold, num_proposals=roi_generator_config.num_proposals, test_pre_nms_top_k=roi_generator_config.test_pre_nms_top_k, test_pre_nms_score_threshold=( roi_generator_config.test_pre_nms_score_threshold), test_pre_nms_min_size_threshold=( roi_generator_config.test_pre_nms_min_size_threshold), test_nms_iou_threshold=roi_generator_config.test_nms_iou_threshold, test_num_proposals=roi_generator_config.test_num_proposals, use_batched_nms=roi_generator_config.use_batched_nms) roi_sampler_cascade = [] roi_sampler_obj = roi_sampler.ROISampler( mix_gt_boxes=roi_sampler_config.mix_gt_boxes, num_sampled_rois=roi_sampler_config.num_sampled_rois, foreground_fraction=roi_sampler_config.foreground_fraction, foreground_iou_threshold=roi_sampler_config.foreground_iou_threshold, background_iou_high_threshold=( roi_sampler_config.background_iou_high_threshold), background_iou_low_threshold=( roi_sampler_config.background_iou_low_threshold)) roi_sampler_cascade.append(roi_sampler_obj) # Initialize addtional roi simplers for cascade heads. if roi_sampler_config.cascade_iou_thresholds: for iou in roi_sampler_config.cascade_iou_thresholds: roi_sampler_obj = roi_sampler.ROISampler( mix_gt_boxes=False, num_sampled_rois=roi_sampler_config.num_sampled_rois, foreground_iou_threshold=iou, background_iou_high_threshold=iou, background_iou_low_threshold=0.0, skip_subsampling=True) roi_sampler_cascade.append(roi_sampler_obj) roi_aligner_obj = roi_aligner.MultilevelROIAligner( crop_size=roi_aligner_config.crop_size, sample_offset=roi_aligner_config.sample_offset) detection_generator_obj = detection_generator.DetectionGenerator( apply_nms=generator_config.apply_nms, pre_nms_top_k=generator_config.pre_nms_top_k, pre_nms_score_threshold=generator_config.pre_nms_score_threshold, nms_iou_threshold=generator_config.nms_iou_threshold, max_num_detections=generator_config.max_num_detections, use_batched_nms=generator_config.use_batched_nms) if model_config.include_mask: mask_head = instance_heads.MaskHead( num_classes=model_config.num_classes, upsample_factor=model_config.mask_head.upsample_factor, num_convs=model_config.mask_head.num_convs, num_filters=model_config.mask_head.num_filters, use_separable_conv=model_config.mask_head.use_separable_conv, activation=model_config.norm_activation.activation, norm_momentum=model_config.norm_activation.norm_momentum, norm_epsilon=model_config.norm_activation.norm_epsilon, kernel_regularizer=l2_regularizer, class_agnostic=model_config.mask_head.class_agnostic) mask_sampler_obj = mask_sampler.MaskSampler( mask_target_size=(model_config.mask_roi_aligner.crop_size * model_config.mask_head.upsample_factor), num_sampled_masks=model_config.mask_sampler.num_sampled_masks) mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner( crop_size=model_config.mask_roi_aligner.crop_size, sample_offset=model_config.mask_roi_aligner.sample_offset) else: mask_head = None mask_sampler_obj = None mask_roi_aligner_obj = None model = maskrcnn_model.MaskRCNNModel( backbone=backbone, decoder=decoder, rpn_head=rpn_head, detection_head=detection_head, roi_generator=roi_generator_obj, roi_sampler=roi_sampler_cascade, roi_aligner=roi_aligner_obj, detection_generator=detection_generator_obj, mask_head=mask_head, mask_sampler=mask_sampler_obj, mask_roi_aligner=mask_roi_aligner_obj, class_agnostic_bbox_pred=detection_head_config. class_agnostic_bbox_pred, cascade_class_ensemble=detection_head_config.cascade_class_ensemble, min_level=model_config.min_level, max_level=model_config.max_level, num_scales=model_config.anchor.num_scales, aspect_ratios=model_config.anchor.aspect_ratios, anchor_size=model_config.anchor.anchor_size) return model
def build_submodel( norm_activation_config: hyperparams.Config, backbone: tf.keras.Model, input_specs: tf.keras.layers.InputSpec, submodel_config: multitask_config.Submodel, l2_regularizer: tf.keras.regularizers.Regularizer = None ) -> tf.keras.Model: """Builds submodel for a subtask. Leverages on SegmentationModel's structure that takes any arbitrary backbone, decoder and head.""" decoder = decoder_factory.build_decoder( input_specs=backbone.output_specs, model_config=submodel_config, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) if submodel_config.decoder.freeze: decoder.trainable = False head_config = submodel_config.head if isinstance(head_config, multitask_config.ImageClassificationHead): head = classification_heads.ClassificationHead( num_classes=submodel_config.num_classes, level=head_config.level, num_convs=head_config.num_convs, num_filters=head_config.num_filters, add_head_batch_norm=head_config.add_head_batch_norm, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, dropout_rate=head_config.dropout_rate, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif isinstance(head_config, multitask_config.SegmentationHead): head = segmentation_heads.SegmentationHead( num_classes=submodel_config.num_classes, level=head_config.level, num_convs=head_config.num_convs, prediction_kernel_size=head_config.prediction_kernel_size, num_filters=head_config.num_filters, upsample_factor=head_config.upsample_factor, feature_fusion=head_config.feature_fusion, low_level=head_config.low_level, low_level_num_filters=head_config.low_level_num_filters, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) elif isinstance(head_config, multitask_config.YoloHead): head = instance_heads.YOLOv3Head( levels=len(decoder.output_specs), num_classes=submodel_config.num_classes, strides=head_config.strides, anchor_per_scale=head_config.anchor_per_scale, anchors=head_config.anchors, xy_scale=head_config.xy_scale, kernel_regularizer=l2_regularizer) else: raise NotImplementedError('%s head is not implemented yet.' % (type(head_config))) if submodel_config.head.freeze: head.trainable = False return SegmentationModel(backbone, decoder, head)
def build_maskrcnn(input_specs: tf.keras.layers.InputSpec, model_config: deep_mask_head_rcnn_config.DeepMaskHeadRCNN, l2_regularizer: tf.keras.regularizers.Regularizer = None): """Builds Mask R-CNN model.""" norm_activation_config = model_config.norm_activation backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) rpn_head_config = model_config.rpn_head roi_generator_config = model_config.roi_generator roi_sampler_config = model_config.roi_sampler roi_aligner_config = model_config.roi_aligner detection_head_config = model_config.detection_head generator_config = model_config.detection_generator num_anchors_per_location = (len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales) rpn_head = dense_prediction_heads.RPNHead( min_level=model_config.min_level, max_level=model_config.max_level, num_anchors_per_location=num_anchors_per_location, num_convs=rpn_head_config.num_convs, num_filters=rpn_head_config.num_filters, use_separable_conv=rpn_head_config.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) detection_head = instance_heads.DetectionHead( num_classes=model_config.num_classes, num_convs=detection_head_config.num_convs, num_filters=detection_head_config.num_filters, use_separable_conv=detection_head_config.use_separable_conv, num_fcs=detection_head_config.num_fcs, fc_dims=detection_head_config.fc_dims, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) roi_generator_obj = roi_generator.MultilevelROIGenerator( pre_nms_top_k=roi_generator_config.pre_nms_top_k, pre_nms_score_threshold=roi_generator_config.pre_nms_score_threshold, pre_nms_min_size_threshold=( roi_generator_config.pre_nms_min_size_threshold), nms_iou_threshold=roi_generator_config.nms_iou_threshold, num_proposals=roi_generator_config.num_proposals, test_pre_nms_top_k=roi_generator_config.test_pre_nms_top_k, test_pre_nms_score_threshold=( roi_generator_config.test_pre_nms_score_threshold), test_pre_nms_min_size_threshold=( roi_generator_config.test_pre_nms_min_size_threshold), test_nms_iou_threshold=roi_generator_config.test_nms_iou_threshold, test_num_proposals=roi_generator_config.test_num_proposals, use_batched_nms=roi_generator_config.use_batched_nms) roi_sampler_obj = roi_sampler.ROISampler( mix_gt_boxes=roi_sampler_config.mix_gt_boxes, num_sampled_rois=roi_sampler_config.num_sampled_rois, foreground_fraction=roi_sampler_config.foreground_fraction, foreground_iou_threshold=roi_sampler_config.foreground_iou_threshold, background_iou_high_threshold=( roi_sampler_config.background_iou_high_threshold), background_iou_low_threshold=( roi_sampler_config.background_iou_low_threshold)) roi_aligner_obj = roi_aligner.MultilevelROIAligner( crop_size=roi_aligner_config.crop_size, sample_offset=roi_aligner_config.sample_offset) detection_generator_obj = detection_generator.DetectionGenerator( apply_nms=True, pre_nms_top_k=generator_config.pre_nms_top_k, pre_nms_score_threshold=generator_config.pre_nms_score_threshold, nms_iou_threshold=generator_config.nms_iou_threshold, max_num_detections=generator_config.max_num_detections, use_batched_nms=generator_config.use_batched_nms) if model_config.include_mask: mask_head = deep_instance_heads.DeepMaskHead( num_classes=model_config.num_classes, upsample_factor=model_config.mask_head.upsample_factor, num_convs=model_config.mask_head.num_convs, num_filters=model_config.mask_head.num_filters, use_separable_conv=model_config.mask_head.use_separable_conv, activation=model_config.norm_activation.activation, norm_momentum=model_config.norm_activation.norm_momentum, norm_epsilon=model_config.norm_activation.norm_epsilon, kernel_regularizer=l2_regularizer, class_agnostic=model_config.mask_head.class_agnostic, convnet_variant=model_config.mask_head.convnet_variant) mask_sampler_obj = mask_sampler.MaskSampler( mask_target_size=(model_config.mask_roi_aligner.crop_size * model_config.mask_head.upsample_factor), num_sampled_masks=model_config.mask_sampler.num_sampled_masks) mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner( crop_size=model_config.mask_roi_aligner.crop_size, sample_offset=model_config.mask_roi_aligner.sample_offset) else: mask_head = None mask_sampler_obj = None mask_roi_aligner_obj = None model = deep_maskrcnn_model.DeepMaskRCNNModel( backbone=backbone, decoder=decoder, rpn_head=rpn_head, detection_head=detection_head, roi_generator=roi_generator_obj, roi_sampler=roi_sampler_obj, roi_aligner=roi_aligner_obj, detection_generator=detection_generator_obj, mask_head=mask_head, mask_sampler=mask_sampler_obj, mask_roi_aligner=mask_roi_aligner_obj, use_gt_boxes_for_masks=model_config.use_gt_boxes_for_masks) return model
def build_panoptic_maskrcnn( input_specs: tf.keras.layers.InputSpec, model_config: panoptic_maskrcnn_cfg.PanopticMaskRCNN, l2_regularizer: tf.keras.regularizers.Regularizer = None ) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras """Builds Panoptic Mask R-CNN model. This factory function builds the mask rcnn first, builds the non-shared semantic segmentation layers, and finally combines the two models to form the panoptic segmentation model. Args: input_specs: `tf.keras.layers.InputSpec` specs of the input tensor. model_config: Config instance for the panoptic maskrcnn model. l2_regularizer: Optional `tf.keras.regularizers.Regularizer`, if specified, the model is built with the provided regularization layer. Returns: tf.keras.Model for the panoptic segmentation model. """ norm_activation_config = model_config.norm_activation segmentation_config = model_config.segmentation_model # Builds the maskrcnn model. maskrcnn_model = deep_mask_head_rcnn.build_maskrcnn( input_specs=input_specs, model_config=model_config, l2_regularizer=l2_regularizer) # Builds the semantic segmentation branch. if not model_config.shared_backbone: segmentation_backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=segmentation_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) segmentation_decoder_input_specs = segmentation_backbone.output_specs else: segmentation_backbone = None segmentation_decoder_input_specs = maskrcnn_model.backbone.output_specs if not model_config.shared_decoder: segmentation_decoder = decoder_factory.build_decoder( input_specs=segmentation_decoder_input_specs, model_config=segmentation_config, l2_regularizer=l2_regularizer) decoder_config = segmentation_decoder.get_config() else: segmentation_decoder = None decoder_config = maskrcnn_model.decoder.get_config() segmentation_head_config = segmentation_config.head detection_head_config = model_config.detection_head postprocessing_config = model_config.panoptic_segmentation_generator segmentation_head = segmentation_heads.SegmentationHead( num_classes=segmentation_config.num_classes, level=segmentation_head_config.level, num_convs=segmentation_head_config.num_convs, prediction_kernel_size=segmentation_head_config.prediction_kernel_size, num_filters=segmentation_head_config.num_filters, upsample_factor=segmentation_head_config.upsample_factor, feature_fusion=segmentation_head_config.feature_fusion, decoder_min_level=segmentation_head_config.decoder_min_level, decoder_max_level=segmentation_head_config.decoder_max_level, low_level=segmentation_head_config.low_level, low_level_num_filters=segmentation_head_config.low_level_num_filters, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, num_decoder_filters=decoder_config['num_filters'], kernel_regularizer=l2_regularizer) if model_config.generate_panoptic_masks: max_num_detections = model_config.detection_generator.max_num_detections mask_binarize_threshold = postprocessing_config.mask_binarize_threshold panoptic_segmentation_generator_obj = panoptic_segmentation_generator.PanopticSegmentationGenerator( output_size=postprocessing_config.output_size, max_num_detections=max_num_detections, stuff_classes_offset=model_config.stuff_classes_offset, mask_binarize_threshold=mask_binarize_threshold, score_threshold=postprocessing_config.score_threshold, things_overlap_threshold=postprocessing_config. things_overlap_threshold, things_class_label=postprocessing_config.things_class_label, stuff_area_threshold=postprocessing_config.stuff_area_threshold, void_class_label=postprocessing_config.void_class_label, void_instance_id=postprocessing_config.void_instance_id, rescale_predictions=postprocessing_config.rescale_predictions) else: panoptic_segmentation_generator_obj = None # Combines maskrcnn, and segmentation models to build panoptic segmentation # model. model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( backbone=maskrcnn_model.backbone, decoder=maskrcnn_model.decoder, rpn_head=maskrcnn_model.rpn_head, detection_head=maskrcnn_model.detection_head, roi_generator=maskrcnn_model.roi_generator, roi_sampler=maskrcnn_model.roi_sampler, roi_aligner=maskrcnn_model.roi_aligner, detection_generator=maskrcnn_model.detection_generator, panoptic_segmentation_generator=panoptic_segmentation_generator_obj, mask_head=maskrcnn_model.mask_head, mask_sampler=maskrcnn_model.mask_sampler, mask_roi_aligner=maskrcnn_model.mask_roi_aligner, segmentation_backbone=segmentation_backbone, segmentation_decoder=segmentation_decoder, segmentation_head=segmentation_head, class_agnostic_bbox_pred=detection_head_config. class_agnostic_bbox_pred, cascade_class_ensemble=detection_head_config.cascade_class_ensemble, min_level=model_config.min_level, max_level=model_config.max_level, num_scales=model_config.anchor.num_scales, aspect_ratios=model_config.anchor.aspect_ratios, anchor_size=model_config.anchor.anchor_size) return model
def build_panoptic_maskrcnn( input_specs: tf.keras.layers.InputSpec, model_config: panoptic_maskrcnn_cfg.PanopticMaskRCNN, l2_regularizer: tf.keras.regularizers.Regularizer = None ) -> tf.keras.Model: """Builds Panoptic Mask R-CNN model. This factory function builds the mask rcnn first, builds the non-shared semantic segmentation layers, and finally combines the two models to form the panoptic segmentation model. Args: input_specs: `tf.keras.layers.InputSpec` specs of the input tensor. model_config: Config instance for the panoptic maskrcnn model. l2_regularizer: Optional `tf.keras.regularizers.Regularizer`, if specified, the model is built with the provided regularization layer. Returns: tf.keras.Model for the panoptic segmentation model. """ norm_activation_config = model_config.norm_activation segmentation_config = model_config.segmentation_model # Builds the maskrcnn model. maskrcnn_model = models_factory.build_maskrcnn( input_specs=input_specs, model_config=model_config, l2_regularizer=l2_regularizer) # Builds the semantic segmentation branch. if not model_config.shared_backbone: segmentation_backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=segmentation_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) segmentation_decoder_input_specs = segmentation_backbone.output_specs else: segmentation_backbone = None segmentation_decoder_input_specs = maskrcnn_model.backbone.output_specs if not model_config.shared_decoder: segmentation_decoder = decoder_factory.build_decoder( input_specs=segmentation_decoder_input_specs, model_config=segmentation_config, l2_regularizer=l2_regularizer) else: segmentation_decoder = None segmentation_head_config = segmentation_config.head detection_head_config = model_config.detection_head segmentation_head = segmentation_heads.SegmentationHead( num_classes=segmentation_config.num_classes, level=segmentation_head_config.level, num_convs=segmentation_head_config.num_convs, prediction_kernel_size=segmentation_head_config.prediction_kernel_size, num_filters=segmentation_head_config.num_filters, upsample_factor=segmentation_head_config.upsample_factor, feature_fusion=segmentation_head_config.feature_fusion, low_level=segmentation_head_config.low_level, low_level_num_filters=segmentation_head_config.low_level_num_filters, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) # Combines maskrcnn, and segmentation models to build panoptic segmentation # model. model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( backbone=maskrcnn_model.backbone, decoder=maskrcnn_model.decoder, rpn_head=maskrcnn_model.rpn_head, detection_head=maskrcnn_model.detection_head, roi_generator=maskrcnn_model.roi_generator, roi_sampler=maskrcnn_model.roi_sampler, roi_aligner=maskrcnn_model.roi_aligner, detection_generator=maskrcnn_model.detection_generator, mask_head=maskrcnn_model.mask_head, mask_sampler=maskrcnn_model.mask_sampler, mask_roi_aligner=maskrcnn_model.mask_roi_aligner, segmentation_backbone=segmentation_backbone, segmentation_decoder=segmentation_decoder, segmentation_head=segmentation_head, class_agnostic_bbox_pred=detection_head_config. class_agnostic_bbox_pred, cascade_class_ensemble=detection_head_config.cascade_class_ensemble, min_level=model_config.min_level, max_level=model_config.max_level, num_scales=model_config.anchor.num_scales, aspect_ratios=model_config.anchor.aspect_ratios, anchor_size=model_config.anchor.anchor_size) return model