Beispiel #1
0
    def test_resnet_creation(self, model_id):
        """Test creation of ResNet models."""

        network = backbones.ResNet(model_id=model_id,
                                   se_ratio=0.0,
                                   norm_momentum=0.99,
                                   norm_epsilon=1e-5)

        backbone_config = backbones_cfg.Backbone(type='resnet',
                                                 resnet=backbones_cfg.ResNet(
                                                     model_id=model_id,
                                                     se_ratio=0.0))
        norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99,
                                                           norm_epsilon=1e-5,
                                                           use_sync_bn=False)

        factory_network = factory.build_backbone(
            input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
            backbone_config=backbone_config,
            norm_activation_config=norm_activation_config)

        network_config = network.get_config()
        factory_network_config = factory_network.get_config()

        self.assertEqual(network_config, factory_network_config)
Beispiel #2
0
class MultiHeadModel(hyperparams.Config):
    """Multi head multi task model config, similar to other models but 
  with input, backbone, activation and weight decay shared."""
    input_size: List[int] = dataclasses.field(default_factory=list)
    backbone: backbones.Backbone = backbones.Backbone(
        type='resnet', resnet=backbones.ResNet())
    norm_activation: common.NormActivation = common.NormActivation()
    heads: List[Submodel] = dataclasses.field(default_factory=list)
    l2_weight_decay: float = 0.0
class ImageClassificationModel(hyperparams.Config):
    num_classes: int = 0
    input_size: List[int] = dataclasses.field(default_factory=list)
    backbone: backbones.Backbone = backbones.Backbone(
        type='resnet', resnet=backbones.ResNet())
    dropout_rate: float = 0.0
    norm_activation: common.NormActivation = common.NormActivation()
    # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
    add_head_batch_norm: bool = False
Beispiel #4
0
class YoloModel(hyperparams.Config):
  num_classes: int = 0
  input_size: List[int] = dataclasses.field(default_factory=list)
  min_level: int = 3 # only for FPN or NASFPN
  max_level: int = 6 # only for FPN or NASFPN
  head: hyperparams.Config = YoloHead()
  backbone: backbones.Backbone = backbones.Backbone(
      type='resnet', resnet=backbones.ResNet())
  decoder: decoders.Decoder = decoders.Decoder(type='identity')
  norm_activation: common.NormActivation = common.NormActivation()
class SemanticSegmentationModel(hyperparams.Config):
    """Semantic segmentation model config."""
    num_classes: int = 0
    input_size: List[int] = dataclasses.field(default_factory=list)
    min_level: int = 3
    max_level: int = 6
    head: SegmentationHead = SegmentationHead()
    backbone: backbones.Backbone = backbones.Backbone(
        type='resnet', resnet=backbones.ResNet())
    decoder: decoders.Decoder = decoders.Decoder(type='identity')
    norm_activation: common.NormActivation = common.NormActivation()
class ImageClassificationModel(hyperparams.Config):
  """The model config."""
  num_classes: int = 0
  input_size: List[int] = dataclasses.field(default_factory=list)
  backbone: backbones.Backbone = backbones.Backbone(
      type='resnet', resnet=backbones.ResNet())
  dropout_rate: float = 0.0
  norm_activation: common.NormActivation = common.NormActivation(
      use_sync_bn=False)
  # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
  add_head_batch_norm: bool = False
  kernel_initializer: str = 'random_uniform'
Beispiel #7
0
class RetinaNet(hyperparams.Config):
    num_classes: int = 0
    input_size: List[int] = dataclasses.field(default_factory=list)
    min_level: int = 3
    max_level: int = 7
    anchor: Anchor = Anchor()
    backbone: backbones.Backbone = backbones.Backbone(
        type='resnet', resnet=backbones.ResNet())
    decoder: decoders.Decoder = decoders.Decoder(type='fpn',
                                                 fpn=decoders.FPN())
    head: RetinaNetHead = RetinaNetHead()
    detection_generator: DetectionGenerator = DetectionGenerator()
    norm_activation: common.NormActivation = common.NormActivation()
Beispiel #8
0
class SimCLRModel(hyperparams.Config):
    """SimCLR model config."""
    input_size: List[int] = dataclasses.field(default_factory=list)
    backbone: backbones.Backbone = backbones.Backbone(
        type='resnet', resnet=backbones.ResNet())
    projection_head: ProjectionHead = ProjectionHead(proj_output_dim=128,
                                                     num_proj_layers=3,
                                                     ft_proj_idx=1)
    supervised_head: SupervisedHead = SupervisedHead(num_classes=1001)
    norm_activation: common.NormActivation = common.NormActivation(
        norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)
    mode: str = simclr_model.PRETRAIN
    backbone_trainable: bool = True
