def test_update_index_embeddings(self, update_index_embeddings_mock):
        aiplatform.init(project=_TEST_PROJECT)

        my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID)
        updated_index = my_index.update_embeddings(
            contents_delta_uri=_TEST_CONTENTS_DELTA_URI_UPDATE,
            is_complete_overwrite=_TEST_IS_COMPLETE_OVERWRITE_UPDATE,
        )

        expected = gca_index.Index(
            name=_TEST_INDEX_NAME,
            metadata={
                "contentsDeltaUri": _TEST_CONTENTS_DELTA_URI_UPDATE,
                "isCompleteOverwrite": _TEST_IS_COMPLETE_OVERWRITE_UPDATE,
            },
        )

        update_index_embeddings_mock.assert_called_once_with(
            index=expected,
            update_mask=field_mask_pb2.FieldMask(paths=["metadata"]),
            metadata=_TEST_REQUEST_METADATA,
        )

        # The service only returns the name of the Index
        assert updated_index.gca_resource == gca_index.Index(
            name=_TEST_INDEX_NAME)
    def test_update_index_metadata(self, update_index_metadata_mock):
        aiplatform.init(project=_TEST_PROJECT)

        my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_ID)
        updated_index = my_index.update_metadata(
            display_name=_TEST_DISPLAY_NAME_UPDATE,
            description=_TEST_DESCRIPTION_UPDATE,
            labels=_TEST_LABELS_UPDATE,
        )

        expected = gca_index.Index(
            name=_TEST_INDEX_NAME,
            display_name=_TEST_DISPLAY_NAME_UPDATE,
            description=_TEST_DESCRIPTION_UPDATE,
            labels=_TEST_LABELS_UPDATE,
        )

        update_index_metadata_mock.assert_called_once_with(
            index=expected,
            update_mask=field_mask_pb2.FieldMask(
                paths=["labels", "display_name", "description"]),
            metadata=_TEST_REQUEST_METADATA,
        )

        assert updated_index.gca_resource == expected
def update_index_embeddings_mock():
    with patch.object(index_service_client.IndexServiceClient,
                      "update_index") as update_index_mock:
        index_lro_mock = mock.Mock(operation.Operation)
        index_lro_mock.result.return_value = gca_index.Index(
            name=_TEST_INDEX_NAME, )
        update_index_mock.return_value = index_lro_mock
        yield update_index_mock
def get_index_mock():
    with patch.object(index_service_client.IndexServiceClient,
                      "get_index") as get_index_mock:
        get_index_mock.return_value = gca_index.Index(
            name=_TEST_INDEX_NAME,
            display_name=_TEST_INDEX_DISPLAY_NAME,
            description=_TEST_INDEX_DESCRIPTION,
        )
        yield get_index_mock
def create_index_mock():
    with patch.object(index_service_client.IndexServiceClient,
                      "create_index") as create_index_mock:
        create_index_lro_mock = mock.Mock(operation.Operation)
        create_index_lro_mock.result.return_value = gca_index.Index(
            name=_TEST_INDEX_NAME,
            display_name=_TEST_INDEX_DISPLAY_NAME,
            description=_TEST_INDEX_DESCRIPTION,
        )
        create_index_mock.return_value = create_index_lro_mock
        yield create_index_mock
