示例#1
0
def build_segmentation_model(
        input_specs: tf.keras.layers.InputSpec,
        model_config: segmentation_cfg.SemanticSegmentationModel,
        l2_regularizer: tf.keras.regularizers.Regularizer = None):
    """Builds Segmentation model."""
    backbone = backbones.factory.build_backbone(input_specs=input_specs,
                                                model_config=model_config,
                                                l2_regularizer=l2_regularizer)

    decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs,
                                            model_config=model_config,
                                            l2_regularizer=l2_regularizer)

    head_config = model_config.head
    norm_activation_config = model_config.norm_activation

    head = segmentation_heads.SegmentationHead(
        num_classes=model_config.num_classes,
        level=head_config.level,
        num_convs=head_config.num_convs,
        num_filters=head_config.num_filters,
        upsample_factor=head_config.upsample_factor,
        feature_fusion=head_config.feature_fusion,
        low_level=head_config.low_level,
        low_level_num_filters=head_config.low_level_num_filters,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    model = segmentation_model.SegmentationModel(backbone, decoder, head)
    return model
示例#2
0
def build_segmentation_model_3d(
    input_specs: tf.keras.layers.InputSpec,
    model_config: hyperparams.Config,
    l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
  """Builds Segmentation model."""
  norm_activation_config = model_config.norm_activation
  backbone = backbone_factory.build_backbone(
      input_specs=input_specs,
      backbone_config=model_config.backbone,
      norm_activation_config=norm_activation_config,
      l2_regularizer=l2_regularizer)

  decoder = decoder_factory.build_decoder(
      input_specs=backbone.output_specs,
      model_config=model_config,
      l2_regularizer=l2_regularizer)

  head_config = model_config.head

  head = segmentation_heads_3d.SegmentationHead3D(
      num_classes=model_config.num_classes,
      level=head_config.level,
      num_convs=head_config.num_convs,
      num_filters=head_config.num_filters,
      upsample_factor=head_config.upsample_factor,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      use_batch_normalization=head_config.use_batch_normalization,
      kernel_regularizer=l2_regularizer,
      output_logits=head_config.output_logits)

  model = segmentation_model.SegmentationModel(backbone, decoder, head)
  return model
示例#3
0
def build_yolo_model(
    input_specs: tf.keras.layers.InputSpec,
    model_config: yolo_cfg.YoloModel,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:
    """Builds YOLO model."""
    norm_activation_config = model_config.norm_activation
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)

    decoder = decoder_factory.build_decoder(input_specs=backbone.output_specs,
                                            model_config=model_config,
                                            l2_regularizer=l2_regularizer)

    head_config = model_config.head

    head = instance_heads.YOLOv3Head(
        levels=len(decoder.output_specs),
        num_classes=model_config.num_classes,
        strides=head_config.strides,
        anchor_per_scale=head_config.anchor_per_scale,
        anchors=head_config.anchors,
        xy_scale=head_config.xy_scale,
        kernel_regularizer=l2_regularizer)

    model = segmentation_model.SegmentationModel(backbone, decoder, head)
    return model
示例#4
0
def build_segmentation_model(
    input_specs: tf.keras.layers.InputSpec,
    model_config: segmentation_cfg.SemanticSegmentationModel,
    l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
    backbone: Optional[tf.keras.regularizers.Regularizer] = None,
    decoder: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
    """Builds Segmentation model."""
    norm_activation_config = model_config.norm_activation
    if not backbone:
        backbone = backbones.factory.build_backbone(
            input_specs=input_specs,
            backbone_config=model_config.backbone,
            norm_activation_config=norm_activation_config,
            l2_regularizer=l2_regularizer)

    if not decoder:
        decoder = decoders.factory.build_decoder(
            input_specs=backbone.output_specs,
            model_config=model_config,
            l2_regularizer=l2_regularizer)

    head_config = model_config.head

    head = segmentation_heads.SegmentationHead(
        num_classes=model_config.num_classes,
        level=head_config.level,
        num_convs=head_config.num_convs,
        prediction_kernel_size=head_config.prediction_kernel_size,
        num_filters=head_config.num_filters,
        use_depthwise_convolution=head_config.use_depthwise_convolution,
        upsample_factor=head_config.upsample_factor,
        feature_fusion=head_config.feature_fusion,
        low_level=head_config.low_level,
        low_level_num_filters=head_config.low_level_num_filters,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    mask_scoring_head = None
    if model_config.mask_scoring_head:
        mask_scoring_head = segmentation_heads.MaskScoring(
            num_classes=model_config.num_classes,
            **model_config.mask_scoring_head.as_dict(),
            activation=norm_activation_config.activation,
            use_sync_bn=norm_activation_config.use_sync_bn,
            norm_momentum=norm_activation_config.norm_momentum,
            norm_epsilon=norm_activation_config.norm_epsilon,
            kernel_regularizer=l2_regularizer)

    model = segmentation_model.SegmentationModel(
        backbone, decoder, head, mask_scoring_head=mask_scoring_head)
    return model
示例#5
0
  def test_serialize_deserialize(self):
    """Validate the network can be serialized and deserialized."""
    num_classes = 3
    backbone = backbones.UNet3D(model_id=4)
    decoder = decoders.UNet3DDecoder(
        model_id=4, input_specs=backbone.output_specs)
    head = segmentation_heads_3d.SegmentationHead3D(
        num_classes, level=1, num_convs=0)
    model = segmentation_model.SegmentationModel(
        backbone=backbone, decoder=decoder, head=head)

    config = model.get_config()
    new_model = segmentation_model.SegmentationModel.from_config(config)

    # Validate that the config can be forced to JSON.
    _ = new_model.to_json()

    # If the serialization was successful, the new config should match the old.
    self.assertAllEqual(model.get_config(), new_model.get_config())
示例#6
0
  def test_segmentation_network_unet3d_creation(self, input_size, depth):
    """Test for creation of a segmentation network."""
    num_classes = 2
    inputs = np.random.rand(2, input_size[0], input_size[0], input_size[1], 3)
    tf.keras.backend.set_image_data_format('channels_last')
    backbone = backbones.UNet3D(model_id=depth)

    decoder = decoders.UNet3DDecoder(
        model_id=depth, input_specs=backbone.output_specs)
    head = segmentation_heads_3d.SegmentationHead3D(
        num_classes, level=1, num_convs=0)

    model = segmentation_model.SegmentationModel(
        backbone=backbone, decoder=decoder, head=head)

    outputs = model(inputs)
    self.assertAllEqual(
        [2, input_size[0], input_size[0], input_size[1], num_classes],
        outputs['logits'].numpy().shape)
示例#7
0
    def test_serialize_deserialize(self):
        """Validate the network can be serialized and deserialized."""
        num_classes = 3
        backbone = backbones.ResNet(model_id=50)
        decoder = fpn.FPN(input_specs=backbone.output_specs,
                          min_level=3,
                          max_level=7)
        head = segmentation_heads.SegmentationHead(num_classes, level=3)
        model = segmentation_model.SegmentationModel(backbone=backbone,
                                                     decoder=decoder,
                                                     head=head)

        config = model.get_config()
        new_model = segmentation_model.SegmentationModel.from_config(config)

        # Validate that the config can be forced to JSON.
        _ = new_model.to_json()

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(model.get_config(), new_model.get_config())
示例#8
0
    def test_segmentation_network_creation(self, input_size, level):
        """Test for creation of a segmentation network."""
        num_classes = 10
        inputs = np.random.rand(2, input_size, input_size, 3)
        tf.keras.backend.set_image_data_format('channels_last')
        backbone = backbones.ResNet(model_id=50)

        decoder = fpn.FPN(input_specs=backbone.output_specs,
                          min_level=2,
                          max_level=7)
        head = segmentation_heads.SegmentationHead(num_classes, level=level)

        model = segmentation_model.SegmentationModel(backbone=backbone,
                                                     decoder=decoder,
                                                     head=head)

        logits = model(inputs)
        self.assertAllEqual([
            2, input_size // (2**level), input_size // (2**level), num_classes
        ],
                            logits.numpy().shape)
示例#9
0
def build_segmentation_model(
    input_specs: tf.keras.layers.InputSpec,
    model_config: segmentation_cfg.SemanticSegmentationModel,
    l2_regularizer: tf.keras.regularizers.Regularizer = None
) -> tf.keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
    """Builds Segmentation model."""
    norm_activation_config = model_config.norm_activation
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)

    decoder = decoders.factory.build_decoder(input_specs=backbone.output_specs,
                                             model_config=model_config,
                                             l2_regularizer=l2_regularizer)

    head_config = model_config.head

    head = segmentation_heads.SegmentationHead(
        num_classes=model_config.num_classes,
        level=head_config.level,
        num_convs=head_config.num_convs,
        prediction_kernel_size=head_config.prediction_kernel_size,
        num_filters=head_config.num_filters,
        use_depthwise_convolution=head_config.use_depthwise_convolution,
        upsample_factor=head_config.upsample_factor,
        feature_fusion=head_config.feature_fusion,
        low_level=head_config.low_level,
        low_level_num_filters=head_config.low_level_num_filters,
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

    model = segmentation_model.SegmentationModel(backbone, decoder, head)
    return model