def build_segmentation_model( input_specs: tf.keras.layers.InputSpec, model_config: segmentation_cfg.SemanticSegmentationModel, l2_regularizer: tf.keras.regularizers.Regularizer = None): """Builds Segmentation model.""" backbone = backbones.factory.build_backbone(input_specs=input_specs, model_config=model_config, l2_regularizer=l2_regularizer) decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) head_config = model_config.head norm_activation_config = model_config.norm_activation head = segmentation_heads.SegmentationHead( num_classes=model_config.num_classes, level=head_config.level, num_convs=head_config.num_convs, num_filters=head_config.num_filters, upsample_factor=head_config.upsample_factor, feature_fusion=head_config.feature_fusion, low_level=head_config.low_level, low_level_num_filters=head_config.low_level_num_filters, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) model = segmentation_model.SegmentationModel(backbone, decoder, head) return model
def build_segmentation_model_3d( input_specs: tf.keras.layers.InputSpec, model_config: hyperparams.Config, l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras """Builds Segmentation model.""" norm_activation_config = model_config.norm_activation backbone = backbone_factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) decoder = decoder_factory.build_decoder( input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) head_config = model_config.head head = segmentation_heads_3d.SegmentationHead3D( num_classes=model_config.num_classes, level=head_config.level, num_convs=head_config.num_convs, num_filters=head_config.num_filters, upsample_factor=head_config.upsample_factor, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, use_batch_normalization=head_config.use_batch_normalization, kernel_regularizer=l2_regularizer, output_logits=head_config.output_logits) model = segmentation_model.SegmentationModel(backbone, decoder, head) return model
def build_yolo_model( input_specs: tf.keras.layers.InputSpec, model_config: yolo_cfg.YoloModel, l2_regularizer: tf.keras.regularizers.Regularizer = None ) -> tf.keras.Model: """Builds YOLO model.""" norm_activation_config = model_config.norm_activation backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) head_config = model_config.head head = instance_heads.YOLOv3Head( levels=len(decoder.output_specs), num_classes=model_config.num_classes, strides=head_config.strides, anchor_per_scale=head_config.anchor_per_scale, anchors=head_config.anchors, xy_scale=head_config.xy_scale, kernel_regularizer=l2_regularizer) model = segmentation_model.SegmentationModel(backbone, decoder, head) return model
def build_segmentation_model( input_specs: tf.keras.layers.InputSpec, model_config: segmentation_cfg.SemanticSegmentationModel, l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None, backbone: Optional[tf.keras.regularizers.Regularizer] = None, decoder: Optional[tf.keras.regularizers.Regularizer] = None ) -> tf.keras.Model: """Builds Segmentation model.""" norm_activation_config = model_config.norm_activation if not backbone: backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) if not decoder: decoder = decoders.factory.build_decoder( input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) head_config = model_config.head head = segmentation_heads.SegmentationHead( num_classes=model_config.num_classes, level=head_config.level, num_convs=head_config.num_convs, prediction_kernel_size=head_config.prediction_kernel_size, num_filters=head_config.num_filters, use_depthwise_convolution=head_config.use_depthwise_convolution, upsample_factor=head_config.upsample_factor, feature_fusion=head_config.feature_fusion, low_level=head_config.low_level, low_level_num_filters=head_config.low_level_num_filters, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) mask_scoring_head = None if model_config.mask_scoring_head: mask_scoring_head = segmentation_heads.MaskScoring( num_classes=model_config.num_classes, **model_config.mask_scoring_head.as_dict(), activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) model = segmentation_model.SegmentationModel( backbone, decoder, head, mask_scoring_head=mask_scoring_head) return model
def test_serialize_deserialize(self): """Validate the network can be serialized and deserialized.""" num_classes = 3 backbone = backbones.UNet3D(model_id=4) decoder = decoders.UNet3DDecoder( model_id=4, input_specs=backbone.output_specs) head = segmentation_heads_3d.SegmentationHead3D( num_classes, level=1, num_convs=0) model = segmentation_model.SegmentationModel( backbone=backbone, decoder=decoder, head=head) config = model.get_config() new_model = segmentation_model.SegmentationModel.from_config(config) # Validate that the config can be forced to JSON. _ = new_model.to_json() # If the serialization was successful, the new config should match the old. self.assertAllEqual(model.get_config(), new_model.get_config())
def test_segmentation_network_unet3d_creation(self, input_size, depth): """Test for creation of a segmentation network.""" num_classes = 2 inputs = np.random.rand(2, input_size[0], input_size[0], input_size[1], 3) tf.keras.backend.set_image_data_format('channels_last') backbone = backbones.UNet3D(model_id=depth) decoder = decoders.UNet3DDecoder( model_id=depth, input_specs=backbone.output_specs) head = segmentation_heads_3d.SegmentationHead3D( num_classes, level=1, num_convs=0) model = segmentation_model.SegmentationModel( backbone=backbone, decoder=decoder, head=head) outputs = model(inputs) self.assertAllEqual( [2, input_size[0], input_size[0], input_size[1], num_classes], outputs['logits'].numpy().shape)
def test_serialize_deserialize(self): """Validate the network can be serialized and deserialized.""" num_classes = 3 backbone = backbones.ResNet(model_id=50) decoder = fpn.FPN(input_specs=backbone.output_specs, min_level=3, max_level=7) head = segmentation_heads.SegmentationHead(num_classes, level=3) model = segmentation_model.SegmentationModel(backbone=backbone, decoder=decoder, head=head) config = model.get_config() new_model = segmentation_model.SegmentationModel.from_config(config) # Validate that the config can be forced to JSON. _ = new_model.to_json() # If the serialization was successful, the new config should match the old. self.assertAllEqual(model.get_config(), new_model.get_config())
def test_segmentation_network_creation(self, input_size, level): """Test for creation of a segmentation network.""" num_classes = 10 inputs = np.random.rand(2, input_size, input_size, 3) tf.keras.backend.set_image_data_format('channels_last') backbone = backbones.ResNet(model_id=50) decoder = fpn.FPN(input_specs=backbone.output_specs, min_level=2, max_level=7) head = segmentation_heads.SegmentationHead(num_classes, level=level) model = segmentation_model.SegmentationModel(backbone=backbone, decoder=decoder, head=head) logits = model(inputs) self.assertAllEqual([ 2, input_size // (2**level), input_size // (2**level), num_classes ], logits.numpy().shape)
def build_segmentation_model( input_specs: tf.keras.layers.InputSpec, model_config: segmentation_cfg.SemanticSegmentationModel, l2_regularizer: tf.keras.regularizers.Regularizer = None ) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras """Builds Segmentation model.""" norm_activation_config = model_config.norm_activation backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) decoder = decoders.factory.build_decoder(input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) head_config = model_config.head head = segmentation_heads.SegmentationHead( num_classes=model_config.num_classes, level=head_config.level, num_convs=head_config.num_convs, prediction_kernel_size=head_config.prediction_kernel_size, num_filters=head_config.num_filters, use_depthwise_convolution=head_config.use_depthwise_convolution, upsample_factor=head_config.upsample_factor, feature_fusion=head_config.feature_fusion, low_level=head_config.low_level, low_level_num_filters=head_config.low_level_num_filters, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, kernel_regularizer=l2_regularizer) model = segmentation_model.SegmentationModel(backbone, decoder, head) return model