Beispiel #1
0
    def test_unet_3d_decoder_creation(self, model_id):
        """Test creation of UNet 3D decoder."""
        # Create test input for decoders based on input model_id.
        input_specs = {}
        for level in range(model_id):
            input_specs[str(level + 1)] = tf.TensorShape([
                1, 128 // (2**level), 128 // (2**level), 128 // (2**level), 1
            ])

        network = decoders.UNet3DDecoder(model_id=model_id,
                                         input_specs=input_specs,
                                         use_sync_bn=True,
                                         use_batch_normalization=True,
                                         use_deconvolution=True)

        model_config = semantic_segmentation_3d_exp.SemanticSegmentationModel3D(
        )
        model_config.num_classes = 2
        model_config.num_channels = 1
        model_config.input_size = [None, None, None]
        model_config.decoder = decoders_cfg.Decoder(
            type='unet_3d_decoder',
            unet_3d_decoder=decoders_cfg.UNet3DDecoder(model_id=model_id))

        factory_network = factory.build_decoder(input_specs=input_specs,
                                                model_config=model_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()
        print(network_config)
        print(factory_network_config)

        self.assertEqual(network_config, factory_network_config)
 def test_unet3d_builder(self, input_size, weight_decay, use_bn):
   num_classes = 3
   input_specs = tf.keras.layers.InputSpec(
       shape=[None, input_size[0], input_size[1], input_size[2], 3])
   model_config = exp_cfg.SemanticSegmentationModel3D(num_classes=num_classes)
   model_config.head.use_batch_normalization = use_bn
   l2_regularizer = (
       tf.keras.regularizers.l2(weight_decay) if weight_decay else None)
   model = factory.build_segmentation_model_3d(
       input_specs=input_specs,
       model_config=model_config,
       l2_regularizer=l2_regularizer)
   self.assertIsInstance(
       model, tf.keras.Model,
       'Output should be a tf.keras.Model instance but got %s' % type(model))
Beispiel #3
0
    def test_identity_creation(self):
        """Test creation of identity decoder."""
        model_config = semantic_segmentation_3d_exp.SemanticSegmentationModel3D(
        )
        model_config.num_classes = 2
        model_config.num_channels = 3
        model_config.input_size = [None, None, None]

        model_config.decoder = decoders_cfg.Decoder(
            type='identity', identity=decoders_cfg.Identity())

        factory_network = factory.build_decoder(input_specs=None,
                                                model_config=model_config)

        self.assertIsNone(factory_network)