def test_update_index_embeddings(self, update_index_embeddings_mock): aiplatform.init(project=_TEST_PROJECT) my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID) updated_index = my_index.update_embeddings( contents_delta_uri=_TEST_CONTENTS_DELTA_URI_UPDATE, is_complete_overwrite=_TEST_IS_COMPLETE_OVERWRITE_UPDATE, ) expected = gca_index.Index( name=_TEST_INDEX_NAME, metadata={ "contentsDeltaUri": _TEST_CONTENTS_DELTA_URI_UPDATE, "isCompleteOverwrite": _TEST_IS_COMPLETE_OVERWRITE_UPDATE, }, ) update_index_embeddings_mock.assert_called_once_with( index=expected, update_mask=field_mask_pb2.FieldMask(paths=["metadata"]), metadata=_TEST_REQUEST_METADATA, ) # The service only returns the name of the Index assert updated_index.gca_resource == gca_index.Index( name=_TEST_INDEX_NAME)
def test_update_index_metadata(self, update_index_metadata_mock): aiplatform.init(project=_TEST_PROJECT) my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID) updated_index = my_index.update_metadata( display_name=_TEST_DISPLAY_NAME_UPDATE, description=_TEST_DESCRIPTION_UPDATE, labels=_TEST_LABELS_UPDATE, ) expected = gca_index.Index( name=_TEST_INDEX_NAME, display_name=_TEST_DISPLAY_NAME_UPDATE, description=_TEST_DESCRIPTION_UPDATE, labels=_TEST_LABELS_UPDATE, ) update_index_metadata_mock.assert_called_once_with( index=expected, update_mask=field_mask_pb2.FieldMask( paths=["labels", "display_name", "description"]), metadata=_TEST_REQUEST_METADATA, ) assert updated_index.gca_resource == expected
def update_index_embeddings_mock(): with patch.object(index_service_client.IndexServiceClient, "update_index") as update_index_mock: index_lro_mock = mock.Mock(operation.Operation) index_lro_mock.result.return_value = gca_index.Index( name=_TEST_INDEX_NAME, ) update_index_mock.return_value = index_lro_mock yield update_index_mock
def get_index_mock(): with patch.object(index_service_client.IndexServiceClient, "get_index") as get_index_mock: get_index_mock.return_value = gca_index.Index( name=_TEST_INDEX_NAME, display_name=_TEST_INDEX_DISPLAY_NAME, description=_TEST_INDEX_DESCRIPTION, ) yield get_index_mock
def create_index_mock(): with patch.object(index_service_client.IndexServiceClient, "create_index") as create_index_mock: create_index_lro_mock = mock.Mock(operation.Operation) create_index_lro_mock.result.return_value = gca_index.Index( name=_TEST_INDEX_NAME, display_name=_TEST_INDEX_DISPLAY_NAME, description=_TEST_INDEX_DESCRIPTION, ) create_index_mock.return_value = create_index_lro_mock yield create_index_mock
def update_index_metadata_mock(): with patch.object(index_service_client.IndexServiceClient, "update_index") as update_index_mock: index_lro_mock = mock.Mock(operation.Operation) index_lro_mock.result.return_value = gca_index.Index( name=_TEST_INDEX_NAME, display_name=_TEST_DISPLAY_NAME_UPDATE, description=_TEST_DESCRIPTION_UPDATE, labels=_TEST_LABELS_UPDATE, ) update_index_mock.return_value = index_lro_mock yield update_index_mock
def test_create_tree_ah_index(self, create_index_mock, sync): aiplatform.init(project=_TEST_PROJECT) my_index = aiplatform.MatchingEngineIndex.create_tree_ah_index( display_name=_TEST_INDEX_DISPLAY_NAME, contents_delta_uri=_TEST_CONTENTS_DELTA_URI, dimensions=_TEST_INDEX_CONFIG_DIMENSIONS, approximate_neighbors_count=_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT, distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE, leaf_node_embedding_count=_TEST_LEAF_NODE_EMBEDDING_COUNT, leaf_nodes_to_search_percent=_TEST_LEAF_NODES_TO_SEARCH_PERCENT, description=_TEST_INDEX_DESCRIPTION, labels=_TEST_LABELS, sync=sync, ) if not sync: my_index.wait() config = { "treeAhConfig": { "leafNodeEmbeddingCount": _TEST_LEAF_NODE_EMBEDDING_COUNT, "leafNodesToSearchPercent": _TEST_LEAF_NODES_TO_SEARCH_PERCENT, } } expected = gca_index.Index( display_name=_TEST_INDEX_DISPLAY_NAME, metadata={ "config": { "algorithmConfig": config, "dimensions": _TEST_INDEX_CONFIG_DIMENSIONS, "approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT, "distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE, }, "contentsDeltaUri": _TEST_CONTENTS_DELTA_URI, }, description=_TEST_INDEX_DESCRIPTION, labels=_TEST_LABELS, ) create_index_mock.assert_called_once_with( parent=_TEST_PARENT, index=expected, metadata=_TEST_REQUEST_METADATA, )
def get_index_mock(): with patch.object(index_service_client.IndexServiceClient, "get_index") as get_index_mock: index = gca_index.Index( name=_TEST_INDEX_NAME, display_name=_TEST_INDEX_DISPLAY_NAME, description=_TEST_INDEX_DESCRIPTION, ) index.deployed_indexes = [ gca_matching_engine_deployed_index_ref.DeployedIndexRef( index_endpoint=index.name, deployed_index_id=_TEST_DEPLOYED_INDEX_ID, ) ] get_index_mock.return_value = index yield get_index_mock
def test_create_brute_force_index(self, create_index_mock, sync): aiplatform.init(project=_TEST_PROJECT) my_index = aiplatform.MatchingEngineIndex.create_brute_force_index( display_name=_TEST_INDEX_DISPLAY_NAME, contents_delta_uri=_TEST_CONTENTS_DELTA_URI, dimensions=_TEST_INDEX_CONFIG_DIMENSIONS, distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE, description=_TEST_INDEX_DESCRIPTION, labels=_TEST_LABELS, sync=sync, ) if not sync: my_index.wait() config = {"bruteForceConfig": {}} expected = gca_index.Index( display_name=_TEST_INDEX_DISPLAY_NAME, metadata={ "config": { "algorithmConfig": config, "dimensions": _TEST_INDEX_CONFIG_DIMENSIONS, "approximateNeighborsCount": None, "distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE, }, "contentsDeltaUri": _TEST_CONTENTS_DELTA_URI, }, description=_TEST_INDEX_DESCRIPTION, labels=_TEST_LABELS, ) create_index_mock.assert_called_once_with( parent=_TEST_PARENT, index=expected, metadata=_TEST_REQUEST_METADATA, )
_TEST_INDEX_DESCRIPTION = "index_description" _TEST_LABELS = {"my_key": "my_value"} _TEST_DISPLAY_NAME_UPDATE = "my new display name" _TEST_DESCRIPTION_UPDATE = "my description update" _TEST_LABELS_UPDATE = {"my_key_update": "my_value_update"} # request_metadata _TEST_REQUEST_METADATA = () # Lists _TEST_INDEX_LIST = [ gca_index.Index( name=_TEST_INDEX_NAME, display_name=_TEST_INDEX_DISPLAY_NAME, description=_TEST_INDEX_DESCRIPTION, ), gca_index.Index( name=_TEST_INDEX_NAME, display_name=_TEST_INDEX_DISPLAY_NAME, description=_TEST_INDEX_DESCRIPTION, ), gca_index.Index( name=_TEST_INDEX_NAME, display_name=_TEST_INDEX_DISPLAY_NAME, description=_TEST_INDEX_DESCRIPTION, ), ]