def test_resnet_creation(self, model_id): """Test creation of ResNet models.""" network = backbones.ResNet(model_id=model_id, se_ratio=0.0, norm_momentum=0.99, norm_epsilon=1e-5) backbone_config = backbones_cfg.Backbone(type='resnet', resnet=backbones_cfg.ResNet( model_id=model_id, se_ratio=0.0)) norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False) factory_network = factory.build_backbone( input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]), backbone_config=backbone_config, norm_activation_config=norm_activation_config) network_config = network.get_config() factory_network_config = factory_network.get_config() self.assertEqual(network_config, factory_network_config)
def test_validation_step(self): config = detr_cfg.DetrTask( model=detr_cfg.Detr( input_size=[1333, 1333, 3], num_encoder_layers=1, num_decoder_layers=1, backbone=backbones.Backbone( type='resnet', resnet=backbones.ResNet(model_id=10, bn_trainable=False)) ), losses=detr_cfg.Losses(class_offset=1), validation_data=detr_cfg.DataConfig( tfds_name='coco/2017', tfds_split='validation', is_training=False, global_batch_size=2, )) with tfds.testing.mock_data(as_dataset_fn=_as_dataset): task = detection.DetectionTask(config) model = task.build_model() metrics = task.build_metrics(training=False) dataset = task.build_inputs(config.validation_data) iterator = iter(dataset) logs = task.validation_step(next(iterator), model, metrics) state = task.aggregate_logs(step_outputs=logs) task.reduce_aggregated_logs(state)
class SimCLRModel(hyperparams.Config): """SimCLR model config.""" input_size: List[int] = dataclasses.field(default_factory=list) backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) projection_head: ProjectionHead = ProjectionHead( proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1) supervised_head: SupervisedHead = SupervisedHead(num_classes=1001) norm_activation: common.NormActivation = common.NormActivation( norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False) mode: str = simclr_model.PRETRAIN backbone_trainable: bool = True
class ImageClassificationModel(hyperparams.Config): """The model config.""" num_classes: int = 0 input_size: List[int] = dataclasses.field(default_factory=list) backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) dropout_rate: float = 0.0 norm_activation: common.NormActivation = common.NormActivation( use_sync_bn=False) # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification add_head_batch_norm: bool = False kernel_initializer: str = 'random_uniform'
class SemanticSegmentationModel(hyperparams.Config): """Semantic segmentation model config.""" num_classes: int = 0 input_size: List[int] = dataclasses.field(default_factory=list) min_level: int = 3 max_level: int = 6 head: SegmentationHead = SegmentationHead() backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) decoder: decoders.Decoder = decoders.Decoder(type='identity') mask_scoring_head: Optional[MaskScoringHead] = None norm_activation: common.NormActivation = common.NormActivation()
class Detr(hyperparams.Config): num_queries: int = 100 hidden_size: int = 256 num_classes: int = 91 # 0: background num_encoder_layers: int = 6 num_decoder_layers: int = 6 input_size: List[int] = dataclasses.field(default_factory=list) backbone: backbones.Backbone = backbones.Backbone(type='resnet', resnet=backbones.ResNet( model_id=50, bn_trainable=False)) norm_activation: common.NormActivation = common.NormActivation()
class RetinaNet(hyperparams.Config): num_classes: int = 0 input_size: List[int] = dataclasses.field(default_factory=list) min_level: int = 3 max_level: int = 7 anchor: Anchor = Anchor() backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) decoder: decoders.Decoder = decoders.Decoder(type='fpn', fpn=decoders.FPN()) head: RetinaNetHead = RetinaNetHead() detection_generator: DetectionGenerator = DetectionGenerator() norm_activation: common.NormActivation = common.NormActivation()
class PanopticDeeplab(hyperparams.Config): """Panoptic Deeplab model config.""" num_classes: int = 2 input_size: List[int] = dataclasses.field(default_factory=list) min_level: int = 3 max_level: int = 6 norm_activation: common.NormActivation = common.NormActivation() backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) decoder: decoders.Decoder = decoders.Decoder(type='aspp') semantic_head: SemanticHead = SemanticHead() instance_head: InstanceHead = InstanceHead() shared_decoder: bool = False generate_panoptic_masks: bool = True post_processor: PanopticDeeplabPostProcessor = PanopticDeeplabPostProcessor( )
class SimCLRMTModelConfig(hyperparams.Config): """Model config for multi-task SimCLR model.""" input_size: List[int] = dataclasses.field(default_factory=list) backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) backbone_trainable: bool = True projection_head: simclr_configs.ProjectionHead = simclr_configs.ProjectionHead( proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1) norm_activation: common.NormActivation = common.NormActivation( norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False) heads: Tuple[SimCLRMTHeadConfig, ...] = () # L2 weight decay is used in the model, not in task. # Note that this can not be used together with lars optimizer. l2_weight_decay: float = 0.0 init_checkpoint: str = '' # backbone_projection or backbone init_checkpoint_modules: str = 'backbone_projection'
def test_train_step(self): config = detr_cfg.DetrTask( model=detr_cfg.Detr( input_size=[1333, 1333, 3], num_encoder_layers=1, num_decoder_layers=1, num_classes=81, backbone=backbones.Backbone( type='resnet', resnet=backbones.ResNet(model_id=10, bn_trainable=False)) ), train_data=coco.COCODataConfig( tfds_name='coco/2017', tfds_split='validation', is_training=True, global_batch_size=2, )) with tfds.testing.mock_data(as_dataset_fn=_as_dataset): task = detection.DetectionTask(config) model = task.build_model() dataset = task.build_inputs(config.train_data) iterator = iter(dataset) opt_cfg = optimization.OptimizationConfig({ 'optimizer': { 'type': 'detr_adamw', 'detr_adamw': { 'weight_decay_rate': 1e-4, 'global_clipnorm': 0.1, } }, 'learning_rate': { 'type': 'stepwise', 'stepwise': { 'boundaries': [120000], 'values': [0.0001, 1.0e-05] } }, }) optimizer = detection.DetectionTask.create_optimizer(opt_cfg) task.train_step(next(iterator), model, optimizer)
class MaskRCNN(hyperparams.Config): num_classes: int = 0 input_size: List[int] = dataclasses.field(default_factory=list) min_level: int = 2 max_level: int = 6 anchor: Anchor = Anchor() include_mask: bool = True backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) decoder: decoders.Decoder = decoders.Decoder(type='fpn', fpn=decoders.FPN()) rpn_head: RPNHead = RPNHead() detection_head: DetectionHead = DetectionHead() roi_generator: ROIGenerator = ROIGenerator() roi_sampler: ROISampler = ROISampler() roi_aligner: ROIAligner = ROIAligner() detection_generator: DetectionGenerator = DetectionGenerator() mask_head: Optional[MaskHead] = MaskHead() mask_sampler: Optional[MaskSampler] = MaskSampler() mask_roi_aligner: Optional[MaskROIAligner] = MaskROIAligner() norm_activation: common.NormActivation = common.NormActivation( norm_momentum=0.997, norm_epsilon=0.0001, use_sync_bn=True)
def simclr_finetuning_imagenet() -> cfg.ExperimentConfig: """Image classification general.""" train_batch_size = 1024 eval_batch_size = 1024 steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size pretrain_model_base = '' return cfg.ExperimentConfig( task=SimCLRFinetuneTask( model=SimCLRModel( mode=simclr_model.FINETUNE, backbone_trainable=True, input_size=[224, 224, 3], backbone=backbones.Backbone( type='resnet', resnet=backbones.ResNet(model_id=50)), projection_head=ProjectionHead( proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1), supervised_head=SupervisedHead(num_classes=1001, zero_init=True), norm_activation=common.NormActivation( norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)), loss=ClassificationLosses(), evaluation=Evaluation(), train_data=DataConfig( parser=Parser(mode=simclr_model.FINETUNE), input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), is_training=True, global_batch_size=train_batch_size), validation_data=DataConfig( parser=Parser(mode=simclr_model.FINETUNE), input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'), is_training=False, global_batch_size=eval_batch_size), init_checkpoint=pretrain_model_base, # all, backbone_projection or backbone init_checkpoint_modules='backbone_projection'), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=60 * steps_per_epoch, validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'lars', 'lars': { 'momentum': 0.9, 'weight_decay_rate': 0.0, 'exclude_from_weight_decay': [ 'batch_normalization', 'bias' ] } }, 'learning_rate': { 'type': 'cosine', 'cosine': { # 0.01 × BatchSize / 512 'initial_learning_rate': 0.01 * train_batch_size / 512, 'decay_steps': 60 * steps_per_epoch } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ])
def simclr_pretraining_imagenet() -> cfg.ExperimentConfig: """Image classification general.""" train_batch_size = 4096 eval_batch_size = 4096 steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size return cfg.ExperimentConfig( task=SimCLRPretrainTask( model=SimCLRModel( mode=simclr_model.PRETRAIN, backbone_trainable=True, input_size=[224, 224, 3], backbone=backbones.Backbone( type='resnet', resnet=backbones.ResNet(model_id=50)), projection_head=ProjectionHead( proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1), supervised_head=SupervisedHead(num_classes=1001), norm_activation=common.NormActivation( norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=True)), loss=ContrastiveLoss(), evaluation=Evaluation(), train_data=DataConfig( parser=Parser(mode=simclr_model.PRETRAIN), decoder=Decoder(decode_label=True), input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), is_training=True, global_batch_size=train_batch_size), validation_data=DataConfig( parser=Parser(mode=simclr_model.PRETRAIN), decoder=Decoder(decode_label=True), input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'), is_training=False, global_batch_size=eval_batch_size), ), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=500 * steps_per_epoch, validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'lars', 'lars': { 'momentum': 0.9, 'weight_decay_rate': 0.000001, 'exclude_from_weight_decay': [ 'batch_normalization', 'bias' ] } }, 'learning_rate': { 'type': 'cosine', 'cosine': { # 0.2 * BatchSize / 256 'initial_learning_rate': 0.2 * train_batch_size / 256, # train_steps - warmup_steps 'decay_steps': 475 * steps_per_epoch } }, 'warmup': { 'type': 'linear', 'linear': { # 5% of total epochs 'warmup_steps': 25 * steps_per_epoch } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ])
def image_classification_imagenet_resnetrs() -> cfg.ExperimentConfig: """Image classification on imagenet with resnet-rs.""" train_batch_size = 4096 eval_batch_size = 4096 steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size config = cfg.ExperimentConfig( task=ImageClassificationTask( model=ImageClassificationModel( num_classes=1001, input_size=[160, 160, 3], backbone=backbones.Backbone( type='resnet', resnet=backbones.ResNet(model_id=50, stem_type='v1', resnetd_shortcut=True, replace_stem_max_pool=True, se_ratio=0.25, stochastic_depth_drop_rate=0.0)), dropout_rate=0.25, norm_activation=common.NormActivation(norm_momentum=0.0, norm_epsilon=1e-5, use_sync_bn=False, activation='swish')), losses=Losses(l2_weight_decay=4e-5, label_smoothing=0.1), train_data=DataConfig( input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), is_training=True, global_batch_size=train_batch_size, aug_type=common.Augmentation( type='randaug', randaug=common.RandAugment(magnitude=10))), validation_data=DataConfig(input_path=os.path.join( IMAGENET_INPUT_PATH_BASE, 'valid*'), is_training=False, global_batch_size=eval_batch_size)), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=350 * steps_per_epoch, validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'ema': { 'average_decay': 0.9999, 'trainable_weights_only': False, }, 'learning_rate': { 'type': 'cosine', 'cosine': { 'initial_learning_rate': 1.6, 'decay_steps': 350 * steps_per_epoch } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def image_classification_imagenet() -> cfg.ExperimentConfig: """Image classification on imagenet with resnet.""" train_batch_size = 4096 eval_batch_size = 4096 steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size config = cfg.ExperimentConfig( runtime=cfg.RuntimeConfig(enable_xla=True), task=ImageClassificationTask( model=ImageClassificationModel( num_classes=1001, input_size=[224, 224, 3], backbone=backbones.Backbone( type='resnet', resnet=backbones.ResNet(model_id=50)), norm_activation=common.NormActivation(norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)), losses=Losses(l2_weight_decay=1e-4), train_data=DataConfig(input_path=os.path.join( IMAGENET_INPUT_PATH_BASE, 'train*'), is_training=True, global_batch_size=train_batch_size), validation_data=DataConfig(input_path=os.path.join( IMAGENET_INPUT_PATH_BASE, 'valid*'), is_training=False, global_batch_size=eval_batch_size)), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=90 * steps_per_epoch, validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'stepwise', 'stepwise': { 'boundaries': [ 30 * steps_per_epoch, 60 * steps_per_epoch, 80 * steps_per_epoch ], 'values': [ 0.1 * train_batch_size / 256, 0.01 * train_batch_size / 256, 0.001 * train_batch_size / 256, 0.0001 * train_batch_size / 256, ] } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def seg_resnetfpn_pascal() -> cfg.ExperimentConfig: """Image segmentation on pascal voc with resnet-fpn.""" train_batch_size = 256 eval_batch_size = 32 steps_per_epoch = PASCAL_TRAIN_EXAMPLES // train_batch_size config = cfg.ExperimentConfig( task=SemanticSegmentationTask( model=SemanticSegmentationModel( num_classes=21, input_size=[512, 512, 3], min_level=3, max_level=7, backbone=backbones.Backbone( type='resnet', resnet=backbones.ResNet(model_id=50)), decoder=decoders.Decoder(type='fpn', fpn=decoders.FPN()), head=SegmentationHead(level=3, num_convs=3), norm_activation=common.NormActivation(activation='swish', use_sync_bn=True)), losses=Losses(l2_weight_decay=1e-4), train_data=DataConfig(input_path=os.path.join( PASCAL_INPUT_PATH_BASE, 'train_aug*'), is_training=True, global_batch_size=train_batch_size, aug_scale_min=0.2, aug_scale_max=1.5), validation_data=DataConfig(input_path=os.path.join( PASCAL_INPUT_PATH_BASE, 'val*'), is_training=False, global_batch_size=eval_batch_size, resize_eval_groundtruth=False, groundtruth_padded_size=[512, 512], drop_remainder=False), ), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=450 * steps_per_epoch, validation_steps=PASCAL_VAL_EXAMPLES // eval_batch_size, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', 'sgd': { 'momentum': 0.9 } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 0.007, 'decay_steps': 450 * steps_per_epoch, 'end_learning_rate': 0.0, 'power': 0.9 } }, 'warmup': { 'type': 'linear', 'linear': { 'warmup_steps': 5 * steps_per_epoch, 'warmup_learning_rate': 0 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config