예제 #1
0
def build_model(
    params = None,
    override_params = None,
    model_arch = None,
    ):
  """Build model by name."""
  if params is None:
    assert model_arch is not None, ('either params or model_arch should be '
                                    'specified')
    params = configs_factory.build_model_configs(model_arch)

  if override_params is not None:
    params.override(override_params)

  backbone_name = params.backbone_config.name
  if backbone_name.startswith('unified_backbone'):
    base_module = vatt_models.UnifiedFusion
  else:
    base_module = vatt_models.AudioTextVideoFusion

  head_module = head_factory.build_model(params=params.head_config)

  model = MMRLModel(
      base_module=base_module,
      head_module=head_module,
      params=params,
  )

  logging.info('Entire MM model %s created successfully.', params.model_name)

  return model
예제 #2
0
def build_model(params=None,
                override_params=None,
                backbone=None,
                mode='embedding'):
    """Build model by name."""
    if params is None:
        assert backbone is not None, 'either params or backbone should be specified'
        params = configs_factory.build_model_configs(backbone)

    if override_params is not None:
        params.override(override_params)

    model_name = params.name.lower()
    if model_name.startswith('resnet'):
        base_model = resnet2d.Resnet2dBase
    elif model_name.startswith('wat'):
        base_model = autx1d.AuTx1D
    elif model_name.startswith('spt'):
        base_model = autx2d.AuTx2D
    else:
        raise ValueError('Unknown model name {!r}'.format(params.name))

    if mode == 'predict':
        pred_aggregator = PredictionAggregator(
            num_test_clips=params.num_test_samples)
    else:
        pred_aggregator = None

    model = AudioModel(base_model=base_model,
                       params=params,
                       pred_aggregator=pred_aggregator)

    logging.info('Audio model %s created successfully.', params.name)

    return model
예제 #3
0
def build_model(
    backbone,
    params=None,
    override_params=None,
):
    """Build model by name."""
    if params is None:
        assert backbone is not None, "either params or backbone should be specified"
        params = configs_factory.build_model_configs(backbone)

    if override_params is not None:
        params.override(override_params)

    if backbone.startswith("ut"):
        base_model = uvatt.UniversalVATT
    else:
        raise ValueError("Unknown backbone {!r}".format(backbone))

    model = UnifiedModule(
        base_model=base_model,
        params=params,
    )

    logging.info("Unified backbone %s created successfully.", params.name)

    return model
예제 #4
0
def build_model(
    params = None,
    override_params = None,
    backbone = None,
    ):
  """Build language model by name."""

  if params is None:
    assert backbone is not None, (
        "either params or backbone should be specified")
    params = configs_factory.build_model_configs(backbone)

  if override_params is not None:
    params.override(override_params)

  model_name = params.name.lower()
  if model_name.startswith("linear"):
    base_lm_head = LANGUAGE_MODEL_HEADS["linear"]
  elif model_name.startswith("t5"):
    base_lm_head = LANGUAGE_MODEL_HEADS["t5"]
  elif model_name.startswith("bert"):
    base_lm_head = LANGUAGE_MODEL_HEADS["bert"]
  else:
    raise ValueError("Unknown model name {!r}".format(params.name))

  model = LanguageModel(
      base_lm_head=base_lm_head,
      params=params,
      )

  logging.info("Text model %s created successfully.", params.name)

  return model
예제 #5
0
def build_model(
    params = None,
    override_params = None,
    backbone = None,
    mode = 'embedding',
    ):
  """Build model by name."""
  if params is None:
    assert backbone is not None, 'either params or backbone should be specified'
    params = configs_factory.build_model_configs(backbone)

  if override_params is not None:
    params.override(override_params)

  model_name = params.name
  if model_name.startswith('i3d'):
    base_model = i3d.InceptionI3D
  elif model_name.startswith('vit'):
    base_model = vitx3d.ViTx3D
  else:
    raise ValueError('Unknown backbone {!r}'.format(model_name))

  if mode == 'predict':
    pred_aggregator = PredictionAggregator(
        num_test_clips=params.num_test_samples
        )
  else:
    pred_aggregator = None

  model = VideoModel(
      base_model=base_model,
      params=params,
      pred_aggregator=pred_aggregator,
      )

  logging.info('Video model %s created successfully.', params.name)

  return model