def test_serialize_deserialize(self):
     head = segmentation_heads.MaskScoring(num_classes=2,
                                           fc_input_size=[4, 4],
                                           fc_dims=128)
     config = head.get_config()
     new_head = segmentation_heads.MaskScoring.from_config(config)
     self.assertAllEqual(head.get_config(), new_head.get_config())
Esempio n. 2
0
def build_segmentation_model(
    input_specs: tf.keras.layers.InputSpec,
    model_config: segmentation_cfg.SemanticSegmentationModel,
    l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
    backbone: Optional[tf.keras.regularizers.Regularizer] = None,
    decoder: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
    """Builds Segmentation model."""
    norm_activation_config = model_config.norm_activation
    if not backbone:
        backbone = backbones.factory.build_backbone(
            input_specs=input_specs,
            backbone_config=model_config.backbone,
            norm_activation_config=norm_activation_config,
            l2_regularizer=l2_regularizer)

    if not decoder:
        decoder = decoders.factory.build_decoder(
            input_specs=backbone.output_specs,
            model_config=model_config,
            l2_regularizer=l2_regularizer)

    head_config = model_config.head

    head = segmentation_heads.SegmentationHead(
        num_classes=model_config.num_classes,
        level=head_config.level,
        num_convs=head_config.num_convs,
        prediction_kernel_size=head_config.prediction_kernel_size,
        num_filters=head_config.num_filters,
        use_depthwise_convolution=head_config.use_depthwise_convolution,
        upsample_factor=head_config.upsample_factor,
        feature_fusion=head_config.feature_fusion,
        low_level=head_config.low_level,
        low_level_num_filters=head_config.low_level_num_filters,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    mask_scoring_head = None
    if model_config.mask_scoring_head:
        mask_scoring_head = segmentation_heads.MaskScoring(
            num_classes=model_config.num_classes,
            **model_config.mask_scoring_head.as_dict(),
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)

    model = segmentation_model.SegmentationModel(
        backbone, decoder, head, mask_scoring_head=mask_scoring_head)
    return model
    def test_forward(self, num_convs, num_fcs, num_filters, fc_input_size):
        features = np.random.rand(2, 64, 64, 16)

        head = segmentation_heads.MaskScoring(num_classes=2,
                                              num_convs=num_convs,
                                              num_filters=num_filters,
                                              fc_dims=128,
                                              fc_input_size=fc_input_size)

        scores = head(features)
        self.assertAllEqual(scores.numpy().shape, [2, 2])