def build_model( params = None, override_params = None, model_arch = None, ): """Build model by name.""" if params is None: assert model_arch is not None, ('either params or model_arch should be ' 'specified') params = configs_factory.build_model_configs(model_arch) if override_params is not None: params.override(override_params) backbone_name = params.backbone_config.name if backbone_name.startswith('unified_backbone'): base_module = vatt_models.UnifiedFusion else: base_module = vatt_models.AudioTextVideoFusion head_module = head_factory.build_model(params=params.head_config) model = MMRLModel( base_module=base_module, head_module=head_module, params=params, ) logging.info('Entire MM model %s created successfully.', params.model_name) return model
def build_model(params=None, override_params=None, backbone=None, mode='embedding'): """Build model by name.""" if params is None: assert backbone is not None, 'either params or backbone should be specified' params = configs_factory.build_model_configs(backbone) if override_params is not None: params.override(override_params) model_name = params.name.lower() if model_name.startswith('resnet'): base_model = resnet2d.Resnet2dBase elif model_name.startswith('wat'): base_model = autx1d.AuTx1D elif model_name.startswith('spt'): base_model = autx2d.AuTx2D else: raise ValueError('Unknown model name {!r}'.format(params.name)) if mode == 'predict': pred_aggregator = PredictionAggregator( num_test_clips=params.num_test_samples) else: pred_aggregator = None model = AudioModel(base_model=base_model, params=params, pred_aggregator=pred_aggregator) logging.info('Audio model %s created successfully.', params.name) return model
def build_model( backbone, params=None, override_params=None, ): """Build model by name.""" if params is None: assert backbone is not None, "either params or backbone should be specified" params = configs_factory.build_model_configs(backbone) if override_params is not None: params.override(override_params) if backbone.startswith("ut"): base_model = uvatt.UniversalVATT else: raise ValueError("Unknown backbone {!r}".format(backbone)) model = UnifiedModule( base_model=base_model, params=params, ) logging.info("Unified backbone %s created successfully.", params.name) return model
def build_model( params = None, override_params = None, backbone = None, ): """Build language model by name.""" if params is None: assert backbone is not None, ( "either params or backbone should be specified") params = configs_factory.build_model_configs(backbone) if override_params is not None: params.override(override_params) model_name = params.name.lower() if model_name.startswith("linear"): base_lm_head = LANGUAGE_MODEL_HEADS["linear"] elif model_name.startswith("t5"): base_lm_head = LANGUAGE_MODEL_HEADS["t5"] elif model_name.startswith("bert"): base_lm_head = LANGUAGE_MODEL_HEADS["bert"] else: raise ValueError("Unknown model name {!r}".format(params.name)) model = LanguageModel( base_lm_head=base_lm_head, params=params, ) logging.info("Text model %s created successfully.", params.name) return model
def build_model( params = None, override_params = None, backbone = None, mode = 'embedding', ): """Build model by name.""" if params is None: assert backbone is not None, 'either params or backbone should be specified' params = configs_factory.build_model_configs(backbone) if override_params is not None: params.override(override_params) model_name = params.name if model_name.startswith('i3d'): base_model = i3d.InceptionI3D elif model_name.startswith('vit'): base_model = vitx3d.ViTx3D else: raise ValueError('Unknown backbone {!r}'.format(model_name)) if mode == 'predict': pred_aggregator = PredictionAggregator( num_test_clips=params.num_test_samples ) else: pred_aggregator = None model = VideoModel( base_model=base_model, params=params, pred_aggregator=pred_aggregator, ) logging.info('Video model %s created successfully.', params.name) return model