Beispiel #1
0
    def test_unet_3d_decoder_creation(self, model_id):
        """Test creation of UNet 3D decoder."""
        # Create test input for decoders based on input model_id.
        input_specs = {}
        for level in range(model_id):
            input_specs[str(level + 1)] = tf.TensorShape([
                1, 128 // (2**level), 128 // (2**level), 128 // (2**level), 1
            ])

        network = decoders.UNet3DDecoder(model_id=model_id,
                                         input_specs=input_specs,
                                         use_sync_bn=True,
                                         use_batch_normalization=True,
                                         use_deconvolution=True)

        model_config = semantic_segmentation_3d_exp.SemanticSegmentationModel3D(
        )
        model_config.num_classes = 2
        model_config.num_channels = 1
        model_config.input_size = [None, None, None]
        model_config.decoder = decoders_cfg.Decoder(
            type='unet_3d_decoder',
            unet_3d_decoder=decoders_cfg.UNet3DDecoder(model_id=model_id))

        factory_network = factory.build_decoder(input_specs=input_specs,
                                                model_config=model_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()
        print(network_config)
        print(factory_network_config)

        self.assertEqual(network_config, factory_network_config)
Beispiel #2
0
def build_segmentation_model_3d(
    input_specs: tf.keras.layers.InputSpec,
    model_config: hyperparams.Config,
    l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:  # pytype: disable=annotation-type-mismatch  # typed-keras
  """Builds Segmentation model."""
  norm_activation_config = model_config.norm_activation
  backbone = backbone_factory.build_backbone(
      input_specs=input_specs,
      backbone_config=model_config.backbone,
      norm_activation_config=norm_activation_config,
      l2_regularizer=l2_regularizer)

  decoder = decoder_factory.build_decoder(
      input_specs=backbone.output_specs,
      model_config=model_config,
      l2_regularizer=l2_regularizer)

  head_config = model_config.head

  head = segmentation_heads_3d.SegmentationHead3D(
      num_classes=model_config.num_classes,
      level=head_config.level,
      num_convs=head_config.num_convs,
      num_filters=head_config.num_filters,
      upsample_factor=head_config.upsample_factor,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      use_batch_normalization=head_config.use_batch_normalization,
      kernel_regularizer=l2_regularizer,
      output_logits=head_config.output_logits)

  model = segmentation_model.SegmentationModel(backbone, decoder, head)
  return model
Beispiel #3
0
    def test_identity_creation(self):
        """Test creation of identity decoder."""
        model_config = semantic_segmentation_3d_exp.SemanticSegmentationModel3D(
        )
        model_config.num_classes = 2
        model_config.num_channels = 3
        model_config.input_size = [None, None, None]

        model_config.decoder = decoders_cfg.Decoder(
            type='identity', identity=decoders_cfg.Identity())

        factory_network = factory.build_decoder(input_specs=None,
                                                model_config=model_config)

        self.assertIsNone(factory_network)