Beispiel #9
0
class SimCLRMTModelConfig(hyperparams.Config):
    """Model config for multi-task SimCLR model."""
    input_size: List[int] = dataclasses.field(default_factory=list)
    backbone: backbones.Backbone = backbones.Backbone(
        type='resnet', resnet=backbones.ResNet())
    backbone_trainable: bool = True
    projection_head: simclr_configs.ProjectionHead = simclr_configs.ProjectionHead(
        proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1)
    norm_activation: common.NormActivation = common.NormActivation(
        norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)
    heads: Tuple[SimCLRMTHeadConfig, ...] = ()
    # L2 weight decay is used in the model, not in task.
    # Note that this can not be used together with lars optimizer.
    l2_weight_decay: float = 0.0
Beispiel #10
0
class MaskRCNN(hyperparams.Config):
    num_classes: int = 0
    input_size: List[int] = dataclasses.field(default_factory=list)
    min_level: int = 2
    max_level: int = 6
    anchor: Anchor = Anchor()
    include_mask: bool = True
    backbone: backbones.Backbone = backbones.Backbone(
        type='resnet', resnet=backbones.ResNet())
    decoder: decoders.Decoder = decoders.Decoder(type='fpn',
                                                 fpn=decoders.FPN())
    rpn_head: RPNHead = RPNHead()
    detection_head: DetectionHead = DetectionHead()
    roi_generator: ROIGenerator = ROIGenerator()
    roi_sampler: ROISampler = ROISampler()
    roi_aligner: ROIAligner = ROIAligner()
    detection_generator: DetectionGenerator = DetectionGenerator()
    mask_head: Optional[MaskHead] = MaskHead()
    mask_sampler: Optional[MaskSampler] = MaskSampler()
    mask_roi_aligner: Optional[MaskROIAligner] = MaskROIAligner()
    norm_activation: common.NormActivation = common.NormActivation(
        norm_momentum=0.997, norm_epsilon=0.0001, use_sync_bn=True)
def image_classification_imagenet_resnetrs() -> cfg.ExperimentConfig:
  """Image classification on imagenet with resnet-rs."""
  train_batch_size = 4096
  eval_batch_size = 4096
  steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
  config = cfg.ExperimentConfig(
      task=ImageClassificationTask(
          model=ImageClassificationModel(
              num_classes=1001,
              input_size=[160, 160, 3],
              backbone=backbones.Backbone(
                  type='resnet',
                  resnet=backbones.ResNet(
                      model_id=50,
                      stem_type='v1',
                      resnetd_shortcut=True,
                      replace_stem_max_pool=True,
                      se_ratio=0.25,
                      stochastic_depth_drop_rate=0.0)),
              dropout_rate=0.25,
              norm_activation=common.NormActivation(
                  norm_momentum=0.0,
                  norm_epsilon=1e-5,
                  use_sync_bn=False,
                  activation='swish')),
          losses=Losses(l2_weight_decay=4e-5, label_smoothing=0.1),
          train_data=DataConfig(
              input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
              is_training=True,
              global_batch_size=train_batch_size,
              aug_type=common.Augmentation(
                  type='randaug', randaug=common.RandAugment(magnitude=10))),
          validation_data=DataConfig(
              input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
              is_training=False,
              global_batch_size=eval_batch_size)),
      trainer=cfg.TrainerConfig(
          steps_per_loop=steps_per_epoch,
          summary_interval=steps_per_epoch,
          checkpoint_interval=steps_per_epoch,
          train_steps=350 * steps_per_epoch,
          validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
          validation_interval=steps_per_epoch,
          optimizer_config=optimization.OptimizationConfig({
              'optimizer': {
                  'type': 'sgd',
                  'sgd': {
                      'momentum': 0.9
                  }
              },
              'ema': {
                  'average_decay': 0.9999,
                  'trainable_weights_only': False,
              },
              'learning_rate': {
                  'type': 'cosine',
                  'cosine': {
                      'initial_learning_rate': 1.6,
                      'decay_steps': 350 * steps_per_epoch
                  }
              },
              'warmup': {
                  'type': 'linear',
                  'linear': {
                      'warmup_steps': 5 * steps_per_epoch,
                      'warmup_learning_rate': 0
                  }
              }
          })),
      restrictions=[
          'task.train_data.is_training != None',
          'task.validation_data.is_training != None'
      ])
  return config
