예제 #1
0
파일: video_ssl.py 프로젝트: zss1980/models
def video_ssl_pretrain_kinetics400() -> cfg.ExperimentConfig:
    """Pretrain SSL Video classification on Kinectics 400 with resnet."""
    exp = video_classification.video_classification_kinetics400()
    exp.task.train_data = DataConfig(is_ssl=True,
                                     **exp.task.train_data.as_dict())
    exp.task.train_data.feature_shape = (16, 224, 224, 3)
    exp.task.train_data.temporal_stride = 2
    return exp
예제 #2
0
def video_ssl_pretrain_kinetics400() -> cfg.ExperimentConfig:
    """Pretrain SSL Video classification on Kinectics 400 with resnet."""
    exp = video_classification.video_classification_kinetics400()
    exp.task = VideoSSLPretrainTask(**exp.task.as_dict())
    exp.task.train_data = DataConfig(is_ssl=True,
                                     **exp.task.train_data.as_dict())
    exp.task.train_data.feature_shape = (16, 224, 224, 3)
    exp.task.train_data.temporal_stride = 2
    exp.task.model = VideoSSLModel(exp.task.model)
    exp.task.model.model_type = 'video_ssl_model'
    exp.task.losses = SSLLosses(exp.task.losses)
    return exp
예제 #3
0
파일: video_ssl.py 프로젝트: zss1980/models
def video_ssl_linear_eval_kinetics400() -> cfg.ExperimentConfig:
    """Pretrain SSL Video classification on Kinectics 400 with resnet."""
    exp = video_classification.video_classification_kinetics400()
    exp.task.train_data = DataConfig(is_ssl=False,
                                     **exp.task.train_data.as_dict())
    exp.task.train_data.feature_shape = (32, 224, 224, 3)
    exp.task.train_data.temporal_stride = 2
    exp.task.validation_data.feature_shape = (32, 256, 256, 3)
    exp.task.validation_data.temporal_stride = 2
    exp.task.validation_data = DataConfig(is_ssl=False,
                                          **exp.task.validation_data.as_dict())
    exp.task.validation_data.min_image_size = 256
    exp.task.validation_data.num_test_clips = 10
    return exp
예제 #4
0
def video_ssl_linear_eval_kinetics400() -> cfg.ExperimentConfig:
    """Pretrain SSL Video classification on Kinectics 400 with resnet."""
    exp = video_classification.video_classification_kinetics400()
    exp.task = VideoSSLEvalTask(**exp.task.as_dict())
    exp.task.train_data = DataConfig(is_ssl=False,
                                     **exp.task.train_data.as_dict())
    exp.task.train_data.feature_shape = (32, 224, 224, 3)
    exp.task.train_data.temporal_stride = 2
    exp.task.validation_data.feature_shape = (32, 256, 256, 3)
    exp.task.validation_data.temporal_stride = 2
    exp.task.validation_data = DataConfig(is_ssl=False,
                                          **exp.task.validation_data.as_dict())
    exp.task.validation_data.min_image_size = 256
    exp.task.validation_data.num_test_clips = 10
    exp.task.validation_data.num_test_crops = 3
    exp.task.model = VideoSSLModel(exp.task.model)
    exp.task.model.model_type = 'video_ssl_model'
    exp.task.model.normalize_feature = True
    exp.task.model.hidden_layer_num = 0
    exp.task.model.projection_dim = 400
    return exp