Beispiel #1
0
    def test_unet_3d_decoder_creation(self, model_id):
        """Test creation of UNet 3D decoder."""
        # Create test input for decoders based on input model_id.
        input_specs = {}
        for level in range(model_id):
            input_specs[str(level + 1)] = tf.TensorShape([
                1, 128 // (2**level), 128 // (2**level), 128 // (2**level), 1
            ])

        network = decoders.UNet3DDecoder(model_id=model_id,
                                         input_specs=input_specs,
                                         use_sync_bn=True,
                                         use_batch_normalization=True,
                                         use_deconvolution=True)

        model_config = semantic_segmentation_3d_exp.SemanticSegmentationModel3D(
        )
        model_config.num_classes = 2
        model_config.num_channels = 1
        model_config.input_size = [None, None, None]
        model_config.decoder = decoders_cfg.Decoder(
            type='unet_3d_decoder',
            unet_3d_decoder=decoders_cfg.UNet3DDecoder(model_id=model_id))

        factory_network = factory.build_decoder(input_specs=input_specs,
                                                model_config=model_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()
        print(network_config)
        print(factory_network_config)

        self.assertEqual(network_config, factory_network_config)
def seg_unet3d_test() -> cfg.ExperimentConfig:
  """Image segmentation on a dummy dataset with 3D UNet for testing purpose."""
  train_batch_size = 2
  eval_batch_size = 2
  steps_per_epoch = 10
  config = cfg.ExperimentConfig(
      task=SemanticSegmentation3DTask(
          model=SemanticSegmentationModel3D(
              num_classes=2,
              input_size=[32, 32, 32],
              num_channels=2,
              backbone=backbones.Backbone(
                  type='unet_3d', unet_3d=backbones.UNet3D(model_id=2)),
              decoder=decoders.Decoder(
                  type='unet_3d_decoder',
                  unet_3d_decoder=decoders.UNet3DDecoder(model_id=2)),
              head=SegmentationHead3D(num_convs=0, num_classes=2),
              norm_activation=common.NormActivation(
                  activation='relu', use_sync_bn=False)),
          train_data=DataConfig(
              input_path='train.tfrecord',
              num_classes=2,
              input_size=[32, 32, 32],
              num_channels=2,
              is_training=True,
              global_batch_size=train_batch_size),
          validation_data=DataConfig(
              input_path='val.tfrecord',
              num_classes=2,
              input_size=[32, 32, 32],
              num_channels=2,
              is_training=False,
              global_batch_size=eval_batch_size),
          losses=Losses(loss_type='adaptive')),
      trainer=cfg.TrainerConfig(
          steps_per_loop=steps_per_epoch,
          summary_interval=steps_per_epoch,
          checkpoint_interval=steps_per_epoch,
          train_steps=10,
          validation_steps=10,
          validation_interval=steps_per_epoch,
          optimizer_config=optimization.OptimizationConfig({
              'optimizer': {
                  'type': 'sgd',
              },
              'learning_rate': {
                  'type': 'constant',
                  'constant': {
                      'learning_rate': 0.000001
                  }
              }
          })),
      restrictions=[
          'task.train_data.is_training != None',
          'task.validation_data.is_training != None'
      ])

  return config
class SemanticSegmentationModel3D(hyperparams.Config):
  """Semantic segmentation model config."""
  num_classes: int = 0
  num_channels: int = 1
  input_size: List[int] = dataclasses.field(default_factory=list)
  min_level: int = 3
  max_level: int = 6
  head: SegmentationHead3D = SegmentationHead3D()
  backbone: backbones.Backbone = backbones.Backbone(
      type='unet_3d', unet_3d=backbones.UNet3D())
  decoder: decoders.Decoder = decoders.Decoder(
      type='unet_3d_decoder', unet_3d_decoder=decoders.UNet3DDecoder())
  norm_activation: common.NormActivation = common.NormActivation()