def image_classification_imagenet() -> cfg.ExperimentConfig:
  """Image classification on imagenet with resnet."""
  train_batch_size = 4096
  eval_batch_size = 4096
  steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
  config = cfg.ExperimentConfig(
      runtime=cfg.RuntimeConfig(enable_xla=True),
      task=ImageClassificationTask(
          model=ImageClassificationModel(
              num_classes=1001,
              input_size=[224, 224, 3],
              backbone=backbones.Backbone(
                  type='resnet', resnet=backbones.ResNet(model_id=50)),
              norm_activation=common.NormActivation(
                  norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)),
          losses=Losses(l2_weight_decay=1e-4),
          train_data=DataConfig(
              input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
              is_training=True,
              global_batch_size=train_batch_size),
          validation_data=DataConfig(
              input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
              is_training=False,
              global_batch_size=eval_batch_size)),
      trainer=cfg.TrainerConfig(
          steps_per_loop=steps_per_epoch,
          summary_interval=steps_per_epoch,
          checkpoint_interval=steps_per_epoch,
          train_steps=90 * steps_per_epoch,
          validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
          validation_interval=steps_per_epoch,
          optimizer_config=optimization.OptimizationConfig({
              'optimizer': {
                  'type': 'sgd',
                  'sgd': {
                      'momentum': 0.9
                  }
              },
              'learning_rate': {
                  'type': 'stepwise',
                  'stepwise': {
                      'boundaries': [
                          30 * steps_per_epoch, 60 * steps_per_epoch,
                          80 * steps_per_epoch
                      ],
                      'values': [
                          0.1 * train_batch_size / 256,
                          0.01 * train_batch_size / 256,
                          0.001 * train_batch_size / 256,
                          0.0001 * train_batch_size / 256,
                      ]
                  }
              },
              'warmup': {
                  'type': 'linear',
                  'linear': {
                      'warmup_steps': 5 * steps_per_epoch,
                      'warmup_learning_rate': 0
                  }
              }
          })),
      restrictions=[
          'task.train_data.is_training != None',
          'task.validation_data.is_training != None'
      ])

  return config
def seg_resnetfpn_pascal() -> cfg.ExperimentConfig:
    """Image segmentation on imagenet with resnet-fpn."""
    train_batch_size = 256
    eval_batch_size = 32
    steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size
    config = cfg.ExperimentConfig(
        task=SemanticSegmentationTask(
            model=SemanticSegmentationModel(
                num_classes=21,
                input_size=[512, 512, 3],
                min_level=3,
                max_level=7,
                backbone=backbones.Backbone(
                    type='resnet', resnet=backbones.ResNet(model_id=50)),
                decoder=decoders.Decoder(type='fpn', fpn=decoders.FPN()),
                head=SegmentationHead(level=3, num_convs=3),
                norm_activation=common.NormActivation(activation='swish',
                                                      use_sync_bn=True)),
            losses=Losses(l2_weight_decay=1e-4),
            train_data=DataConfig(input_path=os.path.join(
                PASCAL_INPUT_PATH_BASE, 'train_aug*'),
                                  is_training=True,
                                  global_batch_size=train_batch_size,
                                  aug_scale_min=0.2,
                                  aug_scale_max=1.5),
            validation_data=DataConfig(input_path=os.path.join(
                PASCAL_INPUT_PATH_BASE, 'val*'),
                                       is_training=False,
                                       global_batch_size=eval_batch_size,
                                       resize_eval_groundtruth=False,
                                       groundtruth_padded_size=[512, 512],
                                       drop_remainder=False),
        ),
        trainer=cfg.TrainerConfig(
            steps_per_loop=steps_per_epoch,
            summary_interval=steps_per_epoch,
            checkpoint_interval=steps_per_epoch,
            train_steps=450 * steps_per_epoch,
            validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size,
            validation_interval=steps_per_epoch,
            optimizer_config=optimization.OptimizationConfig({
                'optimizer': {
                    'type': 'sgd',
                    'sgd': {
                        'momentum': 0.9
                    }
                },
                'learning_rate': {
                    'type': 'polynomial',
                    'polynomial': {
                        'initial_learning_rate': 0.007,
                        'decay_steps': 450 * steps_per_epoch,
                        'end_learning_rate': 0.0,
                        'power': 0.9
                    }
                },
                'warmup': {
                    'type': 'linear',
                    'linear': {
                        'warmup_steps': 5 * steps_per_epoch,
                        'warmup_learning_rate': 0
                    }
                }
            })),
        restrictions=[
            'task.train_data.is_training != None',
            'task.validation_data.is_training != None'
        ])

    return config
