示例#1
0
def build_segmentation_model_3d(
    input_specs: tf.keras.layers.InputSpec,
    model_config: hyperparams.Config,
    l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
  """Builds Segmentation model."""
  norm_activation_config = model_config.norm_activation
  backbone = backbone_factory.build_backbone(
      input_specs=input_specs,
      backbone_config=model_config.backbone,
      norm_activation_config=norm_activation_config,
      l2_regularizer=l2_regularizer)

  decoder = decoder_factory.build_decoder(
      input_specs=backbone.output_specs,
      model_config=model_config,
      l2_regularizer=l2_regularizer)

  head_config = model_config.head

  head = segmentation_heads_3d.SegmentationHead3D(
      num_classes=model_config.num_classes,
      level=head_config.level,
      num_convs=head_config.num_convs,
      num_filters=head_config.num_filters,
      upsample_factor=head_config.upsample_factor,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      use_batch_normalization=head_config.use_batch_normalization,
      kernel_regularizer=l2_regularizer,
      output_logits=head_config.output_logits)

  model = segmentation_model.SegmentationModel(backbone, decoder, head)
  return model
    def test_forward(self, level, num_convs):
        head = segmentation_heads_3d.SegmentationHead3D(num_classes=10,
                                                        level=level,
                                                        num_convs=num_convs)
        backbone_features = {
            '1': np.random.rand(2, 128, 128, 128, 16),
            '2': np.random.rand(2, 64, 64, 64, 16),
        }
        decoder_features = {
            '1': np.random.rand(2, 128, 128, 128, 16),
            '2': np.random.rand(2, 64, 64, 64, 16),
        }
        logits = head(backbone_features, decoder_features)

        if str(level) in decoder_features:
            self.assertAllEqual(logits.numpy().shape, [
                2, decoder_features[str(level)].shape[1],
                decoder_features[str(level)].shape[2],
                decoder_features[str(level)].shape[3], 10
            ])
    def test_serialize_deserialize(self):
        """Validate the network can be serialized and deserialized."""
        num_classes = 3
        backbone = backbones.UNet3D(model_id=4)
        decoder = decoders.UNet3DDecoder(model_id=4,
                                         input_specs=backbone.output_specs)
        head = segmentation_heads_3d.SegmentationHead3D(num_classes,
                                                        level=1,
                                                        num_convs=0)
        model = segmentation_model.SegmentationModel(backbone=backbone,
                                                     decoder=decoder,
                                                     head=head)

        config = model.get_config()
        new_model = segmentation_model.SegmentationModel.from_config(config)

        # Validate that the config can be forced to JSON.
        _ = new_model.to_json()

        # If the serialization was successful, the new config should match the old.
        self.assertAllEqual(model.get_config(), new_model.get_config())
    def test_segmentation_network_unet3d_creation(self, input_size, depth):
        """Test for creation of a segmentation network."""
        num_classes = 2
        inputs = np.random.rand(2, input_size[0], input_size[0], input_size[1],
                                3)
        tf.keras.backend.set_image_data_format('channels_last')
        backbone = backbones.UNet3D(model_id=depth)

        decoder = decoders.UNet3DDecoder(model_id=depth,
                                         input_specs=backbone.output_specs)
        head = segmentation_heads_3d.SegmentationHead3D(num_classes,
                                                        level=1,
                                                        num_convs=0)

        model = segmentation_model.SegmentationModel(backbone=backbone,
                                                     decoder=decoder,
                                                     head=head)

        logits = model(inputs)
        self.assertAllEqual(
            [2, input_size[0], input_size[0], input_size[1], num_classes],
            logits.numpy().shape)
 def test_serialize_deserialize(self):
     head = segmentation_heads_3d.SegmentationHead3D(num_classes=10,
                                                     level=3)
     config = head.get_config()
     new_head = segmentation_heads_3d.SegmentationHead3D.from_config(config)
     self.assertAllEqual(head.get_config(), new_head.get_config())