def build_segmentation_model_3d( input_specs: tf.keras.layers.InputSpec, model_config: hyperparams.Config, l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: """Builds Segmentation model.""" norm_activation_config = model_config.norm_activation backbone = backbone_factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=norm_activation_config, l2_regularizer=l2_regularizer) decoder = decoder_factory.build_decoder( input_specs=backbone.output_specs, model_config=model_config, l2_regularizer=l2_regularizer) head_config = model_config.head head = segmentation_heads_3d.SegmentationHead3D( num_classes=model_config.num_classes, level=head_config.level, num_convs=head_config.num_convs, num_filters=head_config.num_filters, upsample_factor=head_config.upsample_factor, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, norm_momentum=norm_activation_config.norm_momentum, norm_epsilon=norm_activation_config.norm_epsilon, use_batch_normalization=head_config.use_batch_normalization, kernel_regularizer=l2_regularizer, output_logits=head_config.output_logits) model = segmentation_model.SegmentationModel(backbone, decoder, head) return model
def test_forward(self, level, num_convs): head = segmentation_heads_3d.SegmentationHead3D(num_classes=10, level=level, num_convs=num_convs) backbone_features = { '1': np.random.rand(2, 128, 128, 128, 16), '2': np.random.rand(2, 64, 64, 64, 16), } decoder_features = { '1': np.random.rand(2, 128, 128, 128, 16), '2': np.random.rand(2, 64, 64, 64, 16), } logits = head(backbone_features, decoder_features) if str(level) in decoder_features: self.assertAllEqual(logits.numpy().shape, [ 2, decoder_features[str(level)].shape[1], decoder_features[str(level)].shape[2], decoder_features[str(level)].shape[3], 10 ])
def test_serialize_deserialize(self): """Validate the network can be serialized and deserialized.""" num_classes = 3 backbone = backbones.UNet3D(model_id=4) decoder = decoders.UNet3DDecoder(model_id=4, input_specs=backbone.output_specs) head = segmentation_heads_3d.SegmentationHead3D(num_classes, level=1, num_convs=0) model = segmentation_model.SegmentationModel(backbone=backbone, decoder=decoder, head=head) config = model.get_config() new_model = segmentation_model.SegmentationModel.from_config(config) # Validate that the config can be forced to JSON. _ = new_model.to_json() # If the serialization was successful, the new config should match the old. self.assertAllEqual(model.get_config(), new_model.get_config())
def test_segmentation_network_unet3d_creation(self, input_size, depth): """Test for creation of a segmentation network.""" num_classes = 2 inputs = np.random.rand(2, input_size[0], input_size[0], input_size[1], 3) tf.keras.backend.set_image_data_format('channels_last') backbone = backbones.UNet3D(model_id=depth) decoder = decoders.UNet3DDecoder(model_id=depth, input_specs=backbone.output_specs) head = segmentation_heads_3d.SegmentationHead3D(num_classes, level=1, num_convs=0) model = segmentation_model.SegmentationModel(backbone=backbone, decoder=decoder, head=head) logits = model(inputs) self.assertAllEqual( [2, input_size[0], input_size[0], input_size[1], num_classes], logits.numpy().shape)
def test_serialize_deserialize(self): head = segmentation_heads_3d.SegmentationHead3D(num_classes=10, level=3) config = head.get_config() new_head = segmentation_heads_3d.SegmentationHead3D.from_config(config) self.assertAllEqual(head.get_config(), new_head.get_config())