Beispiel #14
0
def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
    """Image classification general."""
    train_batch_size = 1024
    eval_batch_size = 1024
    steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
    pretrain_model_base = ''
    return cfg.ExperimentConfig(
        task=SimCLRFinetuneTask(
            model=SimCLRModel(
                mode=simclr_model.FINETUNE,
                backbone_trainable=True,
                input_size=[224, 224, 3],
                backbone=backbones.Backbone(
                    type='resnet', resnet=backbones.ResNet(model_id=50)),
                projection_head=ProjectionHead(proj_output_dim=128,
                                               num_proj_layers=3,
                                               ft_proj_idx=1),
                supervised_head=SupervisedHead(num_classes=1001,
                                               zero_init=True),
                norm_activation=common.NormActivation(norm_momentum=0.9,
                                                      norm_epsilon=1e-5,
                                                      use_sync_bn=False)),
            loss=ClassificationLosses(),
            evaluation=Evaluation(),
            train_data=DataConfig(parser=Parser(mode=simclr_model.FINETUNE),
                                  input_path=os.path.join(
                                      IMAGENET_INPUT_PATH_BASE, 'train*'),
                                  is_training=True,
                                  global_batch_size=train_batch_size),
            validation_data=DataConfig(
                parser=Parser(mode=simclr_model.FINETUNE),
                input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
                is_training=False,
                global_batch_size=eval_batch_size),
            init_checkpoint=pretrain_model_base,
            # all, backbone_projection or backbone
            init_checkpoint_modules='backbone_projection'),
        trainer=cfg.TrainerConfig(
            steps_per_loop=steps_per_epoch,
            summary_interval=steps_per_epoch,
            checkpoint_interval=steps_per_epoch,
            train_steps=60 * steps_per_epoch,
            validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
            validation_interval=steps_per_epoch,
            optimizer_config=optimization.OptimizationConfig({
                'optimizer': {
                    'type': 'lars',
                    'lars': {
                        'momentum':
                        0.9,
                        'weight_decay_rate':
                        0.0,
                        'exclude_from_weight_decay':
                        ['batch_normalization', 'bias']
                    }
                },
                'learning_rate': {
                    'type': 'cosine',
                    'cosine': {
                        # 0.01 × BatchSize / 512
                        'initial_learning_rate': 0.01 * train_batch_size / 512,
                        'decay_steps': 60 * steps_per_epoch
                    }
                }
            })),
        restrictions=[
            'task.train_data.is_training != None',
            'task.validation_data.is_training != None'
        ])
Beispiel #15
0
def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
    """Image classification general."""
    train_batch_size = 4096
    eval_batch_size = 4096
    steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
    return cfg.ExperimentConfig(
        task=SimCLRPretrainTask(
            model=SimCLRModel(
                mode=simclr_model.PRETRAIN,
                backbone_trainable=True,
                input_size=[224, 224, 3],
                backbone=backbones.Backbone(
                    type='resnet', resnet=backbones.ResNet(model_id=50)),
                projection_head=ProjectionHead(proj_output_dim=128,
                                               num_proj_layers=3,
                                               ft_proj_idx=1),
                supervised_head=SupervisedHead(num_classes=1001),
                norm_activation=common.NormActivation(norm_momentum=0.9,
                                                      norm_epsilon=1e-5,
                                                      use_sync_bn=True)),
            loss=ContrastiveLoss(),
            evaluation=Evaluation(),
            train_data=DataConfig(parser=Parser(mode=simclr_model.PRETRAIN),
                                  decoder=Decoder(decode_label=True),
                                  input_path=os.path.join(
                                      IMAGENET_INPUT_PATH_BASE, 'train*'),
                                  is_training=True,
                                  global_batch_size=train_batch_size),
            validation_data=DataConfig(
                parser=Parser(mode=simclr_model.PRETRAIN),
                decoder=Decoder(decode_label=True),
                input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
                is_training=False,
                global_batch_size=eval_batch_size),
        ),
        trainer=cfg.TrainerConfig(
            steps_per_loop=steps_per_epoch,
            summary_interval=steps_per_epoch,
            checkpoint_interval=steps_per_epoch,
            train_steps=500 * steps_per_epoch,
            validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
            validation_interval=steps_per_epoch,
            optimizer_config=optimization.OptimizationConfig({
                'optimizer': {
                    'type': 'lars',
                    'lars': {
                        'momentum':
                        0.9,
                        'weight_decay_rate':
                        0.000001,
                        'exclude_from_weight_decay':
                        ['batch_normalization', 'bias']
                    }
                },
                'learning_rate': {
                    'type': 'cosine',
                    'cosine': {
                        # 0.2 * BatchSize / 256
                        'initial_learning_rate': 0.2 * train_batch_size / 256,
                        # train_steps - warmup_steps
                        'decay_steps': 475 * steps_per_epoch
                    }
                },
                'warmup': {
                    'type': 'linear',
                    'linear': {
                        # 5% of total epochs
                        'warmup_steps': 25 * steps_per_epoch
                    }
                }
            })),
        restrictions=[
            'task.train_data.is_training != None',
            'task.validation_data.is_training != None'
        ])