def seg_unet3d_test() -> cfg.ExperimentConfig:
  """Image segmentation on a dummy dataset with 3D UNet for testing purpose."""
  train_batch_size = 2
  eval_batch_size = 2
  steps_per_epoch = 10
  config = cfg.ExperimentConfig(
      task=SemanticSegmentation3DTask(
          model=SemanticSegmentationModel3D(
              num_classes=2,
              input_size=[32, 32, 32],
              num_channels=2,
              backbone=backbones.Backbone(
                  type='unet_3d', unet_3d=backbones.UNet3D(model_id=2)),
              decoder=decoders.Decoder(
                  type='unet_3d_decoder',
                  unet_3d_decoder=decoders.UNet3DDecoder(model_id=2)),
              head=SegmentationHead3D(num_convs=0, num_classes=2),
              norm_activation=common.NormActivation(
                  activation='relu', use_sync_bn=False)),
          train_data=DataConfig(
              input_path='train.tfrecord',
              num_classes=2,
              input_size=[32, 32, 32],
              num_channels=2,
              is_training=True,
              global_batch_size=train_batch_size),
          validation_data=DataConfig(
              input_path='val.tfrecord',
              num_classes=2,
              input_size=[32, 32, 32],
              num_channels=2,
              is_training=False,
              global_batch_size=eval_batch_size),
          losses=Losses(loss_type='adaptive')),
      trainer=cfg.TrainerConfig(
          steps_per_loop=steps_per_epoch,
          summary_interval=steps_per_epoch,
          checkpoint_interval=steps_per_epoch,
          train_steps=10,
          validation_steps=10,
          validation_interval=steps_per_epoch,
          optimizer_config=optimization.OptimizationConfig({
              'optimizer': {
                  'type': 'sgd',
              },
              'learning_rate': {
                  'type': 'constant',
                  'constant': {
                      'learning_rate': 0.000001
                  }
              }
          })),
      restrictions=[
          'task.train_data.is_training != None',
          'task.validation_data.is_training != None'
      ])

  return config
class SemanticSegmentationModel3D(hyperparams.Config):
  """Semantic segmentation model config."""
  num_classes: int = 0
  num_channels: int = 1
  input_size: List[int] = dataclasses.field(default_factory=list)
  min_level: int = 3
  max_level: int = 6
  head: SegmentationHead3D = SegmentationHead3D()
  backbone: backbones.Backbone = backbones.Backbone(
      type='unet_3d', unet_3d=backbones.UNet3D())
  decoder: decoders.Decoder = decoders.Decoder(
      type='unet_3d_decoder', unet_3d_decoder=decoders.UNet3DDecoder())
  norm_activation: common.NormActivation = common.NormActivation()