示例#1
0
class AutosegEdgeTPUTaskConfig(base_cfg.SemanticSegmentationTask):
    """The task config inherited from the base segmentation task."""

    model: AutosegEdgeTPUModelConfig = AutosegEdgeTPUModelConfig()
    train_data: base_cfg.DataConfig = base_cfg.DataConfig(is_training=True)
    validation_data: base_cfg.DataConfig = base_cfg.DataConfig(
        is_training=False)
    losses: Losses = Losses()
    init_checkpoint: Optional[str] = None
    init_checkpoint_modules: str = 'backbone'  # all or backbone
    model_output_keys: Optional[List[int]] = dataclasses.field(
        default_factory=list)
示例#2
0
def autoseg_edgetpu_experiment_config(
        backbone_name: str,
        init_backbone: bool = True) -> cfg.ExperimentConfig:
    """Experiment using the semantic segmenatation searched model.

  Args:
    backbone_name: Name of the backbone used for this model
    init_backbone: Whether to initialize backbone from a pretrained checkpoint
  Returns:
    ExperimentConfig
  """
    epochs = 300
    train_batch_size = 64
    eval_batch_size = 32
    image_size = 512
    steps_per_epoch = ADE20K_TRAIN_EXAMPLES // train_batch_size
    train_steps = epochs * steps_per_epoch
    model_config = AutosegEdgeTPUModelConfig(
        num_classes=32, input_size=[image_size, image_size, 3])
    model_config.model_params.model_name = backbone_name
    if init_backbone:
        model_config.model_params.model_weights_path = (
            BACKBONE_PRETRAINED_CHECKPOINT[backbone_name])
    model_config.model_params.overrides.resolution = image_size
    config = cfg.ExperimentConfig(
        task=AutosegEdgeTPUTaskConfig(
            model=model_config,
            train_data=base_cfg.DataConfig(
                input_path=os.path.join(ADE20K_INPUT_PATH_BASE, 'train-*'),
                output_size=[image_size, image_size],
                is_training=True,
                global_batch_size=train_batch_size,
                aug_scale_min=0.5,
                aug_scale_max=2.0),
            validation_data=base_cfg.DataConfig(
                input_path=os.path.join(ADE20K_INPUT_PATH_BASE, 'val-*'),
                output_size=[image_size, image_size],
                is_training=False,
                resize_eval_groundtruth=True,
                drop_remainder=True,
                global_batch_size=eval_batch_size),
            evaluation=base_cfg.Evaluation(report_train_mean_iou=False)),
        trainer=cfg.TrainerConfig(
            steps_per_loop=steps_per_epoch,
            summary_interval=steps_per_epoch,
            checkpoint_interval=steps_per_epoch * 5,
            max_to_keep=10,
            train_steps=train_steps,
            validation_steps=ADE20K_VAL_EXAMPLES // eval_batch_size,
            validation_interval=steps_per_epoch,
            optimizer_config=optimization.OptimizationConfig({
                'optimizer': {
                    'type': 'sgd',
                    'sgd': {
                        'nesterov': True,
                        'momentum': 0.9,
                    }
                },
                'ema': {
                    'average_decay': 0.9998,
                    'trainable_weights_only': False,
                },
                'learning_rate': {
                    'type': 'cosine',
                    'cosine': {
                        'initial_learning_rate': 0.12,
                        'decay_steps': train_steps
                    }
                },
                'warmup': {
                    'type': 'linear',
                    'linear': {
                        'warmup_steps': 5 * steps_per_epoch,
                        'warmup_learning_rate': 0
                    }
                },
            })),
        restrictions=[
            'task.train_data.is_training != None',
            'task.validation_data.is_training != None'
        ])
    return config
def seg_deeplabv3plus_ade20k_32(backbone: str,
                                init_backbone: bool = True
                                ) -> cfg.ExperimentConfig:
    """Semantic segmentation on ADE20K dataset with deeplabv3+."""
    epochs = 200
    train_batch_size = 128
    eval_batch_size = 32
    image_size = 512
    steps_per_epoch = ADE20K_TRAIN_EXAMPLES // train_batch_size
    aspp_dilation_rates = [5, 10, 15]
    pretrained_checkpoint_path = BACKBONE_PRETRAINED_CHECKPOINT[
        backbone] if init_backbone else None
    config = cfg.ExperimentConfig(
        task=CustomSemanticSegmentationTaskConfig(
            model=base_cfg.SemanticSegmentationModel(
                # ADE20K uses only 32 semantic classes for train/evaluation.
                # The void (background) class is ignored in train and evaluation.
                num_classes=32,
                input_size=[None, None, 3],
                backbone=Backbone(
                    type='mobilenet_edgetpu',
                    mobilenet_edgetpu=MobileNetEdgeTPU(
                        model_id=backbone,
                        pretrained_checkpoint_path=pretrained_checkpoint_path,
                        freeze_large_filters=500,
                    )),
                decoder=decoders.Decoder(
                    type='aspp',
                    aspp=decoders.ASPP(
                        level=BACKBONE_HEADPOINT[backbone],
                        use_depthwise_convolution=True,
                        dilation_rates=aspp_dilation_rates,
                        pool_kernel_size=[256, 256],
                        num_filters=128,
                        dropout_rate=0.3,
                    )),
                head=base_cfg.SegmentationHead(
                    level=BACKBONE_HEADPOINT[backbone],
                    num_convs=2,
                    num_filters=256,
                    use_depthwise_convolution=True,
                    feature_fusion='deeplabv3plus',
                    low_level=BACKBONE_LOWER_FEATURES[backbone],
                    low_level_num_filters=48),
                norm_activation=common.NormActivation(activation='relu',
                                                      norm_momentum=0.99,
                                                      norm_epsilon=2e-3,
                                                      use_sync_bn=False)),
            train_data=base_cfg.DataConfig(
                input_path=os.path.join(ADE20K_INPUT_PATH_BASE, 'train-*'),
                output_size=[image_size, image_size],
                is_training=True,
                global_batch_size=train_batch_size),
            validation_data=base_cfg.DataConfig(
                input_path=os.path.join(ADE20K_INPUT_PATH_BASE, 'val-*'),
                output_size=[image_size, image_size],
                is_training=False,
                global_batch_size=eval_batch_size,
                resize_eval_groundtruth=True,
                drop_remainder=False),
            evaluation=base_cfg.Evaluation(report_train_mean_iou=False),
        ),
        trainer=cfg.TrainerConfig(
            steps_per_loop=steps_per_epoch,
            summary_interval=steps_per_epoch,
            checkpoint_interval=steps_per_epoch,
            train_steps=epochs * steps_per_epoch,
            validation_steps=ADE20K_VAL_EXAMPLES // eval_batch_size,
            validation_interval=steps_per_epoch,
            optimizer_config=optimization.OptimizationConfig({
                'optimizer': {
                    'type': 'adam',
                },
                'learning_rate': {
                    'type': 'polynomial',
                    'polynomial': {
                        'initial_learning_rate': 0.0001,
                        'decay_steps': epochs * steps_per_epoch,
                        'end_learning_rate': 0.0,
                        'power': 0.9
                    }
                },
                'warmup': {
                    'type': 'linear',
                    'linear': {
                        'warmup_steps': 4 * steps_per_epoch,
                        'warmup_learning_rate': 0
                    }
                }
            })),
        restrictions=[
            'task.train_data.is_training != None',
            'task.validation_data.is_training != None'
        ])

    return config