class VideoClassificationModel(hyperparams.Config): """The model config.""" backbone: backbones_3d.Backbone3D = backbones_3d.Backbone3D( type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()) norm_activation: common.NormActivation = common.NormActivation() dropout_rate: float = 0.2 add_head_batch_norm: bool = False
def video_classification_ucf101() -> cfg.ExperimentConfig: """Video classification on UCF-101 with resnet.""" train_dataset = DataConfig( name='ucf101', num_classes=101, is_training=True, split='train', drop_remainder=True, num_examples=9537, temporal_stride=2, feature_shape=(32, 224, 224, 3)) train_dataset.tfds_name = 'ucf101' train_dataset.tfds_split = 'train' validation_dataset = DataConfig( name='ucf101', num_classes=101, is_training=True, split='test', drop_remainder=False, num_examples=3783, temporal_stride=2, feature_shape=(32, 224, 224, 3)) validation_dataset.tfds_name = 'ucf101' validation_dataset.tfds_split = 'test' task = VideoClassificationTask( model=VideoClassificationModel( backbone=backbones_3d.Backbone3D( type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()), norm_activation=common.NormActivation( norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)), losses=Losses(l2_weight_decay=1e-4), train_data=train_dataset, validation_data=validation_dataset) config = cfg.ExperimentConfig( runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'), task=task, restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None', 'task.train_data.num_classes == task.validation_data.num_classes', ]) add_trainer( config, train_batch_size=64, eval_batch_size=16, learning_rate=0.8, train_epochs=100) return config
def video_classification_kinetics600() -> cfg.ExperimentConfig: """Video classification on Kinectics 600 with resnet.""" train_dataset = kinetics600(is_training=True) validation_dataset = kinetics600(is_training=False) task = VideoClassificationTask(model=VideoClassificationModel( backbone=backbones_3d.Backbone3D(type='resnet_3d', resnet_3d=backbones_3d.ResNet3D50()), norm_activation=common.NormActivation(norm_momentum=0.9, norm_epsilon=1e-5)), losses=Losses(l2_weight_decay=1e-4), train_data=train_dataset, validation_data=validation_dataset) config = cfg.ExperimentConfig( runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'), task=task, restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None', 'task.train_data.num_classes == task.validation_data.num_classes', ]) add_trainer(config, train_batch_size=1024, eval_batch_size=64) return config