def update_index_metadata_mock():
    with patch.object(index_service_client.IndexServiceClient,
                      "update_index") as update_index_mock:
        index_lro_mock = mock.Mock(operation.Operation)
        index_lro_mock.result.return_value = gca_index.Index(
            name=_TEST_INDEX_NAME,
            display_name=_TEST_DISPLAY_NAME_UPDATE,
            description=_TEST_DESCRIPTION_UPDATE,
            labels=_TEST_LABELS_UPDATE,
        )
        update_index_mock.return_value = index_lro_mock
        yield update_index_mock
    def test_create_tree_ah_index(self, create_index_mock, sync):
        aiplatform.init(project=_TEST_PROJECT)

        my_index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
            display_name=_TEST_INDEX_DISPLAY_NAME,
            contents_delta_uri=_TEST_CONTENTS_DELTA_URI,
            dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
            approximate_neighbors_count=_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
            distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
            leaf_node_embedding_count=_TEST_LEAF_NODE_EMBEDDING_COUNT,
            leaf_nodes_to_search_percent=_TEST_LEAF_NODES_TO_SEARCH_PERCENT,
            description=_TEST_INDEX_DESCRIPTION,
            labels=_TEST_LABELS,
            sync=sync,
        )

        if not sync:
            my_index.wait()

        config = {
            "treeAhConfig": {
                "leafNodeEmbeddingCount": _TEST_LEAF_NODE_EMBEDDING_COUNT,
                "leafNodesToSearchPercent": _TEST_LEAF_NODES_TO_SEARCH_PERCENT,
            }
        }

        expected = gca_index.Index(
            display_name=_TEST_INDEX_DISPLAY_NAME,
            metadata={
                "config": {
                    "algorithmConfig": config,
                    "dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
                    "approximateNeighborsCount":
                    _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT,
                    "distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
                },
                "contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
            },
            description=_TEST_INDEX_DESCRIPTION,
            labels=_TEST_LABELS,
        )

        create_index_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            index=expected,
            metadata=_TEST_REQUEST_METADATA,
        )
예제 #8
0
def get_index_mock():
    with patch.object(index_service_client.IndexServiceClient,
                      "get_index") as get_index_mock:
        index = gca_index.Index(
            name=_TEST_INDEX_NAME,
            display_name=_TEST_INDEX_DISPLAY_NAME,
            description=_TEST_INDEX_DESCRIPTION,
        )

        index.deployed_indexes = [
            gca_matching_engine_deployed_index_ref.DeployedIndexRef(
                index_endpoint=index.name,
                deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
            )
        ]

        get_index_mock.return_value = index
        yield get_index_mock
    def test_create_brute_force_index(self, create_index_mock, sync):
        aiplatform.init(project=_TEST_PROJECT)

        my_index = aiplatform.MatchingEngineIndex.create_brute_force_index(
            display_name=_TEST_INDEX_DISPLAY_NAME,
            contents_delta_uri=_TEST_CONTENTS_DELTA_URI,
            dimensions=_TEST_INDEX_CONFIG_DIMENSIONS,
            distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE,
            description=_TEST_INDEX_DESCRIPTION,
            labels=_TEST_LABELS,
            sync=sync,
        )

        if not sync:
            my_index.wait()

        config = {"bruteForceConfig": {}}

        expected = gca_index.Index(
            display_name=_TEST_INDEX_DISPLAY_NAME,
            metadata={
                "config": {
                    "algorithmConfig": config,
                    "dimensions": _TEST_INDEX_CONFIG_DIMENSIONS,
                    "approximateNeighborsCount": None,
                    "distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE,
                },
                "contentsDeltaUri": _TEST_CONTENTS_DELTA_URI,
            },
            description=_TEST_INDEX_DESCRIPTION,
            labels=_TEST_LABELS,
        )

        create_index_mock.assert_called_once_with(
            parent=_TEST_PARENT,
            index=expected,
            metadata=_TEST_REQUEST_METADATA,
        )
_TEST_INDEX_DESCRIPTION = "index_description"

_TEST_LABELS = {"my_key": "my_value"}
_TEST_DISPLAY_NAME_UPDATE = "my new display name"
_TEST_DESCRIPTION_UPDATE = "my description update"
_TEST_LABELS_UPDATE = {"my_key_update": "my_value_update"}

# request_metadata
_TEST_REQUEST_METADATA = ()

# Lists
_TEST_INDEX_LIST = [
    gca_index.Index(
        name=_TEST_INDEX_NAME,
        display_name=_TEST_INDEX_DISPLAY_NAME,
        description=_TEST_INDEX_DESCRIPTION,
    ),
    gca_index.Index(
        name=_TEST_INDEX_NAME,
        display_name=_TEST_INDEX_DISPLAY_NAME,
        description=_TEST_INDEX_DESCRIPTION,
    ),
    gca_index.Index(
        name=_TEST_INDEX_NAME,
        display_name=_TEST_INDEX_DISPLAY_NAME,
        description=_TEST_INDEX_DESCRIPTION,
    ),
]