def test_deploy_with_explanations(self,
                                      deploy_model_with_explanations_mock,
                                      sync):
        test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
        test_model = models.Model(_TEST_ID)
        test_endpoint.deploy(
            model=test_model,
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
            explanation_metadata=_TEST_EXPLANATION_METADATA,
            explanation_parameters=_TEST_EXPLANATION_PARAMETERS,
            sync=sync,
        )

        if not sync:
            test_endpoint.wait()

        expected_machine_spec = gca_machine_resources.MachineSpec(
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
        )
        expected_dedicated_resources = gca_machine_resources.DedicatedResources(
            machine_spec=expected_machine_spec,
            min_replica_count=1,
            max_replica_count=1,
        )
        expected_deployed_model = gca_endpoint.DeployedModel(
            dedicated_resources=expected_dedicated_resources,
            model=test_model.resource_name,
            display_name=None,
            explanation_spec=gca_endpoint.explanation.ExplanationSpec(
                metadata=_TEST_EXPLANATION_METADATA,
                parameters=_TEST_EXPLANATION_PARAMETERS,
            ),
        )
        deploy_model_with_explanations_mock.assert_called_once_with(
            endpoint=test_endpoint.resource_name,
            deployed_model=expected_deployed_model,
            traffic_split={"0": 100},
            metadata=(),
        )
    def test_deploy_with_dedicated_resources(self, deploy_model_mock, sync):
        test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
        test_model = models.Model(_TEST_ID)
        test_endpoint.deploy(
            model=test_model,
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
            service_account=_TEST_SERVICE_ACCOUNT,
            sync=sync,
        )

        if not sync:
            test_endpoint.wait()

        expected_machine_spec = gca_machine_resources.MachineSpec(
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
        )
        expected_dedicated_resources = gca_machine_resources.DedicatedResources(
            machine_spec=expected_machine_spec,
            min_replica_count=1,
            max_replica_count=1,
        )
        expected_deployed_model = gca_endpoint.DeployedModel(
            dedicated_resources=expected_dedicated_resources,
            model=test_model.resource_name,
            display_name=None,
            service_account=_TEST_SERVICE_ACCOUNT,
        )
        deploy_model_mock.assert_called_once_with(
            endpoint=test_endpoint.resource_name,
            deployed_model=expected_deployed_model,
            traffic_split={"0": 100},
            metadata=(),
        )
    def test_batch_predict_with_all_args(
        self, create_batch_prediction_job_with_explanations_mock, sync
    ):
        aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
        creds = auth_credentials.AnonymousCredentials()

        batch_prediction_job = jobs.BatchPredictionJob.create(
            model_name=_TEST_MODEL_NAME,
            job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
            gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE,
            gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX,
            predictions_format="csv",
            model_parameters={},
            machine_type=_TEST_MACHINE_TYPE,
            accelerator_type=_TEST_ACCELERATOR_TYPE,
            accelerator_count=_TEST_ACCELERATOR_COUNT,
            starting_replica_count=_TEST_STARTING_REPLICA_COUNT,
            max_replica_count=_TEST_MAX_REPLICA_COUNT,
            generate_explanation=True,
            explanation_metadata=_TEST_EXPLANATION_METADATA,
            explanation_parameters=_TEST_EXPLANATION_PARAMETERS,
            labels=_TEST_LABEL,
            credentials=creds,
            sync=sync,
            create_request_timeout=None,
        )

        batch_prediction_job.wait_for_resource_creation()

        batch_prediction_job.wait()

        # Construct expected request
        expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob(
            display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME,
            model=_TEST_MODEL_NAME,
            input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig(
                instances_format="jsonl",
                gcs_source=gca_io_compat.GcsSource(
                    uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE]
                ),
            ),
            output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig(
                gcs_destination=gca_io_compat.GcsDestination(
                    output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX
                ),
                predictions_format="csv",
            ),
            dedicated_resources=gca_machine_resources_compat.BatchDedicatedResources(
                machine_spec=gca_machine_resources_compat.MachineSpec(
                    machine_type=_TEST_MACHINE_TYPE,
                    accelerator_type=_TEST_ACCELERATOR_TYPE,
                    accelerator_count=_TEST_ACCELERATOR_COUNT,
                ),
                starting_replica_count=_TEST_STARTING_REPLICA_COUNT,
                max_replica_count=_TEST_MAX_REPLICA_COUNT,
            ),
            generate_explanation=True,
            explanation_spec=gca_explanation_compat.ExplanationSpec(
                metadata=_TEST_EXPLANATION_METADATA,
                parameters=_TEST_EXPLANATION_PARAMETERS,
            ),
            labels=_TEST_LABEL,
        )

        create_batch_prediction_job_with_explanations_mock.assert_called_once_with(
            parent=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
            batch_prediction_job=expected_gapic_batch_prediction_job,
            timeout=None,
        )
    def _build_deployed_index(
        deployed_index_id: str,
        index_resource_name: Optional[str] = None,
        display_name: Optional[str] = None,
        machine_type: Optional[str] = None,
        min_replica_count: Optional[int] = None,
        max_replica_count: Optional[int] = None,
        enable_access_logging: Optional[bool] = None,
        reserved_ip_ranges: Optional[Sequence[str]] = None,
        deployment_group: Optional[str] = None,
        auth_config_audiences: Optional[Sequence[str]] = None,
        auth_config_allowed_issuers: Optional[Sequence[str]] = None,
    ) -> gca_matching_engine_index_endpoint.DeployedIndex:
        """Builds a DeployedIndex.

        Args:
            deployed_index_id (str):
                Required. The user specified ID of the
                DeployedIndex. The ID can be up to 128
                characters long and must start with a letter and
                only contain letters, numbers, and underscores.
                The ID must be unique within the project it is
                created in.
            index_resource_name (str):
                Optional. A fully-qualified index endpoint resource name or a index ID.
                Example: "projects/123/locations/us-central1/index_endpoints/my_index_id"
            display_name (str):
                Optional. The display name of the DeployedIndex. If not provided upon
                creation, the Index's display_name is used.
            machine_type (str):
                Optional. The type of machine. Not specifying machine type will
                result in model to be deployed with automatic resources.
            min_replica_count (int):
                Optional. The minimum number of machine replicas this deployed
                model will be always deployed on. If traffic against it increases,
                it may dynamically be deployed onto more replicas, and as traffic
                decreases, some of these extra replicas may be freed.

                If this value is not provided, the value of 2 will be used.
            max_replica_count (int):
                Optional. The maximum number of replicas this deployed model may
                be deployed on when the traffic against it increases. If requested
                value is too large, the deployment will error, but if deployment
                succeeds then the ability to scale the model to that many replicas
                is guaranteed (barring service outages). If traffic against the
                deployed model increases beyond what its replicas at maximum may
                handle, a portion of the traffic will be dropped. If this value
                is not provided, the larger value of min_replica_count or 2 will
                be used. If value provided is smaller than min_replica_count, it
                will automatically be increased to be min_replica_count.
            enable_access_logging (bool):
                Optional. If true, private endpoint's access
                logs are sent to StackDriver Logging.
                These logs are like standard server access logs,
                containing information like timestamp and
                latency for each MatchRequest.
                Note that Stackdriver logs may incur a cost,
                especially if the deployed index receives a high
                queries per second rate (QPS). Estimate your
                costs before enabling this option.
            deployed_index_auth_config (google.cloud.aiplatform_v1.types.DeployedIndexAuthConfig):
                Optional. If set, the authentication is
                enabled for the private endpoint.
            reserved_ip_ranges (Sequence[str]):
                Optional. A list of reserved ip ranges under
                the VPC network that can be used for this
                DeployedIndex.
                If set, we will deploy the index within the
                provided ip ranges. Otherwise, the index might
                be deployed to any ip ranges under the provided
                VPC network.

                The value sohuld be the name of the address
                (https://cloud.google.com/compute/docs/reference/rest/v1/addresses)
                Example: 'vertex-ai-ip-range'.
            deployment_group (str):
                Optional. The deployment group can be no longer than 64
                characters (eg: 'test', 'prod'). If not set, we will use the
                'default' deployment group.

                Creating ``deployment_groups`` with ``reserved_ip_ranges``
                is a recommended practice when the peered network has
                multiple peering ranges. This creates your deployments from
                predictable IP spaces for easier traffic administration.
                Also, one deployment_group (except 'default') can only be
                used with the same reserved_ip_ranges which means if the
                deployment_group has been used with reserved_ip_ranges: [a,
                b, c], using it with [a, b] or [d, e] is disallowed.

                Note: we only support up to 5 deployment groups(not
                including 'default').
            auth_config_audiences (Sequence[str]):
                Optional. The list of JWT
                `audiences <https://tools.ietf.org/html/draft-ietf-oauth-json-web-token-32#section-4.1.3>`__.
                that are allowed to access. A JWT containing any of these
                audiences will be accepted.
            auth_config_allowed_issuers (Sequence[str]):
                Optional. A list of allowed JWT issuers. Each entry must be a valid
                Google service account, in the following format:

                ``[email protected]``
            request_metadata (Sequence[Tuple[str, str]]):
                Optional. Strings which should be sent along with the request as metadata.
        """

        deployed_index = gca_matching_engine_index_endpoint.DeployedIndex(
            id=deployed_index_id,
            index=index_resource_name,
            display_name=display_name,
            enable_access_logging=enable_access_logging,
            reserved_ip_ranges=reserved_ip_ranges,
            deployment_group=deployment_group,
        )

        if auth_config_audiences and auth_config_allowed_issuers:
            deployed_index.deployed_index_auth_config = gca_matching_engine_index_endpoint.DeployedIndexAuthConfig(
                auth_provider=gca_matching_engine_index_endpoint.
                DeployedIndexAuthConfig.AuthProvider(
                    audiences=auth_config_audiences,
                    allowed_issuers=auth_config_allowed_issuers,
                ))

        if machine_type:
            machine_spec = gca_machine_resources_compat.MachineSpec(
                machine_type=machine_type)

            deployed_index.dedicated_resources = (
                gca_machine_resources_compat.DedicatedResources(
                    machine_spec=machine_spec,
                    min_replica_count=min_replica_count,
                    max_replica_count=max_replica_count,
                ))

        else:
            deployed_index.automatic_resources = (
                gca_machine_resources_compat.AutomaticResources(
                    min_replica_count=min_replica_count,
                    max_replica_count=max_replica_count,
                ))
        return deployed_index