def load_model_from_vertex(
        project: str,
        region: str,
        endpoint_id: str,
        credentials: Optional[google.auth.credentials.Credentials] = None,
        input_modalities: Optional[Dict[str, str]] = None) -> model_lib.Model:
    """Loads a model from Unified Cloud AI Platform.

  Args:
    project: AI Platform project name.
    region: GCP Region for the deployed model.
    endpoint_id: Version of the given model. If not given, it will load the
      default version for the model.
    credentials: The OAuth2.0 credentials to use for GCP services.
    input_modalities: Dictionary mapping from modalities to input names in the
      explain metadata. For example {'numeric': ['input1', 'input2'], 'all':
      ['input1', 'input2']}. All inputs must be collected under the 'all' key.
      If None, a default metadata of {'all': [...]} will be created.

  Returns:
     A model object

  Raises:
    NotImplementedError: If there are no registered remote models.
  """
    if _VERTEX_MODEL_KEY not in _MODEL_REGISTRY:
        available_models = ', '.join(_MODEL_REGISTRY)
        raise NotImplementedError(
            'There are no implementations for Vertex models. '
            f'Avilable models are: {{{available_models}}}.')
    resource_path = os.path.join('projects', project, 'locations', region,
                                 'endpoints', endpoint_id)
    return _MODEL_REGISTRY[_VERTEX_MODEL_KEY](utils.get_endpoint_uri(
        resource_path, region, True), credentials, input_modalities)
def load_model_from_ai_platform(
        project: str,
        model: str,
        version: Optional[str] = None,
        credentials: Optional[google.auth.credentials.Credentials] = None,
        region: Optional[str] = None,
        input_modalities: Optional[Dict[str, str]] = None) -> model_lib.Model:
    """Loads a model from Cloud AI Platform.

  Args:
    project: AI Platform project name.
    model: AI Platform Prediction model name.
    version: Version of the given model. If not given, it will load the
      default version for the model.
    credentials: The OAuth2.0 credentials to use for GCP services.
    region: GCP Region for the deployed model.
    input_modalities: Dictionary mapping from modalities to input names in the
      explain metadata. For example {'numeric': ['input1', 'input2'], 'all':
      ['input1', 'input2']}. All inputs must be collected under the 'all' key.
      If None, modalities will be inferred from the model metadata.

  Returns:
     A model object

  Raises:
    NotImplementedError: If there are no registered remote models.
  """
    if _CAIP_MODEL_KEY not in _MODEL_REGISTRY:
        available_models = ', '.join(_MODEL_REGISTRY)
        raise NotImplementedError(
            'There are no implementations for CAIP models. '
            f'Avilable models are: {{{available_models}}}.')
    resouce_path = os.path.join('projects', project, 'models', model)
    if version:
        resouce_path = os.path.join(resouce_path, 'versions', version)

    model_endpoint_uri = utils.get_endpoint_uri(resouce_path, region, False)
    if not input_modalities:
        input_modalities = utils.create_modality_inputs_map_from_metadata(
            utils.fetch_explanation_metadata(model_endpoint_uri, credentials))

    return _MODEL_REGISTRY[_CAIP_MODEL_KEY](model_endpoint_uri, credentials,
                                            input_modalities)
 def test_get_endpoint_uri_env_variable(self):
     self.assertEqual('https://overriden/v1/m/1/e/4',
                      utils.get_endpoint_uri('m/1/e/4'))
 def test_get_endpoint_uri(self, region, is_vertex, expected_uri):
     self.assertEqual(expected_uri,
                      utils.get_endpoint_uri('m/1/e/4', region, is_vertex))