コード例 #1
0
ファイル: movinet.py プロジェクト: shen-zc/models
def movinet_kinetics600() -> cfg.ExperimentConfig:
    """Video classification on Videonet with MoViNet backbone."""
    exp = video_classification.video_classification_kinetics600()
    exp.task.train_data.dtype = 'bfloat16'
    exp.task.validation_data.dtype = 'bfloat16'

    model = MovinetModel()
    exp.task.model = model

    return exp
コード例 #2
0
ファイル: video_ssl.py プロジェクト: vishalbelsare/models
def video_ssl_pretrain_kinetics600() -> cfg.ExperimentConfig:
  """Pretrain SSL Video classification on Kinectics 400 with resnet."""
  exp = video_classification.video_classification_kinetics600()
  exp.task = VideoSSLPretrainTask(**exp.task.as_dict())
  exp.task.train_data = DataConfig(is_ssl=True, **exp.task.train_data.as_dict())
  exp.task.train_data.feature_shape = (16, 224, 224, 3)
  exp.task.train_data.temporal_stride = 2
  exp.task.model = VideoSSLModel(exp.task.model)
  exp.task.model.model_type = 'video_ssl_model'
  exp.task.losses = SSLLosses(exp.task.losses)
  return exp
コード例 #3
0
ファイル: video_ssl.py プロジェクト: vishalbelsare/models
def video_ssl_linear_eval_kinetics600() -> cfg.ExperimentConfig:
  """Pretrain SSL Video classification on Kinectics 400 with resnet."""
  exp = video_classification.video_classification_kinetics600()
  exp.task = VideoSSLEvalTask(**exp.task.as_dict())
  exp.task.train_data = DataConfig(is_ssl=False,
                                   **exp.task.train_data.as_dict())
  exp.task.train_data.feature_shape = (32, 224, 224, 3)
  exp.task.train_data.temporal_stride = 2
  exp.task.validation_data = DataConfig(is_ssl=False,
                                        **exp.task.validation_data.as_dict())
  exp.task.validation_data.feature_shape = (32, 256, 256, 3)
  exp.task.validation_data.temporal_stride = 2
  exp.task.validation_data.min_image_size = 256
  exp.task.validation_data.num_test_clips = 10
  exp.task.validation_data.num_test_crops = 3
  exp.task.model = VideoSSLModel(exp.task.model)
  exp.task.model.model_type = 'video_ssl_model'
  exp.task.model.normalize_feature = True
  exp.task.model.hidden_layer_num = 0
  exp.task.model.projection_dim = 600
  return exp
コード例 #4
0
def get_dataset() -> tf.data.Dataset:
    """Gets dataset source."""
    config = video_classification_configs.video_classification_kinetics600()

    temporal_stride = FLAGS.temporal_stride
    num_frames = FLAGS.num_frames
    image_size = FLAGS.image_size
    feature_shape = (num_frames, image_size, image_size, 3)

    config.task.validation_data.global_batch_size = 1
    config.task.validation_data.feature_shape = feature_shape
    config.task.validation_data.temporal_stride = temporal_stride
    config.task.train_data.min_image_size = int(1.125 * image_size)
    config.task.validation_data.dtype = 'float32'
    config.task.validation_data.drop_remainder = False

    task = video_classification.VideoClassificationTask(config.task)

    valid_dataset = task.build_inputs(config.task.validation_data)
    valid_dataset = valid_dataset.map(lambda x, y: (x['image'], y))
    valid_dataset = valid_dataset.prefetch(32)
    return valid_dataset
コード例 #5
0
def assemblenet_kinetics600() -> cfg.ExperimentConfig:
    """Video classification on Videonet with assemblenet."""
    exp = video_classification.video_classification_kinetics600()

    feature_shape = (32, 224, 224, 3)
    exp.task.train_data.global_batch_size = 1024
    exp.task.validation_data.global_batch_size = 32
    exp.task.train_data.feature_shape = feature_shape
    exp.task.validation_data.feature_shape = (120, 224, 224, 3)
    exp.task.train_data.dtype = 'bfloat16'
    exp.task.validation_data.dtype = 'bfloat16'

    model = AssembleNetModel()
    model.backbone.assemblenet.model_id = '50'
    model.backbone.assemblenet.blocks = flat_lists_to_blocks(
        asn50_structure, asn_structure_weights)
    model.backbone.assemblenet.num_frames = feature_shape[0]
    exp.task.model = model

    assert exp.task.model.backbone.assemblenet.num_frames > 0, (
        f'backbone num_frames '
        f'{exp.task.model.backbone.assemblenet}')

    return exp