def test_unet_3d_decoder_creation(self, model_id): """Test creation of UNet 3D decoder.""" # Create test input for decoders based on input model_id. input_specs = {} for level in range(model_id): input_specs[str(level + 1)] = tf.TensorShape([ 1, 128 // (2**level), 128 // (2**level), 128 // (2**level), 1 ]) network = decoders.UNet3DDecoder(model_id=model_id, input_specs=input_specs, use_sync_bn=True, use_batch_normalization=True, use_deconvolution=True) model_config = semantic_segmentation_3d_exp.SemanticSegmentationModel3D( ) model_config.num_classes = 2 model_config.num_channels = 1 model_config.input_size = [None, None, None] model_config.decoder = decoders_cfg.Decoder( type='unet_3d_decoder', unet_3d_decoder=decoders_cfg.UNet3DDecoder(model_id=model_id)) factory_network = factory.build_decoder(input_specs=input_specs, model_config=model_config) network_config = network.get_config() factory_network_config = factory_network.get_config() print(network_config) print(factory_network_config) self.assertEqual(network_config, factory_network_config)
def seg_unet3d_test() -> cfg.ExperimentConfig: """Image segmentation on a dummy dataset with 3D UNet for testing purpose.""" train_batch_size = 2 eval_batch_size = 2 steps_per_epoch = 10 config = cfg.ExperimentConfig( task=SemanticSegmentation3DTask( model=SemanticSegmentationModel3D( num_classes=2, input_size=[32, 32, 32], num_channels=2, backbone=backbones.Backbone( type='unet_3d', unet_3d=backbones.UNet3D(model_id=2)), decoder=decoders.Decoder( type='unet_3d_decoder', unet_3d_decoder=decoders.UNet3DDecoder(model_id=2)), head=SegmentationHead3D(num_convs=0, num_classes=2), norm_activation=common.NormActivation( activation='relu', use_sync_bn=False)), train_data=DataConfig( input_path='train.tfrecord', num_classes=2, input_size=[32, 32, 32], num_channels=2, is_training=True, global_batch_size=train_batch_size), validation_data=DataConfig( input_path='val.tfrecord', num_classes=2, input_size=[32, 32, 32], num_channels=2, is_training=False, global_batch_size=eval_batch_size), losses=Losses(loss_type='adaptive')), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=10, validation_steps=10, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', }, 'learning_rate': { 'type': 'constant', 'constant': { 'learning_rate': 0.000001 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
class SemanticSegmentationModel3D(hyperparams.Config): """Semantic segmentation model config.""" num_classes: int = 0 num_channels: int = 1 input_size: List[int] = dataclasses.field(default_factory=list) min_level: int = 3 max_level: int = 6 head: SegmentationHead3D = SegmentationHead3D() backbone: backbones.Backbone = backbones.Backbone( type='unet_3d', unet_3d=backbones.UNet3D()) decoder: decoders.Decoder = decoders.Decoder( type='unet_3d_decoder', unet_3d_decoder=decoders.UNet3DDecoder()) norm_activation: common.NormActivation = common.NormActivation()
def test_identity_creation(self): """Test creation of identity decoder.""" model_config = semantic_segmentation_3d_exp.SemanticSegmentationModel3D( ) model_config.num_classes = 2 model_config.num_channels = 3 model_config.input_size = [None, None, None] model_config.decoder = decoders_cfg.Decoder( type='identity', identity=decoders_cfg.Identity()) factory_network = factory.build_decoder(input_specs=None, model_config=model_config) self.assertIsNone(factory_network)