Ejemplo n.º 1
0
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
  """Image classification on imagenet with vision transformer."""
  train_batch_size = 4096
  eval_batch_size = 4096
  steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
  config = cfg.ExperimentConfig(
      task=ImageClassificationTask(
          model=ImageClassificationModel(
              num_classes=1001,
              input_size=[224, 224, 3],
              backbone=backbones.Backbone(
                  type='vit',
                  vit=backbones.VisionTransformer(
                      model_name='vit-b16', representation_size=768))),
          losses=Losses(l2_weight_decay=0.0),
          train_data=DataConfig(
              input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
              is_training=True,
              global_batch_size=train_batch_size),
          validation_data=DataConfig(
              input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
              is_training=False,
              global_batch_size=eval_batch_size)),
      trainer=cfg.TrainerConfig(
          steps_per_loop=steps_per_epoch,
          summary_interval=steps_per_epoch,
          checkpoint_interval=steps_per_epoch,
          train_steps=300 * steps_per_epoch,
          validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
          validation_interval=steps_per_epoch,
          optimizer_config=optimization.OptimizationConfig({
              'optimizer': {
                  'type': 'adamw',
                  'adamw': {
                      'weight_decay_rate': 0.3,
                      'include_in_weight_decay': r'.*(kernel|weight):0$',
                  }
              },
              'learning_rate': {
                  'type': 'cosine',
                  'cosine': {
                      'initial_learning_rate': 0.003,
                      'decay_steps': 300 * steps_per_epoch,
                  }
              },
              'warmup': {
                  'type': 'linear',
                  'linear': {
                      'warmup_steps': 10000,
                      'warmup_learning_rate': 0
                  }
              }
          })),
      restrictions=[
          'task.train_data.is_training != None',
          'task.validation_data.is_training != None'
      ])

  return config
Ejemplo n.º 2
0
class ImageClassificationModel(hyperparams.Config):
  """The model config."""
  num_classes: int = 0
  input_size: List[int] = dataclasses.field(default_factory=list)
  backbone: backbones.Backbone = backbones.Backbone(
      type='vit', vit=backbones.VisionTransformer())
  dropout_rate: float = 0.0
  norm_activation: common.NormActivation = common.NormActivation(
      use_sync_bn=False)
  # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
  add_head_batch_norm: bool = False
def image_classification_imagenet_vit_finetune() -> cfg.ExperimentConfig:
    """Image classification on imagenet with vision transformer."""
    train_batch_size = 512
    eval_batch_size = 512
    steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
    config = cfg.ExperimentConfig(
        task=ImageClassificationTask(
            model=ImageClassificationModel(
                num_classes=1001,
                input_size=[384, 384, 3],
                backbone=backbones.Backbone(
                    type='vit',
                    vit=backbones.VisionTransformer(model_name='vit-b16'))),
            losses=Losses(l2_weight_decay=0.0),
            train_data=DataConfig(input_path=os.path.join(
                IMAGENET_INPUT_PATH_BASE, 'train*'),
                                  is_training=True,
                                  global_batch_size=train_batch_size),
            validation_data=DataConfig(input_path=os.path.join(
                IMAGENET_INPUT_PATH_BASE, 'valid*'),
                                       is_training=False,
                                       global_batch_size=eval_batch_size)),
        trainer=cfg.TrainerConfig(
            steps_per_loop=steps_per_epoch,
            summary_interval=steps_per_epoch,
            checkpoint_interval=steps_per_epoch,
            train_steps=20000,
            validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
            validation_interval=steps_per_epoch,
            optimizer_config=optimization.OptimizationConfig({
                'optimizer': {
                    'type': 'sgd',
                    'sgd': {
                        'momentum': 0.9,
                        'global_clipnorm': 1.0,
                    }
                },
                'learning_rate': {
                    'type': 'cosine',
                    'cosine': {
                        'initial_learning_rate': 0.003,
                        'decay_steps': 20000,
                    }
                }
            })),
        restrictions=[
            'task.train_data.is_training != None',
            'task.validation_data.is_training != None'
        ])

    return config
def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
    """Image classification on imagenet with vision transformer."""
    train_batch_size = 4096  # originally was 1024 but 4096 better for tpu v3-32
    eval_batch_size = 4096  # originally was 1024 but 4096 better for tpu v3-32
    num_classes = 1001
    label_smoothing = 0.1
    steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
    config = cfg.ExperimentConfig(
        task=ImageClassificationTask(
            model=ImageClassificationModel(
                num_classes=num_classes,
                input_size=[224, 224, 3],
                kernel_initializer='zeros',
                backbone=backbones.Backbone(
                    type='vit',
                    vit=backbones.VisionTransformer(
                        model_name='vit-b16',
                        representation_size=768,
                        init_stochastic_depth_rate=0.1,
                        original_init=False,
                        transformer=backbones.Transformer(
                            dropout_rate=0.0, attention_dropout_rate=0.0)))),
            losses=Losses(l2_weight_decay=0.0,
                          label_smoothing=label_smoothing,
                          one_hot=False,
                          soft_labels=True),
            train_data=DataConfig(
                input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
                is_training=True,
                global_batch_size=train_batch_size,
                aug_type=common.Augmentation(type='randaug',
                                             randaug=common.RandAugment(
                                                 magnitude=9,
                                                 exclude_ops=['Cutout'])),
                mixup_and_cutmix=common.MixupAndCutmix(
                    label_smoothing=label_smoothing)),
            validation_data=DataConfig(input_path=os.path.join(
                IMAGENET_INPUT_PATH_BASE, 'valid*'),
                                       is_training=False,
                                       global_batch_size=eval_batch_size)),
        trainer=cfg.TrainerConfig(
            steps_per_loop=steps_per_epoch,
            summary_interval=steps_per_epoch,
            checkpoint_interval=steps_per_epoch,
            train_steps=300 * steps_per_epoch,
            validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
            validation_interval=steps_per_epoch,
            optimizer_config=optimization.OptimizationConfig({
                'optimizer': {
                    'type': 'adamw',
                    'adamw': {
                        'weight_decay_rate': 0.05,
                        'include_in_weight_decay': r'.*(kernel|weight):0$',
                        'gradient_clip_norm': 0.0
                    }
                },
                'learning_rate': {
                    'type': 'cosine',
                    'cosine': {
                        'initial_learning_rate':
                        0.0005 * train_batch_size / 512,
                        'decay_steps': 300 * steps_per_epoch,
                    }
                },
                'warmup': {
                    'type': 'linear',
                    'linear': {
                        'warmup_steps': 5 * steps_per_epoch,
                        'warmup_learning_rate': 0
                    }
                }
            })),
        restrictions=[
            'task.train_data.is_training != None',
            'task.validation_data.is_training != None'
        ])

    return config