Example #1
0
    def test_mobilenet_creation(self, model_id, filter_size_scale):
        """Test creation of Mobilenet models."""

        network = backbones.MobileNet(model_id=model_id,
                                      filter_size_scale=filter_size_scale,
                                      norm_momentum=0.99,
                                      norm_epsilon=1e-5)

        backbone_config = backbones_cfg.Backbone(
            type='mobilenet',
            mobilenet=backbones_cfg.MobileNet(
                model_id=model_id, filter_size_scale=filter_size_scale))
        norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99,
                                                           norm_epsilon=1e-5,
                                                           use_sync_bn=False)

        factory_network = factory.build_backbone(
            input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
            backbone_config=backbone_config,
            norm_activation_config=norm_activation_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()

        self.assertEqual(network_config, factory_network_config)
Example #2
0
    def test_spinenet_creation(self, model_id):
        """Test creation of SpineNet models."""
        input_size = 128
        min_level = 3
        max_level = 7

        input_specs = tf.keras.layers.InputSpec(
            shape=[None, input_size, input_size, 3])
        network = backbones.SpineNet(input_specs=input_specs,
                                     min_level=min_level,
                                     max_level=max_level,
                                     norm_momentum=0.99,
                                     norm_epsilon=1e-5)

        backbone_config = backbones_cfg.Backbone(
            type='spinenet',
            spinenet=backbones_cfg.SpineNet(model_id=model_id))
        norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99,
                                                           norm_epsilon=1e-5,
                                                           use_sync_bn=False)

        factory_network = factory.build_backbone(
            input_specs=tf.keras.layers.InputSpec(
                shape=[None, input_size, input_size, 3]),
            backbone_config=backbone_config,
            norm_activation_config=norm_activation_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()

        self.assertEqual(network_config, factory_network_config)
Example #3
0
    def test_efficientnet_creation(self, model_id, se_ratio):
        """Test creation of EfficientNet models."""

        network = backbones.EfficientNet(model_id=model_id,
                                         se_ratio=se_ratio,
                                         norm_momentum=0.99,
                                         norm_epsilon=1e-5)

        backbone_config = backbones_cfg.Backbone(
            type='efficientnet',
            efficientnet=backbones_cfg.EfficientNet(model_id=model_id,
                                                    se_ratio=se_ratio))
        norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99,
                                                           norm_epsilon=1e-5,
                                                           use_sync_bn=False)

        factory_network = factory.build_backbone(
            input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
            backbone_config=backbone_config,
            norm_activation_config=norm_activation_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()

        self.assertEqual(network_config, factory_network_config)
    def test_resnet_creation(self, model_id):
        """Test creation of ResNet models."""

        network = backbones.ResNet(model_id=model_id,
                                   se_ratio=0.0,
                                   norm_momentum=0.99,
                                   norm_epsilon=1e-5)

        backbone_config = backbones_cfg.Backbone(type='resnet',
                                                 resnet=backbones_cfg.ResNet(
                                                     model_id=model_id,
                                                     se_ratio=0.0))
        norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99,
                                                           norm_epsilon=1e-5,
                                                           use_sync_bn=False)
        model_config = retinanet_cfg.RetinaNet(
            backbone=backbone_config, norm_activation=norm_activation_config)

        factory_network = factory.build_backbone(
            input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
            model_config=model_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()

        self.assertEqual(network_config, factory_network_config)
class VideoClassificationModel(hyperparams.Config):
    """The model config."""
    backbone: backbones_3d.Backbone3D = backbones_3d.Backbone3D(
        type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50())
    norm_activation: common.NormActivation = common.NormActivation()
    dropout_rate: float = 0.2
    add_head_batch_norm: bool = False
class AssembleNetModel(video_classification.VideoClassificationModel):
    """The AssembleNet model config."""
    model_type: str = 'assemblenet'
    backbone: Backbone3D = Backbone3D()
    norm_activation: common.NormActivation = common.NormActivation(
        norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=True)
    max_pool_preditions: bool = False
Example #7
0
    def testBuildCenterNet(self):
        backbone = hourglass.build_hourglass(
            input_specs=tf.keras.layers.InputSpec(shape=[None, 512, 512, 3]),
            backbone_config=backbones.Backbone(type='hourglass'),
            norm_activation_config=common.NormActivation(use_sync_bn=True))

        task_config = {
            'ct_heatmaps': 90,
            'ct_offset': 2,
            'ct_size': 2,
        }

        input_levels = ['2_0', '2']

        head = centernet_head.CenterNetHead(task_outputs=task_config,
                                            input_specs=backbone.output_specs,
                                            input_levels=input_levels)

        detection_ge = detection_generator.CenterNetDetectionGenerator()

        model = centernet_model.CenterNetModel(
            backbone=backbone, head=head, detection_generator=detection_ge)

        outputs = model(tf.zeros((5, 512, 512, 3)))
        self.assertLen(outputs['raw_output'], 3)
        self.assertLen(outputs['raw_output']['ct_heatmaps'], 2)
        self.assertLen(outputs['raw_output']['ct_offset'], 2)
        self.assertLen(outputs['raw_output']['ct_size'], 2)
        self.assertEqual(outputs['raw_output']['ct_heatmaps'][0].shape,
                         (5, 128, 128, 90))
        self.assertEqual(outputs['raw_output']['ct_offset'][0].shape,
                         (5, 128, 128, 2))
        self.assertEqual(outputs['raw_output']['ct_size'][0].shape,
                         (5, 128, 128, 2))
class VideoClassificationModel(hyperparams.Config):
  """The model config."""
  model_type: str = 'video_classification'
  backbone: backbones_3d.Backbone3D = backbones_3d.Backbone3D(
      type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50())
  norm_activation: common.NormActivation = common.NormActivation()
  dropout_rate: float = 0.2
  aggregate_endpoints: bool = False
Example #9
0
class VideoSSLModel(VideoClassificationModel):
    """The model config."""
    normalize_feature: bool = False
    hidden_dim: int = 2048
    hidden_layer_num: int = 3
    projection_dim: int = 128
    hidden_norm_activation: common.NormActivation = common.NormActivation(
        use_sync_bn=False, norm_momentum=0.997, norm_epsilon=1.0e-05)
def seg_unet3d_test() -> cfg.ExperimentConfig:
  """Image segmentation on a dummy dataset with 3D UNet for testing purpose."""
  train_batch_size = 2
  eval_batch_size = 2
  steps_per_epoch = 10
  config = cfg.ExperimentConfig(
      task=SemanticSegmentation3DTask(
          model=SemanticSegmentationModel3D(
              num_classes=2,
              input_size=[32, 32, 32],
              num_channels=2,
              backbone=backbones.Backbone(
                  type='unet_3d', unet_3d=backbones.UNet3D(model_id=2)),
              decoder=decoders.Decoder(
                  type='unet_3d_decoder',
                  unet_3d_decoder=decoders.UNet3DDecoder(model_id=2)),
              head=SegmentationHead3D(num_convs=0, num_classes=2),
              norm_activation=common.NormActivation(
                  activation='relu', use_sync_bn=False)),
          train_data=DataConfig(
              input_path='train.tfrecord',
              num_classes=2,
              input_size=[32, 32, 32],
              num_channels=2,
              is_training=True,
              global_batch_size=train_batch_size),
          validation_data=DataConfig(
              input_path='val.tfrecord',
              num_classes=2,
              input_size=[32, 32, 32],
              num_channels=2,
              is_training=False,
              global_batch_size=eval_batch_size),
          losses=Losses(loss_type='adaptive')),
      trainer=cfg.TrainerConfig(
          steps_per_loop=steps_per_epoch,
          summary_interval=steps_per_epoch,
          checkpoint_interval=steps_per_epoch,
          train_steps=10,
          validation_steps=10,
          validation_interval=steps_per_epoch,
          optimizer_config=optimization.OptimizationConfig({
              'optimizer': {
                  'type': 'sgd',
              },
              'learning_rate': {
                  'type': 'constant',
                  'constant': {
                      'learning_rate': 0.000001
                  }
              }
          })),
      restrictions=[
          'task.train_data.is_training != None',
          'task.validation_data.is_training != None'
      ])

  return config
Example #11
0
 def test_hourglass(self):
     backbone = hourglass.build_hourglass(
         input_specs=tf.keras.layers.InputSpec(shape=[None, 512, 512, 3]),
         backbone_config=backbones.Backbone(type='hourglass'),
         norm_activation_config=common.NormActivation(use_sync_bn=True))
     inputs = np.zeros((2, 512, 512, 3), dtype=np.float32)
     outputs = backbone(inputs)
     self.assertEqual(outputs['2_0'].shape, (2, 128, 128, 256))
     self.assertEqual(outputs['2'].shape, (2, 128, 128, 256))
Example #12
0
class MultiHeadModel(hyperparams.Config):
    """Multi head multi task model config, similar to other models but 
  with input, backbone, activation and weight decay shared."""
    input_size: List[int] = dataclasses.field(default_factory=list)
    backbone: backbones.Backbone = backbones.Backbone(
        type='resnet', resnet=backbones.ResNet())
    norm_activation: common.NormActivation = common.NormActivation()
    heads: List[Submodel] = dataclasses.field(default_factory=list)
    l2_weight_decay: float = 0.0
class ImageClassificationModel(hyperparams.Config):
    num_classes: int = 0
    input_size: List[int] = dataclasses.field(default_factory=list)
    backbone: backbones.Backbone = backbones.Backbone(
        type='darknet', darknet=backbones.Darknet())
    dropout_rate: float = 0.0
    norm_activation: common.NormActivation = common.NormActivation()
    # Adds a Batch Normalization layer pre-GlobalAveragePooling in classification.
    add_head_batch_norm: bool = False
Example #14
0
class YoloModel(hyperparams.Config):
  num_classes: int = 0
  input_size: List[int] = dataclasses.field(default_factory=list)
  min_level: int = 3 # only for FPN or NASFPN
  max_level: int = 6 # only for FPN or NASFPN
  head: hyperparams.Config = YoloHead()
  backbone: backbones.Backbone = backbones.Backbone(
      type='resnet', resnet=backbones.ResNet())
  decoder: decoders.Decoder = decoders.Decoder(type='identity')
  norm_activation: common.NormActivation = common.NormActivation()
Example #15
0
class MovinetModel(video_classification.VideoClassificationModel):
    """The MoViNet model config."""
    model_type: str = 'movinet'
    backbone: Backbone3D = Backbone3D()
    norm_activation: common.NormActivation = common.NormActivation(
        activation='swish',
        norm_momentum=0.99,
        norm_epsilon=1e-3,
        use_sync_bn=True)
    output_states: bool = False
Example #16
0
class ImageClassificationModel(hyperparams.Config):
  """The model config."""
  num_classes: int = 0
  input_size: List[int] = dataclasses.field(default_factory=list)
  backbone: backbones.Backbone = backbones.Backbone(
      type='vit', vit=backbones.VisionTransformer())
  dropout_rate: float = 0.0
  norm_activation: common.NormActivation = common.NormActivation(
      use_sync_bn=False)
  # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
  add_head_batch_norm: bool = False
class SemanticSegmentationModel(hyperparams.Config):
    """Semantic segmentation model config."""
    num_classes: int = 0
    input_size: List[int] = dataclasses.field(default_factory=list)
    min_level: int = 3
    max_level: int = 6
    head: SegmentationHead = SegmentationHead()
    backbone: backbones.Backbone = backbones.Backbone(
        type='resnet', resnet=backbones.ResNet())
    decoder: decoders.Decoder = decoders.Decoder(type='identity')
    norm_activation: common.NormActivation = common.NormActivation()
Example #18
0
class BASNetModel(hyperparams.Config):
  """BASNet model config."""
  num_classes: int = 0
  input_size: List[int] = dataclasses.field(default_factory=list)
  #min_level: int = 3
  #max_level: int = 6
  #head: BASNetHead = BASNetHead()
  backbone: backbones.Backbone = backbones.Backbone(
      type='basnet_en', basnet_en=backbones.BASNet_En())
  decoder: decoders.Decoder = decoders.Decoder(type='basnet_de')
  norm_activation: common.NormActivation = common.NormActivation()
Example #19
0
class ImageClassificationModel(hyperparams.Config):
    """Image classification model config."""
    num_classes: int = 0
    input_size: List[int] = dataclasses.field(
        default_factory=lambda: [224, 224])
    backbone: backbones.Backbone = backbones.Backbone(
        type='darknet', darknet=backbones.Darknet())
    dropout_rate: float = 0.0
    norm_activation: common.NormActivation = common.NormActivation()
    # Adds a Batch Normalization layer pre-GlobalAveragePooling in classification.
    add_head_batch_norm: bool = False
    kernel_initializer: str = 'VarianceScaling'
Example #20
0
class CenterNetModel(hyperparams.Config):
  """Config for centernet model."""
  num_classes: int = 90
  max_num_instances: int = 128
  input_size: List[int] = dataclasses.field(default_factory=list)
  backbone: backbones.Backbone = backbones.Backbone(
      type='hourglass', hourglass=backbones.Hourglass(model_id=52))
  head: CenterNetHead = CenterNetHead()
  # pylint: disable=line-too-long
  detection_generator: CenterNetDetectionGenerator = CenterNetDetectionGenerator()
  norm_activation: common.NormActivation = common.NormActivation(
      norm_momentum=0.1, norm_epsilon=1e-5, use_sync_bn=True)
Example #21
0
class DbofModel(hyperparams.Config):
    """The model config."""
    cluster_size: int = 3000
    hidden_size: int = 2000
    add_batch_norm: bool = True
    sample_random_frames: bool = True
    use_context_gate_cluster_layer: bool = False
    context_gate_cluster_bottleneck_size: int = 0
    pooling_method: str = 'average'
    yt8m_agg_classifier_model: str = 'MoeModel'
    agg_model: hyperparams.Config = MoeModel()
    norm_activation: common.NormActivation = common.NormActivation(
        activation='relu', use_sync_bn=False)
Example #22
0
class SimCLRModel(hyperparams.Config):
    """SimCLR model config."""
    input_size: List[int] = dataclasses.field(default_factory=list)
    backbone: backbones.Backbone = backbones.Backbone(
        type='resnet', resnet=backbones.ResNet())
    projection_head: ProjectionHead = ProjectionHead(proj_output_dim=128,
                                                     num_proj_layers=3,
                                                     ft_proj_idx=1)
    supervised_head: SupervisedHead = SupervisedHead(num_classes=1001)
    norm_activation: common.NormActivation = common.NormActivation(
        norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)
    mode: str = simclr_model.PRETRAIN
    backbone_trainable: bool = True
Example #23
0
class RetinaNet(hyperparams.Config):
    num_classes: int = 0
    input_size: List[int] = dataclasses.field(default_factory=list)
    min_level: int = 3
    max_level: int = 7
    anchor: Anchor = Anchor()
    backbone: backbones.Backbone = backbones.Backbone(
        type='resnet', resnet=backbones.ResNet())
    decoder: decoders.Decoder = decoders.Decoder(type='fpn',
                                                 fpn=decoders.FPN())
    head: RetinaNetHead = RetinaNetHead()
    detection_generator: DetectionGenerator = DetectionGenerator()
    norm_activation: common.NormActivation = common.NormActivation()
class SemanticSegmentationModel3D(hyperparams.Config):
  """Semantic segmentation model config."""
  num_classes: int = 0
  num_channels: int = 1
  input_size: List[int] = dataclasses.field(default_factory=list)
  min_level: int = 3
  max_level: int = 6
  head: SegmentationHead3D = SegmentationHead3D()
  backbone: backbones.Backbone = backbones.Backbone(
      type='unet_3d', unet_3d=backbones.UNet3D())
  decoder: decoders.Decoder = decoders.Decoder(
      type='unet_3d_decoder', unet_3d_decoder=decoders.UNet3DDecoder())
  norm_activation: common.NormActivation = common.NormActivation()
Example #25
0
class SimCLRMTModelConfig(hyperparams.Config):
    """Model config for multi-task SimCLR model."""
    input_size: List[int] = dataclasses.field(default_factory=list)
    backbone: backbones.Backbone = backbones.Backbone(
        type='resnet', resnet=backbones.ResNet())
    backbone_trainable: bool = True
    projection_head: simclr_configs.ProjectionHead = simclr_configs.ProjectionHead(
        proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1)
    norm_activation: common.NormActivation = common.NormActivation(
        norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)
    heads: Tuple[SimCLRMTHeadConfig, ...] = ()
    # L2 weight decay is used in the model, not in task.
    # Note that this can not be used together with lars optimizer.
    l2_weight_decay: float = 0.0
Example #26
0
def video_classification_ucf101() -> cfg.ExperimentConfig:
  """Video classification on UCF-101 with resnet."""
  train_dataset = DataConfig(
      name='ucf101',
      num_classes=101,
      is_training=True,
      split='train',
      drop_remainder=True,
      num_examples=9537,
      temporal_stride=2,
      feature_shape=(32, 224, 224, 3))
  train_dataset.tfds_name = 'ucf101'
  train_dataset.tfds_split = 'train'
  validation_dataset = DataConfig(
      name='ucf101',
      num_classes=101,
      is_training=True,
      split='test',
      drop_remainder=False,
      num_examples=3783,
      temporal_stride=2,
      feature_shape=(32, 224, 224, 3))
  validation_dataset.tfds_name = 'ucf101'
  validation_dataset.tfds_split = 'test'
  task = VideoClassificationTask(
      model=VideoClassificationModel(
          backbone=backbones_3d.Backbone3D(
              type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()),
          norm_activation=common.NormActivation(
              norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)),
      losses=Losses(l2_weight_decay=1e-4),
      train_data=train_dataset,
      validation_data=validation_dataset)
  config = cfg.ExperimentConfig(
      runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
      task=task,
      restrictions=[
          'task.train_data.is_training != None',
          'task.validation_data.is_training != None',
          'task.train_data.num_classes == task.validation_data.num_classes',
      ])
  add_trainer(
      config,
      train_batch_size=64,
      eval_batch_size=16,
      learning_rate=0.8,
      train_epochs=100)
  return config
class ImageClassificationModel(hyperparams.Config):
    num_classes: int = 1000
    input_size: List[int] = dataclasses.field(
        default_factory=lambda: [256, 256, 3])
    backbone: backbones.Backbone = backbones.Backbone(
        type='darknet', resnet=backbones.DarkNet(model_id='cspdarknet'))
    dropout_rate: float = 0.0
    norm_activation: common.NormActivation = common.NormActivation(
        activation='leaky', use_sync_bn=False)
    # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
    add_head_batch_norm: bool = False
    min_level: Optional[int] = None
    max_level: int = 5
    dilate: bool = False
    darknet_weights_file: str = 'cache://csdarknet53.weights'
    darknet_weights_cfg: str = 'cache://csdarknet53.cfg'
Example #28
0
class Yolo(ModelConfig):
  num_classes: int = 80
  _input_size: Optional[List[int]] = None
  min_level: int = 3
  max_level: int = 5
  boxes_per_scale: int = 3
  base: Union[str, YoloBase] = YoloBase()
  dilate: bool = False
  filter: YoloLossLayer = YoloLossLayer()
  norm_activation: common.NormActivation = common.NormActivation(
      activation='leaky',
      use_sync_bn=False,
      norm_momentum=0.99,
      norm_epsilon=0.001)
  decoder_activation: str = 'leaky'
  _boxes: Optional[List[str]] = dataclasses.field(default_factory=lambda: [
      '(12, 16)', '(19, 36)', '(40, 28)', '(36, 75)', '(76, 55)', '(72, 146)',
      '(142, 110)', '(192, 243)', '(459, 401)'
  ])
Example #29
0
class Yolo(hyperparams.Config):
    input_size: Optional[List[int]] = dataclasses.field(
        default_factory=lambda: [512, 512, 3])
    backbone: backbones.Backbone = backbones.Backbone(
        type='darknet', darknet=backbones.Darknet(model_id='cspdarknet53'))
    decoder: decoders.Decoder = decoders.Decoder(
        type='yolo_decoder',
        yolo_decoder=decoders.YoloDecoder(version='v4', type='regular'))
    head: YoloHead = YoloHead()
    detection_generator: YoloDetectionGenerator = YoloDetectionGenerator()
    loss: YoloLoss = YoloLoss()
    norm_activation: common.NormActivation = common.NormActivation(
        activation='mish',
        use_sync_bn=True,
        norm_momentum=0.99,
        norm_epsilon=0.001)
    num_classes: int = 80
    anchor_boxes: AnchorBoxes = AnchorBoxes()
    darknet_based_model: bool = False
Example #30
0
class MaskRCNN(hyperparams.Config):
    num_classes: int = 0
    input_size: List[int] = dataclasses.field(default_factory=list)
    min_level: int = 2
    max_level: int = 6
    anchor: Anchor = Anchor()
    include_mask: bool = True
    backbone: backbones.Backbone = backbones.Backbone(
        type='resnet', resnet=backbones.ResNet())
    decoder: decoders.Decoder = decoders.Decoder(type='fpn',
                                                 fpn=decoders.FPN())
    rpn_head: RPNHead = RPNHead()
    detection_head: DetectionHead = DetectionHead()
    roi_generator: ROIGenerator = ROIGenerator()
    roi_sampler: ROISampler = ROISampler()
    roi_aligner: ROIAligner = ROIAligner()
    detection_generator: DetectionGenerator = DetectionGenerator()
    mask_head: Optional[MaskHead] = MaskHead()
    mask_sampler: Optional[MaskSampler] = MaskSampler()
    mask_roi_aligner: Optional[MaskROIAligner] = MaskROIAligner()
    norm_activation: common.NormActivation = common.NormActivation(
        norm_momentum=0.997, norm_epsilon=0.0001, use_sync_bn=True)