def test_mobilenet_creation(self, model_id, filter_size_scale): """Test creation of Mobilenet models.""" network = backbones.MobileNet(model_id=model_id, filter_size_scale=filter_size_scale, norm_momentum=0.99, norm_epsilon=1e-5) backbone_config = backbones_cfg.Backbone( type='mobilenet', mobilenet=backbones_cfg.MobileNet( model_id=model_id, filter_size_scale=filter_size_scale)) norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False) factory_network = factory.build_backbone( input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]), backbone_config=backbone_config, norm_activation_config=norm_activation_config) network_config = network.get_config() factory_network_config = factory_network.get_config() self.assertEqual(network_config, factory_network_config)
def test_spinenet_creation(self, model_id): """Test creation of SpineNet models.""" input_size = 128 min_level = 3 max_level = 7 input_specs = tf.keras.layers.InputSpec( shape=[None, input_size, input_size, 3]) network = backbones.SpineNet(input_specs=input_specs, min_level=min_level, max_level=max_level, norm_momentum=0.99, norm_epsilon=1e-5) backbone_config = backbones_cfg.Backbone( type='spinenet', spinenet=backbones_cfg.SpineNet(model_id=model_id)) norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False) factory_network = factory.build_backbone( input_specs=tf.keras.layers.InputSpec( shape=[None, input_size, input_size, 3]), backbone_config=backbone_config, norm_activation_config=norm_activation_config) network_config = network.get_config() factory_network_config = factory_network.get_config() self.assertEqual(network_config, factory_network_config)
def test_efficientnet_creation(self, model_id, se_ratio): """Test creation of EfficientNet models.""" network = backbones.EfficientNet(model_id=model_id, se_ratio=se_ratio, norm_momentum=0.99, norm_epsilon=1e-5) backbone_config = backbones_cfg.Backbone( type='efficientnet', efficientnet=backbones_cfg.EfficientNet(model_id=model_id, se_ratio=se_ratio)) norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False) factory_network = factory.build_backbone( input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]), backbone_config=backbone_config, norm_activation_config=norm_activation_config) network_config = network.get_config() factory_network_config = factory_network.get_config() self.assertEqual(network_config, factory_network_config)
def test_resnet_creation(self, model_id): """Test creation of ResNet models.""" network = backbones.ResNet(model_id=model_id, se_ratio=0.0, norm_momentum=0.99, norm_epsilon=1e-5) backbone_config = backbones_cfg.Backbone(type='resnet', resnet=backbones_cfg.ResNet( model_id=model_id, se_ratio=0.0)) norm_activation_config = common_cfg.NormActivation(norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False) model_config = retinanet_cfg.RetinaNet( backbone=backbone_config, norm_activation=norm_activation_config) factory_network = factory.build_backbone( input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]), model_config=model_config) network_config = network.get_config() factory_network_config = factory_network.get_config() self.assertEqual(network_config, factory_network_config)
class VideoClassificationModel(hyperparams.Config): """The model config.""" backbone: backbones_3d.Backbone3D = backbones_3d.Backbone3D( type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()) norm_activation: common.NormActivation = common.NormActivation() dropout_rate: float = 0.2 add_head_batch_norm: bool = False
class AssembleNetModel(video_classification.VideoClassificationModel): """The AssembleNet model config.""" model_type: str = 'assemblenet' backbone: Backbone3D = Backbone3D() norm_activation: common.NormActivation = common.NormActivation( norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=True) max_pool_preditions: bool = False
def testBuildCenterNet(self): backbone = hourglass.build_hourglass( input_specs=tf.keras.layers.InputSpec(shape=[None, 512, 512, 3]), backbone_config=backbones.Backbone(type='hourglass'), norm_activation_config=common.NormActivation(use_sync_bn=True)) task_config = { 'ct_heatmaps': 90, 'ct_offset': 2, 'ct_size': 2, } input_levels = ['2_0', '2'] head = centernet_head.CenterNetHead(task_outputs=task_config, input_specs=backbone.output_specs, input_levels=input_levels) detection_ge = detection_generator.CenterNetDetectionGenerator() model = centernet_model.CenterNetModel( backbone=backbone, head=head, detection_generator=detection_ge) outputs = model(tf.zeros((5, 512, 512, 3))) self.assertLen(outputs['raw_output'], 3) self.assertLen(outputs['raw_output']['ct_heatmaps'], 2) self.assertLen(outputs['raw_output']['ct_offset'], 2) self.assertLen(outputs['raw_output']['ct_size'], 2) self.assertEqual(outputs['raw_output']['ct_heatmaps'][0].shape, (5, 128, 128, 90)) self.assertEqual(outputs['raw_output']['ct_offset'][0].shape, (5, 128, 128, 2)) self.assertEqual(outputs['raw_output']['ct_size'][0].shape, (5, 128, 128, 2))
class VideoClassificationModel(hyperparams.Config): """The model config.""" model_type: str = 'video_classification' backbone: backbones_3d.Backbone3D = backbones_3d.Backbone3D( type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()) norm_activation: common.NormActivation = common.NormActivation() dropout_rate: float = 0.2 aggregate_endpoints: bool = False
class VideoSSLModel(VideoClassificationModel): """The model config.""" normalize_feature: bool = False hidden_dim: int = 2048 hidden_layer_num: int = 3 projection_dim: int = 128 hidden_norm_activation: common.NormActivation = common.NormActivation( use_sync_bn=False, norm_momentum=0.997, norm_epsilon=1.0e-05)
def seg_unet3d_test() -> cfg.ExperimentConfig: """Image segmentation on a dummy dataset with 3D UNet for testing purpose.""" train_batch_size = 2 eval_batch_size = 2 steps_per_epoch = 10 config = cfg.ExperimentConfig( task=SemanticSegmentation3DTask( model=SemanticSegmentationModel3D( num_classes=2, input_size=[32, 32, 32], num_channels=2, backbone=backbones.Backbone( type='unet_3d', unet_3d=backbones.UNet3D(model_id=2)), decoder=decoders.Decoder( type='unet_3d_decoder', unet_3d_decoder=decoders.UNet3DDecoder(model_id=2)), head=SegmentationHead3D(num_convs=0, num_classes=2), norm_activation=common.NormActivation( activation='relu', use_sync_bn=False)), train_data=DataConfig( input_path='train.tfrecord', num_classes=2, input_size=[32, 32, 32], num_channels=2, is_training=True, global_batch_size=train_batch_size), validation_data=DataConfig( input_path='val.tfrecord', num_classes=2, input_size=[32, 32, 32], num_channels=2, is_training=False, global_batch_size=eval_batch_size), losses=Losses(loss_type='adaptive')), trainer=cfg.TrainerConfig( steps_per_loop=steps_per_epoch, summary_interval=steps_per_epoch, checkpoint_interval=steps_per_epoch, train_steps=10, validation_steps=10, validation_interval=steps_per_epoch, optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'sgd', }, 'learning_rate': { 'type': 'constant', 'constant': { 'learning_rate': 0.000001 } } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None' ]) return config
def test_hourglass(self): backbone = hourglass.build_hourglass( input_specs=tf.keras.layers.InputSpec(shape=[None, 512, 512, 3]), backbone_config=backbones.Backbone(type='hourglass'), norm_activation_config=common.NormActivation(use_sync_bn=True)) inputs = np.zeros((2, 512, 512, 3), dtype=np.float32) outputs = backbone(inputs) self.assertEqual(outputs['2_0'].shape, (2, 128, 128, 256)) self.assertEqual(outputs['2'].shape, (2, 128, 128, 256))
class MultiHeadModel(hyperparams.Config): """Multi head multi task model config, similar to other models but with input, backbone, activation and weight decay shared.""" input_size: List[int] = dataclasses.field(default_factory=list) backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) norm_activation: common.NormActivation = common.NormActivation() heads: List[Submodel] = dataclasses.field(default_factory=list) l2_weight_decay: float = 0.0
class ImageClassificationModel(hyperparams.Config): num_classes: int = 0 input_size: List[int] = dataclasses.field(default_factory=list) backbone: backbones.Backbone = backbones.Backbone( type='darknet', darknet=backbones.Darknet()) dropout_rate: float = 0.0 norm_activation: common.NormActivation = common.NormActivation() # Adds a Batch Normalization layer pre-GlobalAveragePooling in classification. add_head_batch_norm: bool = False
class YoloModel(hyperparams.Config): num_classes: int = 0 input_size: List[int] = dataclasses.field(default_factory=list) min_level: int = 3 # only for FPN or NASFPN max_level: int = 6 # only for FPN or NASFPN head: hyperparams.Config = YoloHead() backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) decoder: decoders.Decoder = decoders.Decoder(type='identity') norm_activation: common.NormActivation = common.NormActivation()
class MovinetModel(video_classification.VideoClassificationModel): """The MoViNet model config.""" model_type: str = 'movinet' backbone: Backbone3D = Backbone3D() norm_activation: common.NormActivation = common.NormActivation( activation='swish', norm_momentum=0.99, norm_epsilon=1e-3, use_sync_bn=True) output_states: bool = False
class ImageClassificationModel(hyperparams.Config): """The model config.""" num_classes: int = 0 input_size: List[int] = dataclasses.field(default_factory=list) backbone: backbones.Backbone = backbones.Backbone( type='vit', vit=backbones.VisionTransformer()) dropout_rate: float = 0.0 norm_activation: common.NormActivation = common.NormActivation( use_sync_bn=False) # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification add_head_batch_norm: bool = False
class SemanticSegmentationModel(hyperparams.Config): """Semantic segmentation model config.""" num_classes: int = 0 input_size: List[int] = dataclasses.field(default_factory=list) min_level: int = 3 max_level: int = 6 head: SegmentationHead = SegmentationHead() backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) decoder: decoders.Decoder = decoders.Decoder(type='identity') norm_activation: common.NormActivation = common.NormActivation()
class BASNetModel(hyperparams.Config): """BASNet model config.""" num_classes: int = 0 input_size: List[int] = dataclasses.field(default_factory=list) #min_level: int = 3 #max_level: int = 6 #head: BASNetHead = BASNetHead() backbone: backbones.Backbone = backbones.Backbone( type='basnet_en', basnet_en=backbones.BASNet_En()) decoder: decoders.Decoder = decoders.Decoder(type='basnet_de') norm_activation: common.NormActivation = common.NormActivation()
class ImageClassificationModel(hyperparams.Config): """Image classification model config.""" num_classes: int = 0 input_size: List[int] = dataclasses.field( default_factory=lambda: [224, 224]) backbone: backbones.Backbone = backbones.Backbone( type='darknet', darknet=backbones.Darknet()) dropout_rate: float = 0.0 norm_activation: common.NormActivation = common.NormActivation() # Adds a Batch Normalization layer pre-GlobalAveragePooling in classification. add_head_batch_norm: bool = False kernel_initializer: str = 'VarianceScaling'
class CenterNetModel(hyperparams.Config): """Config for centernet model.""" num_classes: int = 90 max_num_instances: int = 128 input_size: List[int] = dataclasses.field(default_factory=list) backbone: backbones.Backbone = backbones.Backbone( type='hourglass', hourglass=backbones.Hourglass(model_id=52)) head: CenterNetHead = CenterNetHead() # pylint: disable=line-too-long detection_generator: CenterNetDetectionGenerator = CenterNetDetectionGenerator() norm_activation: common.NormActivation = common.NormActivation( norm_momentum=0.1, norm_epsilon=1e-5, use_sync_bn=True)
class DbofModel(hyperparams.Config): """The model config.""" cluster_size: int = 3000 hidden_size: int = 2000 add_batch_norm: bool = True sample_random_frames: bool = True use_context_gate_cluster_layer: bool = False context_gate_cluster_bottleneck_size: int = 0 pooling_method: str = 'average' yt8m_agg_classifier_model: str = 'MoeModel' agg_model: hyperparams.Config = MoeModel() norm_activation: common.NormActivation = common.NormActivation( activation='relu', use_sync_bn=False)
class SimCLRModel(hyperparams.Config): """SimCLR model config.""" input_size: List[int] = dataclasses.field(default_factory=list) backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) projection_head: ProjectionHead = ProjectionHead(proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1) supervised_head: SupervisedHead = SupervisedHead(num_classes=1001) norm_activation: common.NormActivation = common.NormActivation( norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False) mode: str = simclr_model.PRETRAIN backbone_trainable: bool = True
class RetinaNet(hyperparams.Config): num_classes: int = 0 input_size: List[int] = dataclasses.field(default_factory=list) min_level: int = 3 max_level: int = 7 anchor: Anchor = Anchor() backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) decoder: decoders.Decoder = decoders.Decoder(type='fpn', fpn=decoders.FPN()) head: RetinaNetHead = RetinaNetHead() detection_generator: DetectionGenerator = DetectionGenerator() norm_activation: common.NormActivation = common.NormActivation()
class SemanticSegmentationModel3D(hyperparams.Config): """Semantic segmentation model config.""" num_classes: int = 0 num_channels: int = 1 input_size: List[int] = dataclasses.field(default_factory=list) min_level: int = 3 max_level: int = 6 head: SegmentationHead3D = SegmentationHead3D() backbone: backbones.Backbone = backbones.Backbone( type='unet_3d', unet_3d=backbones.UNet3D()) decoder: decoders.Decoder = decoders.Decoder( type='unet_3d_decoder', unet_3d_decoder=decoders.UNet3DDecoder()) norm_activation: common.NormActivation = common.NormActivation()
class SimCLRMTModelConfig(hyperparams.Config): """Model config for multi-task SimCLR model.""" input_size: List[int] = dataclasses.field(default_factory=list) backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) backbone_trainable: bool = True projection_head: simclr_configs.ProjectionHead = simclr_configs.ProjectionHead( proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1) norm_activation: common.NormActivation = common.NormActivation( norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False) heads: Tuple[SimCLRMTHeadConfig, ...] = () # L2 weight decay is used in the model, not in task. # Note that this can not be used together with lars optimizer. l2_weight_decay: float = 0.0
def video_classification_ucf101() -> cfg.ExperimentConfig: """Video classification on UCF-101 with resnet.""" train_dataset = DataConfig( name='ucf101', num_classes=101, is_training=True, split='train', drop_remainder=True, num_examples=9537, temporal_stride=2, feature_shape=(32, 224, 224, 3)) train_dataset.tfds_name = 'ucf101' train_dataset.tfds_split = 'train' validation_dataset = DataConfig( name='ucf101', num_classes=101, is_training=True, split='test', drop_remainder=False, num_examples=3783, temporal_stride=2, feature_shape=(32, 224, 224, 3)) validation_dataset.tfds_name = 'ucf101' validation_dataset.tfds_split = 'test' task = VideoClassificationTask( model=VideoClassificationModel( backbone=backbones_3d.Backbone3D( type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()), norm_activation=common.NormActivation( norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)), losses=Losses(l2_weight_decay=1e-4), train_data=train_dataset, validation_data=validation_dataset) config = cfg.ExperimentConfig( runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'), task=task, restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None', 'task.train_data.num_classes == task.validation_data.num_classes', ]) add_trainer( config, train_batch_size=64, eval_batch_size=16, learning_rate=0.8, train_epochs=100) return config
class ImageClassificationModel(hyperparams.Config): num_classes: int = 1000 input_size: List[int] = dataclasses.field( default_factory=lambda: [256, 256, 3]) backbone: backbones.Backbone = backbones.Backbone( type='darknet', resnet=backbones.DarkNet(model_id='cspdarknet')) dropout_rate: float = 0.0 norm_activation: common.NormActivation = common.NormActivation( activation='leaky', use_sync_bn=False) # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification add_head_batch_norm: bool = False min_level: Optional[int] = None max_level: int = 5 dilate: bool = False darknet_weights_file: str = 'cache://csdarknet53.weights' darknet_weights_cfg: str = 'cache://csdarknet53.cfg'
class Yolo(ModelConfig): num_classes: int = 80 _input_size: Optional[List[int]] = None min_level: int = 3 max_level: int = 5 boxes_per_scale: int = 3 base: Union[str, YoloBase] = YoloBase() dilate: bool = False filter: YoloLossLayer = YoloLossLayer() norm_activation: common.NormActivation = common.NormActivation( activation='leaky', use_sync_bn=False, norm_momentum=0.99, norm_epsilon=0.001) decoder_activation: str = 'leaky' _boxes: Optional[List[str]] = dataclasses.field(default_factory=lambda: [ '(12, 16)', '(19, 36)', '(40, 28)', '(36, 75)', '(76, 55)', '(72, 146)', '(142, 110)', '(192, 243)', '(459, 401)' ])
class Yolo(hyperparams.Config): input_size: Optional[List[int]] = dataclasses.field( default_factory=lambda: [512, 512, 3]) backbone: backbones.Backbone = backbones.Backbone( type='darknet', darknet=backbones.Darknet(model_id='cspdarknet53')) decoder: decoders.Decoder = decoders.Decoder( type='yolo_decoder', yolo_decoder=decoders.YoloDecoder(version='v4', type='regular')) head: YoloHead = YoloHead() detection_generator: YoloDetectionGenerator = YoloDetectionGenerator() loss: YoloLoss = YoloLoss() norm_activation: common.NormActivation = common.NormActivation( activation='mish', use_sync_bn=True, norm_momentum=0.99, norm_epsilon=0.001) num_classes: int = 80 anchor_boxes: AnchorBoxes = AnchorBoxes() darknet_based_model: bool = False
class MaskRCNN(hyperparams.Config): num_classes: int = 0 input_size: List[int] = dataclasses.field(default_factory=list) min_level: int = 2 max_level: int = 6 anchor: Anchor = Anchor() include_mask: bool = True backbone: backbones.Backbone = backbones.Backbone( type='resnet', resnet=backbones.ResNet()) decoder: decoders.Decoder = decoders.Decoder(type='fpn', fpn=decoders.FPN()) rpn_head: RPNHead = RPNHead() detection_head: DetectionHead = DetectionHead() roi_generator: ROIGenerator = ROIGenerator() roi_sampler: ROISampler = ROISampler() roi_aligner: ROIAligner = ROIAligner() detection_generator: DetectionGenerator = DetectionGenerator() mask_head: Optional[MaskHead] = MaskHead() mask_sampler: Optional[MaskSampler] = MaskSampler() mask_roi_aligner: Optional[MaskROIAligner] = MaskROIAligner() norm_activation: common.NormActivation = common.NormActivation( norm_momentum=0.997, norm_epsilon=0.0001, use_sync_bn=True)