def load_model_from_vertex( project: str, region: str, endpoint_id: str, credentials: Optional[google.auth.credentials.Credentials] = None, input_modalities: Optional[Dict[str, str]] = None) -> model_lib.Model: """Loads a model from Unified Cloud AI Platform. Args: project: AI Platform project name. region: GCP Region for the deployed model. endpoint_id: Version of the given model. If not given, it will load the default version for the model. credentials: The OAuth2.0 credentials to use for GCP services. input_modalities: Dictionary mapping from modalities to input names in the explain metadata. For example {'numeric': ['input1', 'input2'], 'all': ['input1', 'input2']}. All inputs must be collected under the 'all' key. If None, a default metadata of {'all': [...]} will be created. Returns: A model object Raises: NotImplementedError: If there are no registered remote models. """ if _VERTEX_MODEL_KEY not in _MODEL_REGISTRY: available_models = ', '.join(_MODEL_REGISTRY) raise NotImplementedError( 'There are no implementations for Vertex models. ' f'Avilable models are: {{{available_models}}}.') resource_path = os.path.join('projects', project, 'locations', region, 'endpoints', endpoint_id) return _MODEL_REGISTRY[_VERTEX_MODEL_KEY](utils.get_endpoint_uri( resource_path, region, True), credentials, input_modalities)
def load_model_from_ai_platform( project: str, model: str, version: Optional[str] = None, credentials: Optional[google.auth.credentials.Credentials] = None, region: Optional[str] = None, input_modalities: Optional[Dict[str, str]] = None) -> model_lib.Model: """Loads a model from Cloud AI Platform. Args: project: AI Platform project name. model: AI Platform Prediction model name. version: Version of the given model. If not given, it will load the default version for the model. credentials: The OAuth2.0 credentials to use for GCP services. region: GCP Region for the deployed model. input_modalities: Dictionary mapping from modalities to input names in the explain metadata. For example {'numeric': ['input1', 'input2'], 'all': ['input1', 'input2']}. All inputs must be collected under the 'all' key. If None, modalities will be inferred from the model metadata. Returns: A model object Raises: NotImplementedError: If there are no registered remote models. """ if _CAIP_MODEL_KEY not in _MODEL_REGISTRY: available_models = ', '.join(_MODEL_REGISTRY) raise NotImplementedError( 'There are no implementations for CAIP models. ' f'Avilable models are: {{{available_models}}}.') resouce_path = os.path.join('projects', project, 'models', model) if version: resouce_path = os.path.join(resouce_path, 'versions', version) model_endpoint_uri = utils.get_endpoint_uri(resouce_path, region, False) if not input_modalities: input_modalities = utils.create_modality_inputs_map_from_metadata( utils.fetch_explanation_metadata(model_endpoint_uri, credentials)) return _MODEL_REGISTRY[_CAIP_MODEL_KEY](model_endpoint_uri, credentials, input_modalities)
def test_get_endpoint_uri_env_variable(self): self.assertEqual('https://overriden/v1/m/1/e/4', utils.get_endpoint_uri('m/1/e/4'))
def test_get_endpoint_uri(self, region, is_vertex, expected_uri): self.assertEqual(expected_uri, utils.get_endpoint_uri('m/1/e/4', region, is_vertex))